TorchVision Transforms API가 객체 감지, 인스턴스/의미 분할 및 비디오 작업을 지원하도록 대폭 업그레이드되었습니다.

콘텐츠 소개: TorchVision Transforms API가 확장 및 업그레이드되어 객체 감지, 인스턴스 및 의미 분할, 비디오 작업을 지원하게 되었습니다. 새로운 API는 아직 테스트 단계이므로 개발자는 사용해 볼 수 있습니다.
이 기사는 WeChat: PyTorch 개발자 커뮤니티에 처음 게시되었습니다.

TorchVision이 Transforms API로 확장되었습니다. 자세한 내용은 다음과 같습니다.
- 이제 이미지 분류뿐만 아니라 객체 감지, 인스턴스 및 의미 분할, 비디오 분류와 같은 작업에도 사용할 수 있습니다.
- MixUp, CutMix, Large Scale Jitter, SimpleCopyPaste 등 TorchVision에서 직접 SoTA 데이터 증강을 가져오는 기능을 지원합니다.
- 비디오, 경계 상자 및 분할 마스크를 변환하기 위해 새로운 기능 변환을 사용할 수 있도록 지원합니다.
Transforms의 현재 제한 사항
TorchVision Transforms API의 안정적인 버전(Transforms V1이라고도 함)단일 이미지만 지원하므로 분류 작업에만 적합합니다.
from torchvision import transforms
trans = transforms.Compose([
transforms.ColorJitter(contrast=0.5),
transforms.RandomRotation(30),
transforms.CenterCrop(480),
])
imgs = trans(imgs)
위의 방법은 레이블을 사용해야 하는 객체 감지, 분할 또는 분류 변환을 지원하지 않습니다. 예를 들어 MixUp과 cutMix 등이 있습니다. 이로 인해 분류 이외의 컴퓨터 비전 작업에 대해 Transforms API를 사용하여 필요한 확장을 수행하는 것이 불가능해집니다. 동시에,이로 인해 TorchVision 기본 요소를 사용하여 고정밀 모델을 훈련하는 것이 더 어려워집니다.
이러한 한계를 극복하기 위해,TorchVision은 참조 스크립트에서 사용자 정의 구현을 제공합니다. 모든 작업에서 개선이 어떻게 수행되는지 보여주는 데 사용됩니다.
이러한 접근 방식을 사용하면 개발자가 고정밀 분류, 객체 감지 및 세분화 모델을 훈련할 수 있지만, 이는 엉성한 접근 방식입니다.변환은 여전히 TorchVision 바이너리로 가져올 수 없습니다.
새로운 변환 API
Transforms V2 API는 비디오, 경계 상자, 레이블 및 분할 마스크를 지원합니다. 즉, 다양한 컴퓨터 비전 작업에 대한 기본 지원을 제공합니다. 새로운 솔루션은 더 직접적인 대안입니다.
from torchvision.prototype import transforms
# Exactly the same interface as V1:
trans = transforms.Compose([
transforms.ColorJitter(contrast=0.5),
transforms.RandomRotation(30),
transforms.CenterCrop(480),
])
imgs, bboxes, labels = trans(imgs, bboxes, labels)
새로운 변환 클래스는 특정 순서나 구조를 적용하지 않고도 아무리 많은 입력을 받아들일 수 있습니다.
# Already supported:
trans(imgs) # Image Classification
trans(videos) # Video Tasks
trans(imgs_or_videos, labels) # MixUp/CutMix-style Transforms
trans(imgs, bboxes, labels) # Object Detection
trans(imgs, bboxes, masks, labels) # Instance Segmentation
trans(imgs, masks) # Semantic Segmentation
trans({"image": imgs, "box": bboxes, "tag": labels}) # Arbitrary Structure
# Future support:
trans(imgs, bboxes, labels, keypoints) # Keypoint Detection
trans(stereo_images, disparities, masks) # Depth Perception
trans(image1, image2, optical_flows, masks) # Optical Flow
크기 조정, 자르기, 아핀 변환, 패딩 등 모든 필수 신호 처리 커널 입력을 지원하도록 기능적 API가 업데이트되었습니다.
from torchvision.prototype.transforms import functional as F
# High-level dispatcher, accepts any supported input type, fully BC
F.resize(inpt, resize=[224, 224])
# Image tensor kernel
F.resize_image_tensor(img_tensor, resize=[224, 224], antialias=True)
# PIL image kernel
F.resize_image_pil(img_pil, resize=[224, 224], interpolation=BILINEAR)
# Video kernel
F.resize_video(video, resize=[224, 224], antialias=True)
# Mask kernel
F.resize_mask(mask, resize=[224, 224])
# Bounding box kernel
F.resize_bounding_box(bbox, resize=[224, 224], spatial_size=[256, 256])
API는 텐서 서브클래싱을 사용하여 입력을 래핑하고, 유용한 메타데이터를 첨부하고, 올바른 커널로 전송합니다. TorchData Data Pipe를 사용하여 Datasets V2 작업을 완료하면 더 이상 수동으로 입력을 래핑할 필요가 없습니다. 현재 사용자는 다음과 같은 방법으로 입력을 수동으로 래핑할 수 있습니다.
from torchvision.prototype import features
imgs = features.Image(images, color_space=ColorSpace.RGB)
vids = features.Video(videos, color_space=ColorSpace.RGB)
masks = features.Mask(target["masks"])
bboxes = features.BoundingBox(target["boxes"], format=BoundingBoxFormat.XYXY, spatial_size=imgs.spatial_size)
labels = features.Label(target["labels"], categories=["dog", "cat"])
PyTorch 임원진은 새로운 API 외에도 SoTA 연구에 사용되는 일부 데이터 향상 기능(예: MixUp, CutMix, Large Scale Jitter, SimpleCopyPaste, AutoAugmentation 방법 및 몇 가지 새로운 Geometric, Colour, Type Conversion 변환)에 대한 중요한 구현을 제공했습니다.
API는 단일 이미지나 일괄 입력 이미지에 대한 PIL 및 Tensor 백엔드를 계속 지원하고 기능적 API에서 JIT 스크립팅 기능을 유지합니다.이를 통해 이미지 매핑을 uint8에서 float로 연기할 수 있습니다. 성능이 더욱 향상되었습니다.
현재 TorchVision의 프로토타입 영역에서 사용할 수 있으며 야간 빌드에서 가져올 수 있습니다.새로운 API는 이전 구현의 정확성과 일관성이 있는 것으로 검증되었습니다.
현재 제한 사항
기능적 API(커널)는 JIT 스크립팅 가능하고 완전한 BC를 유지하며, 변환 클래스는 동일한 인터페이스를 제공하지만 스크립팅할 수 없습니다.
이는 Transform Class가 Tensor Subclassing을 사용하고 임의의 개수의 입력을 허용하기 때문인데, 이는 JIT에서 지원되지 않습니다. 이런 제한은 이후 버전에서 지속적으로 최적화될 예정입니다.
엔드투엔드
다음은 PIL 이미지와 텐서 모두에서 작동할 수 있는 새로운 API의 예입니다.
테스트 이미지:
코드 예:
import PIL
from torchvision import io, utils
from torchvision.prototype import features, transforms as T
from torchvision.prototype.transforms import functional as F
# Defining and wrapping input to appropriate Tensor Subclasses
path = "COCO_val2014_000000418825.jpg"
img = features.Image(io.read_image(path), color_space=features.ColorSpace.RGB)
# img = PIL.Image.open(path)
bboxes = features.BoundingBox(
[[2, 0, 206, 253], [396, 92, 479, 241], [328, 253, 417, 332],
[148, 68, 256, 182], [93, 158, 170, 260], [432, 0, 438, 26],
[422, 0, 480, 25], [419, 39, 424, 52], [448, 37, 456, 62],
[435, 43, 437, 50], [461, 36, 469, 63], [461, 75, 469, 94],
[469, 36, 480, 64], [440, 37, 446, 56], [398, 233, 480, 304],
[452, 39, 463, 63], [424, 38, 429, 50]],
format=features.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img),
)
labels = features.Label([59, 58, 50, 64, 76, 74, 74, 74, 74, 74, 74, 74, 74, 74, 50, 74, 74])
# Defining and applying Transforms V2
trans = T.Compose(
[
T.ColorJitter(contrast=0.5),
T.RandomRotation(30),
T.CenterCrop(480),
]
)
img, bboxes, labels = trans(img, bboxes, labels)
# Visualizing results
viz = utils.draw_bounding_boxes(F.to_image_tensor(img), boxes=bboxes)
F.to_pil_image(viz).show()
-- 위에--