ResNet18+ArcFaceでCIFAR10を距離学習

以前、「簡易モデルでMNISTを距離学習」と 「ResNet18でCIFAR10を画像分類」 を実施した。 今回はこれらを組み合わせて「ResNet18+ArcFaceでCIFAR10を距離学習」を行った。

基本的には「ResNet18でCIFAR10を画像分類」 で実施した内容と同じになる。 異なるのはResNet18の最終層の前で特徴抽出して、それをメトリックであるArcFaceに通してから、損失関数に入力している点である。 なので、コード全体の説明は「ResNet18でCIFAR10を画像分類」 に譲るとして、ここではメトリックの周辺の実装について説明する。 なお、今回利用するメトリックはArcFaceで、上記で述べたように画像分類モデルに付け足すだけの優れものである。 しかも非常に精度が高い。 なぜ精度が高くなるのかは 「モダンな深層距離学習 (deep metric learning) 手法: SphereFace, CosFace, ArcFace - Qiita」 が詳しいので一読することをお薦めする。

なお、今回説明するコードは ここ に置いてある。

概要

実行手順は次の通り。

  1. データの取得
  2. モデルの定義
  3. メトリックの定義
  4. 損失関数と最適化関数の定義
  5. 学習と検証

「3. メトリックの定義」が今回新たに実装された部分である。 上記の手順はmain()で次のように実行される。

def main():
    # Parse arguments.
    args = parse_args()
    
    # Set device.
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Load dataset.
    train_loader, test_loader, class_names = cifar10.load_data(args.data_dir)
    
    # Set a model.
    model = get_model(args.model_name, args.n_feats)
    model = model.to(device)
    print(model)

    # Set a metric
    metric = metrics.ArcMarginProduct(args.n_feats, len(class_names), s=args.norm, m=args.margin, easy_margin=args.easy_margin)
    metric.to(device)

    # Set loss function and optimization function.
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD([{'params': model.parameters()}, {'params': metric.parameters()}],
                          lr=args.lr, 
                          weight_decay=args.weight_decay)

    # Train and test.
    for epoch in range(args.n_epoch):
        # Train and test a model.
        train_acc, train_loss = train(device, train_loader, model, metric, criterion, optimizer)
        test_acc, test_loss = test(device, test_loader, model, metric, criterion)
        
        # Output score.
        stdout_temp = 'epoch: {:>3}, train acc: {:<8}, train loss: {:<8}, test acc: {:<8}, test loss: {:<8}'
        print(stdout_temp.format(epoch+1, train_acc, train_loss, test_acc, test_loss))

        # Save a model checkpoint.
        model_ckpt_path = args.model_ckpt_path_temp.format(args.dataset_name, args.model_name, epoch+1)
        torch.save(model.state_dict(), model_ckpt_path)
        print('Saved a model checkpoint at {}'.format(model_ckpt_path))
        print('')

1. データの取得

torchvisionからCIFAR10を取得。 詳しくは こちら を参照。

2. モデルの定義

画像分類モデルはget_model()内で呼び出している。 この時、モデル名と出力する特徴数を引数として渡す。 今回、後者は512にしている。

model = get_model(args.model_name, args.n_feats)

今回はResNet18を使う。 出力クラス数が定義できるので、そこに特徴数を入れる。

from models.resnet import ResNet18

def get_model(model_name, num_classes=512):
    ...
    elif model_name == 'ResNet18':
        model = ResNet18(num_classes)

modelディレクトリーのresnet.pyを呼び出す。 以下が中身。 最終層で512次元の特徴を返す。

def ResNet18(n_feats):
    return ResNetFace(BasicBlock, [2,2,2,2], num_classes=n_feats)


class ResNetFace(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNetFace, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

3. メトリックの定義

同じクラスを近くに、異なるクラスを遠くに置くようにするための損失関数を指す。 ArcFaceは、簡単にいうと下図のように円弧上(実際には超球面上)にクラスが適切に分布するように角度を学習する損失関数となる。

f:id:Shoto:20200118113944p:plain

コードは これ がちゃんと動いた。 ArcMarginProduct()がArcFace。

metric = metrics.ArcMarginProduct(args.n_feats, len(class_names), s=args.norm, m=args.margin, easy_margin=args.easy_margin)

第1引数のargs.n_featsは入力する特徴数。 これはモデルの出力数でもある。 第2引数のlen(class_names)はクラス数。 sは下図のLogit前のFeature Re-scaleに当たる。 Logitの値がsoftmaxで機能するよう適切な値にスケールする。 mはAdditive Anguler Mergin Penaltyにあたる。 つまりマージンによるペナルティーで、 クラスの重みW_{y_i}と画像の特徴ベクトルx_iのなす角度\theta _{y_i}を最小化するのだが、 その際マージンmをペナルティーとして加えることで、同じクラスを近くに、異なるクラスを遠くに置くようにするための効果が増す。

f:id:Shoto:20200118114003p:plain

easy_marginは以下のような処理を行っているが、理解しきれなかったので暇ができたら後で調べる。 ちなみにeasy_marginTrueにしないと全然学習しなかったので注意。

if self.easy_margin:
        phi = torch.where(cosine > 0, phi, cosine)
    else:
        phi = torch.where(cosine > self.th, phi, cosine - self.mm)

図ではSoftmaxにかけたあと、CrossEntropyLossを算出している。 PyTorchではcriterionに指定したCrossEntropyLossSoftmaxも内包されているため、 特に記述する必要がない。

criterion = nn.CrossEntropyLoss()

features = model(inputs)
outputs = metric_fc(features, targets)
loss = criterion(outputs, targets)

4. 損失関数と最適化関数の定義

損失関数は上記で述べた通り。 最適化関数はSGDを使い、paramsにモデルとメトリックの両方を渡す。

optimizer = optim.SGD([{'params': model.parameters()}, {'params': metric.parameters()}],
                      lr=args.lr, 
                      weight_decay=args.weight_decay)

5. 学習と検証

画像分類と同様に訓練を実施した結果、 100エポックでテスト精度が82.4%になった。

epoch: 100, train acc: 0.947966, train loss: 0.003992, test acc: 0.824237, test loss: 0.027159

ちなみにResNet18の最終層をコメントアウトしても512次元の特徴が取れるのだが、 このケースではテスト精度は84.7%になった。 CIFAR10ではResNet18でもDeepなのかも。

def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.layer4(out)
    out = F.avg_pool2d(out, 4)
    out = out.view(out.size(0), -1)
    #out = self.linear(out)
    return out
epoch: 100, train acc: 0.955092, train loss: 0.004396, test acc: 0.847445, test loss: 0.028565

所感

中身が完全に理解できていないのと、モデルとメトリックのパラメーター調整がもっと必要な気がしている。 時間があったら、もっと突っ込んでやりたい。

参考文献