TensorBoardのEventsによる学習状況の可視化

TensorFlowに学習状況を可視化するTensorBoardというツールがあるのだが、 コスト関数と精度を可視化してみたかったので使ってみた。

実装

Deep Learning入門としてのFizzBuzz問題 のモデル学習のメソッドをベースにしている。 コメント# Logging data for TensorBoard# Write log to TensorBoardの直下のコードがTensorBoardを利用する部分のコード。

def train_model(self, data, model):
    # dataのセット
    train_X = data[0]
    train_Y = data[1]
    test_X = data[2]
    test_Y = data[3]

    # modelのセット
    X = model['X']
    Y = model['Y']
    Y_ = model['Y_']
    loss = model['loss']
    train_step = model['train_step']

    # 定義
    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1)), tf.float32))

    # 初期化
    sess = tf.InteractiveSession()
    tf.initialize_all_variables().run()

    # Logging data for TensorBoard
    _ = tf.scalar_summary('loss', loss)
    _ = tf.scalar_summary('accuracy', accuracy)
    writer = tf.train.SummaryWriter('./log/', graph_def=sess.graph_def)

    for epoch in range(10000+1):
        # データのランダマイズ
        p = np.random.permutation(range(len(train_X)))
        train_X, train_Y = train_X[p], train_Y[p]

        # 学習
        for start in range(0, train_X.shape[0], 100):
            end = start + 100
            sess.run(train_step, feed_dict={X: train_X[start:end], Y_: train_Y[start:end]})

        # テスト
        if epoch % 100 == 0:
            # 教師データのコスト関数
            loss_train = sess.run(loss, feed_dict={X: train_X, Y_: train_Y})
            # 教師データの精度
            accu_train = sess.run(accuracy, feed_dict={X: train_X, Y_: train_Y})
            # テストデータのコスト関数
            loss_test = sess.run(loss, feed_dict={X: test_X, Y_: test_Y})
            # テストデータの精度
            accu_test = sess.run(accuracy, feed_dict={X: test_X, Y_: test_Y})
            # 標準出力
            std_output = 'Epoch: %s, \t Train Loss: %s, \t Train Accracy: %s, \t Test Loss: %s, \t Test Accracy: %s'
            print std_output % (epoch, round(loss_train, 6), round(accu_train, 6), round(loss_test, 6), round(accu_test, 6))

        # Write log to TensorBoard
        summary_str = sess.run(tf.merge_all_summaries(), feed_dict={X: test_X, Y_: test_Y})
        writer.add_summary(summary_str, epoch)

 

解説

コスト関数は引数として渡される。 実装はコメントの通り。

# loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(Y, Y_))
loss = model['loss']

 

精度は以下の通り。 詳しい実装内容はこちらを参照。

accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1)), tf.float32))

 

ここからが、TensorBoardの実装部分。 1、2行目は可視化したいコスト関数と精度の変数を出力名と共に、tf.scalar_summary()に渡している。 3行目は出力先を指定している。

# Logging data for TensorBoard
_ = tf.scalar_summary('loss', loss)
_ = tf.scalar_summary('accuracy', accuracy)
writer = tf.train.SummaryWriter('./log/', graph_def=sess.graph_def)

 

tf.merge_all_summaries()は上記で設定した_(コスト関数と精度)を指す。 これらにfeed_dictに指定したテストデータとテストラベルを渡して、結果をTensorBoardに出力する。 出力間隔はepochとしたので、for文の中で定義している。

for epoch in range(10000+1):

    ...

    # Write log to TensorBoard
    summary_str = sess.run(tf.merge_all_summaries(), feed_dict={X: test_X, Y_: test_Y})
    writer.add_summary(summary_str, epoch)

 

学習終了したら、別ターミナルで以下のコードを実行すると、 http://0.0.0.0:6006からTensorBoardを見ることができる。 ただし、1行目はvirtualenvの場合のみ実行。

$ source ~/tensorflow/bin/activate
(tensorflow) $ python ~/tensorflow/lib/python2.7/site-packages/tensorflow/tensorboard/tensorboard.py --logdir=./log/
Starting TensorBoard 16 on port 6006
(You can navigate to http://0.0.0.0:6006)

 

結果

次のように可視化される。 accuracyがサッチってないので、バッチ数などのパラメーターを変えるか、epoch数を増やす必要があることが分かる。

f:id:Shoto:20161218214154p:plainf:id:Shoto:20161218214200p:plain

参考文献

Deep Learning入門としてのFizzBuzz問題

FizzBuzz問題とは何か。 こちらから引用させてもらう。

Fizz-Buzz問題の例はこんな感じだ。
1から100までの数をプリントするプログラムを書け。ただし3の倍数のときは数の代わりに「Fizz」と、5の倍数のときは「Buzz」とプリントし、3と5両方の倍数の場合には「FizzBuzz」とプリントすること。
ちゃんとしたプログラマであれば、これを実行するプログラムを2分とかからずに紙に書き出せるはずだ。怖い事実を聞きたい? コンピュータサイエンス学科卒業生の過半数にはそれができないのだ。自称上級プログラマが答えを書くのに10-15分もかかっているのを見たこともある。

これをDeep Learningで学習したモデルに行わせる。 なぜDeep Learning入門用ベンチマークかと言うと、 教師データの101〜1023とテストデータの0〜100の作成コードが、ラベリングも含めて簡単に実装できるし、 モデルの設計も隠れ層が1つとシンプルなので速く収束し、そこそこ高い精度が出せるから。 データ作成とDeep Learningの実装はTensorFlowコトハジメ Fizz-Buzz問題が詳しい。 以下では同様にTensorFlowを使っているが、実装やパラメーターが異なるし、説明は先のリンクの方が丁寧である。 FizzBuzz問題、DeepLearning、TensorFlowについて、すべて詳しくない方はこちらの記事を読んでおくことをお勧めする。

フロー

main関数の3ステップがフローとなる。

class DLFizzBuzz:
    def main(self):
        # 1. FizzBuzzデータを生成して取得する。
        data = DataFizzBuzz().main()

        # 2. Deep Learnigモデルを設計する。
        model = self.design_model(data)

        # 3. Deep Learningモデルを学習させる。
        self.train_model(data, model)

データ生成

生成コードは以下の通り。 main()を呼び出すと教師データ、教師ラベル、テストデータ、テストラベルをリストで取得できる。

BORDER = 101
NUM_DIGITS = 10

class DataFizzBuzz:
    def main(self):
        # Train
        train_data = np.array([self.binary_encode(i, NUM_DIGITS) for i in range(BORDER, 2**NUM_DIGITS)])
        train_label = np.array([self.fizz_buzz_encode(i) for i in range(BORDER, 2**NUM_DIGITS)])

        # Test
        test_data = np.array([self.binary_encode(i, NUM_DIGITS) for i in range(0, BORDER)])
        test_label = np.array([self.fizz_buzz_encode(i) for i in range(0, BORDER)])

        # Collect
        data = [train_data, train_label, test_data, test_label]

        return data


    def binary_encode(self, i, num_digit):
        binary = np.array([i >> d & 1 for d in range(NUM_DIGITS)])

        return binary


    def fizz_buzz_encode(self, i):
        if i % 15 == 0:
            result = np.array([0, 0, 0, 1])
        elif i % 5 == 0:
            result = np.array([0, 0, 1, 0])
        elif i % 3 == 0:
            result = np.array([0, 1, 0, 0])
        else:
            result = np.array([1, 0, 0, 0])

        return result

モデル設計

教師データは101〜1023なので、range(101, 210)となり、2進数で10桁に収まる。 またテストデータは0〜100だが、とりあえずFizz/Buzz/FizzBuzz/その他が区別できればよいので、4次元ベクトルで表現できる。 そのためノード数は、入力層が10、出力層が4となる。 また隠れ層を100とした場合、モデルの設計は以下のようになる。 なお、重み、バイアス、学習関数等は初歩的なものを用いている。

def design_model(self, data):
    # 入力層
    X  = tf.placeholder(tf.float32, [None, data[0].shape[1]])

    # 隠れ層
    W1 = tf.Variable(tf.random_normal([data[0].shape[1], 100], stddev=0.01))
    B1 = tf.Variable(tf.zeros([100]))
    H1 = tf.nn.relu(tf.matmul(X, W1) + B1)

    # 出力層
    W2 = tf.Variable(tf.random_normal([100, data[1].shape[1]], stddev=0.01))
    B2 = tf.Variable(tf.zeros([data[1].shape[1]]))
    Y = tf.matmul(H1, W2) + B2

    # 正解
    Y_ = tf.placeholder(tf.float32, [None, data[1].shape[1]])

    # 学習関数
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(Y, Y_))
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

    model = {'X': X, 'Y': Y, 'Y_': Y_, 'loss': loss, 'train_step': train_step}

    return model

モデル学習

まずデータとモデルを引数として渡し、変数にセットする。

def train_model(self, data, model):
    # dataのセット
    train_X = data[0]
    train_Y = data[1]
    test_X = data[2]
    test_Y = data[3]

    # modelのセット
    X = model['X']
    Y = model['Y']
    Y_ = model['Y_']
    loss = model['loss']
    train_step = model['train_step']

学習中、定期的に精度を検証する。 以下の1行についてはTensorFlowによる精度計算の流れを追うで解説している。

    # 定義
    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1)), tf.float32))

教師データは101〜1023の923レコードあり、これらをランダマイズして、10レコードずつ学習させる。 この過程を1エポックとして10000回行う。 また100エポックごとにテストを行い、コスト、教師データの精度、テストデータの精度を算出して標準出力させる。

    # 初期化
    sess = tf.InteractiveSession()
    tf.initialize_all_variables().run()

    for epoch in range(10000+1):
        # データのランダマイズ
        p = np.random.permutation(range(len(train_X)))
        train_X, train_Y = train_X[p], train_Y[p]

        # 学習
        for start in range(0, train_X.shape[0], 10):
            end = start + 10
            sess.run(train_step, feed_dict={X: train_X[start:end], Y_: train_Y[start:end]})

        # テスト
        if epoch % 100 == 0:
            # コスト
            lo55 = sess.run(loss, feed_dict={X: train_X, Y_: train_Y})
            # 教師データの精度
            accu_train = sess.run(accuracy, feed_dict={X: train_X, Y_: train_Y})
            # テストデータの精度
            accu_test = sess.run(accuracy, feed_dict={X: test_X, Y_: test_Y})
            # 標準出力
            print 'Epoch: %s, \t Loss: %-8s, \t Train Accracy: %-8s, \t Test Accracy: %-8s' % (epoch, lo55, accu_train, accu_test)

結果と課題

最終的に、コストが0.002、教師データの精度が100%、テストデータの精度が95%となった。 初期の段階では、教師データの精度が94%にも関わらず、テストデータの精度が100%になったが、 その後、過学習が起こり精度が逆転した。 今後はDropoutを実装して過学習を防いだり、パラメーターの最適化を行う。 あと学習曲線も表示させる。

参考文献

Pythonの並列処理

最近Pythonの並列処理をよく使うのでまとめておく。

基本形

並列処理したいメソッドを別に書いてPoolから呼び出す。 multiprocessing.cpu_count()はシステムのCPU数を返す。 僕の環境では4。デュアルコアなのでスレッド数だと思う。

import multiprocessing

def f(x):
    return x*x

n = multiprocessing.cpu_count()
p = multiprocessing.Pool(n)
params = range(1,4)
result = p.map(f, params)

1から3までを二乗しているので、次のような結果になる。

>>> print result
[1, 4, 9]

複数の引数を渡す

Pool().map()には引数は1つしか渡せない。 しかし、これは複数の引数を1つの引数にまとめることで解決する。 下記ではdictを利用しているが、listでもtupleでもいい。

import multiprocessing

def f(param):
    return param['x']*param['y']

n = multiprocessing.cpu_count()
p = multiprocessing.Pool(n)
# Make one param including multi params.
params = [{'x': i, 'y': i+5} for i in range(1,4)]
result = p.map(f, params)

この結果、2つのパラメーターが渡され、 [1*(1+5), 2*(2+5), 3*(3+5)]を計算した結果が出力される。

>>> print result
[6, 14, 24]

プロセスのメモリ解放

並列処理するメソッド内で、DBを読み込む処理を行い、 更にこの並列処理の何度も繰り返し呼び出していたら、 OSError: [Errno 24] Too many open filesというエラーが出た。 プロセスを終了して、メモリを開放する必要があるが、 Pool().close()とPool().terminate()を呼び出すことでエラーが回避された。

import multiprocessing

def f(param):
    return param['x']*param['y']

n = multiprocessing.cpu_count()
p = multiprocessing.Pool(n)
params = [{'x': i, 'y': i+5} for i in range(1,4)]
result = p.map(f, params)
p.close()  # add this.
p.terminate()  # add this.

プロセス終了後もresultに結果が残っている。

>>> print result
[6, 14, 24]

ちなみに今回追加した処理の内容は次の通り。

  • close()
    これ以上プールでタスクが実行されないようにします。すべてのタスクが完了した後でワーカープロセスが終了します。

  • terminate()
    実行中の処理を完了させずにワーカープロセスをすぐに停止します。プールオブジェクトがガベージコレクトされるときに terminate() が呼び出されます。

参考文献

TensorFlowによる精度計算の流れを追う

accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1)), tf.float32))

Yをニューラルネットワークの出力層、Y_をその正解ラベルとした場合、学習モデルの精度を計算するとき、 TensorFlowでは上記ようにAPIを組合せて書くことがある。 式が長いので、各APIの処理の流れを、1つずつ出力しながら追ってみる。 分かれば非常にシンプルな処理の流れである。

計算の流れ

NumpyとTensorFlowをimportしておく。

>>> import tensorflow as tf
>>> import numpy as np

YとY_に[3, 4]のサンプルデータをセットする。 1行目を不正解としている。

>>> Y = np.array([
...                 [0.1, 0.2, 0.3, 0.4],
...                 [0.0, 0.8, 0.2, 0.0],
...                 [0.0, 0.4, 0.5, 0.1]
...             ])
>>> print Y
[[ 0.1  0.2  0.3  0.4]
 [ 0.   0.8  0.2  0. ]
 [ 0.   0.4  0.5  0.1]]

>>> Y_ = np.array([
...                 [0.0, 0.0, 1.0, 0.0],
...                 [0.0, 1.0, 0.0, 0.0],
...                 [0.0, 0.0, 1.0, 0.0]
...             ])
>>> print Y_
[[ 0.  0.  1.  0.]
 [ 0.  1.  0.  0.]
 [ 0.  0.  1.  0.]]

TensorFlowのSessionを開始。 Sessionと言えば、昨日、映画のSessionを見ましたが非常に良い映画でした。

>>> sess = tf.Session()

ここから精度計算におけるTensorFlowのAPIの説明。 まずはtf.argmax()から。 第2パラメーターに1をセットすると、行ごとに最大となる列を返す。 Yの場合、1行目が4列目の0.4、2行目が2列目の0.8、3行目が3列目の0.5が最大となる。 1行目は0からカウントされるので、以下のようになる。 Y_についても同様。 ちなみに、第2パラメーターに0をセットすると、列ごとに最大となる行を返す。

>>> print sess.run(tf.argmax(Y, 1))
[3 1 2]
>>> print sess.run(tf.argmax(Y_, 1))
[2 1 2]

続いてtf.equal()。 渡された2つのベクトルが一致しているか否かを見る。 今回は[3 1 2]と[2 1 2]を比較しているので次のようになる。

>>> eq = tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1))
>>> print sess.run(eq)
[False  True  True]

tf.cast()では第1パラメーターを第2パラメーターのデータ・タイプに変換する。 [False True True]もfloat32に変換すると次の通り。

>>> print sess.run(tf.cast(eq, tf.float32))
[ 0.  1.  1.]

最後はtf.reduce_mean()。 np.mean()と同じで平均を計算する。 [ 0. 1. 1.]の平均なので2/3となる。

>>> print sess.run(tf.reduce_mean(tf.cast(eq, tf.float32)))
0.666667

最後に全部繋げれば、一気に計算できる。

>>> accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1)), tf.float32))
>>> print sess.run(accuracy)
0.666667

jsmでYahooファイナンスのデータを片っ端から取得してMongoDBに保存する

はじめに

jsm(Japanese Stock Market)という Yahooファイナンスをクロールして株関連データを取得できるライブラリーがある。 Brand、Finance、Priceデータが取得できるので、東証一部に絞ってデータを片っ端から取得するコードを書いた。

インストール

pipを更新してからjsmをインストールする。 スクレイピングはBeautifulSoup4で行っているので(再)インストール

$ sudo pip install --upgrade pip
$ sudo pip install jsm
$ sudo pip install beautifulsoup4 -U

MongoDBのラッパー

save, read, deleteを用意したMongoDBのラッパーファイル(db.py)を作成する。 DB名とCollection名を初期化で指定する。 saveはdictのlistをdataとして渡す。

# -*- coding: utf-8 -*-
import sys
import pymongo

reload(sys)
sys.setdefaultencoding('utf-8')


class DB:
    def __init__(self, db_name, coll_name):
        self.db_name = db_name
        self.coll_name = coll_name


    def save(self, data):
        client = pymongo.MongoClient('localhost', 27017)
        db = client[self.db_name]
        coll = db[self.coll_name]

        coll.insert(data)


    def read(self):
        client = pymongo.MongoClient('localhost', 27017)
        db = client[self.db_name]
        coll = db[self.coll_name]

        data = [d for d in coll.find()]

        return data


    def delete(self):
        client = pymongo.MongoClient('localhost', 27017)
        db = client[self.db_name]
        coll = db[self.coll_name]

        coll.drop()

前準備

クローラーファイル(clawer.py)を作成し、必要なライブラリーと初期設定を行う。 上記のdb.pyもインポートしておく。 stockデータベースに、Brand、Finance、Priceコレクションを作成して取得したデータを保存する。

# -*- coding: utf-8 -*-
import sys
import jsm
from progressbar import ProgressBar
import pandas as pd
import datetime

from db import DB

reload(sys)
sys.setdefaultencoding('utf-8')

DB_STOCK = 'stock'  # Stock DB
COLL_BRAND = 'brand'  # Brand Collection
COLL_FINANCE = 'finance'  # Finace Collection
COLL_PRICE = 'price'  # Price Collection
START_DATE = datetime.date(2014, 1, 1)  # 株価の取得開始日

Brandデータを取得

get_brand()で全銘柄が取得できるが、取得状況の進捗を見たいのでカテゴリーごとに取得する。

def get_brands(self):
    db = DB(DB_STOCK, COLL_BRAND)
    #db.delete()

    categories = [
        '0050',  # 農林・水産業
        '1050',  # 鉱業
        '2050',  # 建設業
        '3050',  # 食料品
        '3100',  # 繊維製品
        '3150',  # パルプ・紙
        '3200',  # 化学
        '3250',  # 医薬品
        '3300',  # 石油・石炭製品
        '3350',  # ゴム製品
        '3400',  # ガラス・土石製品
        '3450',  # 鉄鋼
        '3500',  # 非鉄金属
        '3550',  # 金属製品
        '3600',  # 機械
        '3650',  # 電気機器
        '3700',  # 輸送機器
        '3750',  # 精密機器
        '3800',  # その他製品
        '4050',  # 電気・ガス業
        '5050',  # 陸運業
        '5100',  # 海運業
        '5150',  # 空運業
        '5200',  # 倉庫・運輸関連業
        '5250',  # 情報・通信
        '6050',  # 卸売業
        '6100',  # 小売業
        '7050',  # 銀行業
        '7100',  # 証券業
        '7150',  # 保険業
        '7200',  # その他金融業
        '8050',  # 不動産業
        '9050'   # サービス業
    ]

    q = jsm.Quotes()
    pb = ProgressBar(maxval=len(categories)).start()
    for i in range(len(categories)):
        lis = []
        try:
            brands = q.get_brand(categories[i])
        except:
            pass
        for b in brands:
            dic = {
                    'category': categories[i],
                    'ccode': b.ccode,
                    'market': b.market,
                    'name': b.name,
                    'info': b.info
            }
            lis.append(dic)
        db.save(lis)
        pb.update(i+1)

Finaceデータを取得

Brandデータから東証一部の証券コード(ccode)のみを取得する。

def get_target_ccodes(self):
    data = DB(DB_STOCK, COLL_BRAND).read()
    df_brand = pd.DataFrame(data)
    df_brand = df_brand[df_brand['market']=='東証1部']
    ccodes = df_brand['ccode'].tolist()

    return ccodes

取得した東証一部の証券コードを引数にしてFianceデータを取得する。

def get_finances(self, ccodes):
    db = DB(DB_STOCK, COLL_FINANCE)
    #db.delete()

    q = jsm.Quotes()
    pb = ProgressBar(maxval=len(ccodes)).start()
    lis = []
    for i in range(len(ccodes)):
        try:
            f = q.get_finance(ccodes[i])
        except:
            pass
        dic = {
                'ccode': ccodes[i],
                'market_cap': f.market_cap,
                'shares_issued': f.shares_issued,
                'dividend_yield': f.dividend_yield,
                'dividend_one': f.dividend_one,
                'per': f.per,
                'pbr': f.pbr,
                'eps': f.eps,
                'bps': f.bps,
                'price_min': f.price_min,
                'round_lot': f.round_lot,
                'years_high': f.years_high,
                'years_low': f.years_low
        }
        lis.append(dic)
        pb.update(i+1)

    db.save(lis)

Priceデータを取得

Financeデータと同様、東証一部の証券コードを引数にしてPriceデータを取得する。

def get_prices(self, ccodes):
    start_date = START_DATE
    end_date = datetime.date.today()

    db = DB(DB_STOCK, COLL_PRICE)
    #db.delete()

    q = jsm.Quotes()
    pb = ProgressBar(maxval=len(ccodes)).start()
    for i in range(len(ccodes)):
        lis = []
        try:
            prices = q.get_historical_prices(ccodes[i], jsm.DAILY, start_date, end_date)
        except:
            pass
        for p in prices:
            dic = {
                    'ccode': ccodes[i],
                    'date': p.date,
                    'open': p.open,
                    'high': p.high,
                    'low': p.low,
                    'close': p.close,
                    'volume': p.volume
            }
            lis.append(dic)
        db.save(lis)
        pb.update(i+1)

crawler.pyの全ソース

db.pyと同じ階層にファイルを置いて$ python crawler.pyを実行すればクロールを開始する。

# -*- coding: utf-8 -*-
import sys
import jsm
from progressbar import ProgressBar
import pandas as pd
import datetime

from db import DB

reload(sys)
sys.setdefaultencoding('utf-8')

DB_STOCK = 'stock'
COLL_BRAND = 'brand'
COLL_PRICE = 'price'
COLL_FINANCE = 'finance'
START_DATE = datetime.date(2014, 1, 1)


class Crawler:
    """
    Refer to https://pypi.python.org/pypi/jsm/0.19
    """

    def __init__(self):
        pass


    def main(self):
        print 'getting brands...'
        self.get_brands()

        print 'getting finances...'
        ccodes = self.get_target_ccodes()
        self.get_finances(ccodes)

        print 'getting prices...'
        self.get_prices(ccodes)


    def get_brands(self):
        db = DB(DB_STOCK, COLL_BRAND)
        #db.delete()

        categories = [
            '0050',  # 農林・水産業
            '1050',  # 鉱業
            '2050',  # 建設業
            '3050',  # 食料品
            '3100',  # 繊維製品
            '3150',  # パルプ・紙
            '3200',  # 化学
            '3250',  # 医薬品
            '3300',  # 石油・石炭製品
            '3350',  # ゴム製品
            '3400',  # ガラス・土石製品
            '3450',  # 鉄鋼
            '3500',  # 非鉄金属
            '3550',  # 金属製品
            '3600',  # 機械
            '3650',  # 電気機器
            '3700',  # 輸送機器
            '3750',  # 精密機器
            '3800',  # その他製品
            '4050',  # 電気・ガス業
            '5050',  # 陸運業
            '5100',  # 海運業
            '5150',  # 空運業
            '5200',  # 倉庫・運輸関連業
            '5250',  # 情報・通信
            '6050',  # 卸売業
            '6100',  # 小売業
            '7050',  # 銀行業
            '7100',  # 証券業
            '7150',  # 保険業
            '7200',  # その他金融業
            '8050',  # 不動産業
            '9050'   # サービス業
        ]

        q = jsm.Quotes()
        pb = ProgressBar(maxval=len(categories)).start()
        for i in range(len(categories)):
            lis = []
            try:
                brands = q.get_brand(categories[i])
            except:
                pass
            for b in brands:
                dic = {
                        'category': categories[i],
                        'ccode': b.ccode, 
                        'market': b.market, 
                        'name': b.name, 
                        'info': b.info
                }
                lis.append(dic)
            db.save(lis)
            pb.update(i+1)


    def get_finances(self, ccodes):
        db = DB(DB_STOCK, COLL_FINANCE)
        #db.delete()

        q = jsm.Quotes()
        pb = ProgressBar(maxval=len(ccodes)).start()
        lis = []
        for i in range(len(ccodes)):
            try:
                f = q.get_finance(ccodes[i])
            except:
                pass
            dic = {
                    'ccode': ccodes[i],
                    'market_cap': f.market_cap,
                    'shares_issued': f.shares_issued,
                    'dividend_yield': f.dividend_yield,
                    'dividend_one': f.dividend_one,
                    'per': f.per,
                    'pbr': f.pbr,
                    'eps': f.eps,
                    'bps': f.bps,
                    'price_min': f.price_min,
                    'round_lot': f.round_lot,
                    'years_high': f.years_high,
                    'years_low': f.years_low
            }
            lis.append(dic)
            pb.update(i+1)
        
        db.save(lis)
        

    def get_prices(self, ccodes):
        start_date = START_DATE
        end_date = datetime.date.today()

        db = DB(DB_STOCK, COLL_PRICE)
        #db.delete()

        q = jsm.Quotes()
        pb = ProgressBar(maxval=len(ccodes)).start()
        for i in range(len(ccodes)):
            lis = []
            try:
                prices = q.get_historical_prices(ccodes[i], jsm.DAILY, start_date, end_date)
            except:
                pass
            for p in prices:
                dic = {
                        'ccode': ccodes[i],
                        'date': p.date, 
                        'open': p.open, 
                        'high': p.high, 
                        'low': p.low,
                        'close': p.close,
                        'volume': p.volume
                }
                lis.append(dic)
            db.save(lis)
            pb.update(i+1)


    def get_target_ccodes(self):
        data = DB(DB_STOCK, COLL_BRAND).read()
        df_brand = pd.DataFrame(data)
        df_brand = df_brand[df_brand['market']=='東証1部']
        ccodes = df_brand['ccode'].tolist()

        return ccodes


if __name__ == '__main__':
    Crawler().main()

所感

Priceデータを取得するのに15時間ぐらいかかったので、マルチスレッドにした方がいい。 上記のコードで継続的に最新のデータを収集し続けるには、もう少し改良が必要だけど、 とりあえず、サクッとデータを収集して分析したい人は使えると思う。 jsm自体は、 更新が2015年で止まっていて、GitHubからは削除されているが、ソース自体はPyPIに上がってるので、 これまで自作してた人は動かなくなっても改修できると思う。

ReactのListとKey

ReactでList(ObjectのArray)を描画する際、mapを利用して各Objectの項目をセットする。

var data = [
        {id: 1, author: "Pete Hunt", text: "This is one comment"},
        {id: 2, author: "Jordan Walke", text: "This is *another* comment"}
];

var CommentList = React.createClass({
  render: function() {
    var commentNodes = this.props.data.map(function (comment) {
      return (
        <div author={comment.author}>
          {comment.text}
        </div>
      );
    });
    return (
      <div className="commentList">
        {commentNodes}
      </div>
    );
  }
});

 

もし描画する際に、各ObjectにユニークなKeyを持たせていない場合は警告が出る。

react.js:20541 Warning: Each child in an array or iterator should have a unique "key" prop.

 

以下のように、keyに何かしらのユニークな値を持たせると、以下の警告は消える。

<div author={comment.author} key={comment.id}>

参考文献

React + Flask + Python + MongoDBで作るRSSリーダー

これまで2回に渡ってReactについて学んできた。

testpy.hatenablog.com testpy.hatenablog.com

僕は普段、Pythonを使って機械学習やデータ解析のコードを実装してるのだが、 Webアプリ化したいな、できればReactで実現できたらいいな、と思うことが度々あった。 そこでPythonistaのために、Reactを使った簡単なWebアプリ作成記事があればと、 今(2016/10/31現在)は亡きReactチュートリアル日本語版 をベースに、MongoDBに格納したRSSPythonで読み込み、Flask経由でクライアントに送り、 Reactで描画してみたので、コードを載せておく。 ただし説明はほとんどないので、バリバリのReact初心者の方は、上の記事でベースを固めてから読んでみて下さい。 ちなみに、Reactチュートリアルのソースはまだあります

見た目

こんな感じにする。

f:id:Shoto:20161030013827p:plain

ファイル構造

react-tutorial
├── gigazine_rss.py  // GIGAZINEのRSSをMongoDBに格納
├── node_modules
├── package.json
├── public
│   ├── css
│   │   └── base.css  // RSSを見やすいように加工
│   ├── index.html
│   └── scripts
│       └── example.js  // シンプルなDOMをReactで作成
└── server.py  // Flaskを用いてクライアントとサーバーを連携

データフロー

  1. index.htmlにアクセスする
  2. index.htmlがexample.jsを呼び出す
  3. example.jsがserver.jsを呼び出す
  4. server.jsがgigazine_rss.pyを呼び出す
  5. gigazine_rss.pyがGIGAZINE RSSを取得してMongoDBに格納する
  6. server.jsがgigazine_rss.py経由でMongoDBに格納したRSSを読み込む
  7. server.jsがexample.jsにRSSを渡す
  8. example.jsがRSSのDOMを作成する
  9. DOMがindex.htmlに描画される
index.html
  |
example.js
  |
server.py
  |
gigazine_rss.py <-> GIGAZINE RSS
  |
MongoDB

index.html

Reactはここではなく、exmaple.jsに記述する。

<!DOCTYPE html>
<html>
  <head>
    <meta charset="utf-8">
    <title>React Tutorial</title>
    <!-- Not present in the tutorial. Just for basic styling. -->
    <link rel="stylesheet" href="css/base.css" />
    <script src="https://unpkg.com/react@15.3.0/dist/react.js"></script>
    <script src="https://unpkg.com/react-dom@15.3.0/dist/react-dom.js"></script>
    <script src="https://unpkg.com/babel-standalone@6.15.0/babel.min.js"></script>
    <script src="https://unpkg.com/jquery@3.1.0/dist/jquery.min.js"></script>
    <script src="https://unpkg.com/remarkable@1.7.1/dist/remarkable.min.js"></script>
  </head>
  <body>
    <div id="content"></div>
    <script type="text/babel" src="scripts/example.js"></script>
  </body>
</html>

example.js

チュートリアルでは、CommentBox、CommentList、Commetがあったが、 ここではCommentをRssに置き換えている。 またBoxとListで十分表現可能で、公式のDocs にも、そのように書けと書いてあるので、RssBoxとRssListのみとした。

var RssBox = React.createClass({
  loadRssFromServer: function() {
    $.ajax({
      url: this.props.url,
      dataType: 'json',
      cache: false,
      success: function(data) {
        this.setState({data: data});
      }.bind(this),
      error: function(xhr, status, err) {
        console.error(this.props.url, status, err.toString());
      }.bind(this)
    });
  },
  getInitialState: function() {
    return {data: []};
  },
  componentDidMount: function() {
    this.loadRssFromServer();
    setInterval(this.loadRssFromServer, this.props.pollInterval);
  },
  render: function() {
    return (
      <div className="rssBox">
        <h1 className="siteTitle">GIGAZINE RSS</h1>
        <RssList data={this.state.data} />
      </div>
    );
  }
});

var RssList = React.createClass({
  render: function() {
    var rssNodes = this.props.data.map(function (rss) {
      return (
        <div className="rss" key={rss.id}>
          <h3 className="title">
            <a href={rss.link}>
              {rss.title}
            </a>
          </h3>
          <p className="updated">{rss.updated}</p>
          <p className="summary">{rss.summary}</p>
        </div>
      );
    });
    return (
      <div className="rssList">
        {rssNodes}
      </div>
    );
  }
});

ReactDOM.render(
  <RssBox url="/api/rss" pollInterval={2000000} />,
  document.getElementById('content')
);

server.py

gigazine_rss.pyを呼び出して、GIAZINE RSSの保存と読み込みを行っている。

import json
import os
import time
from flask import Flask, Response, request

from gigazine_rss import Gigazine_RSS

app = Flask(__name__, static_url_path='', static_folder='public')
app.add_url_rule('/', 'root', lambda: app.send_static_file('index.html'))  # client and server side with MongoDB

@app.route('/api/rss', methods=['GET', 'POST'])
def comments_handler():
    Gigazine_RSS().save()
    rss = Gigazine_RSS().read()

    return Response(
        json.dumps(rss),
        mimetype='application/json',
        headers={
            'Cache-Control': 'no-cache',
            'Access-Control-Allow-Origin': '*'
        }
    )


if __name__ == '__main__':
    app.run(port=int(os.environ.get("PORT", 3000)), debug=True)

gigazine_rss.py

testpy.hatenablog.com

この記事を1ファイルで実行できるようにした。

# -*- coding: utf-8 -*-
import sys
import json
import nltk
import numpy
import feedparser
import urllib2
from bs4 import BeautifulSoup
import re
import pymongo

reload(sys)
sys.setdefaultencoding('utf-8')

DATABASE_NAME = 'gigazine'
COLLECTION_NAME = 'rss'

class Gigazine_RSS:
    def __init__(self):
        pass


    def save(self):
        rss = self.__get_rss()
        self.__save_rss(rss)

        return rss


    def __get_rss(self):
        rss_url = 'http://feed.rssad.jp/rss/gigazine/rss_2.0'
        articles = feedparser.parse(rss_url)

        rss = []
        for e in articles.entries:
            dic = {
                'id': e.id,
                'updated': e.updated,
                'title': e.title,
                'link': e.link,
                'summary': self.__getTextOnly(BeautifulSoup(e.summary))
            }
            rss.append(dic)

        return rss


    # Extract the text from an HTML page (no tags)
    def __getTextOnly(self, soup):
        v = soup.string  # Split by tags and check whether nested tags
        if v == None:  # If tags are nested
            c = soup.contents  # Eliminate outmost tags
            resulttext = ''
            for t in c:
                subtext = self.__getTextOnly(t)
                # If the subtext is null(u''), don't append it
                if len(subtext) > 0:
                    resulttext += subtext + '\n'
            return resulttext
        else:
            return v.strip()  # Eliminate '\n'


    def __save_rss(self, data):
        client = pymongo.MongoClient('localhost', 27017)
        db = client[DATABASE_NAME]
        co = db[COLLECTION_NAME]
        co.drop()

        co.insert(data)


    def read(self):
        client = pymongo.MongoClient('localhost', 27017)
        db = client[DATABASE_NAME]
        co = db[COLLECTION_NAME]

        data = [d for d in co.find()]

        rss = []
        for c in co.find():
            c.pop('_id', None)
            rss.append(c)

        return rss


if __name__ == '__main__':
    Gigazine_RSS().save()
    #Gigazine_RSS().read()

base.css

見やすいようにセンタリングやコントラストなどの調整をした。

body {
  background: #fff;
  font-family: "Helvetica Neue", Helvetica, Arial, sans-serif;
  font-size: 15px;
  line-height: 1.7;
  margin: 0;
  padding: 30px;
}

a {
  color: #4183c4;
  text-decoration: none;
}

a:hover {
  text-decoration: underline;
}

code {
  background-color: #f8f8f8;
  border: 1px solid #ddd;
  border-radius: 3px;
  font-family: "Bitstream Vera Sans Mono", Consolas, Courier, monospace;
  font-size: 12px;
  margin: 0 2px;
  padding: 0 5px;
}

h1, h2, h3, h4 {
  font-weight: bold;
  margin: 0 0 15px;
  padding: 0;
}

h1 {
  font-size: 2.5em;
}

h2 {
  border-bottom: 1px solid #eee;
  font-size: 2em;
}

h3 {
  font-size: 1.5em;
}

h4 {
  font-size: 1.2em;
}

p, ul {
  margin: 15px 0;
}

ul {
  padding-left: 30px;
}

.rssBox {
  width: 600px;
  margin: auto;
}

.updated {
  color: #999;
}

.siteTitle {
  text-align: center;
  margin: 20px 0 40px;
}

.summary {
  margin: 0 0 40px;
}