HyperAIHyperAI

Command Palette

Search for a command to run...

과거는 과거가 아니다: Memory-Enhanced Dynamic Reward Shaping

Yang Liu Enxi Wang Yufei Gao Weixin Zhang Bo Wang Zhiyuan Zeng Yikai Zhang Yining Zheng Xipeng Qiu

초록

제시해주신 기술 텍스트를 전문적인 학술적 스타일의 한국어로 번역해 드립니다. (요청하신 지침에 따라 한국어로 답변을 작성하였습니다.)[번역본]Large Language Model(LLM)을 위한 Reinforcement Learning의 성공에도 불구하고, 흔히 발생하는 실패 모드 중 하나는 sampling diversity가 감소하여 policy가 유사한 오류 동작을 반복적으로 생성하는 현상입니다. 전통적인 entropy regularization은 현재 policy 하에서의 무작위성을 장려하지만, 여러 rollouts에 걸쳐 반복되는 실패 패턴을 명시적으로 억제하지는 못합니다. 본 논문에서는 과거의 행동 신호를 reward 설계에 통합하는 Memory-Enhanced Dynamic reward Shaping 프레임워크인 MEDS를 제안합니다. MEDS는 중간 단계의 model representations를 저장하고 활용함으로써 과거 rollouts의 특징을 포착하며, density-based clustering을 사용하여 빈번하게 반복되는 오류 패턴을 식별합니다. 더 빈번하게 발생하는 오류 클러스터에 할당된 rollouts에는 더 높은 penalty를 부여함으로써, 반복적인 실수를 줄이는 동시에 더 광범위한 exploration을 유도합니다. 5개의 데이터셋과 3개의 base model을 대상으로 실험한 결과, MEDS는 기존 baseline 대비 평균 성능을 일관되게 향상시켰으며, pass@1에서 최대 4.13포인트, pass@128에서 최대 4.37포인트의 성능 향상을 달달성했습니다. LLM 기반 annotation 및 정량적 diversity metric을 사용한 추가 분석 결과, MEDS가 sampling 과정 중 행동의 diversity를 증가시킨다는 것을 확인했습니다.

One-sentence Summary

To mitigate reduced sampling diversity in reinforcement learning for large language models, the proposed Memory-Enhanced Dynamic reward Shaping (MEDS) framework leverages intermediate model representations and density-based clustering to penalize recurring error patterns, ultimately improving performance across five datasets and three base models by up to 4.37 pass@128 points.

Key Contributions

  • The paper introduces MEDS, a Memory-Enhanced Dynamic reward Shaping framework designed to penalize recurrent failure patterns by incorporating historical behavioral signals into the reward design.
  • The method utilizes layer-wise logits as lightweight representations of reasoning trajectories to perform density-based clustering, which identifies and suppresses frequently occurring error modes during reinforcement learning.
  • Experimental results across five datasets and three base models demonstrate that MEDS improves reasoning performance by up to 4.37 pass@128 points and increases behavioral diversity during sampling.

Introduction

Reinforcement learning is a critical driver for optimizing the reasoning capabilities of large language models. However, on-policy optimization often suffers from a collapse in sampling diversity, where the model becomes trapped in repetitive, erroneous reasoning patterns. While traditional entropy regularization attempts to mitigate this by increasing randomness in the current policy, it fails to explicitly discourage specific failure modes that recur across different training rollouts. The authors leverage a framework called MEDS (Memory-Enhanced Dynamic reward Shaping) to address this by incorporating historical behavioral signals into the reward design. By using layer-wise logits as lightweight representations of reasoning trajectories, the authors employ density-based clustering to identify and penalize frequently recurring error patterns, thereby encouraging broader exploration and preventing the model from entrenching itself in self-reinforcing mistakes.

Dataset

Dataset overview
Dataset overview

Dataset Overview

The authors utilize a structured dataset designed for mathematical reasoning, organized into the following components:

  • Dataset Composition and Subsets

    • Group A: This subset consists of original answers used during the initial training phases.
    • Group B: This subset contains additional answers introduced during a later training stage to further refine model performance.
  • Data Processing and Formatting

    • Standardized Templating: All training and evaluation problems are processed through the Qwen-Math template. This ensures a consistent prompt structure that instructs the model to reason step by step and place the final result within a LaTeX boxed format.
    • Prompt Structure: The template utilizes specific control tokens to define the system instruction, the user problem, and the assistant response, creating a standardized environment for mathematical problem solving.
  • Model Usage

    • Training Stages: The data is applied in a multi-stage training approach, where Group A serves as the foundational training set and Group B is integrated during a subsequent stage to enhance reasoning capabilities.

Method

The framework of MEDS operates through a three-stage process designed to mitigate the recurrence of errors by identifying and penalizing shared reasoning patterns across responses. The overall architecture, as illustrated in the diagram, consists of logic feature extraction, memory-based clustering, and reward shaping.

In the first stage, logic feature extraction, the model processes an input xxx to generate a response y~\tilde{y}y~, from which a logic feature vector f(y~)f(\tilde{y})f(y~) is derived. This vector is constructed from the layer-wise logits of the first token yy^*y in the final answer, specifically using the latter half of the Transformer layers to capture reasoning patterns. The logits at the position of yy^*y from each layer nnn are aggregated into a feature vector f(y~)=concat(l(n)n=N/2,,N)f(\tilde{y}) = \mathrm{concat}(l^{*(n)} \mid n = N/2, \ldots, N)f(y~)=concat(l(n)n=N/2,,N), where l(n)l^{*(n)}l(n) is the logit corresponding to yy^*y at layer nnn. This process leverages the fact that the evolution of logits across layers reflects the model's internal reasoning, as demonstrated in the figure showing logits aggregation in a 3-layer model.

Logits aggregation process
Logits aggregation process

The second stage, memory-based clustering, maintains a per-prompt error memory GxG_xGx that stores the feature representations of all historical responses sampled for a given prompt. This set is then clustered using HDBSCAN to group responses with similar logic features into clusters CkC_kCk. The number of clusters KKK is determined dynamically, and the clustering process identifies patterns in the reasoning trajectories.

Memory-based Clustering
Memory-based Clustering

In the final stage, reward shaping, the indicator function c(y~)c(\tilde{y})c(y~) is defined as log(Ck+1)\log(|C_k| + 1)log(Ck+1), where Ck|C_k|Ck is the size of the cluster to which f(y~)f(\tilde{y})f(y~) is assigned. The reward is adjusted by subtracting a penalty proportional to the cluster size, computed as min(αlog(Ck+1),β)\min(\alpha \log(|C_k| + 1), \beta)min(αlog(Ck+1),β), resulting in the shaped reward r~(x,y~)=r(x,y~)penalty\tilde{r}(x, \tilde{y}) = r(x, \tilde{y}) - \text{penalty}r~(x,y~)=r(x,y~)penalty. This penalty discourages the policy from generating responses that follow error patterns already observed in the past, effectively shaping the reward landscape to promote diverse and correct reasoning paths. The adjusted reward is then used to update the policy, as shown in the framework diagram.

Experiment

The proposed MEDS method is evaluated across three model scales using five mathematical reasoning benchmarks to compare its performance against baselines like GRPO and DAPO. The results demonstrate that MEDS consistently achieves superior mathematical reasoning capabilities, with benefits that scale alongside the base model's capacity. Behavioral and representational analyses confirm that the method enhances exploration by maintaining higher rollout diversity and utilizing logit-based clustering that effectively captures distinct reasoning patterns.

The authors evaluate their method against baselines on multiple math benchmarks using pass@k metrics. Results show that their approach consistently outperforms the base model and other reinforcement learning methods across different model scales and datasets, with improvements becoming more pronounced as model size increases. The proposed method achieves the best performance on all benchmarks across all model scales. Performance improvements are most significant on larger models, indicating scalability with model capacity. The method consistently outperforms strong baselines like DAPO and GRPO across all settings.

Pass@k performance across models
Pass@k performance across models

The the the table compares different logit aggregation methods for clustering, showing their correlation with human annotations. Using the last 14 layers achieves the highest correlation, indicating better alignment with reasoning patterns. Using the last 14 layers as clustering features yields the highest correlation with human annotations. Logit aggregation methods show varying levels of correlation, with some outperforming others. The correlation results suggest that better clustering quality aligns with improved downstream performance.

Logit aggregation correlation
Logit aggregation correlation

The the the table compares the performance of different methods across multiple mathematical benchmarks. The proposed MEDS-14 method achieves the highest average performance, consistently outperforming DAPO and other variants across most metrics. MEDS-14 achieves the highest average performance across all benchmarks. The proposed method outperforms DAPO and other variants on most individual benchmarks. Performance varies significantly across methods, with MEDS-14 showing the most consistent improvements.

Performance comparison of methods
Performance comparison of methods

The authors compare their method against several baselines across multiple math benchmarks and model scales. Results show that their approach consistently outperforms the base model and strong baselines, with improvements becoming more pronounced on larger models. The method also demonstrates enhanced exploration behavior and effective use of logit-based clustering signals. The proposed method consistently achieves the best performance across all benchmarks and model scales, outperforming both the base model and strong baselines. The method shows stronger improvements on larger models, indicating that its benefits scale with model capability. The approach enhances exploration behavior, as evidenced by higher diversity metrics and lower top-1 eigen ratios compared to baselines.

Performance comparison on math benchmarks
Performance comparison on math benchmarks

The heatmap visualizes logit values from model responses across different layers, showing distinct patterns for responses labeled A1, A2, B1, B2, and B3. Responses within clusters exhibit similar logit trajectories, indicating shared reasoning structures, while responses from different clusters show divergent patterns, reflecting different reasoning paths. Responses within clusters show similar logit trajectories across layers, indicating shared reasoning structures. Responses from different clusters exhibit distinct logit patterns, reflecting different reasoning paths. Logit values change more significantly in later layers, where reasoning patterns become more apparent.

Logit heatmap of model responses
Logit heatmap of model responses

The authors evaluate their proposed method against various baselines and logit aggregation techniques across multiple mathematical benchmarks and model scales. The results demonstrate that the method consistently outperforms existing reinforcement learning approaches, with performance gains scaling effectively alongside model capacity. Furthermore, analysis of logit trajectories reveals that utilizing the final layers for clustering provides the strongest alignment with human reasoning patterns and enables more effective exploration.


AI로 AI 구축

아이디어에서 출시까지 — 무료 AI 코코딩, 즉시 사용 가능한 환경, 최적의 GPU 가격으로 AI 개발을 가속화하세요.

AI 협업 코딩
바로 사용 가능한 GPU
최적의 가격

HyperAI Newsletters

최신 정보 구독하기
한국 시간 매주 월요일 오전 9시 에 이번 주의 최신 업데이트를 메일로 발송합니다
이메일 서비스 제공: MailChimp