Learning to Skip the Middle Layers of Transformers

Conditional computation is a popular strategy to make Transformers moreefficient. Existing methods often target individual modules (e.g.,mixture-of-experts layers) or skip layers independently of one another.However, interpretability research has demonstrated that the middle layers ofTransformers exhibit greater redundancy, and that early layers aggregateinformation into token positions. Guided by these insights, we propose a novelarchitecture that dynamically skips a variable number of layers from the middleoutward. In particular, a learned gating mechanism determines whether to bypassa symmetric span of central blocks based on the input, and a gated attentionmechanism prevents subsequent tokens from attending to skipped token positions.Residual norms are controlled with a 'sandwich' or 'perilayernorm' scheme andgate sparsity with an adaptive regularization loss. We had aimed to reducecompute requirements for 'simpler' tokens and potentially foster an emergentmulti-level representational hierarchy but, at the scales investigated, ourapproach does not achieve improvements in the trade-off between validationcross-entropy and estimated FLOPs compared to dense baselines with fewerlayers. We release our code at https://github.com/tim-lawson/skip-middle.