Shanghai Jiao Tong Researchers Develop OctoThinker: A Two-Stage Mid-Training Strategy to Enhance Llama Models for Reinforcement Learning
Researchers from Shanghai Jiao Tong University have proposed a novel approach called OctoThinker to improve the scalability of reinforcement learning (RL) in large language models (LLMs). This method leverages chain-of-thought (CoT) prompting and a two-stage mid-training strategy to enhance the RL capabilities of models, particularly those from the Llama family, which have historically struggled with this process. Background and Context LLMs have demonstrated remarkable advancements in complex reasoning tasks by combining CoT prompting with large-scale RL. Models like Deepseek-R1-Zero have exhibited strong reasoning skills through direct RL application to their base models. Similarly, methods such as SimpleRL and Open-ReasonerZero have improved the reasoning abilities of smaller models, such as those in the Qwen series. However, achieving consistent results across different base model families remains a challenge. For instance, while RL has been successful with the Qwen series, it has faced difficulties with the Llama series, raising questions about the underlying factors affecting RL performance. Limitations of RL Scaling on Llama Models Large-scale RL advancements in models like OpenAI’s o1 and o3, and DeepSeek’s R1, have primarily focused on solving competition-level mathematics problems. These successes have motivated the exploration of RL in smaller models with fewer than 100 billion parameters. However, most of these studies are confined to the Qwen model family, making it hard to replicate the same results on other families like Llama. The lack of transparency in pre-training pipelines has hindered the understanding of how pre-training influences RL scaling. Additionally, one-shot prompting, which benefits Qwen models, offers minimal advantage to Llama models. Efforts to curate high-quality mathematical pre-training corpora, such as OpenWebMath, MathPile, InfiMM-Web-Math, and FineMath, have shown promise but are still limited in scale to under 100 billion tokens. Exploring Mid-Training with the Stable-then-Decay Strategy To address these challenges, the researchers at Shanghai Jiao Tong University examined how mid-training strategies can influence RL dynamics, focusing on both Qwen and Llama models. Their study revealed several key insights: High-Quality Mathematical Corpora: Using advanced mathematical datasets, such as MegaMath-Web-Pro, significantly boosts both base model and RL outcomes. QA-Style Data: Incorporating long CoT reasoning in QA-style data further enhances RL results. Verbosity and Instability: Long CoT introductions can lead to increased verbosity and instability during RL training. Mid-Training Scaling: Applying scaling during mid-training improves downstream RL performance. They introduced a two-stage mid-training strategy called Stable-then-Decay. In the first stage, the base models are trained on 200 billion tokens, ensuring a stable pre-training phase. In the second stage, the models undergo 20 billion tokens of training across three CoT-focused branches, resulting in the OctoThinker models. These branches are designed to explore different aspects of CoT reasoning during the RL process. RL Configuration and Benchmark Evaluation The researchers used the MATH8K dataset for RL training prompts, with a global training batch size of 128, 16 rollout responses per query, and a PPO mini-batch size of 64. The experiments were conducted on Llama-3.2-3B-Base and Qwen2.5-3B-Base models. For evaluation, the base language models were tested with few-shot prompting, while the RL-tuned models were evaluated using zero-shot prompts across several benchmark tasks, including GSM8K, MATH500, OlympiadBench, and AMC23. During RL training, Qwen models exhibited reasonable response lengths that increased gradually. In contrast, Llama models showed abnormal behavior, with average response lengths skyrocketing to 4,096 tokens. The evaluation results demonstrated that RL-tuned Qwen2.5-3B achieved notable improvements across benchmarks, while Llama-3.2-3B showed only marginal gains. Performance of OctoThinker Models Each OctoThinker branch demonstrated 10%-20% improvements over the original Llama base model and consistent gains over the stable-stage model across all sizes. The OctoThinker-Zero families displayed diverse thinking behaviors during RL scaling, with the OctoThinker-Long variant performing the best. When comparing three 3B-scale base models during RL training, OctoThinker-Long-3B outperformed the original Llama-3.2-3B and matched the performance of the Qwen2.5-3B, a model known for its strong reasoning capabilities and extensive pre-training. Conclusion and Future Work This research provides valuable insights into why base models like Llama and Qwen exhibit divergent behaviors during RL for reasoning tasks. The two-stage mid-training strategy effectively transforms Llama models into RL-ready foundation models, resulting in the superior OctoThinker models. Future research directions include further refining the mid-training stages, exploring applications in other domains, and enhancing the scalability and robustness of these models across different tasks. Industry Insights and Company Profiles The findings of this study are highly significant in the field of AI, as they offer a practical solution to the limitations of RL in certain base models. Industry insiders laud the innovative approach, highlighting the potential for OctoThinker to bridge the gap between different LLM families and bring more consistency to the use of reinforcement learning in AI development. Companies investing in AI research and development, such as Meta and Google, may find this methodology beneficial in improving the reasoning capabilities of their models, thereby accelerating their AI initiatives.