MDTv2 : le Masked Diffusion Transformer est un puissant synthétiseur d'images

Malgré ses succès dans la synthèse d’images, nous observons que les modèles probabilistes de diffusion (DPM) peinent souvent à raisonner de manière contextuelle, c’est-à-dire à apprendre les relations entre les parties d’objets au sein d’une image, ce qui entraîne un processus d’apprentissage lent. Pour résoudre ce problème, nous proposons un Masked Diffusion Transformer (MDT), qui introduit un schéma de modélisation latente masquée afin d’améliorer explicitement la capacité des DPM à apprendre les relations contextuelles entre les parties sémantiques des objets dans une image. Pendant l’entraînement, MDT opère dans l’espace latent en masquant certains tokens. Ensuite, un transformateur de diffusion asymétrique est conçu pour prédire les tokens masqués à partir des tokens non masqués, tout en préservant le processus de génération par diffusion. Notre MDT est capable de reconstruire l’information complète d’une image à partir d’un input contextuel partiel, permettant ainsi d’apprendre efficacement les relations associées entre les tokens d’image. Nous améliorons ensuite MDT en proposant une architecture de réseau macro plus efficace ainsi qu’une stratégie d’entraînement optimisée, appelée MDTv2. Les résultats expérimentaux montrent que MDTv2 atteint des performances supérieures en synthèse d’images, notamment un nouveau score SOTA FID de 1,58 sur le jeu de données ImageNet, avec une vitesse d’apprentissage plus de 10 fois supérieure à celle du modèle SOTA précédent DiT. Le code source est disponible à l’adresse suivante : https://github.com/sail-sg/MDT.