ResNet18+ArcFaceでCIFAR10を距離学習
以前、「簡易モデルでMNISTを距離学習」と 「ResNet18でCIFAR10を画像分類」 を実施した。 今回はこれらを組み合わせて「ResNet18+ArcFaceでCIFAR10を距離学習」を行った。
基本的には「ResNet18でCIFAR10を画像分類」 で実施した内容と同じになる。 異なるのはResNet18の最終層の前で特徴抽出して、それをメトリックであるArcFaceに通してから、損失関数に入力している点である。 なので、コード全体の説明は「ResNet18でCIFAR10を画像分類」 に譲るとして、ここではメトリックの周辺の実装について説明する。 なお、今回利用するメトリックはArcFaceで、上記で述べたように画像分類モデルに付け足すだけの優れものである。 しかも非常に精度が高い。 なぜ精度が高くなるのかは 「モダンな深層距離学習 (deep metric learning) 手法: SphereFace, CosFace, ArcFace - Qiita」 が詳しいので一読することをお薦めする。
なお、今回説明するコードは ここ に置いてある。
概要
実行手順は次の通り。
- データの取得
- モデルの定義
- メトリックの定義
- 損失関数と最適化関数の定義
- 学習と検証
「3. メトリックの定義」が今回新たに実装された部分である。
上記の手順はmain()
で次のように実行される。
def main(): # Parse arguments. args = parse_args() # Set device. device = 'cuda' if torch.cuda.is_available() else 'cpu' # Load dataset. train_loader, test_loader, class_names = cifar10.load_data(args.data_dir) # Set a model. model = get_model(args.model_name, args.n_feats) model = model.to(device) print(model) # Set a metric metric = metrics.ArcMarginProduct(args.n_feats, len(class_names), s=args.norm, m=args.margin, easy_margin=args.easy_margin) metric.to(device) # Set loss function and optimization function. criterion = nn.CrossEntropyLoss() optimizer = optim.SGD([{'params': model.parameters()}, {'params': metric.parameters()}], lr=args.lr, weight_decay=args.weight_decay) # Train and test. for epoch in range(args.n_epoch): # Train and test a model. train_acc, train_loss = train(device, train_loader, model, metric, criterion, optimizer) test_acc, test_loss = test(device, test_loader, model, metric, criterion) # Output score. stdout_temp = 'epoch: {:>3}, train acc: {:<8}, train loss: {:<8}, test acc: {:<8}, test loss: {:<8}' print(stdout_temp.format(epoch+1, train_acc, train_loss, test_acc, test_loss)) # Save a model checkpoint. model_ckpt_path = args.model_ckpt_path_temp.format(args.dataset_name, args.model_name, epoch+1) torch.save(model.state_dict(), model_ckpt_path) print('Saved a model checkpoint at {}'.format(model_ckpt_path)) print('')
1. データの取得
torchvisionからCIFAR10を取得。 詳しくは こちら を参照。
2. モデルの定義
画像分類モデルはget_model()
内で呼び出している。
この時、モデル名と出力する特徴数を引数として渡す。
今回、後者は512にしている。
model = get_model(args.model_name, args.n_feats)
今回はResNet18を使う。 出力クラス数が定義できるので、そこに特徴数を入れる。
from models.resnet import ResNet18 def get_model(model_name, num_classes=512): ... elif model_name == 'ResNet18': model = ResNet18(num_classes)
model
ディレクトリーのresnet.py
を呼び出す。
以下が中身。
最終層で512次元の特徴を返す。
def ResNet18(n_feats): return ResNetFace(BasicBlock, [2,2,2,2], num_classes=n_feats) class ResNetFace(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super(ResNetFace, self).__init__() self.in_planes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear = nn.Linear(512*block.expansion, num_classes) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1]*(num_blocks-1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = F.avg_pool2d(out, 4) out = out.view(out.size(0), -1) out = self.linear(out) return out
3. メトリックの定義
同じクラスを近くに、異なるクラスを遠くに置くようにするための損失関数を指す。 ArcFaceは、簡単にいうと下図のように円弧上(実際には超球面上)にクラスが適切に分布するように角度を学習する損失関数となる。
コードは
これ
がちゃんと動いた。
ArcMarginProduct()
がArcFace。
metric = metrics.ArcMarginProduct(args.n_feats, len(class_names), s=args.norm, m=args.margin, easy_margin=args.easy_margin)
第1引数のargs.n_feats
は入力する特徴数。
これはモデルの出力数でもある。
第2引数のlen(class_names)
はクラス数。
s
は下図のLogit前のFeature Re-scaleに当たる。
Logitの値がsoftmaxで機能するよう適切な値にスケールする。
m
はAdditive Anguler Mergin Penaltyにあたる。
つまりマージンによるペナルティーで、
クラスの重みと画像の特徴ベクトルのなす角度を最小化するのだが、
その際マージンをペナルティーとして加えることで、同じクラスを近くに、異なるクラスを遠くに置くようにするための効果が増す。
easy_margin
は以下のような処理を行っているが、理解しきれなかったので暇ができたら後で調べる。
ちなみにeasy_margin
はTrue
にしないと全然学習しなかったので注意。
if self.easy_margin: phi = torch.where(cosine > 0, phi, cosine) else: phi = torch.where(cosine > self.th, phi, cosine - self.mm)
図ではSoftmax
にかけたあと、CrossEntropyLoss
を算出している。
PyTorchではcriterion
に指定したCrossEntropyLoss
にSoftmax
も内包されているため、
特に記述する必要がない。
criterion = nn.CrossEntropyLoss() features = model(inputs) outputs = metric_fc(features, targets) loss = criterion(outputs, targets)
4. 損失関数と最適化関数の定義
損失関数は上記で述べた通り。
最適化関数はSGDを使い、params
にモデルとメトリックの両方を渡す。
optimizer = optim.SGD([{'params': model.parameters()}, {'params': metric.parameters()}], lr=args.lr, weight_decay=args.weight_decay)
5. 学習と検証
画像分類と同様に訓練を実施した結果、 100エポックでテスト精度が82.4%になった。
epoch: 100, train acc: 0.947966, train loss: 0.003992, test acc: 0.824237, test loss: 0.027159
ちなみにResNet18の最終層をコメントアウトしても512次元の特徴が取れるのだが、 このケースではテスト精度は84.7%になった。 CIFAR10ではResNet18でもDeepなのかも。
def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = F.avg_pool2d(out, 4) out = out.view(out.size(0), -1) #out = self.linear(out) return out
epoch: 100, train acc: 0.955092, train loss: 0.004396, test acc: 0.847445, test loss: 0.028565
所感
中身が完全に理解できていないのと、モデルとメトリックのパラメーター調整がもっと必要な気がしている。 時間があったら、もっと突っ込んでやりたい。