Die TorchVision Transforms API Wurde Erheblich Aktualisiert, Um Objekterkennung, Instanz-/semantische Segmentierung Und Videoaufgaben Zu Unterstützen

Inhaltseinführung: Die TorchVision Transforms API wurde erweitert und aktualisiert, um Objekterkennung, Instanz- und semantische Segmentierung sowie Videoaufgaben zu unterstützen. Die neue API befindet sich noch in der Testphase und kann von Entwicklern ausprobiert werden.
Dieser Artikel wurde zuerst auf WeChat veröffentlicht: PyTorch Developer Community

TorchVision wurde um die Transforms API erweitert. Die Einzelheiten sind wie folgt:
- Zusätzlich zur Bildklassifizierung kann es jetzt für Aufgaben wie Objekterkennung, Instanz- und semantische Segmentierung und Videoklassifizierung verwendet werden.
- Unterstützung für den Import von SoTA-Datenerweiterungen direkt aus TorchVision, wie z. B. MixUp, CutMix, Large Scale Jitter und SimpleCopyPaste.
- Unterstützt die Verwendung neuer funktionaler Transformationen zum Transformieren von Videos, Begrenzungsrahmen und Segmentierungsmasken.
Aktuelle Einschränkungen von Transforms
Die stabile Version der TorchVision Transforms API, auch bekannt als Transforms V1,Unterstützt nur ein einzelnes Bild und ist daher nur für Klassifizierungsaufgaben geeignet:
from torchvision import transforms
trans = transforms.Compose([
transforms.ColorJitter(contrast=0.5),
transforms.RandomRotation(30),
transforms.CenterCrop(480),
])
imgs = trans(imgs)
Die oben genannten Methoden unterstützen keine Transformationen zur Objekterkennung, Segmentierung oder Klassifizierung, die die Verwendung von Labels erfordern. Wie MixUp und cutMix. Dies macht es unmöglich, die erforderlichen Erweiterungen mithilfe der Transforms-API für andere Computer Vision-Aufgaben als die Klassifizierung durchzuführen. gleichzeitig,Dies erschwert auch das Trainieren hochpräziser Modelle mithilfe von TorchVision-Primitiven.
Um diese Einschränkung zu überwinden,TorchVision bietet in seinem Referenzskript eine benutzerdefinierte Implementierung. Wird verwendet, um zu demonstrieren, wie Verbesserungen bei allen Aufgaben durchgeführt werden.
Obwohl dieser Ansatz es Entwicklern ermöglicht, hochpräzise Klassifizierungs-, Objekterkennungs- und Segmentierungsmodelle zu trainieren, handelt es sich dabei um einen groben Ansatz.Transformationen können in die TorchVision-Binärdatei immer noch nicht importiert werden.
Neue Transforms-API
Die Transforms V2 API unterstützt Video, Begrenzungsrahmen, Beschriftungen und Segmentierungsmasken. Dies bedeutet, dass es native Unterstützung für viele Computer Vision-Aufgaben bietet. Die neue Lösung ist eine direktere Alternative:
from torchvision.prototype import transforms
# Exactly the same interface as V1:
trans = transforms.Compose([
transforms.ColorJitter(contrast=0.5),
transforms.RandomRotation(30),
transforms.CenterCrop(480),
])
imgs, bboxes, labels = trans(imgs, bboxes, labels)
Die neue Transform-Klasse kann eine beliebige Anzahl von Eingaben akzeptieren, ohne eine bestimmte Reihenfolge oder Struktur zu erzwingen:
# Already supported:
trans(imgs) # Image Classification
trans(videos) # Video Tasks
trans(imgs_or_videos, labels) # MixUp/CutMix-style Transforms
trans(imgs, bboxes, labels) # Object Detection
trans(imgs, bboxes, masks, labels) # Instance Segmentation
trans(imgs, masks) # Semantic Segmentation
trans({"image": imgs, "box": bboxes, "tag": labels}) # Arbitrary Structure
# Future support:
trans(imgs, bboxes, labels, keypoints) # Keypoint Detection
trans(stereo_images, disparities, masks) # Depth Perception
trans(image1, image2, optical_flows, masks) # Optical Flow
Die funktionale API wurde aktualisiert, um alle notwendigen Kerneleingaben zur Signalverarbeitung zu unterstützen, wie z. B. Größenänderung, Zuschneiden, affine Transformationen, Auffüllen usw.:
from torchvision.prototype.transforms import functional as F
# High-level dispatcher, accepts any supported input type, fully BC
F.resize(inpt, resize=[224, 224])
# Image tensor kernel
F.resize_image_tensor(img_tensor, resize=[224, 224], antialias=True)
# PIL image kernel
F.resize_image_pil(img_pil, resize=[224, 224], interpolation=BILINEAR)
# Video kernel
F.resize_video(video, resize=[224, 224], antialias=True)
# Mask kernel
F.resize_mask(mask, resize=[224, 224])
# Bounding box kernel
F.resize_bounding_box(bbox, resize=[224, 224], spatial_size=[256, 256])
Die API verwendet Tensor-Unterklassen, um die Eingabe zu verpacken, nützliche Metadaten anzuhängen und an den richtigen Kernel zu versenden. Sobald Sie Ihre Arbeit mit Datasets V2 unter Verwendung der TorchData Data Pipe abgeschlossen haben, müssen Sie Ihre Eingaben nicht mehr manuell umschließen. Derzeit können Benutzer Eingaben auf folgende Weise manuell umbrechen:
from torchvision.prototype import features
imgs = features.Image(images, color_space=ColorSpace.RGB)
vids = features.Video(videos, color_space=ColorSpace.RGB)
masks = features.Mask(target["masks"])
bboxes = features.BoundingBox(target["boxes"], format=BoundingBoxFormat.XYXY, spatial_size=imgs.spatial_size)
labels = features.Label(target["labels"], categories=["dog", "cat"])
Zusätzlich zur neuen API stellten die Verantwortlichen von PyTorch auch wichtige Implementierungen für einige in der SoTA-Forschung verwendete Datenerweiterungen bereit, wie etwa MixUp, CutMix, Large Scale Jitter, SimpleCopyPaste, AutoAugmentation-Methoden und einige neue geometrische, Farb- und Typkonvertierungstransformationen.
Die API unterstützt weiterhin PIL- und Tensor-Backends für einzelne Bilder oder gestapelte Eingabebilder und behält die JIT-Skriptfähigkeit der funktionalen API bei.Dadurch kann die Bildzuordnung von uint8 auf float verschoben werden. Hat zu einer weiteren Leistungssteigerung geführt.
Es ist derzeit im Prototypenbereich von TorchVision verfügbar und kann aus Nightly Builds importiert werden.Es wurde überprüft, dass die neue API mit der Genauigkeit der vorherigen Implementierung übereinstimmt.
Aktuelle Einschränkungen
Die funktionale API (Kernel) bleibt JIT-skriptfähig und vollständig BC, und die Transform-Klasse bietet dieselbe Schnittstelle, kann aber nicht geskriptet werden.
Dies liegt daran, dass die Transform-Klasse Tensor-Subclassing verwendet und eine beliebige Anzahl von Eingaben akzeptiert, was von JIT nicht unterstützt wird. Diese Einschränkung wird in den Folgeversionen kontinuierlich optimiert.
Eine End-to-End
Hier ist ein Beispiel der neuen API, die sowohl mit PIL-Bildern als auch mit Tensoren arbeiten kann.
Testbild:
Codebeispiel:
import PIL
from torchvision import io, utils
from torchvision.prototype import features, transforms as T
from torchvision.prototype.transforms import functional as F
# Defining and wrapping input to appropriate Tensor Subclasses
path = "COCO_val2014_000000418825.jpg"
img = features.Image(io.read_image(path), color_space=features.ColorSpace.RGB)
# img = PIL.Image.open(path)
bboxes = features.BoundingBox(
[[2, 0, 206, 253], [396, 92, 479, 241], [328, 253, 417, 332],
[148, 68, 256, 182], [93, 158, 170, 260], [432, 0, 438, 26],
[422, 0, 480, 25], [419, 39, 424, 52], [448, 37, 456, 62],
[435, 43, 437, 50], [461, 36, 469, 63], [461, 75, 469, 94],
[469, 36, 480, 64], [440, 37, 446, 56], [398, 233, 480, 304],
[452, 39, 463, 63], [424, 38, 429, 50]],
format=features.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img),
)
labels = features.Label([59, 58, 50, 64, 76, 74, 74, 74, 74, 74, 74, 74, 74, 74, 50, 74, 74])
# Defining and applying Transforms V2
trans = T.Compose(
[
T.ColorJitter(contrast=0.5),
T.RandomRotation(30),
T.CenterCrop(480),
]
)
img, bboxes, labels = trans(img, bboxes, labels)
# Visualizing results
viz = utils.draw_bounding_boxes(F.to_image_tensor(img), boxes=bboxes)
F.to_pil_image(viz).show()
-- über--