PythonでData Augmentation

Pythonで画像の左右反転、回転、拡大を行ってみた。 Data Augmentationに使えるかなと。

f:id:Shoto:20170602001453p:plain

左右反転

scikit-imageだけで実現したかったのだが、APIを見つけられなかったのでOpenCVで実装。 でも3つの処理の中で最も簡単に書けた。 ちなみに第2引数を1ではなく0にすると上下反転になる。

# flip
img_fliped = cv2.flip(img, 1)

回転

angleには角度(°)を指定。 resizeがFalseの場合は、そのまま回転。Trueにすると元画像の画像の角が隠れないようになる。 centerをNoneにすると、中央を原点にして回転。

# rotatate
img_rotated = skimage.transform.rotate(img, angle=10, resize=False, center=None)

拡大

目的の処理が行えるシンプルなAPIがなかったので、AffineTransform()を使用。 rateがプラスの倍率。0.2の場合は20%拡大。scaleには1-rateを指定する形となる。 拡大は左上を原点にして行われるので、元画像の中央が拡大後も中央になるようにtranslationで修正。 具体的には拡大した大きさの半分だけ左上に平行移動している。

# expand
rate = 0.2
size = img.shape[0]
matrix_expanded = skimage.transform.AffineTransform(scale=(1-rate, 1-rate), translation=(size*rate/2, size*rate/2))
img_expanded = skimage.transform.warp(img, matrix_expanded)

実行

記事のトップ画像を表示するためのコードは以下の通り。

# coding: utf-8

import numpy as np
import cv2
import matplotlib.pyplot as plt
import skimage
from skimage import io
from skimage import transform

DIR = '../data/'

class Data:
    def test(self):
        # read Lenna image
        img_raw = skimage.io.imread(DIR+'lenna.jpg')

        # resize
        #img = skimage.transform.resize(img_raw, (32, 32))
        img = img_raw

        # flip
        img_fliped = cv2.flip(img, 1)

        # rotatate
        img_rotated = skimage.transform.rotate(img, angle=10, resize=False, center=None)

        # expand
        rate = 0.2
        size = img.shape[0]
        matrix_expanded = skimage.transform.AffineTransform(scale=(1-rate, 1-rate), translation=(size*rate/2, size*rate/2))
        img_expanded = skimage.transform.warp(img, matrix_expanded)

        # white background
        fig = plt.figure()
        fig.patch.set_facecolor('white')

        # display images
        plt.subplot(141)
        plt.title('raw')
        plt.imshow(img)

        plt.subplot(142)
        plt.title('flip')
        plt.imshow(img_fliped)

        plt.subplot(143)
        plt.title('rotate')
        plt.imshow(img_rotated)

        plt.subplot(144)
        plt.title('expand')
        plt.imshow(img_expanded)

        plt.show()

参考文献