HyperAI

TorchVision Transforms API Has Been Greatly Upgraded to Support Object Detection, Instance/semantic Segmentation, and Video Tasks

2 years ago
Information
Jiaxin Sun
特色图像

Content introduction: TorchVision Transforms API has been expanded and upgraded to support object detection, instance and semantic segmentation, and video tasks. The new API is still in the testing phase and developers can try it out.

This article was first published on WeChat: PyTorch Developer Community

insert image description here

TorchVision has been extended with the Transforms API. The details are as follows:

  • In addition to image classification, it can now be used for tasks such as object detection, instance and semantic segmentation, and video classification;
  • Support for importing SoTA data augmentation directly from TorchVision, such as MixUp, CutMix, Large Scale Jitter, and SimpleCopyPaste.
  • Supports the use of new functional transforms to transform videos, bounding boxes and segmentation masks.

Current limitations of Transforms

The stable version of TorchVision Transforms API, also known as Transforms V1,Only supports single images, and therefore is only suitable for classification tasks:

from torchvision import transforms
trans = transforms.Compose([
   transforms.ColorJitter(contrast=0.5),
   transforms.RandomRotation(30),
   transforms.CenterCrop(480),
])
imgs = trans(imgs)

The above methods do not support object detection, segmentation, or classification Transforms that require the use of Labels. Such as MixUp and cutMix. This makes it impossible to use the Transforms API to perform necessary extensions for computer vision tasks other than classification.This also makes it more difficult to train high-precision models using TorchVision primitives.

To overcome this limitation,TorchVision provides a custom implementation in its reference script. Used to demonstrate how enhancements are performed in all tasks.

Although this approach allows developers to train high-precision classification, object detection, and segmentation models, it is a crude approach.Transforms are still not importable in the TorchVision binary.

New Transforms API

Transforms V2 API supports video, bounding box, label, and segmentation mask. This means it provides native support for many computer vision tasks. The new solution is a more direct alternative:

from torchvision.prototype import transforms
# Exactly the same interface as V1:
trans = transforms.Compose([
    transforms.ColorJitter(contrast=0.5),
    transforms.RandomRotation(30),
    transforms.CenterCrop(480),
])
imgs, bboxes, labels = trans(imgs, bboxes, labels)

The new Transform Class can accept any number of inputs without enforcing a specific order or structure:

# Already supported:
trans(imgs)  # Image Classification
trans(videos)  # Video Tasks
trans(imgs_or_videos, labels)  # MixUp/CutMix-style Transforms
trans(imgs, bboxes, labels)  # Object Detection
trans(imgs, bboxes, masks, labels)  # Instance Segmentation
trans(imgs, masks)  # Semantic Segmentation
trans({"image": imgs, "box": bboxes, "tag": labels})  # Arbitrary Structure
# Future support:
trans(imgs, bboxes, labels, keypoints)  # Keypoint Detection
trans(stereo_images, disparities, masks)  # Depth Perception
trans(image1, image2, optical_flows, masks)  # Optical Flow

The functional API has been updated to support all necessary signal processing kernel inputs, such as resizing, cropping, affine transforms, padding, etc.:

from torchvision.prototype.transforms import functional as F
# High-level dispatcher, accepts any supported input type, fully BC
F.resize(inpt, resize=[224, 224])
# Image tensor kernel
F.resize_image_tensor(img_tensor, resize=[224, 224], antialias=True)
# PIL image kernel
F.resize_image_pil(img_pil, resize=[224, 224], interpolation=BILINEAR)
# Video kernel
F.resize_video(video, resize=[224, 224], antialias=True)
# Mask kernel
F.resize_mask(mask, resize=[224, 224])
# Bounding box kernel
F.resize_bounding_box(bbox, resize=[224, 224], spatial_size=[256, 256])

The API uses Tensor subclassing to wrap the input, attach useful metadata, and dispatch to the correct kernel. Once you have completed work with Datasets V2 using TorchData Data Pipe, you no longer need to manually wrap inputs. Currently, users can manually wrap inputs in the following ways:

from torchvision.prototype import features
imgs = features.Image(images, color_space=ColorSpace.RGB)
vids = features.Video(videos, color_space=ColorSpace.RGB)
masks = features.Mask(target["masks"])
bboxes = features.BoundingBox(target["boxes"], format=BoundingBoxFormat.XYXY, spatial_size=imgs.spatial_size)
labels = features.Label(target["labels"], categories=["dog", "cat"])

In addition to the new API, PyTorch officials also provided important implementations for some data enhancements used in SoTA research, such as MixUp, CutMix, Large Scale Jitter, SimpleCopyPaste, AutoAugmentation methods, and some new Geometric, Colour, and Type Conversion transforms.

The API continues to support PIL and Tensor backends for single image or batched input image, and retains JIT-scriptability on the functional API.This allows image mapping to be deferred from uint8 to float. Brought about further improvement in performance.

It is currently available in the prototype area of TorchVision and can be imported from nightly builds.The new API has been verified to be consistent with the accuracy of the previous implementation.

Current limitations

The functional API (kernel) remains JIT-scriptable and fully-BC, and the Transform Class provides the same interface, but cannot be scripted.

This is because Transform Class uses Tensor Subclassing and accepts any number of inputs, which is not supported by JIT. This limitation will be continuously optimized in subsequent versions.

An end-to-end

Here is an example of the new API that can work with both PIL images and tensors.

Test image:

insert image description here
Code example:

import PIL
from torchvision import io, utils
from torchvision.prototype import features, transforms as T
from torchvision.prototype.transforms import functional as F
# Defining and wrapping input to appropriate Tensor Subclasses
path = "COCO_val2014_000000418825.jpg"
img = features.Image(io.read_image(path), color_space=features.ColorSpace.RGB)
# img = PIL.Image.open(path)
bboxes = features.BoundingBox(
    [[2, 0, 206, 253], [396, 92, 479, 241], [328, 253, 417, 332],
     [148, 68, 256, 182], [93, 158, 170, 260], [432, 0, 438, 26],
     [422, 0, 480, 25], [419, 39, 424, 52], [448, 37, 456, 62],
     [435, 43, 437, 50], [461, 36, 469, 63], [461, 75, 469, 94],
     [469, 36, 480, 64], [440, 37, 446, 56], [398, 233, 480, 304],
     [452, 39, 463, 63], [424, 38, 429, 50]],
    format=features.BoundingBoxFormat.XYXY,
    spatial_size=F.get_spatial_size(img),
)
labels = features.Label([59, 58, 50, 64, 76, 74, 74, 74, 74, 74, 74, 74, 74, 74, 50, 74, 74])
# Defining and applying Transforms V2
trans = T.Compose(
    [
        T.ColorJitter(contrast=0.5),
        T.RandomRotation(30),
        T.CenterCrop(480),
    ]
)
img, bboxes, labels = trans(img, bboxes, labels)
# Visualizing results
viz = utils.draw_bounding_boxes(F.to_image_tensor(img), boxes=bboxes)
F.to_pil_image(viz).show()

-- over--