HyperAI

L'API TorchVision Transforms a Été Considérablement Mise À Niveau Pour Prendre En Charge La Détection D'objets, La Segmentation D'instance/sémantique Et Les Tâches Vidéo

il y a 2 ans
Information
Jiaxin Sun
特色图像

Introduction au contenu : l'API TorchVision Transforms a été étendue et mise à niveau pour prendre en charge la détection d'objets, la segmentation d'instances et sémantique et les tâches vidéo. La nouvelle API est encore en phase de test et les développeurs peuvent l'essayer.

Cet article a été publié pour la première fois sur WeChat : PyTorch Developer Community

insérer la description de l'image ici

TorchVision a été étendu avec l'API Transforms. Les détails sont les suivants :

  • Outre la classification d’images, il peut désormais être utilisé pour des tâches telles que la détection d’objets, la segmentation d’instances et sémantique et la classification vidéo ;
  • Prise en charge de l'importation d'augmentation de données SoTA directement depuis TorchVision, telles que MixUp, CutMix, Large Scale Jitter et SimpleCopyPaste.
  • Prend en charge l'utilisation de nouvelles transformations fonctionnelles pour transformer des vidéos, des cadres de délimitation et des masques de segmentation.

Limitations actuelles des transformations

La version stable de TorchVision Transforms API, également connue sous le nom de Transforms V1,Ne prend en charge qu'une seule image et ne convient donc qu'aux tâches de classification :

from torchvision import transforms
trans = transforms.Compose([
   transforms.ColorJitter(contrast=0.5),
   transforms.RandomRotation(30),
   transforms.CenterCrop(480),
])
imgs = trans(imgs)

Les méthodes ci-dessus ne prennent pas en charge la détection d'objets, la segmentation ou la classification des transformations qui nécessitent l'utilisation d'étiquettes. Comme MixUp et cutMix. Cela rend impossible l'exécution des extensions nécessaires à l'aide de l'API Transforms pour les tâches de vision par ordinateur autres que la classification. en même temps,Cela rend également plus difficile la formation de modèles de haute précision à l’aide de primitives TorchVision.

Pour surmonter cette limitation,TorchVision fournit une implémentation personnalisée dans son script de référence. Utilisé pour démontrer comment les améliorations sont effectuées dans toutes les tâches.

Bien que cette approche permette aux développeurs de former des modèles de classification, de détection d’objets et de segmentation de haute précision, il s’agit d’une approche grossière.Les transformations ne sont toujours pas importables dans le binaire TorchVision.

Nouvelle API de transformations

L'API Transforms V2 prend en charge la vidéo, le cadre de délimitation, l'étiquette et le masque de segmentation. Cela signifie qu'il fournit un support natif pour de nombreuses tâches de vision par ordinateur. La nouvelle solution est une alternative plus directe :

from torchvision.prototype import transforms
# Exactly the same interface as V1:
trans = transforms.Compose([
    transforms.ColorJitter(contrast=0.5),
    transforms.RandomRotation(30),
    transforms.CenterCrop(480),
])
imgs, bboxes, labels = trans(imgs, bboxes, labels)

La nouvelle classe Transform peut accepter n'importe quel nombre d'entrées sans imposer un ordre ou une structure spécifique :

# Already supported:
trans(imgs)  # Image Classification
trans(videos)  # Video Tasks
trans(imgs_or_videos, labels)  # MixUp/CutMix-style Transforms
trans(imgs, bboxes, labels)  # Object Detection
trans(imgs, bboxes, masks, labels)  # Instance Segmentation
trans(imgs, masks)  # Semantic Segmentation
trans({"image": imgs, "box": bboxes, "tag": labels})  # Arbitrary Structure
# Future support:
trans(imgs, bboxes, labels, keypoints)  # Keypoint Detection
trans(stereo_images, disparities, masks)  # Depth Perception
trans(image1, image2, optical_flows, masks)  # Optical Flow

L'API fonctionnelle a été mise à jour pour prendre en charge toutes les entrées nécessaires du noyau de traitement du signal, telles que le redimensionnement, le recadrage, les transformations affines, le remplissage, etc. :

from torchvision.prototype.transforms import functional as F
# High-level dispatcher, accepts any supported input type, fully BC
F.resize(inpt, resize=[224, 224])
# Image tensor kernel
F.resize_image_tensor(img_tensor, resize=[224, 224], antialias=True)
# PIL image kernel
F.resize_image_pil(img_pil, resize=[224, 224], interpolation=BILINEAR)
# Video kernel
F.resize_video(video, resize=[224, 224], antialias=True)
# Mask kernel
F.resize_mask(mask, resize=[224, 224])
# Bounding box kernel
F.resize_bounding_box(bbox, resize=[224, 224], spatial_size=[256, 256])

L'API utilise la sous-classification Tensor pour envelopper l'entrée, joindre des métadonnées utiles et les envoyer au noyau approprié. Une fois que vous avez terminé votre travail avec Datasets V2 à l'aide du pipeline de données TorchData, vous n'avez plus besoin d'encapsuler manuellement vos entrées. Actuellement, les utilisateurs peuvent encapsuler manuellement les entrées des manières suivantes :

from torchvision.prototype import features
imgs = features.Image(images, color_space=ColorSpace.RGB)
vids = features.Video(videos, color_space=ColorSpace.RGB)
masks = features.Mask(target["masks"])
bboxes = features.BoundingBox(target["boxes"], format=BoundingBoxFormat.XYXY, spatial_size=imgs.spatial_size)
labels = features.Label(target["labels"], categories=["dog", "cat"])

En plus de la nouvelle API, les responsables de PyTorch ont également fourni des implémentations importantes pour certaines améliorations de données utilisées dans la recherche SoTA, telles que MixUp, CutMix, Large Scale Jitter, SimpleCopyPaste, les méthodes AutoAugmentation et certaines nouvelles transformations de conversion géométrique, de couleur et de type.

L'API continue de prendre en charge les backends PIL et Tensor pour une image unique ou une image d'entrée par lots, et conserve la possibilité de script JIT sur l'API fonctionnelle.Cela permet de différer le mappage d'image de uint8 à float. A permis une amélioration supplémentaire des performances.

Il est actuellement disponible dans la zone prototype de TorchVision et peut être importé à partir de builds nocturnes.La nouvelle API a été vérifiée comme étant cohérente avec la précision de l’implémentation précédente.

Limitations actuelles

L'API fonctionnelle (noyau) reste scriptable JIT et entièrement BC, et la classe Transform fournit la même interface, mais ne peut pas être scriptée.

Cela est dû au fait que la classe Transform utilise le sous-classement Tensor et accepte un nombre arbitraire d'entrées, ce qui n'est pas pris en charge par JIT. Cette limitation sera continuellement optimisée dans les versions ultérieures.

Un bout à bout

Voici un exemple de la nouvelle API qui peut fonctionner à la fois avec les images PIL et les tenseurs.

Image de test :

insérer la description de l'image ici
Exemple de code :

import PIL
from torchvision import io, utils
from torchvision.prototype import features, transforms as T
from torchvision.prototype.transforms import functional as F
# Defining and wrapping input to appropriate Tensor Subclasses
path = "COCO_val2014_000000418825.jpg"
img = features.Image(io.read_image(path), color_space=features.ColorSpace.RGB)
# img = PIL.Image.open(path)
bboxes = features.BoundingBox(
    [[2, 0, 206, 253], [396, 92, 479, 241], [328, 253, 417, 332],
     [148, 68, 256, 182], [93, 158, 170, 260], [432, 0, 438, 26],
     [422, 0, 480, 25], [419, 39, 424, 52], [448, 37, 456, 62],
     [435, 43, 437, 50], [461, 36, 469, 63], [461, 75, 469, 94],
     [469, 36, 480, 64], [440, 37, 446, 56], [398, 233, 480, 304],
     [452, 39, 463, 63], [424, 38, 429, 50]],
    format=features.BoundingBoxFormat.XYXY,
    spatial_size=F.get_spatial_size(img),
)
labels = features.Label([59, 58, 50, 64, 76, 74, 74, 74, 74, 74, 74, 74, 74, 74, 50, 74, 74])
# Defining and applying Transforms V2
trans = T.Compose(
    [
        T.ColorJitter(contrast=0.5),
        T.RandomRotation(30),
        T.CenterCrop(480),
    ]
)
img, bboxes, labels = trans(img, bboxes, labels)
# Visualizing results
viz = utils.draw_bounding_boxes(F.to_image_tensor(img), boxes=bboxes)
F.to_pil_image(viz).show()

-- sur--