Shortcuts

Source code for pytorchvideo.transforms.transforms

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import Callable, Dict

import pytorchvideo.transforms.functional
import torch


[docs]class ApplyTransformToKey: """ Applies transform to key of dictionary input. Args: key (str): the dictionary key the transform is applied to transform (callable): the transform that is applied Example: >>> transforms.ApplyTransformToKey( >>> key='video', >>> transform=UniformTemporalSubsample(num_video_samples), >>> ) """ def __init__(self, key: str, transform: Callable): self._key = key self._transform = transform def __call__(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: x[self._key] = self._transform(x[self._key]) return x
class RemoveKey(torch.nn.Module): def __init__(self, key: str): self._key = key def __call__(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if self._key in x: del x[self._key] return x
[docs]class UniformTemporalSubsample(torch.nn.Module): """ nn.Module wrapper for pytorchvideo.transforms.functional.uniform_temporal_subsample. """ def __init__(self, num_samples: int): super().__init__() self._num_samples = num_samples def forward(self, x: torch.Tensor) -> torch.Tensor: return pytorchvideo.transforms.functional.uniform_temporal_subsample( x, self._num_samples )
[docs]class ShortSideScale(torch.nn.Module): """ nn.Module wrapper for pytorchvideo.transforms.functional.short_side_scale. """ def __init__(self, size: int): super().__init__() self._size = size def forward(self, x: torch.Tensor) -> torch.Tensor: return pytorchvideo.transforms.functional.short_side_scale(x, self._size)
[docs]class RandomShortSideScale(torch.nn.Module): """ nn.Module wrapper for pytorchvideo.transforms.functional.short_side_scale. The size parameter is chosen randomly in [min_size, max_size]. """ def __init__(self, min_size: int, max_size: int): super().__init__() self._min_size = min_size self._max_size = max_size def forward(self, x: torch.Tensor) -> torch.Tensor: size = torch.randint(self._min_size, self._max_size + 1, (1,)).item() return pytorchvideo.transforms.functional.short_side_scale(x, size)
[docs]class UniformCropVideo(torch.nn.Module): """ nn.Module wrapper for pytorchvideo.transforms.functional.uniform_crop. """ def __init__( self, size: int, video_key: str = "video", aug_index_key: str = "aug_index" ): super().__init__() self._size = size self._video_key = video_key self._aug_index_key = aug_index_key def __call__(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: x[self._video_key] = pytorchvideo.transforms.functional.uniform_crop( x[self._video_key], self._size, x[self._aug_index_key] ) return x