10日前

長尾分類における良質なモーメンタム因果効果の保持と悪質なモーメンタム因果効果の除去

Kaihua Tang, Jianqiang Huang, Hanwang Zhang
長尾分類における良質なモーメンタム因果効果の保持と悪質なモーメンタム因果効果の除去
要約

クラス数が増加するにつれて、データが長尾分布(long-tailed)であるという性質上、多数のクラスにわたりバランスの取れたデータセットを維持することは困難である。特に、一つの収集単位(例:一枚の画像内に複数の視覚的インスタンスが存在する場合)に注目対象のサンプルが重複して含まれる状況では、その維持はそもそも不可能となる。したがって、スケールに応じたディープラーニングにおける鍵となる課題は、長尾分類(long-tailed classification)の解決である。しかし、従来の手法は主に再重み付け(re-weighting)や再サンプリング(re-sampling)に基づくヒューリスティックなアプローチであり、根本的な理論的基盤に欠ける。本論文では、因果推論(causal inference)の枠組みを構築することで、従来手法の背後にある「なぜ」を解明するとともに、新たな理論的根拠に基づく解決策を導出する。具体的には、本理論により、SGDのモーメンタム(momentum)が長尾分類において本質的に交絡要因(confounder)であることが示される。一方で、モーメンタムは、尾部(tail)の予測を頭部(head)に偏らせる有害な因果効果を持つ一方で、その誘導する中間経路(mediation)は表現学習および頭部予測に有益な効果をもたらす。本フレームワークは、入力サンプルが引き起こす直接的な因果効果に注目することで、モーメンタムの矛盾した効果を洗練的に分離する。特に、学習フェーズでは因果介入(causal intervention)を、推論フェーズでは反事実的推論(counterfactual reasoning)を用いることで、「悪」な影響を除去しつつ、「善」な影響を保持する。本手法は、長尾視覚認識の3つのベンチマークにおいて、新たなSOTA(state-of-the-art)を達成した。具体的には、長尾CIFAR-10/-100、ImageNet-LT(画像分類)およびLVIS(インスタンスセグメンテーション)において、性能の向上が確認された。