Metric Learning 入門

f:id:copypaste_ds:20190301002935p:plain

はじめに

metric learningについて学ぶ機会があったので忘れないうちに得た知識を書き留めておきます。学んだ期間は10日程度と短く、deep learningも含めて初心者ですので疑いながら読んでいただければと思います。間違いを見つけた方はご指摘ください。本記事ではmetric learningの概要から画像データ・テーブルデータへの適用結果まで簡単に紹介します。

metric learningとは

metric learningとは データ間の計量(距離や類似度など)を学習する手法 です。直感的には、意味の近いデータは近く、意味の遠いデータは遠くなるように計量を学習 します。以下の図は靴の画像データに対してmetric learningを適応したイメージ図です。(画像はこのページから拝借しました。)

f:id:copypaste_ds:20190226193837p:plain

スニーカー同士、ブーツ同士、ヒール同士が近くに配置されていることがわかります。また、ヒールが高いほど右方向に配置されていることもわかります。このように、metric learningを応用すれば、靴の形状を考慮した距離を学習することができます。
ちなみに「計量を学習する」を「特徴量(特徴量空間)を学習する」あるいは「特徴量生成用の関数を学習する」と考えてもよいです。この考え方ついてはマハラノビス距離学習の説明時に軽く触れます。  

metric learningの利点は意味的な距離を考慮した特徴量を学習できることです。教師ラベルを上手く準備できれば、特徴量空間の作り方(考慮したいデータの意味)をある程度コントールできます。 応用範囲は多義に渡り、例えば以下のようなタスクがあります。

意味的な距離を考慮した特徴量空間を上手く学習できれば、未知クラスのデータに対してもある程度頑健に対応できる点も強力です。 先程の靴の例だと、学習データに含まれない高さのヒールも空間の右上にうまいこと配置されるイメージでしょうか。

マハラノビス距離学習

metric learningの古典的な手法にマハラノビス距離学習があります。アイディア自体はdeep learningを用いたmetric learningと似ています。
マハラノビス距離学習では以下の式の共分散行列\( M \)を学習します。

f:id:copypaste_ds:20190226195628p:plain

\( M \)が決まれば距離が定まるので距離を学習していることになります。
ところで、\( M \)が半正定値行列であればマハラノビス距離は以下のように変形できます。

f:id:copypaste_ds:20190226195818p:plain

つまりデータ間のユークリッド距離が適切になるような変換\( L \)(あるいは特徴量\(L \boldsymbol{x}, L \boldsymbol{y} \))を学習する手法 とも考えられます。 このように、「計量を学習する」を「特徴量(特徴量空間)を学習する」あるいは「特徴量生成用の関数を学習する」と考えることができます。

次に\( M \) (= \( L \) )の決め方ですが、マハラノビス距離学習では以下の最適化問題を解けば良いです。\( S \)(類似データの組)と\( D \)(非類似データの組)は事前に用意しておく必要があります。よくある準備方法として、同じクラスのデータなら似ている、異なるクラスのデータなら似ていないとする方法があります。

f:id:copypaste_ds:20190226200339p:plain

このように似ているデータは近く、似ていないデータは遠くなるように学習させるのがmetric learningの特徴です。

deep metric learningとは

最近はdeep learningを用いたmetric learning(deep metric learning??)が注目されています。deep metric learningではマハラノビス距離学習で言うところの\( L \) ( = \( M \) )をdeep learningで学習します。これによって非線形変換が可能になります。

今回は以下の手法を順に紹介します。

  • siamese network
    • 2サンプルを一組で入力するやつ(contrastive lossと一緒に紹介)
  • triplet network
    • 3サンプルを一組で入力するやつ(triplet lossと一緒に紹介)
  • L2 softmax lossを使ったnetwork(本記事ではL2 softmax networkと呼ぶことにします)
    • metric learningっぽいくないけどそれなりに強いやつ

f:id:copypaste_ds:20190226201747p:plain

siamese network

siamese networkはずいぶん前(少なくとも2006年)に提案された手法です。(論文はこちら)特徴は2サンプルを一組で入力する点とサンプル間の距離をcontrastive lossで明示的に調節する点です。図の \( f \) をdeep learningに置き換えればdeep metric learningとなります。

f:id:copypaste_ds:20190226204013p:plain

マハラノビス距離学習と同様に、contrastive lossにも類似データは近く、非類似データは遠くするような工夫がされています。図にある通り、類似サンプルの組を入力した場合は \( d_i \) が小さくなるように、非類似サンプルを入力した場合は \( -d_i \) が小さくなるよう( \( d_i \) が大きくなるように)に学習させます。距離\( d \) は好きに選べば良いですが、元論文ではユークリッド距離が採用されていました。\( m \)(定数)を足してヒンジロスに通しているのは、\( -d_i \) を限りなく小さく( \( - \inf \) など)することでlossを下げる現象を防ぐためです。 metric learningで学習した特徴量を抽出する際は1サンプルずつ関数 \( f \) に通して、lossの手前の出力を使えば良いです。特徴量の抽出方法は他の手法も同様です。

triplet network

2014年に提案された手法です。今回はこの論文で紹介されているtriplet lossを紹介しますが、triplet network自体はこの論文で提案されています。最大の特徴は3サンプルを一組で入力する点です。

f:id:copypaste_ds:20190226205158p:plain

triplet networkもsiamese networkと同様に類似サンプル同士は近く、非類似サンプルが遠くなるようにlossを設計しますが、サンプルの準備方法が異なります。サンプルの組は以下の手順で作成します。

  1. 基準となるサンプル \( a_i \)(anchor)を選択する
  2. \( a_i \) と似ているサンプル \( p_i \) 、似ていないサンプル \( n_i \) を一つずつ選択する

サンプルの組を作った後は、各サンプルを同じ関数 \( f \) に通してlossを計算すれば良いです。また、contrastive lossと同様に距離 \( d \)は好きに設定できます

サンプルの選び方と直感的理解

siamese networkと同様にtriplet lossもサンプルの組の選び方が重要です。
サンプル選択の理解を深めるためにもう少し式を眺めてみます。

f:id:copypaste_ds:20190226210529p:plain

lossは低い程良いので青枠部分がゼロ以下になるのが理想です。ただし、この関係を満たすサンプルの組は学習に役立たない(既にlossがゼロ)ので、学習時には青枠部分がゼロより大きくなるサンプルの組(学習が難しい組)を選択した方が良いそうです。

ちなみに理想的な関係式を移行して図にすると以下のように描けます。

f:id:copypaste_ds:20190227191500p:plain

図のオレンジの線(非類似サンプルまでの距離)が青線(類似サンプルまでの距離)+緑線(マージン)より長くなるように学習させているわけです。

L2 softmax network

2017年に提案された手法です。(論文はこちら)contrastive lossやtriplet lossのように明示的に距離を操作することはしません。softmax関数の手前で2つの処理をするのが特徴です。(どうやら暗黙的にコサイン類似度で調節しているという話も聞きましたが知識不足でよくわかりません)

f:id:copypaste_ds:20190226211758p:plain

処理の内容は簡単で、softmax関数に通す前にL2ノルムで割って定数倍するだけです。L2ノルムで割る(=単位ベクトル化する)ことで、softmax関数で予測容易なサンプルの予測値を限りなく1に近づける現象を抑えているそうです。また、表現力を調整する定数 \( \alpha \) はハイパーパラメータで事前に決める必要があります。(正直あまり知りません)

MNISTで実験

MNISTデータを使って3つの実験をしてみました。基本的に同じ数字は似ている、違う数字は似ていない組としてmetric learningを行います。(実験1-3だけは別ですが)

実験条件

実験条件は以下のとおりです。

  • CNNの構造
    • Convolution(kernel_size=3)+ReLuを4つ重ねた後、全結合層256次元
    • Convolutionのチャネル数は 16 -> 32 -> 64 -> 128
  • 最適化手法
    • Adam
  • 特徴量抽出のタイミング
    • lossの手前
    • L2 softmaxの場合は単位ベクトル化する前
  • データ
    • train: 60,000枚
    • test: 10,000枚
  • チューニング
    • ほぼしていません。 

実験1-1: 表現力の確認

まずはmetric learningがそれっぽく動作してくれるのか確認しました。実験手順は以下のとおりです。

  1. 10クラスで学習し、可視化
  2. metric learningで得た特徴量を用いて分類モデルを学習し、精度評価

まずは10クラスで学習した後、試験データの1000サンプルをtSNEで可視化しました。

f:id:copypaste_ds:20190227202009p:plain

no metric learning(metric learning未使用)はデータが混ざり合っており、metric learningを使った場合は上手く分かれているように見えます。また、contrastive lossとtriplet lossはデータ間の距離を考慮しているようにも見えますが、L2 softmaxは全体的にきっちりわける傾向が強いように見えます。(曖昧な表現をしているのはtSNEも多様体学習しており、2次元プロットの距離が高次元空間の距離を反映しているとは限らないためです)
次にクラスごとに特徴量ベクトルのセントロイドを算出して、セントロイド間の距離をヒートマップで可視化してみました。 距離は0~1にスケールしてあります。

f:id:copypaste_ds:20190228203728p:plain

文字が小さくて恐縮ですが、no metric learning, triplet loss, contrastive lossに関してはtSNEの可視化と同様の傾向が見て取れます。一方でL2 softmaxに関しては、tSNEの可視化では近く見えた3と5も実は遠い距離にあることがわかります。 (L2 softmaxを使用した場合にはコサイン類似度を使用したほうが良さそうですが、今回は比較のためにユークリッド距離を使用しています。そもそも直接比較はできない数値なのであくまで参考程度に)

次にmetric learningで得た特徴量を用いて分類モデルを学習させました。評価指標はaccuracyです。

f:id:copypaste_ds:20190226220958p:plain

no ML(metric learning未使用)よりもmetric learningのほうが高いaccuracyを示しています。中でもL2 softmax lossが特に優れています。これはL2 softmax networkが分類モデルを学習させている(最終的にsoftmaxでlossを計算している)ため、分類モデルにとって都合の良い特徴量を作成しているのだと思います。後段で分類モデルを使いたいならL2 softmaxが良さそうです。

実験1-2: 未知クラスの表現力を確認

次に未知クラスの特徴量をそれっぽく再現できるか実験してみました。実験手順は以下の通りです。

  1. 3, 4を除いた8クラスで学習し、可視化
  2. metric learningで得た特徴量を用いて分類モデルを学習し、精度評価

まずは8クラス(数字の3, 4以外)で学習した後、3, 4も合わせてtSNEで可視化しました。 実装はtriplet lossとl2 softmax lossで行いました。 3, 4以外は比較的分かれており、4は9の近く、3は丸みを帯びた文字の近くに位置していそうです。

f:id:copypaste_ds:20190227202049p:plain

生成した特徴量を用いてKNNを学習させました。KNNには3, 4も訓練データとして与えています。正規化済みの混同行列は以下のようになりました

f:id:copypaste_ds:20190227202058p:plain

metric learningモデルにとっての未知クラスの分類精度(オレンジ部分)はtriplet lossのほうが良いことがわかります。未知クラスを扱うときはtriplet lossが有効ということでしょうか。既知クラスの精度でL2 softmax lossが強いことは実験1-1の結果と一致します。triplet lossの精度を見るとクラス3が0.93、クラス4が0.85とそこそこ高いので未知クラスでもそれっぽい特徴量を作成できたことがわかります。

実験1-3: 奇数/偶数を学習

最後に奇数(1, 5, 7, 9)と偶数(0, 2, 6, 8)の2クラスで類似/非類似の組を作って実験してみました。 奇数/偶数で数字の形は必ずしも似ていないので少し無茶ぶりをしたつもりです。実験手順は以下の通りです。 実装にはtriplet lossを使用しました。

  • 3, 4を除いた上で奇数/偶数の特徴量を学習し、可視化
  • 3, 4も合わせて可視化し、位置を確認

まずは3, 4を抜いて学習させて可視化してみました。 左図はtSNEの可視化、右図が各クラスのセントロイド間のユークリッド距離を0~1に正規化したものです。

f:id:copypaste_ds:20190227202242p:plain

tSNEの可視化結果を見ると奇数/偶数で分離できそうに見えます。 右図の距離行列を見ても、奇数は奇数同士近く、偶数は偶数同士近く、奇数と偶数は遠く配置されていることがわかります。 無茶ぶりのつもりでしたがdeep learningにとっては簡単なタスクだったのでしょうか。
次に未知クラス(3と4)も合わせて可視化してみました。

f:id:copypaste_ds:20190227202257p:plain

4と9、 3と5が混ざり合っているように見えます。 今回の学習方法だとdeep learningは偶数/奇数の概念までは学習できないので4は奇数側(9の近く)に寄ってしまいました。 未知クラスも視野に入れるなら、metric learningで考慮したい意味(距離)は慎重に考える必要がありそうです。

天気データで実験

日本の天気データを使ってMNISTと似たような実験をしてみました。テーブルデータでもそれっぽく学習できるのか確かめることが目的です。

データの準備

気象庁のページから過去11年分の12都道府県の天気データをダウンロードしました。データは1時間おきの計測データで、クローリング失敗やそもそもの欠損などはpandasのdropnaで雑に処理をしました。必要なカラムだけ抽出して整形したデータは以下の通りです。今回は同じ都道府県ならば似ている、異なる都道府県なら似ていないとしてmetric learningしてみます。

f:id:copypaste_ds:20190227194517p:plain

入力データの次元があまりに少ないと面白くないので、1日1サンプルとして最大で12都道府県×11年×365日分のサンプルを用意しました。(実際には1/3程度が欠損で消えましたが...)都道府県ごとのサンプル数は沖縄が最大で3682サンプル、神奈川が最小で1425サンプルで若干隔たりがあります。

特徴量としては月(ダミー変数)と以下の項目を用意しました。ダミー変数を含めると52次元のデータとなります。画像データと比べると次元はかなり少ない印象です。

  • 朝/昼/夜ごとに以下の統計量を算出
    • 気温の最小・平均・最大・標準偏差
    • 降水量の平均・最大・合計
    • 降雪量の平均・最大・合計
    • 風速の最小・平均・最大・標準偏差

データの準備はとても雑ですが、ひとまずこれで準備完了です。

実験条件

実験条件は以下のとおりです。

  • MLPの構造
    • 全結合層(256次元)+ReLUを3つ
  • 最適化手法
    • Adam
  • 特徴量抽出のタイミング
    • Lossの手前
    • L2 softmaxの場合は単位ベクトル化する前
  • データセット
    • train: 9年分(2008年 ~ 2017年)
    • test: 2年分(2017年 ~ 2019年)
  • チューニング
    • ほぼできていません

実験2-1: 表現力の確認(その1)

まずは簡単そうな問題設定で動作確認をしてみました。実験手順は以下のとおりです。

  1. 4クラス(札幌、東京、大阪、沖縄)で学習し、可視化
  2. metric learningで得た特徴量を用いて分類モデルを学習し、精度評価

下図がtSNEの可視化結果です。metric learning未使用の場合は季節や月でクラスタができているように見えます。

f:id:copypaste_ds:20190227200727p:plain

札幌と沖縄の特徴量表現は比較的簡単ですが、東京と大阪の違いを学習させることは少し難しいようです。 triplet loss, L2 softmax lossは東京と大阪もある程度分離できていますが、contrastive lossは混ざり合っているように見えます。 ひとまずテーブルデータでもそれっぽく動作することが確認できました。

次にmetric learningで得た特徴量を用いて分類モデルを学習させました。評価指標はaccuracyです。

f:id:copypaste_ds:20190228224334p:plain

MNISTのときと同様に、metric learningを使ったほうが精度が高い結果となりました。 今回もL2 softmaxが最良の結果を示しています。

実験2-2: 表現力の確認(その2)

クラス数を増やして実験2-1と同様の実験をしてみました。クラス数が増えるので難易度も上がります。実験手順は以下のとおりです。

  1. 9クラス(札幌、青森、山形、東京、富山、奈良、山口、福岡、沖縄)で学習し、可視化

metric learning未使用の場合は実験2-1と同様に季節でクラスタが分かれているように見えます。 metric learningを使うことでクラスタが一つになるようですが、タスクが難しくなったせいか全体的にやや混ざり合っているように見えます。また、contrastive lossはミミズ型になりやすく少しチューニングに手間がかかりました。

f:id:copypaste_ds:20190227200911p:plain

下図は各クラスのセントロイド間のユークリッド距離を0~1に正規化したものです。 地理的な関係が反映されているかどうかはさておき、contrasitve lossとtriplet lossは距離を学習している気配があります。L2 softmaxは相変わらずはっきりと分ける傾向にあり、沖縄を除くとどの都道府県も似たような距離関係にあります。

f:id:copypaste_ds:20190228203953p:plain

実験2-3: 未知クラスの表現力を確認

最後に未知クラスの特徴量をそれっぽく再現できるか確認してみました。triplet lossとL2 softmax lossで実装しています。実験手順は以下の通りです。

  1. 9クラス(札幌、青森、山形、東京、富山、奈良、山口、福岡、沖縄)で学習し、未知クラス(秋田、神奈川、大分)も含めて可視化
  2. metric learningで得た特徴量を用いて分類モデルを学習し、精度評価

9クラスで再度学習させてtSNEで可視化してみました。左図は既知クラスのみ、右図は未知クラスも含めて可視化しています。 既知クラスのほうは辛うじてクラスごとのクラスタを確認できますが、未知クラスのほうは完全に混ざり合っているように見えます。 秋田は青森と、神奈川は東京と、大分は福岡と混ざり合ってしまったようです。チューニングすればある程度改善できると思いますが、タスクとして少し難しかったのかもしれません。

f:id:copypaste_ds:20190228204030p:plain

実験1-2と同様にmetric learningで学習した特徴量を用いて分類問題を解いてみました。正規化済みの混同行列は下図の通りです。

f:id:copypaste_ds:20190227201503p:plain

未知クラスの精度が著しく低いことがわかります。やはり既知クラスと混ざり合ってしまい上手く特徴量を表現できなかったようです。 一応2種類のlossを比較しておくと、未知クラスにはtriplet lossが強く、既知クラスにはL2 softmax loss が強いという実験1-2と同様の結果が得られました。

まとめ

  • metric learningは計量を学習する手法
  • 意味的な距離を考慮した特徴量を作成できる
  • 画像だけでなくテーブルデータにもある程度機能する
    • 実データで使えるかどうかはよくわかりません
  • 意味的な距離を考慮したいならtriplet loss
  • 未知クラスを主に扱うならtriplet loss
  • 分類タスクの特徴量生成に使うならL2 softmax loss
  • 素人のブログなのであまり信じすぎないでください

おわりに

metric learningの概要からToyデータによる実験結果まで簡単に紹介しました。冒頭でも述べた通り、私自身metric learningを学び始めて間もないですので、記事中に間違いがある可能性は非常に高いです。間違いを見つけた方はご指摘ただければと思います。感想としてはlossの設計で特徴量空間をある程度コントロールできるのは面白いと感じました。本記事で紹介していないlossやautoencoderとの組み合わせなど、データを変えながら遊んでみたいです。そういえばkaggleのタンパク質コンペ(?)の1st place solutionがmetric learningを使った解法だったそうです。今後はコンペでも積極的に使用されるかもしれませんね。

参考

[1] deep metric learningによるcross-domain画像検索 - ZOZO Technologies TECH BLOG
[2] Deep Metric Learning Using Triplet Networkの論文を流し読む – Urusu Lambda Web
[3] http://researchers.lille.inria.fr/abellet/talks/metric_learning_tutorial_CIL.pdf