動画からの音声抽出と動画への音声結合

testpy.hatenablog.com

上記の記事で、動画を左右反転させて、新たな動画を生成した。 生成した動画には音声がないため、元の動画から音声を抽出して、結合してみた。

コード

# coding: utf-8
import sys
import cv2
import moviepy.editor as mp

class Test:
    def __init__(self):
        # Set video names.
        self.input_video = sys.argv[1]
        self.output_video = sys.argv[2]


    def main(self):
        self.set_audio()


    def set_audio(self):
        # Extract audio from input video.
        clip_input = mp.VideoFileClip(self.input_video).subclip()
        clip_input.audio.write_audiofile('audio.mp3')

        # Add audio to output video.
        clip_output = mp.VideoFileClip(self.output_video).subclip()
        clip_output.write_videofile(self.output_video.replace('.avi', '.mp4'), audio='audio.mp3')


if __name__ == "__main__":
    Test().main()

set_audio()が本体。前半2行で音声の抽出と保存、後半の2行で音声の結合を行っている。 結合する際、動画の拡張子がmp4じゃないとコーデックエラーが出る。 でも、mp4の方が容量が少ないので、それでいいかなと。

使い方

上記のコードをtest.pyで保存して、第一引数に入力動画名、第二引数に出力動画名(aviファイル)を指定する。

> python test.py input.MOV ouput.avi
[MoviePy] Writing audio in audio.mp3
100%|###################################################################################################################| 623/623 [00:00<00:00, 733.71it/s]
[MoviePy] Done.
[MoviePy] >>>> Building video output.mp4
[MoviePy] Writing video ouput.mp4
100%|####################################################################################################################| 846/846 [00:25<00:00, 33.65it/s]
[MoviePy] Done.
[MoviePy] >>>> Video ready: output.mp4

>

動画の生成と音声の結合を同時に実行

前回行った動画の生成と、今回行った音声の結合を同時に行ってみた。 音声を結合した動画がmp4なので、生成する動画もmp4(コーデックはMP4S)にしたのだが、警告が出て、どうも上手く音声付きの動画が生成されない。 この辺は後ほど対応したい。

OpenCV: FFMPEG: tag 0x5334504d/'MP4S' is not supported with codec id 13 and format 'mp4 / MP4 (MPEG-4 Part 14)'
OpenCV: FFMPEG: fallback to use tag 0x00000020/' ???'

同時に実行したコードを載せておく。 使い方は上に示したのと同じ。

# coding: utf-8
import sys
import cv2
import moviepy.editor as mp

class Test:
    def __init__(self):
        # Set video names.
        self.input_video = sys.argv[1]
        self.output_video = sys.argv[2]


    def main(self):
        self.make_video()
        self.set_audio()


    def make_video(self):
        # Get input video information.
        cap = cv2.VideoCapture(self.input_video)
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        # Set output video infomation.
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        #fourcc = cv2.VideoWriter_fourcc(*'MP4S')
        vw = cv2.VideoWriter(self.output_video, fourcc, fps, (width, height))  # Set the above information.

        # Make the output video.
        print('Making a video...')
        while(True):
            ret, img = cap.read()
            if ret == True:
                img_flip = cv2.flip(img, 1)  # Flip horizontal
                vw.write(img_flip)  # Add frame
            else:
                break

        # Post processing.
        cap.release()
        cv2.destroyAllWindows()


    def set_audio(self):
        # Extract audio from input video.
        clip_input = mp.VideoFileClip(self.input_video).subclip()
        clip_input.audio.write_audiofile('audio.mp3')

        # Add audio to output video.
        clip_output = mp.VideoFileClip(self.output_video).subclip()
        clip_output.write_videofile(self.output_video.replace('.avi', '.mp4'), audio='audio.mp3')



if __name__ == "__main__":
    Test().main()

参考文献

OpenCVで動画を生成する

OpenCVで動画を生成できる。 動画を左右反転させて、新たな動画を作成してみた。

コード

# coding: utf-8
import sys
import cv2

class VideoMaker:
    def __init__(self):
        pass


    def main(self):
        # Set video names.
        input_video = sys.argv[1]
        output_video = input_video.replace('.MOV', '_flip.avi')

        # Get input video information.
        cap = cv2.VideoCapture(input_video)
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        # Set output video infomation.
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        vw = cv2.VideoWriter(output_video, fourcc, fps, (width, height))  # Set the above information.

        # Make the output video.
        print('Making a video...')
        while(True):
            ret, img = cap.read()
            if ret == True:
                img_flip = cv2.flip(img, 1)  # Flip horizontal
                vw.write(img_flip)  # Add frame
            else:
                break

        # Post processing.
        cap.release()
        cv2.destroyAllWindows()


if __name__ == "__main__":
    VideoMaker().main()

解説

処理の大まかな流れは以下の通り。

  1. 入力動画の情報をゲット
  2. 出力動画の情報をセット
  3. 出力動画の生成・保存

FPSや画素数などを入出力で一致させている。 参考文献によると、コーデックはXVIDがおすすめとのこと。 MJPGX264も指定できるが、サイズが大きすぎたり小さすぎたりするらしい。 試してないけど。 vw.write(img_flip)で処理した画像を追加すれば、動画の生成と保存を同時に行ってくれる。 特に保存の必要はない。

使い方

video_maker.pyにコードを記載している。 コードと同じ場所に保存している動画(MOVファイル)を引数に指定すると、左右反転させた動画が生成される。

> python video_maker.py 2018_spring.MOV
Making a video...
>

参考文献

膨張差分法とキャニー法による線画の比較

アニメや漫画を線画にする際、白を膨張させてグレースケールとの差分を取る方法(以下、膨張差分法と呼ぶ)が多く用いられている。 しかし、実写に膨張差分法を適用したところ、実写の描写の細かさが影響してノイズが残りやすいことが分かった。 そこで膨張差分法とは別に、キャニー法という一般的なエッジ検出を適用して線画を生成し、両者の比較を行った。

処理コード

膨張差分法による線画生成をimage_2_linedraw_4_anime()、キャニー法による線画生成をimage_2_linedraw_4_photo()に実装。 前者は前述や参考文献でも述べられている通り、白を膨張させてグレースケールとの差分を取っている。 後者はブラーをかけたあとキャニー法を適用している。 main()を実行すると、画像ディレクトリーに入っているすべての画像に対して、2種類の手法で線画を生成する。

# coding: utf-8
import sys
import os
import random
import numpy as np
import cv2
import matplotlib
import numpy as np
import matplotlib.pyplot as plt


class Main:
    def __init__(self):
        self.blur_size = (7,7)
        self.dilate_size = (3,3)


    def main(self):
        image_dir = '../data/images/'
        image_names = [f for f in os.listdir(image_dir) if f[-4:].lower()=='.jpg' or f[-4:].lower()=='.png']
        image_names = [f for f in image_names if f.find('_linedraw.jpg')==-1]

        for image_name in image_names:
            self.image_2_linedraw_4_anime(image_dir, image_name)
            self.image_2_linedraw_4_photo(image_dir, image_name)


    def image_2_linedraw_4_anime(self, image_dir, image_name):
        img = cv2.imread(image_dir+image_name, cv2.IMREAD_GRAYSCALE) # Gray
        img_dilate = cv2.dilate(img, np.ones(self.dilate_size), iterations=1)  # Dilation
        img = cv2.absdiff(img_dilate, img)  # diff
        img = cv2.bitwise_not(img)  # Black and white inversion

        cv2.imwrite(image_dir+image_name.replace('.jpg', '_anime_linedraw.jpg'), img)  # Save


    def image_2_linedraw_4_photo(self, image_dir, image_name):
        img = cv2.imread(image_dir+image_name)
        img = cv2.GaussianBlur(img, self.blur_size, 0)  # Blur
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)  # BGR to Gray

        img = cv2.Canny(img, threshold1=90, threshold2=110)  # Canny method
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # Gray to BGR
        img = cv2.bitwise_not(img)  # Black and white inversion

        cv2.imwrite(image_dir+image_name.replace('.jpg', '_photo_linedraw.jpg'), img)


if __name__ == "__main__":
    Main().main()

膨張差分法とキャニー法による線画作成の比較

左から元画像、膨張差分法を用いたimage_2_linedraw_4_anime()の結果、キャニー法を用いたimage_2_linedraw_4_photo()の結果。 画像が小さいのでクリックして拡大して見てください。

アニメから線画を生成

膨張差分法の最初の4枚は素晴らしいほど、上手く線画になっている。 PaintsChainerが上手く機能するのも納得の訓練データ生成処理と言える。 最後だけ漫画と実写の共存だが、膨張差分法はTシャツの漫画が上手くできてて、キャニー法は女性の上半身が上手くできてる。 両者を融合すると良い感じになると思う。

f:id:Shoto:20180318035809j:plainf:id:Shoto:20180318035811j:plainf:id:Shoto:20180318035814j:plain
f:id:Shoto:20180318035820j:plainf:id:Shoto:20180318035825j:plainf:id:Shoto:20180318035830j:plain
f:id:Shoto:20180318035741j:plainf:id:Shoto:20180318035744j:plainf:id:Shoto:20180318035748j:plain
f:id:Shoto:20180318035757j:plainf:id:Shoto:20180318035801j:plainf:id:Shoto:20180318035804j:plain
f:id:Shoto:20180318035706j:plainf:id:Shoto:20180318035713j:plainf:id:Shoto:20180318035719j:plain

写真から線画を生成

やはり膨張差分法だと若干ノイズが大きい感じ。 かと言ってキャニー法がさらに良いかと言うと微妙だけど。 一応、最初の4枚は実写から線画の生成が個人的に上手くいったと思ったものを載せている。 ただ、最後だけは膨張差分法の方が良い感じだと思う。ノーマン・ロックウェルっぽささえある。 背景が白くて、対象物の足と靴がシンプルだからかも知れない。 というか、実写でも前処理でシンプルにすると、膨張差分法の方が上手くいくのかも。 この辺は暇ができたらリサーチしたい。

f:id:Shoto:20180318035631j:plainf:id:Shoto:20180318035635j:plainf:id:Shoto:20180318035640j:plain
f:id:Shoto:20180318035648j:plainf:id:Shoto:20180318035652j:plainf:id:Shoto:20180318035656j:plain
f:id:Shoto:20180318035839j:plainf:id:Shoto:20180318035842j:plainf:id:Shoto:20180318035846j:plain
f:id:Shoto:20180318035850j:plainf:id:Shoto:20180318035854j:plainf:id:Shoto:20180318035858j:plain
f:id:Shoto:20180318035905j:plainf:id:Shoto:20180318035909j:plainf:id:Shoto:20180318035914j:plain

まとめ

実写を線画にする際は、膨張差分法だとノイズが多くて、キャニー法だとノイズが少なすぎる傾向にある。 なので、その中間地点の線画が出せると良い感じになると思う。

参考文献

ビットコイン対円のティッカーを可視化

前回、pybitflyerを利用して bitFlyerからビットコイン対円のティッカーを2秒ごとに10分間分取得した。

testpy.hatenablog.com

ティッカーを取得、とかさらっと言っているが、実はFX初めて。 ただ株は少しやったことがあって、そうゆう人間からすると、 ティッカーを取得したら売買結果だった、というのはちょっと違和感があった。 と言うのもティッカーって、AppleならAAPLとか、Yahoo!Japanなら4689とか、企業を指すものだと思ってたから。 なので、FXで言うところのティッカーって何なのかを知るために、用語を調べたり、matplotlibで可視化してみた。

用語

まずは用語を少し調べてみたけど、こんな感じだろうか。。 間違ってたら教えて下さい。

  • best_ask: 最高買い価格
  • best_bid: 最低売り価格
  • best_ask_size: 最高買い価格の数
  • best_bid_size: 最低売り価格の数
  • ltp: 最終取引価格
  • total_ask_depth: 買い注文総数
  • total_bid_depth: 売り注文総数
  • volume_by_product: 価格ごとの出来高

best_ask, best_bid, ltpの単位は円でOKだと思うけど、 best_ask_size, best_bid_size, total_ask_depth, total_bid_depth, volume_by_productは単位はBTCかな? CSVのファイルの数値は少数第8位まであって、0.00000001BTC=1SatoshiなのでOKとは思うが。

なお、total_ask_depth, total_bid_depthのイメージは以下が分かりやすいと思う。 下のグラフの両端の数値がそれらに該当すると思われる。

f:id:Shoto:20171116004239j:plain https://en.wikipedia.org/wiki/Market_depth

コード

前回read_csv()で取得できる dfを次のplot_ticker()の引数として渡すと可視化できる。

def plot_ticker(self, df):
    fig = plt.figure(figsize=(16, 9), dpi=100)
    fig.patch.set_facecolor('white')

    ax = fig.add_subplot(3, 1, 1)
    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
    ax.plot(df.index, df['best_ask'], ls='-', color='red', label='best_ask')
    ax.plot(df.index, df['best_bid'], ls='-', color='blue', label='best_bid')
    ax.plot(df.index, df['ltp'], ls='-', color='gray', label='ltp')
    plt.legend(loc='best')
    ax.grid()

    ax = fig.add_subplot(3, 1, 2)
    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
    ax.plot(df.index, df['total_ask_depth'], ls='-', color='red', label='total_ask_depth')
    ax.plot(df.index, df['total_bid_depth'], ls='-', color='blue', label='total_bid_depth')
    plt.legend(loc='best')
    ax.grid()

    ax = fig.add_subplot(3, 1, 3)
    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
    # ax.plot(df.index, df['volume'], ls='-', color='green', label='volume')
    ax.plot(df.index, df['volume_by_product'], ls='-', color='black', label='volume_by_product')
    plt.legend(loc='best')
    ax.grid()

    plt.tight_layout()
    plt.show()

結果

f:id:Shoto:20171116004242p:plain

  • 上グラフ
    best_askが上、best_bidが下のラインを支えていて, ltpがその中間にあることが確認できる。
  • 中グラフ
    観測した中ではtotal_ask_depthtotal_bid_depthをずっと下回っている。 あとtotal_ask_depthが一定なのに足して、total_bid_depthがたまに増える。
  • 下グラフ
    volume_by_productが急激に減ったと思ったら、BTCの価格が上がっている。

参考文献

bitFlyerからビットコイン対円のティッカーを取得

かなり前からビットコインが熱かったけど放置してたら完全に乗り遅れた。 今更ながらビットコインの波に乗ろうと思う。 まずは対円の情報をbitFlyerから取得しようと思う。 簡単のため、pybitflyerというPythonライブラリーを使う。

コード

api = pybitflyer.API()APIを取得して、 ticker = api.ticker(product_code='BTC_JPY')でティッカーが取得できる。 keyもsecretもいらず非常にシンプルだが、 サーバーへ2秒置きにアクセスして10分間分のティッカーを取得し、 CSVファイルに保存して、読み込みデータを表示するまでのコードを書いた。

# coding: utf-8
import pybitflyer
import pandas as pd
from time import sleep
from progressbar import ProgressBar

SLEEP_SECOND = 2  # データ取得間隔(秒)
N_MINUTE = 10  # データ取得時間(分)
DIR_DATA = '../data/'  # データ格納フォルダー
PRODUCT_CODE = 'BTC_JPY'  # 取得するデータ
FILE_BTC_JPY = 'BTC_JPY_TICKER.csv'  # TICKERファイル名


class Agent:
    def __init__(self):
        pass


    def main(self):
        self.get_ticker()
        df = self.read_ticker()


    """
  TICKER
  """
    def get_ticker(self):
        api = pybitflyer.API()
        tickers = []
        count = 60 // SLEEP_SECOND * N_MINUTE  # サーバーへのアクセス回数
        pb = ProgressBar(max_value=count)
        for i in range(count):
            ticker = api.ticker(product_code=PRODUCT_CODE)
            tickers.append(ticker)
            sleep(SLEEP_SECOND)
            pb.update(i)  # Update progressbar

        df = pd.DataFrame(tickers)
        df.to_csv(DIR_DATA+FILE_BTC_JPY, index=False)


    def read_ticker(self):
        df = pd.read_csv(DIR_DATA+FILE_BTC_JPY)
        df['timestamp'] = pd.to_datetime(df['timestamp'])  # stringからdatetimeへ

        keys = ['timestamp', 'product_code', 'tick_id', \
                'best_ask', 'best_ask_size', 'best_bid', 'best_bid_size', \
                'total_ask_depth', 'total_bid_depth', 'ltp', \
                'volume', 'volume_by_product']
        df = df[keys]  # カラムを表示したい順に並べる
        df = df.set_index('timestamp')  # timestampをindexに

        # debug
        print(df.head().to_string())
        print('')
        print(df.tail().to_string())
        print('')

        return df


if __name__ == "__main__":
    Agent().main()

結果

上記のコードをagent.pyに保存して、一つ上にdataフォルダーを作って置けば、 実行すればdataフォルダー内にCSVファイルが出力されているはず。 それを読み込んだ結果が以下の通り。 なかなか2秒ごとにデータ取得できないもんですな。

> python .\agent.py
                        product_code   tick_id  best_ask  best_ask_size  best_bid  best_bid_size  total_ask_depth  total_bid_depth       ltp         volume  volume_by_product
timestamp
2017-11-15 13:57:22.963      BTC_JPY  16622557  811140.0       1.200000  810785.0       1.822747      1781.033755      4285.326827  810785.0  207454.303454       19592.381994
2017-11-15 13:57:25.343      BTC_JPY  16622606  811139.0       0.682363  810490.0       3.907947      1782.946859      4276.717027  810490.0  207466.718872       19603.633094
2017-11-15 13:57:27.250      BTC_JPY  16622669  810678.0       0.028423  810490.0       3.907947      1784.647044      4277.266086  810679.0  207473.906758       19604.627594
2017-11-15 13:57:28.953      BTC_JPY  16622708  811068.0       0.949000  810491.0       0.005685      1785.397923      4273.360839  810490.0  207478.447690       19605.862525
2017-11-15 13:57:31.313      BTC_JPY  16622740  811066.0       0.307494  810491.0       0.012506      1784.854531      4273.544345  811067.0  207487.503203       19607.357628

                        product_code   tick_id  best_ask  best_ask_size  best_bid  best_bid_size  total_ask_depth  total_bid_depth       ltp         volume  volume_by_product
timestamp
2017-11-15 14:08:08.573      BTC_JPY  16637242  814999.0       0.800000  814900.0       0.167781      1714.455080      4223.328344  814900.0  206588.403105       19472.318694
2017-11-15 14:08:11.090      BTC_JPY  16637282  815000.0       5.638385  814697.0       0.320000      1714.380850      4225.357163  815000.0  206584.681053       19473.110742
2017-11-15 14:08:12.957      BTC_JPY  16637289  815000.0       4.738385  814698.0       1.120080      1713.758990      4225.517242  815000.0  206591.556565       19474.010742
2017-11-15 14:08:15.550      BTC_JPY  16637359  815000.0       1.123285  814800.0       0.160200      1710.768781      4219.641562  815000.0  206576.876997       19470.398019
2017-11-15 14:08:17.817      BTC_JPY  16637439  815000.0       0.784285  814800.0       0.160200      1700.492640      4216.927261  815000.0  206562.011497       19469.980019

参考文献

Twitter APIを使った検索方法

Twitter分析をすることになったため、APIを使った検索について調査検証を行った。 結論から言うと、公式のAPIは、パラメーターが少なくロクな検索ができないのだが、 クエリにパラメーターを含めることで様々な検索が可能になることが分かった。 以下にその方法を示す。

1. APIとの接続

APIへ接続する前に、CONSUMER_KEY, CONSUMER_SECRET, ACCESS_TOKEN, ACCESS_TOKEN_SECRETを取得する。 Twitter REST APIの使い方とか読み進めていけばできると思う。 僕の場合は、昔に設定してたのが残っていたので、それを使用することにした。

取得できたら、設定ファイル(今回はconfig.py)を作成して、以下のように各コードを書き留めておく('xxx'の部分を変更)。 こうすることで、検索ファイルにべた書きするより多少リスクが軽減される。

CONSUMER_KEY        = 'xxx'
CONSUMER_SECRET     = 'xxx'
ACCESS_TOKEN        = 'xxx'
ACCESS_TOKEN_SECRET = 'xxx'

APIには、OAuthを簡単にしてくれるライブラリー requests_oauthlibOAuth1Sessionを使って接続する。 検索ファイル(今回はagent.py)を作成して、以下のように設定ファイル(config.py)から 各コードを取得し、OAuth1Sessionに渡すとAPIに接続できる。

from requests_oauthlib import OAuth1Session
import config

class Agent:

    ...

    def connect_api(self):
        api = OAuth1Session(config.CONSUMER_KEY,
                            config.CONSUMER_SECRET,
                            config.ACCESS_TOKEN,
                            config.ACCESS_TOKEN_SECRET)

        return api

2. 検索条件の設定

公式ドキュメント を見ると分かるが、ほとんど条件が指定できない。 しかし、queryに以下のように記述することで、様々な条件で検索することが可能になる。

    def make_params(self):
        query = '猫 filter:images min_replies:10 min_retweets:500 min_faves:500 exclude:retweets'
        params = {
            'q': query,
            'count': 20
        }

        return params

上記の条件は以下の通り。 他の条件については、Twitterの検索APIについて が詳しい。

key value exmaple discription
filter images 画像があるツイート
min_replies 10 リプライ数が指定値以上のツイート
min_retweets 500 リツイート数が指定値以上のツイート
min_faves 500 ライク数が指定値以上のツイート
exclude retweets リツイートでないツイート?

3. tweetの検索

上記で取得したapiparamsを引数とする検索用メソッドを以下のように作る。 ツイートはstatusesに入っているのでresultとして返す。

    def search_tweet(self, api, params):
        url = 'https://api.twitter.com/1.1/search/tweets.json'
        req = api.get(url, params=params)

        result = []
        if req.status_code == 200:
            tweets = json.loads(req.text)
            result = tweets['statuses']        
        else:
            print("ERROR!: %d" % req.status_code)
            result = None

        assert(len(result) > 0)

        return result

statusに入っているツイートは各々 Tweet data dictionariesTweet Data Dictionaryのkeyを持っている。 とりあえず簡単な分析に必要なkeyを標準出力させるメソッドは次の通り。

    def output_tweets(self, result):
        for r in result:
            for k,v in r.items():
                if k in ['text', 'retweet_count', 'favorite_count', 'id', 'created_at']:
                    print(k+':')
                    print(v)
                    print('    ')
            print('-----------------------------------------------------------------')

以下が出力結果。条件通りの検索結果が返ってきていることが分かる。

id:
925369434549624832

text:
我が家のネコさま。←むかし いま→ https://t.co/qg3TXEZpUv

favorite_count:
63259

created_at:
Tue Oct 31 14:30:40 +0000 2017

retweet_count:
30089

-----------------------------------------------------------------
id:
924974322115940357

text:
クロネコ https://t.co/JNZLtFHn2z

favorite_count:
19209

created_at:
Mon Oct 30 12:20:38 +0000 2017

retweet_count:
5106

-----------------------------------------------------------------
id:
924485965237731328

text:
【猫から学ぶ女子力】真枝アキ『彼氏のネコがかわいくない!』 https://t.co/y8lmBlbHkC #ツイ4 https://t.co/3J3AxKVnvV

favorite_count:
3255

created_at:
Sun Oct 29 04:00:04 +0000 2017

retweet_count:
847
-----------------------------------------------------------------

4. 文字化け対策

なおWindows環境だと、コマンドプロンプト上で、 上記のように日本語等を出力しようとすると、文字化けが起こる可能性が高い。 理由はPythonutf-8で扱っているのに、コマンドプロンプトはcp932を扱っているから。 もし文字化けが起こった場合は、以下を実行してみるとよい。

これでちょっとだけストレスが少なくなる。

参考文献

Playerと複数生成されるSpawnerの位置を取得

UnityでFlappy Birdを作ってみたが、すぐにPlayer(Flappy Bird)を自動制御したくなったので、まずはPlayerとSpawner(土管)の位置を取得してみた。


Flappy Bird Made with Unity

Spawnerの設定

Spawnerは土管で、1秒ごとにY軸をランダムにずらして生成し、Playerに向かわせている。 Spawnerが生成された際にタグを付けることで、Player側でSpawnerの位置を取得することができる。

 public GameObject wallPrefab;  // Spawnerとして土管のprefabを設定する
    public float interval;  // 1秒ごとに生成する
    private GameObject spawner;  // 生成されたSpawnerを格納する
    public string tag;  // Spawnerにタグを付ける

    // Use this for initialization
    IEnumerator Start () {
        while (true) {
            // 画面外の座標(12, 2)をベースにY軸を0.0~4.0だけランダムにずらしてSpawnerを生成するよう設定
            transform.position = new Vector2(transform.position.x, Random.Range(0.0f, 4.0f));
            // Spawnerを生成
            spawner = Instantiate(wallPrefab, transform.position, transform.rotation);
            // 生成したSpawnerにタグ付け!
            spawner.tag = "Wall";

            yield return new WaitForSeconds(interval);
        }
    }

なお、Spawnerに指定した土管prefabには1秒ごとに5ずつX軸マイナス方法へ向かうように設定してある。

Playerの設定

Playerはフレームごとに状態を観測することができる。 このスクリプトがPlayerのインスタンスになるので、Playerの位置は簡単に取得できるのだが、 Spawnerの位置は、上記のようにタグ付けをすることで、Playerからも取得することができるようになる。

    public float jumpPower;  // Playerは強さ5でジャンプ
    private int num_wall = 2;  // SpawnerはPlayerに最も近い2つだけ位置を取得

    // Update is called once per frame
    void Update () {
        // Initialize state
        // 観測する位置は以下の5つ
        // [player_pos_y, wall1_pos_x, wall1_pos_y, wall2_pos_x, wall2_pos_y]
        List<float> state = new List<float>();  // 位置を格納するリスト

        // Set player position x
        var pos_player = this.transform.position;  // Playerの位置
        state.Add(pos_player.y);  // リストにPlayerのY座標を追加

        // Set wall position x and y
        GameObject[] walls = GameObject.FindGameObjectsWithTag("Wall");  // タグから全Spawnerを取得
        foreach (var wall in walls) {  // 全Spawnerについて生成された順に
            var pos_wall = wall.gameObject.transform.position;  // Spawnerの位置を取得
            // SpawnerがPlayerより前にいて、かつSpawnerの数がリストに2つ未満の場合
            // If a wall is behind the player & the number of wall in state is not max
            if ((pos_player.x <= pos_wall.x + 1.0f) && (state.Count < 1+2*num_wall)) {  
                state.Add(pos_wall.x);  // SpanerのX座標を格納
                state.Add(pos_wall.y);  // SpanerのY座標を格納
            }
        }

        // debug
        // 取得したPlayerとSpanerの位置を出力
        string text = "";
        foreach (var s in state) {
            text = text + s.ToString() + " ";
        }
        print(text);

    }

位置の取得結果

Playerのスクリプトで設定したコンソールへの出力は以下のようになる。 数値の内容は[player_pos_y, wall1_pos_x, wall1_pos_y, wall2_pos_x, wall2_pos_y]。 SpawnerがX座標マイナス方向にずれていくのが観測できる。 PlayerのX座標は変わらないので、Y座標だけ取得している。

f:id:Shoto:20171103161032p:plain

ソースコード

SpawnerとPlayerの全ソースコードを以下に示しておく。

  • Spawner.cs
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class Spawner : MonoBehaviour {

    public GameObject wallPrefab;
    public float interval;
    private GameObject spawner;
    public string tag;

    // Use this for initialization
    IEnumerator Start () {
        while (true) {
            transform.position = new Vector2(transform.position.x, Random.Range(0.0f, 4.0f));
            spawner = Instantiate(wallPrefab, transform.position, transform.rotation);
            spawner.tag = "Wall";

            yield return new WaitForSeconds(interval);
        }
    }

    // Update is called once per frame
    void Update () {

    }
}
  • Player.cs
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class Player : MonoBehaviour {

    public float jumpPower;

    private int num_wall = 2;

    // Use this for initialization
    void Start () {

    }

    // Update is called once per frame
    void Update () {
        // Initialize state
        // [player_pos_y, wall1_pos_x, wall1_pos_y, wall2_pos_x, wall2_pos_y]
        List<float> state = new List<float>();

        // Set player position x
        var pos_player = this.transform.position;
        state.Add(pos_player.y);

        // Set wall position x and y
        GameObject[] walls = GameObject.FindGameObjectsWithTag("Wall");
        foreach (var wall in walls) {
            var pos_wall = wall.gameObject.transform.position;
            // If a wall is behind the player & the number of wall in state is not max
            if ((pos_player.x <= pos_wall.x + 1.0f) && (state.Count < 1+2*num_wall)) {  
                state.Add(pos_wall.x);
                state.Add(pos_wall.y);
            }
        }

        // debug
        string text = "";
        foreach (var s in state) {
            text = text + s.ToString() + " ";
        }
        print(text);

        // Manual operation
        if (Input.GetButtonDown("Jump")) {
            GetComponent<Rigidbody2D>().velocity = new Vector2(0, jumpPower);
        }
    }

    /*
   void OnCollisionEnter2D (Collision2D other) {
       // Application.LoadLevel(Application.loadedLevel);  //old function
        UnityEngine.SceneManagement.SceneManager.LoadScene(UnityEngine.SceneManagement.SceneManager.GetActiveScene().buildIndex);    
   }
   */

    void OnCollisionEnter2D (Collision2D other) {
        Invoke("Restart", 1.0f);     
    }

    void Restart () {
        // Application.LoadLevel(Application.loadedLevel);  //old function
        UnityEngine.SceneManagement.SceneManager.LoadScene(UnityEngine.SceneManagement.SceneManager.GetActiveScene().buildIndex);   
    }

}

参考文献