17 天前

基于隐式梯度的元学习

Aravind Rajeswaran, Chelsea Finn, Sham Kakade, Sergey Levine
基于隐式梯度的元学习
摘要

智能系统的一项核心能力是能够基于先前经验快速学习新任务。基于梯度(或优化)的元学习方法近年来已成为少样本学习(few-shot learning)中一种高效的技术路径。在该框架下,元参数在外部循环中进行学习,而针对特定任务的模型则在内部循环中通过仅使用当前任务的少量数据进行学习。然而,这类方法在扩展性方面面临一个关键挑战:需要对内部循环的学习过程进行反向传播求导,这往往带来巨大的计算开销和内存负担。为此,我们借鉴隐式微分(implicit differentiation)的思想,提出了一种新的隐式MAML(implicit MAML)算法。该方法仅依赖于内部优化问题的最终解,而无需依赖内部优化器在迭代过程中所走的路径。这一设计实现了元梯度计算与内部优化器选择之间的有效解耦。因此,我们的方法对内部优化器的选择具有完全的无关性,能够在不出现梯度消失或内存瓶颈的情况下,平稳处理大量内部优化步骤。理论上,我们证明了隐式MAML能够以极低的内存开销计算出高精度的元梯度,其内存占用量(仅受小常数因子影响)不超过计算单次内部循环梯度所需内存,且整体计算成本并未增加。实验结果表明,隐式MAML在少样本图像识别基准任务上显著提升了性能,充分验证了其理论优势在实际应用中的有效性。