MNISTを距離学習
Person ReIDが必要になったので、まずはMNISTを題材に距離学習を勉強している。 あと、これまでKerasを使ってきたけど、PyTorch使えないと厳しい世の中になってきたので、 PyTorchについて色々調べつつ実装してみた。
なお今回はこちらの記事(以下、参照記事)を参考にしている。 距離学習をメインで学びたい人は本記事より参照記事を読むことをお薦めする。 本記事はPyTorch入門みたいな要素が強いので。
概要
距離学習をもの凄く簡単に言うと画像分類の拡張。 なので、処理フローはだいたい画像分類と同じで以下のようになる。
- データ準備
- モデル定義
- 損失関数定義
- 最適化関数定義
- 訓練検証
距離学習は、同じクラスは近く異なるクラスは遠くなるようにモデルを学習することで、 未知のクラスの同定を行えるのが画像分類と違うところ。 ポイントは損失関数で、今回はCenterLossというのを使っているが、 説明は参照記事が詳しい。
本記事で説明するコードはここにある。
以下のtrain_mnist_original_center.py
のmain()
を実行すると、
参照記事と同じような結果が得られるが、個人的にコード整理してみたので、
上述の処理フローに従って順に説明する。
def main(): # Arguments args = parse_args() # Dataset train_loader, test_loader, classes = mnist_loader.load_dataset(args.dataset_dir, img_show=True) # Device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Model model = Net().to(device) print(model) # Loss nllloss = nn.NLLLoss().to(device) # CrossEntropyLoss = log_softmax + NLLLoss loss_weight = 1 centerloss = CenterLoss(10, 2).to(device) # Optimizer dnn_optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005) sheduler = lr_scheduler.StepLR(dnn_optimizer, 20, gamma=0.8) center_optimizer = optim.SGD(centerloss.parameters(), lr =0.5) print('Start training...') for epoch in range(100): # Update parameters. epoch += 1 sheduler.step() # Train and test a model. train_acc, train_loss, feat, labels = train(device, train_loader, model, nllloss, loss_weight, centerloss, dnn_optimizer, center_optimizer) test_acc, test_loss = test(device, test_loader, model, nllloss, loss_weight, centerloss) stdout_temp = 'Epoch: {:>3}, train acc: {:<8}, train loss: {:<8}, test acc: {:<8}, test loss: {:<8}' print(stdout_temp.format(epoch, train_acc, train_loss, test_acc, test_loss)) # Visualize features of each class. vis_img_path = args.vis_img_path_temp.format(str(epoch).zfill(3)) visualize(feat.data.cpu().numpy(), labels.data.cpu().numpy(), epoch, vis_img_path) # Save a trained model. model_path = args.model_path_temp.format(str(epoch).zfill(3)) torch.save(model.state_dict(), model_path)
1. データ準備
先に引数の説明を少し。
# Arguments
args = parse_args()
dataset_dir
はMNISTデータの保存場所。
後述するPyTorchの機能でここにダウンロードしてくれる。
model_path_temp
は学習済みモデルのチェックポイント。
各エポック終了後に保存する。
vis_img_path_temp
はMNISTの各クラスの特徴分布を可視化した画像。
こちらも各エポック終了後に保存する。
だんだんとクラス内でまとまりクラス間が離れていく様子が確認できる。
下図は100エポック後の特徴分布。
def parse_args(): arg_parser = argparse.ArgumentParser(description="parser for focus one") arg_parser.add_argument("--dataset_dir", type=str, default='../inputs/') arg_parser.add_argument("--model_path_temp", type=str, default='../outputs/models/checkpoints/mnist_original_softmax_center_epoch_{}.pth') arg_parser.add_argument("--vis_img_path_temp", type=str, default='../outputs/visual/epoch_{}.png') args = arg_parser.parse_args() return args
では、MNISTのデータセットを取得する。
MNIST関連は、mnist_loader.py
という別ファイルを作って処理している。
# Dataset train_loader, test_loader, classes = mnist_loader.load_dataset(args.dataset_dir, img_show=True)
load_dataset
は、train_loader、test_loader、クラス名を取得するメソッド。
ここからPyTorch色が強くなるが、データ準備では次の手順を踏む。
1. 画像の前処理
torchvision
のtransform
を利用する。
ToTensor()
でPyTorchのtorch.Tensor
型に変換する。
他にも、クロップやフリップなどData Augmentation的な事を行えるが、今回は未実施。
今回はNormalize()
で正規化を行っている。
なおMNISTは自然画像ではないので、平均0.1307、標準偏差0.3081となるようにする。
from torchvision import transforms transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])
2. 画像データセットを取得
torchvision
のdatasets.MNIST
を使うとMNISTが簡単に利用できる。
第1引数はMNISTデータの保存場所。
第2引数でtrain用かtest用かを選ぶ。
第3引数がTrueの場合は保存場所にMNISTデータがない場合に自動でダウンロードしてくれる。
第4引数で先に定義したtransformをセットする。
from torchvision import datasets trainset = datasets.MNIST(dataset_dir, train=True, download=True, transform=transform)
3. データローダーを定義
torch.utils.data
のDataLoader
を利用して、指定バッチ数分のデータを取得する。
第1引数は2で定義したデータセット。
第2引数はバッチサイズ。
第3引数はデータシャッフルするか否か。訓練時はTrueが妥当。
第4引数はデータロードの並列処理数。
from torch.utils.data import DataLoader train_loader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=0)
上記のメソッドを組み合わせることで、mnist_loader.load_dataset()
は次のようになる。
def load_dataset(dataset_dir, train_batch_size=128, test_batch_size=128, img_show=False): # Dataset transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) trainset = datasets.MNIST(dataset_dir, train=True, download=True, transform=transform) train_loader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=0) testset = datasets.MNIST(dataset_dir, train=False, download=True, transform=transform) test_loader = DataLoader(testset, batch_size=test_batch_size, shuffle=False, num_workers=0) classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] if img_show == True: show_data(train_loader) return train_loader, test_loader, classes
show_data()
はMNISTを可視化するメソッド。
torchvision.utils.make_grid()
により train_loader のバッチを簡単に可視化できる。
def show_data(data_loader): images, labels = iter(data_loader).next() # data_loader のミニバッチの image を取得 img = torchvision.utils.make_grid(images, nrow=16, padding=1) # nrom*nrom のタイル形状の画像を作る plt.ion() plt.imshow(np.transpose(img.numpy(), (1, 2, 0))) # 画像を matplotlib 用に変換 plt.draw() plt.pause(3) # Display an image for three seconds. plt.close()
2. モデル定義
PyTorchでは処理をGPUとCPUのどちらで行うかtorch.device
で明示的に選択して、
それをモデルやデータにセットする必要がある。
モデル定義はMNIST向けのをmnist_net.py
のNet()
で別途定義している。
# Device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Model model = Net().to(device) print(model)
mnist_net.py
のNet()
は次の通り。
Define by Runでは、
__init__()
で計算グラフを幾つか定義して、ネットワーク生成時に1度だけ呼びし、
データ入力時にforward()
を呼び出す使用となっている。
6つの畳み込み層とPReLUの後、2次元空間に落とし込んだ特徴ip1と、
それをPReLUに通して10次元空間に写像したip2を出力する。
ip1が特徴分布で、ip2は画像分類に利用する。
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1_1 = nn.Conv2d(1, 32, kernel_size=5, padding=2) self.prelu1_1 = nn.PReLU() self.conv1_2 = nn.Conv2d(32, 32, kernel_size=5, padding=2) self.prelu1_2 = nn.PReLU() self.conv2_1 = nn.Conv2d(32, 64, kernel_size=5, padding=2) self.prelu2_1 = nn.PReLU() self.conv2_2 = nn.Conv2d(64, 64, kernel_size=5, padding=2) self.prelu2_2 = nn.PReLU() self.conv3_1 = nn.Conv2d(64, 128, kernel_size=5, padding=2) self.prelu3_1 = nn.PReLU() self.conv3_2 = nn.Conv2d(128, 128, kernel_size=5, padding=2) self.prelu3_2 = nn.PReLU() self.ip1 = nn.Linear(128*3*3, 2) self.preluip1 = nn.PReLU() self.ip2 = nn.Linear(2, 10, bias=False) def forward(self, x): x = self.prelu1_1(self.conv1_1(x)) x = self.prelu1_2(self.conv1_2(x)) x = F.max_pool2d(x,2) x = self.prelu2_1(self.conv2_1(x)) x = self.prelu2_2(self.conv2_2(x)) x = F.max_pool2d(x,2) x = self.prelu3_1(self.conv3_1(x)) x = self.prelu3_2(self.conv3_2(x)) x = F.max_pool2d(x,2) x = x.view(-1, 128*3*3) ip1 = self.preluip1(self.ip1(x)) ip2 = self.ip2(ip1) return ip1, F.log_softmax(ip2, dim=1)
なおnn.PReLU
はReLUの改良の改良。
LeakyReLU で x < 0
の時に y < 0
することで学習が進みやすくなったものの、
パラメーター α が増えたため、それを減らすべく学習させることにしたのがPReLU。
ちなみにPReLUは、"a Parametric Rectified Linear Unit" の略。
まとめると以下のようになる。
ReLU y = x (0 =< x) y = 0 (x < 0) LeakyReLU y = x (0 =< x) y = αx (x < 0), set α as a parameter PReLU y = x (0 =< x) y = αx (x < 0), learning α
またview
はnumyp.reshape
と同じ。
第一引数が-1のとき、第二引数の形に自動調整してくれる。
上記の場合だと、x の shape が (3, 3, 128) になるので、
1次元に変換している。
3. 損失関数定義
損失関数は、画像分類用のNLL Loss
にMetric Learninig用のCenter Loss
を加重加算したものを利用する。
Loss = NLL Loss + α * Center Loss, α is weight
NLL Loss
はNegative Log-Likelihood (NLL) Lossの略。
softmaxの最大値は結果の確信度を表すが、それをマイナスの対数で取った値となる。
NLL Lossにより、高い確信度であれば低いロス、低い確信度であれば高いロスを割り当てることができる。
ip2のsoftmax(定義したモデルの出力)を入力とする。
一方Center Loss
は特徴の中心の損失関数。ip1を入力する。
詳しい説明は参照記事に任せる。
ちなみに、自分は距離学習に
ArcFaceから入ったので、
Center Loss
はこの記事以外では使わないかな、と思っている。
PyTorchで、損失関数は次のように定義される。
CenterLoss()
は自作関数でクラス数と特徴数が引数となる。
# NLL Loss & Center Loss nllloss = nn.NLLLoss().to(device) # CrossEntropyLoss = log_softmax + NLLLoss loss_weight = 1 # weight centerloss = CenterLoss(10, 2).to(device) # Loss loss = nllloss(pred, labels) + loss_weight * centerloss(labels, ip1)
4. 最適化関数定義
最適化関数にはSGDを利用するが、画像分類と距離特徴の両方を行っているので、
それぞれで最適化関数を定義する。
前者については、学習率の減衰をlr_scheduler.StepLR()
で行う。
第一引数は画像分類用の最適化関数、第二引数は学習率を更新するタイミングのエポック数、
第三引数は学習率の更新率。
# Optimizer dnn_optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005) center_optimizer = optim.SGD(centerloss.parameters(), lr =0.5) import torch.optim.lr_scheduler as lr_scheduler sheduler = lr_scheduler.StepLR(dnn_optimizer, 20, gamma=0.8)
5. 訓練検証
これまで定義してきた変数と関数を利用して訓練を行う。
エポックごとにtrain()
を呼び出す。
train_acc, train_loss, feat, labels = train(device, train_loader, model, nllloss, loss_weight, centerloss, dnn_optimizer, center_optimizer)
train()
は一般的な機械学習。
def train(device, train_loader, model, nllloss, loss_weight, centerloss, dnn_optimizer, center_optimizer): running_loss = 0.0 pred_list = [] label_list = [] ip1_loader = [] idx_loader = [] model.train() for i,(imgs, labels) in enumerate(train_loader): # Set batch data. imgs, labels = imgs.to(device), labels.to(device) # Predict labels. ip1, pred = model(imgs) # Calculate loss. loss = nllloss(pred, labels) + loss_weight * centerloss(labels, ip1) # Initilize gradient. dnn_optimizer.zero_grad() center_optimizer.zero_grad() # Calculate gradient. loss.backward() # Update parameters. dnn_optimizer.step() center_optimizer.step() # For calculation. running_loss += loss.item() pred_list += [int(p.argmax()) for p in pred] label_list += [int(l) for l in labels] # For visualization. ip1_loader.append(ip1) idx_loader.append((labels)) # Calculate training accurary and loss. result = classification_report(pred_list, label_list, output_dict=True) train_acc = round(result['weighted avg']['f1-score'], 6) train_loss = round(running_loss / len(train_loader.dataset), 6) # Concatinate features and labels. feat = torch.cat(ip1_loader, 0) labels = torch.cat(idx_loader, 0) return train_acc, train_loss, feat, labels
PyTorchでは訓練中、lossとoptimizerはバッチごとに次の手順を踏んで、パラメーターを更新していく。
optimizer.zero_grad() # 勾配の初期化 loss.backward() # 勾配の計算 optimizer.step() # パラメータの更新
またsklearn.metrics
のclassification_report()
を利用すると、
簡単に精度が算出できる。
今回のように距離学習の損失関数を入れても、検証精度は100エポックで98.2%になっている。
訓練と検証の精度と損失の変化は以下の通り。
Epoch: 1, train acc: 0.209305, train loss: 0.019642, test acc: 0.253963, test loss: 0.018308 Epoch: 2, train acc: 0.302789, train loss: 0.017461, test acc: 0.418725, test loss: 0.016906 Epoch: 3, train acc: 0.455266, train loss: 0.015967, test acc: 0.492158, test loss: 0.015241 Epoch: 4, train acc: 0.531249, train loss: 0.014266, test acc: 0.526375, test loss: 0.013594 Epoch: 5, train acc: 0.609915, train loss: 0.012737, test acc: 0.629488, test loss: 0.012123 ... Epoch: 96, train acc: 1.0 , train loss: 0.000309, test acc: 0.982005, test loss: 0.003887 Epoch: 97, train acc: 0.999983, train loss: 0.000307, test acc: 0.981601, test loss: 0.003929 Epoch: 98, train acc: 1.0 , train loss: 0.000303, test acc: 0.981306, test loss: 0.003924 Epoch: 99, train acc: 1.0 , train loss: 0.000296, test acc: 0.981907, test loss: 0.003937 Epoch: 100, train acc: 1.0 , train loss: 0.000272, test acc: 0.981805, test loss: 0.003965
参考文献
- 【深層距離学習】Center Lossを徹底解説 - はやぶさの技術ノート
- PyTorch まずMLPを使ってみる - cedro-blog
- Normalization in the mnist example - PyTorch Forums
- ChainerのDefine by Runとは? - HELLO CYBERNETICS
- LeakyRelu活性化関数 - Thoth Children
- PRelu活性化関数 - Thoth Children
- Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification - arXiv
- Understanding softmax and the negative log-likelihood - Lj Miranda
- 実践Pytorch - Qiita
- PyTorchのSchedulerまとめ - 情弱大学生の独り言