MNISTを画像検索

前回「MNISTで距離学習」 という記事を書いたが画像分類の域を出なかった。 距離学習と言えば画像検索なので、今回はそれをMNISTで行った。

概要

今回は前回 訓練したMNISTの距離学習モデルを利用して画像検索を行う。 手順は次の通り。

  1. データ準備
  2. モデルロード
  3. 特徴抽出
  4. 距離算出

これらについて順番に説明する。 なお、コードはここに置いてある。

1. データ準備

テスト時、画像分類ではテストデータを用意するが、 画像検索では検索画像となるQueryと検索対象画像群となるGalleryが必要になる。 そこでまずはテストデータをQueryとGalleryに分ける。 また後々のことを考えて、いったん画像に保存したものを読み込むようにする。

ここではメインファイルのimage_retrieval.pymake_query_and_gallery_from_mnist()を呼び出す。 これにより、MNISTから1枚をQueryに100枚をGalleryにランダム選択して振り分ける。 各画像はQueryとGallery各々のディレクトリーに保存され、それらの情報はCSVファイルに記載される。

dataset_dirはMNISTを保存するディレクトリー、 query_dirはQuery画像を保存するディレクトリー、 gallery_dirはGallery画像を保存するディレクトリー、 anno_pathはQueryとGalleryの画像情報を記載したCSVファイル である。

make_query_and_gallery_from_mnist(args.dataset_dir, args.query_dir, args.gallery_dir, args.anno_path)

make_query_and_gallery_from_mnist()では、 まずmake_query_and_gallery()でMNISTの画像をQueryとGalleryに振り分けて、 次にmake_anno_file()でQueryとGalleryの画像情報をCSVに保存する。 これらはMNISTデータ処理専用のmnist_data.pyで行う。

def make_query_and_gallery_from_mnist(dataset_dir, query_dir, gallery_dir, anno_path):
    mnist_data.make_query_and_gallery(dataset_dir, query_dir, gallery_dir)
    mnist_data.make_anno_file(query_dir, gallery_dir, anno_path)

以下からmnist_data.pymake_query_and_gallery()は次の通り。 transformしたMNISTを取得して、Query画像1枚とGallery画像100枚をランダムに選択後、 各々のディレクトリーに画像として保存している。 保存前の画像はtransfromで正規化した後なので0~255になっていないが、 scipy.misc.imsave()を使うと0~255にして保存してくれる。

def make_query_and_gallery(dataset_dir, query_dir, gallery_dir):
    # 
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    testset = datasets.MNIST(dataset_dir, train=False, download=True, transform=transform)
    q_idx = random.choice(range(len(testset)))
    g_idxs= random.sample(range(len(testset)), 100)
    
    # Save query image.
    if os.path.exists(query_dir) == True:
        shutil.rmtree(query_dir)
    os.makedirs(query_dir)
    q_img, q_label = testset[q_idx]
    scipy.misc.imsave(query_dir+'{}_{}.png'.format(q_label, q_idx), q_img.numpy()[0])
    
    # Save gallery images.
    if os.path.exists(gallery_dir) == True:
        shutil.rmtree(gallery_dir)
    os.makedirs(gallery_dir)
    for g_idx in g_idxs:
        g_img, g_label = testset[g_idx]
        scipy.misc.imsave(gallery_dir+'{}_{}.png'.format(g_label, g_idx), g_img.numpy()[0])

make_anno_file()では、Query/Gallery、 画像名、画像パス、ラベル名、IDを記載したCSVファイルを作成して保存する。

def make_anno_file(query_dir, gallery_dir, anno_path):
    annos = []
    annos += __set_annos(query_dir, 'query')
    annos += __set_annos(gallery_dir, 'gallery')
    df = pd.DataFrame(annos)
    df.to_csv(anno_path, index=False)


def __set_annos(img_dir, data_type):
    annos = []
    for d in os.listdir(img_dir):
        dic = {}
        dic['data_type'] = data_type
        dic['img_name'] = d
        dic['img_path'] = img_dir + d
        dic['label'] = d.split('_')[0]
        dic['id'] = d.split('.')[0].split('_')[1]
        annos.append(dic)

    return annos

最後に画像検索用のDataLoaderを作成する。 先ほど作成したQueryとGalleryのCSVファイルを用いて、 各々の画像ローダーが作成できる仕様にしている。 また、画像とラベルの他に画像パスも返すようにしている。

class ReIDDataset(Dataset):
    def __init__(self, anno_path, data_type, transform=None):
        df_all = pd.read_csv(anno_path)
        self.df = df_all[df_all['data_type']==data_type].reset_index(drop=True)  # Filter data by query or gallery.
        self.transform = transform


    def __len__(self):
        return len(self.df)


    def __getitem__(self, idx):
        img_path = self.df.loc[idx, 'img_path']
        assert os.path.exists(img_path)
        image = io.imread(img_path)
        label = self.df.loc[idx, 'label']
        img_path = self.df.loc[idx, 'img_path']
        if self.transform:
            image = self.transform(image)
        
        return image, label, img_path

画像分類の時にtrain_loader、test_loader、classesを返していたのと同様、 画像検索の時はquery_loader、gallery_loader、classesを返す。

def load_query_and_gallery(anno_path, img_show=False):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Query
    query_dataset = ReIDDataset(anno_path, 'query', transform)
    query_loader = DataLoader(query_dataset, batch_size=len(query_dataset), shuffle=False)
        
    # Gallery
    gallery_dataset = ReIDDataset(anno_path, 'gallery', transform)
    #gallery_loader = DataLoader(gallery_dataset, batch_size=len(gallery_dataset), shuffle=True)
    gallery_loader = DataLoader(gallery_dataset, batch_size=8, shuffle=True)
    
    # Class
    classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

    # debug
    print('num query: {}, num gallery: {}'.format(len(query_dataset), len(gallery_dataset)))
    print('')
    if img_show == True:
        show_data(gallery_loader)

    return query_loader, gallery_loader, classes

2. モデルロード

再びメインファイルのimage_retrieval.py。 モデルのロードは、model.load_state_dict(torch.load(args.model_path))で行う。 学習済みモデルは、前回の記事を参考に生成する。 今回はテストなのでeval()で評価モードにしておく。

# Set device, GPU or CPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model
model = Net().to(device)
model.load_state_dict(torch.load(args.model_path))
model.eval()

3. 特徴抽出

1. データ準備で説明したquery_loaderを利用する。 と言っても画像は1枚だけなので、学習済みモデルに通して特徴を取得するだけ。 modelの出力は2次元特徴と予測結果なるが、今回は前者のみを利用する。

# Query
for i, (query_img, query_label, query_path) in enumerate(query_loader):
    with torch.no_grad():
        query_img = query_img.to(device)
        query_feat, pred = model(query_img)

Galleryは100枚画像があるが、 gallery_loaderのバッチサイズは8にしているので、 各情報をリストに格納して、最後にconcatinateする。

# Gallery
gallery_feats = []
gallery_labels = []
gallery_paths = []
for i, (g_imgs, g_labels, g_paths) in enumerate(gallery_loader):
    with torch.no_grad():
        g_imgs = g_imgs.to(device)
        g_feats_temp, preds_temp = model(g_imgs)
        gallery_feats.append(g_feats_temp)
        gallery_labels.append(g_labels)
        gallery_paths += list(g_paths)  # Data type of g_paths is tuple.
gallery_feats = torch.cat(gallery_feats, 0)
gallery_labels = torch.cat(gallery_labels, 0)

4. 距離算出

3. 特徴抽出でQueryの特徴query_featとGalleryの特徴gallery_featsが取得できた。 query_featgallery_featsの各特徴との距離を算出するためにコサイン類似度を利用する。

dist_matrix = cosine_similarity(query_feat, gallery_feats)

コサイン類似の実装は次の通り。

def cosine_similarity(qf, gf):
    epsilon = 0.00001
    dist_mat = qf.mm(gf.t())
    qf_norm = torch.norm(qf, p=2, dim=1, keepdim=True) #mx1
    gf_norm = torch.norm(gf, p=2, dim=1, keepdim=True) #nx1
    qg_normdot = qf_norm.mm(gf_norm.t())

    dist_mat = dist_mat.mul(1/qg_normdot).cpu().numpy()
    dist_mat = np.clip(dist_mat, -1+epsilon,1-epsilon)
    dist_mat = np.arccos(dist_mat)

    return dist_mat

QueryとGalleryの各距離、Galleryのラベル名と画像パスをセットにした DataFrameを作成して、距離でソートする。 これにより、距離の近い順にラベルが表示される。

# Organize ReID ranking.
lis = []
for i in range(len(gallery_paths)):
    dic = {}
    dic['dist'] = dist_matrix.tolist()[0][i]
    dic['label'] = np.array(gallery_labels).tolist()[i]
    dic['img_path'] = gallery_paths[i]
    lis.append(dic)
df = pd.DataFrame(lis)
df = df.sort_values(by=['dist'], ascending=True)
df = df.reset_index(drop=True)

以下は実行結果。 Queryが9で、Galleryも9がindexの0から8まで占めている。

$ python image_retrieval.py
num query: 1, num gallery: 100

Query Image Label: 9

Search Result
        dist                      img_path  label
0   0.005382  ../inputs/gallery/9_1801.png      9
1   0.018921  ../inputs/gallery/9_4237.png      9
2   0.036690   ../inputs/gallery/9_481.png      9
3   0.047976  ../inputs/gallery/9_7380.png      9
4   0.069177  ../inputs/gallery/9_8213.png      9
5   0.076138  ../inputs/gallery/9_3970.png      9
6   0.078646  ../inputs/gallery/9_2685.png      9
7   0.107746  ../inputs/gallery/9_5977.png      9
8   0.387746  ../inputs/gallery/9_4505.png      9
9   0.523175  ../inputs/gallery/3_8981.png      3
10  0.538863   ../inputs/gallery/3_927.png      3
11  0.560314   ../inputs/gallery/3_142.png      3
12  0.565455  ../inputs/gallery/3_8451.png      3
13  0.582634  ../inputs/gallery/3_4755.png      3
14  0.586750  ../inputs/gallery/3_2174.png      3
15  0.589938  ../inputs/gallery/3_9986.png      3
16  0.675965  ../inputs/gallery/1_4491.png      1
17  0.682165  ../inputs/gallery/3_8508.png      3
18  0.683414  ../inputs/gallery/3_4785.png      3
19  0.698637  ../inputs/gallery/1_1038.png      1

labelをカウントすると、ラベルが9の画像は9枚あることが分かる。 よって、Galleryにあるラベル9の全画像を検索上位に持ってくることができたことが分かる。

1    15
7    13
0    12
8    11
4    10
3    10
9     9  # <- this
6     7
2     7
5     6

参考文献