HyperAI超神経

ターゲット検出、インスタンス/セマンティック セグメンテーション、およびビデオ タスクをサポートするために、TorchVision Transforms API が大幅にアップグレードされました。

2年前
情報
Jiaxin Sun
特色图像

コンテンツの紹介:TorchVision Transforms API が拡張およびアップグレードされ、ターゲット検出、インスタンスおよびセマンティック セグメンテーション、ビデオ タスクがサポートされるようになりました。新しい API はまだテスト段階にあり、開発者はそれを試すことができます。

この記事は WeChat パブリック アカウントで最初に公開されました: PyTorch 開発者コミュニティ

ここに画像の説明を挿入します

TorchVision は Transforms API 用に拡張されました。 詳細は以下のとおりです。

  • 画像分類に使用されるだけでなく、ターゲット検出、インスタンスおよびセマンティック セグメンテーション、ビデオ分類などのタスクにも使用できるようになりました。
  • MixUp、CutMix、Large Scale Jitter、SimpleCopyPaste など、TorchVision からの SoTA データ拡張機能の直接インポートをサポートします。
  • ビデオ、バウンディング ボックス、セグメンテーション マスク (セグメンテーション マスク) を変換するための新しい関数変換の使用をサポートします。

電流制限を変える

TorchVision Transforms API の安定バージョン (Transforms V1 と呼ばれることが多い)単一の画像のみがサポートされるため、分類タスクにのみ適しています。

from torchvision import transforms
trans = transforms.Compose([
   transforms.ColorJitter(contrast=0.5),
   transforms.RandomRotation(30),
   transforms.CenterCrop(480),
])
imgs = trans(imgs)

上記の方法は、ラベルの使用を必要とするターゲットの検出、セグメンテーション、または分類の変換をサポートしていません。 MixUpやcutMixなど。これにより、Transforms API は分類以外のコンピューター ビジョン タスクに必要な拡張を実行できなくなります。同時に、これにより、TorchVision プリミティブを使用して高精度モデルをトレーニングすることも困難になります。

この制限を克服するには、TorchVision は、リファレンス スクリプトでカスタム実装を提供します。 すべてのタスクで拡張がどのように実行されるかを示すために使用されます。

このアプローチにより、開発者は高精度の分類、ターゲット検出、セグメンテーション モデルをトレーニングできますが、比較的大まかで、変換はまだTorchVisionバイナリにインポートできません。

新しい変換 API

Transforms V2 API は、ビデオ、バウンディング ボックス、ラベル、セグメンテーション マスクをサポートしています。 これは、多くのコンピューター ビジョン タスクに対するネイティブ サポートを提供することを意味します。新しいソリューションは、より簡単な代替案です。

from torchvision.prototype import transforms
# Exactly the same interface as V1:
trans = transforms.Compose([
    transforms.ColorJitter(contrast=0.5),
    transforms.RandomRotation(30),
    transforms.CenterCrop(480),
])
imgs, bboxes, labels = trans(imgs, bboxes, labels)

新しい Transform クラスは、特定の順序や構造を強制することなく、任意の数の入力を受け入れることができます。

# Already supported:
trans(imgs)  # Image Classification
trans(videos)  # Video Tasks
trans(imgs_or_videos, labels)  # MixUp/CutMix-style Transforms
trans(imgs, bboxes, labels)  # Object Detection
trans(imgs, bboxes, masks, labels)  # Instance Segmentation
trans(imgs, masks)  # Semantic Segmentation
trans({"image": imgs, "box": bboxes, "tag": labels})  # Arbitrary Structure
# Future support:
trans(imgs, bboxes, labels, keypoints)  # Keypoint Detection
trans(stereo_images, disparities, masks)  # Depth Perception
trans(image1, image2, optical_flows, masks)  # Optical Flow

機能 API が更新され、サイズ変更、クロッピング、アフィン変換、パディングなどの、必要な入力信号処理カーネルをすべてサポートするようになりました。

from torchvision.prototype.transforms import functional as F
# High-level dispatcher, accepts any supported input type, fully BC
F.resize(inpt, resize=[224, 224])
# Image tensor kernel
F.resize_image_tensor(img_tensor, resize=[224, 224], antialias=True)
# PIL image kernel
F.resize_image_pil(img_pil, resize=[224, 224], interpolation=BILINEAR)
# Video kernel
F.resize_video(video, resize=[224, 224], antialias=True)
# Mask kernel
F.resize_mask(mask, resize=[224, 224])
# Bounding box kernel
F.resize_bounding_box(bbox, resize=[224, 224], spatial_size=[256, 256])

API は Tensor サブクラス化を使用して入力をラップし、有用なメタデータを追加し、正しいカーネルにディスパッチします。 TorchData データ パイプを使用した Datasets V2 の作業が完了したら、入力を手動でラップする必要はありません。現在、ユーザーは次の方法で入力を手動でラップできます。

from torchvision.prototype import features
imgs = features.Image(images, color_space=ColorSpace.RGB)
vids = features.Video(videos, color_space=ColorSpace.RGB)
masks = features.Mask(target["masks"])
bboxes = features.BoundingBox(target["boxes"], format=BoundingBoxFormat.XYXY, spatial_size=imgs.spatial_size)
labels = features.Label(target["labels"], categories=["dog", "cat"])

新しい API に加えて、PyTorch は、MixUp、CutMix、Large Scale Jitter、SimpleCopyPaste、AutoAugmentation メソッド、およびいくつかの新しい幾何学的、色および型変換変換など、SoTA 研究で使用されるいくつかのデータ拡張機能の重要な実装を正式に提供します。

API は、単一イメージまたはバッチ入力イメージの PIL および Tensor バックエンドを引き続きサポートし、機能 API での JIT スクリプト機能を保持します。これにより、画像マッピングを uint8 から float に遅らせることができます。 さらなるパフォーマンスの向上をもたらします。

現在、TorchVision のプロトタイプ領域で利用可能で、夜間ビルドからのインポートをサポートしています。新しい API は、以前の実装と同じくらい正確であることが検証されています。

現在の制限

機能 API (カーネル) は JIT スクリプト可能かつ完全な BC のままであり、Transform クラスは同じインターフェイスを提供しますが、スクリプトは使用できません。

これは、Transform クラスが Tensor サブクラス化を使用し、JIT でサポートされていない任意の数の入力を受け入れるためです。この制限は、後続のバージョンで継続的に最適化されます。

エンドツーエンドのプレゼンテーション

以下は、PIL イメージとテンソルの両方で動作する新しい API の例です。

テスト写真:

ここに画像の説明を挿入します
コード例:

import PIL
from torchvision import io, utils
from torchvision.prototype import features, transforms as T
from torchvision.prototype.transforms import functional as F
# Defining and wrapping input to appropriate Tensor Subclasses
path = "COCO_val2014_000000418825.jpg"
img = features.Image(io.read_image(path), color_space=features.ColorSpace.RGB)
# img = PIL.Image.open(path)
bboxes = features.BoundingBox(
    [[2, 0, 206, 253], [396, 92, 479, 241], [328, 253, 417, 332],
     [148, 68, 256, 182], [93, 158, 170, 260], [432, 0, 438, 26],
     [422, 0, 480, 25], [419, 39, 424, 52], [448, 37, 456, 62],
     [435, 43, 437, 50], [461, 36, 469, 63], [461, 75, 469, 94],
     [469, 36, 480, 64], [440, 37, 446, 56], [398, 233, 480, 304],
     [452, 39, 463, 63], [424, 38, 429, 50]],
    format=features.BoundingBoxFormat.XYXY,
    spatial_size=F.get_spatial_size(img),
)
labels = features.Label([59, 58, 50, 64, 76, 74, 74, 74, 74, 74, 74, 74, 74, 74, 50, 74, 74])
# Defining and applying Transforms V2
trans = T.Compose(
    [
        T.ColorJitter(contrast=0.5),
        T.RandomRotation(30),
        T.CenterCrop(480),
    ]
)
img, bboxes, labels = trans(img, bboxes, labels)
# Visualizing results
viz = utils.draw_bounding_boxes(F.to_image_tensor(img), boxes=bboxes)
F.to_pil_image(viz).show()

- 以上 -