HyperAIHyperAI

Command Palette

Search for a command to run...

Early Exiting Predictive Coding Neural Networks for Edge AI

Alaa Zniber Mounir Ghogho Ouassim Karrakchou Mehdi Zakroum

Abstract

The Internet of Things is transforming various fields, with sensors increasingly embedded in wearables, smart buildings, and connected equipment. While deep learning enables valuable insights from IoT data, conventional models are too computationally demanding for resource-limited edge devices. Moreover, privacy concerns and real-time processing needs make local computation a necessity over cloud-based solutions. Inspired by the brain's energy efficiency, we propose a shallow bidirectional predictive coding network with early exiting, dynamically halting computations once a performance threshold is met. This reduces the memory footprint and computational overhead while maintaining high accuracy. We validate our approach using the CIFAR-10 dataset. Our model achieves performance comparable to deep networks with significantly fewer parameters and lower computational complexity, demonstrating the potential of biologically inspired architectures for efficient edge AI.

One-sentence Summary

Researchers from the International University of Rabat, University Mohammed VI Polytechnic, and the University of Leeds propose EE-PCN, a shallow bidirectional predictive coding network with early exiting that dynamically halts computation to achieve deep-network accuracy with minimal memory and FLOPs for extreme edge AI.

Key Contributions

  • The paper introduces a new derivation of predictive coding cycling rules for bidirectional networks that effectively implements both feedback and feedforward update mechanisms.
  • A shallow predictive coding network is designed to achieve accuracy comparable to deeper models while significantly reducing the memory footprint for deployment on extreme-edge devices.
  • The method incorporates a dynamic early-exiting mechanism and knowledge distillation across cycles to adaptively adjust the number of operations, thereby improving inference efficiency and the performance of early exits.

Introduction

The rise of IoT in sectors like health monitoring and smart cities demands real-time data processing on resource-constrained edge devices, yet conventional deep learning models are too computationally heavy and memory-intensive for these environments. While Predictive Coding Networks (PCNs) offer biologically inspired efficiency, prior implementations often double parameter counts compared to standard models and lack adaptive mechanisms, forcing them to perform unnecessary computations on simple inputs. To address these challenges, the authors propose a shallow bidirectional PCN that integrates an early exiting mechanism to dynamically halt inference once a performance threshold is met. This approach leverages knowledge distillation across cycles to maintain high accuracy while drastically reducing memory footprint and computational overhead, making it suitable for extreme edge deployment.

Dataset

  • The authors use the CIFAR-10 dataset, which contains 60,000 32x32 RGB images evenly distributed across 10 classes to simulate low-resolution IoT applications like surveillance and smart farming.
  • The dataset is split into a training set of 50,000 images and a test set of 10,000 images.
  • Data augmentation is applied to the training set using random translation and horizontal flipping.
  • The training data is processed into batches of 128 images for model learning.

Method

The authors propose a Predictive Coding Network (PCN) model enhanced with early exiting capabilities to optimize inference efficiency. The architecture consists of a shared backbone serving as a feature extractor, along with multiple downstream task classifiers. The backbone is designed as a bidirectional hierarchy of convolutional and deconvolutional layers.

As shown in the figure below:

In this framework, blue arrows denote the forward convolutional pass, while red arrows indicate the feedback deconvolutions used to reduce local errors. During inference, the model performs a variable number of cycles, tTt \leq TtT, over the backbone to iteratively minimize local prediction errors across all layers. Once the cycling process concludes, the final layer feature vector is passed to the classifier corresponding to the current cycle count ttt, indicated by the green arrow. The classification confidence is then compared against a predefined user threshold. If the confidence exceeds the threshold, the inference is terminated and a response is returned. Otherwise, another cycle is initiated, followed by another classification and threshold comparison.

The architecture employs TTT distinct classifiers rather than a single classifier shared across all cycles. This decision is driven by the evolving nature of feature representations throughout the iterative process. Since feature vectors undergo continuous refinement from one cycle to the next, a classifier trained on feature representations from a five-cycle model would be unable to accurately interpret the patterns extracted by a one-cycle model for the same input.

To derive the PC update rules, the authors apply gradient descent to minimize the local errors at each pass. Let rl(t)\mathbf{r}_l(t)rl(t) denote the feature representation at convolution layer lll and cycle ttt. The representation at layer l=0l=0l=0 is fixed as the input image. For t=0t=0t=0, all feature representations are initialized through a standard feedforward pass: rl(0)=ϕ(Wl1,lrl1(0)),l=1,,L\mathbf { r } _ { l } ( 0 ) = \phi ( \mathbf { W } _ { l - 1 , l } \mathbf { r } _ { l - 1 } ( 0 ) ) , \qquad l = 1, \cdots, Lrl(0)=ϕ(Wl1,lrl1(0)),l=1,,L where ϕ\phiϕ is a nonlinear activation function, assumed to be ReLU in the experiments.

The feedback pass update rule governs a process in which the higher-layer representation, rl+1(t)\mathbf{r}_{l+1}(t)rl+1(t), generates a top-down prediction of the lower-layer representation, rl(t)\mathbf{r}_{l}(t)rl(t), denoted by pl(t)\mathbf{p}_{l}(t)pl(t). This prediction is given by: pl(t)=ϕ[Wl+1,lrl+1(t)]\mathbf { p } _ { l } ( t ) = \phi \left[ \mathbf { W } _ { l + 1 , l } \mathbf { r } _ { l + 1 } ( t ) \right]pl(t)=ϕ[Wl+1,lrl+1(t)] The update is carried out by minimizing the local error, defined as ϵl(t)=12rl(t)pl(t)22\epsilon _ { l } ( t ) = \frac { 1 } { 2 } \left| \left| \mathbf { r } _ { l } ( t ) - \mathbf { p } _ { l } ( t ) \right| \right| _ { 2 } ^ { 2 }ϵl(t)=21rl(t)pl(t)22. The feedback update rule, computed at the midpoint t+1/2t+1/2t+1/2, is expressed as: rl(t+1/2)=(1αl)rl(t)+αlϕ[Wl+1,lrl+1(t)]\mathbf { r } _ { l } ( t + 1 / 2 ) = ( 1 - \alpha _ { l } ) \mathbf { r } _ { l } ( t ) + \alpha _ { l } \phi \left[ \mathbf { W } _ { l + 1 , l } \mathbf { r } _ { l + 1 } ( t ) \right]rl(t+1/2)=(1αl)rl(t)+αlϕ[Wl+1,lrl+1(t)] The representation of the last layer remains unaffected during the feedback pass by design.

The feed-forward pass update rule governs a process in which the lower-layer representation generates a bottom-up prediction, which is then used to update the upper-layer representation. The feed-forward prediction is given by: pl(t+1/2)=ϕ[Wl1,lrl1(t+1/2)]\mathbf { p } _ { l } ( t + 1 / 2 ) = \phi [ \mathbf { W } _ { l - 1 , l } \mathbf { r } _ { l - 1 } ( t + 1 / 2 ) ]pl(t+1/2)=ϕ[Wl1,lrl1(t+1/2)] This results in the following feed-forward update rule: rl(t+1)=(1βl)rl(t+1/2)+βlϕ[Wl1,lrl1(t+1/2)]\mathbf { r } _ { l } ( t + 1 ) = ( 1 - \beta _ { l } ) \mathbf { r } _ { l } ( t + 1 / 2 ) + \beta _ { l } \phi [ \mathbf { W } _ { l - 1 , l } \mathbf { r } _ { l - 1 } ( t + 1 / 2 ) ]rl(t+1)=(1βl)rl(t+1/2)+βlϕ[Wl1,lrl1(t+1/2)] Unlike prior formulations that rely solely on feedback convolution weight matrices, this formulation integrates both top-down and bottom-up predictions, leading to a more comprehensive update mechanism.

Regarding training, the classification task is formulated as a multi-objective optimization problem where TTT losses, denoted as Li\mathcal{L}_iLi, compete over the shared weights. The authors address this using scalarization, transforming the problem into a single-objective optimization through a weighted average. Furthermore, they incorporate Kullback-Leibler (KL) divergence, denoted as KD\mathcal{KD}KD, between intermediate logits and the final-cycle logits to facilitate knowledge distillation. In this framework, the deepest network acts as the teacher, while the preceding shallow sub-networks serve as students. The total loss is expressed as: Ltot=ρi=1TλiLi+(1ρ)i=1T1KD(y^i,y^T)\mathcal { L } _ { \mathrm { t o t } } = \rho \sum _ { i = 1 } ^ { T } \lambda _ { i } \mathcal { L } _ { i } + ( 1 - \rho ) \sum _ { i = 1 } ^ { T - 1 } \mathcal { K D } ( \hat { \mathbf { y } } _ { i } , \hat { \mathbf { y } } _ { T } )Ltot=ρi=1TλiLi+(1ρ)i=1T1KD(y^i,y^T) where λi\lambda_{i}λi is a positive weighting factor for the loss function Li\mathcal{L}_{i}Li, y^i\hat{\mathbf{y}}_{i}y^i represents the logit vector from classifier iii, and ρ\rhoρ is a balancing factor.

The model design leverages PC dynamics to develop shallow networks capable of running on extreme edge devices. The models are based on VGG-like architectures where all convolutions use a 3×3 kernel with a stride of 1 and are followed by a ReLU activation function. Whenever the number of channels changes, max-pooling is applied in the feed-forward direction or upsampling in the feedback direction with a 2×2 kernel. Finally, the early exit classifiers are implemented as simple linear layers to ensure minimal overhead.

Experiment

  • Experiments validate that recursive processing with PC update rules in shallow models achieves competitive performance on extreme edge devices, outperforming edge-specific baselines and approaching VGG-11 accuracy with significantly fewer parameters.
  • Results demonstrate that additional processing cycles enhance model expressivity, allowing shallow architectures to better learn complex patterns and distinguish difficult classes.
  • Integrating an early exiting mechanism significantly reduces computational load and energy consumption, with high-confidence thresholds enabling the model to exit early for most inputs while maintaining accuracy.
  • The proposed models meet strict memory constraints of frugal microcontrollers, and their recursive nature ensures lower FLOP counts than deep networks for a large portion of the dataset, facilitating extended battery life.
  • Comparisons confirm that predictive coding rules combining top-down and bottom-up predictions outperform equivalent feed-forward CNNs.

Build AI with AI

From idea to launch — accelerate your AI development with free AI co-coding, out-of-the-box environment and best price of GPUs.

AI Co-coding
Ready-to-use GPUs
Best Pricing

HyperAI Newsletters

Subscribe to our latest updates
We will deliver the latest updates of the week to your inbox at nine o'clock every Monday morning
Powered by MailChimp