Shortcuts

Source code for pytorchvideo.models.head

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import Callable, Tuple

import torch
import torch.nn as nn
from pytorchvideo.layers.utils import set_attributes


[docs]def create_res_basic_head( *, # Projection configs. in_features: int, out_features: int, # Pooling configs. pool: Callable = nn.AvgPool3d, output_size: Tuple[int] = (1, 1, 1), pool_kernel_size: Tuple[int] = (1, 7, 7), pool_stride: Tuple[int] = (1, 1, 1), pool_padding: Tuple[int] = (0, 0, 0), # Dropout configs. dropout_rate: float = 0.5, # Activation configs. activation: Callable = None, # Output configs. output_with_global_average: bool = True, ) -> nn.Module: """ Creates ResNet basic head. This layer performs an optional pooling operation followed by an optional dropout, a fully-connected projection, an activation layer and a global spatiotemporal averaging. :: Pooling Dropout Projection Activation Averaging Activation examples include: ReLU, Softmax, Sigmoid, and None. Pool3d examples include: AvgPool3d, MaxPool3d, AdaptiveAvgPool3d, and None. Args: in_features: input channel size of the resnet head. out_features: output channel size of the resnet head. pool (callable): a callable that constructs resnet head pooling layer, examples include: nn.AvgPool3d, nn.MaxPool3d, nn.AdaptiveAvgPool3d, and None (not applying pooling). pool_kernel_size (tuple): pooling kernel size(s) when not using adaptive pooling. pool_stride (tuple): pooling stride size(s) when not using adaptive pooling. pool_padding (tuple): pooling padding size(s) when not using adaptive pooling. output_size (tuple): spatial temporal output size when using adaptive pooling. activation (callable): a callable that constructs resnet head activation layer, examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not applying activation). dropout_rate (float): dropout rate. output_with_global_average (bool): if True, perform global averaging on temporal and spatial dimensions and reshape output to batch_size x out_features. """ if activation is None: activation_model = None elif activation == nn.Softmax: activation_model = activation(dim=1) else: activation_model = activation() if pool is None: pool_model = None elif pool == nn.AdaptiveAvgPool3d: pool_model = pool(output_size) else: pool_model = pool( kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding ) if output_with_global_average: output_pool = nn.AdaptiveAvgPool3d(1) else: output_pool = None return ResNetBasicHead( proj=nn.Linear(in_features, out_features), activation=activation_model, pool=pool_model, dropout=nn.Dropout(dropout_rate) if dropout_rate > 0 else None, output_pool=output_pool, )
[docs]class ResNetBasicHead(nn.Module): """ ResNet basic head. This layer performs an optional pooling operation followed by an optional dropout, a fully-connected projection, an optional activation layer and a global spatiotemporal averaging. :: Pool3d Dropout Projection Activation Averaging The builder can be found in `create_res_basic_head`. """
[docs] def __init__( self, pool: nn.Module = None, dropout: nn.Module = None, proj: nn.Module = None, activation: nn.Module = None, output_pool: nn.Module = None, ) -> None: """ Args: pool (torch.nn.modules): pooling module. dropout(torch.nn.modules): dropout module. proj (torch.nn.modules): project module. activation (torch.nn.modules): activation module. output_pool (torch.nn.Module): pooling module for output. """ super().__init__() set_attributes(self, locals()) assert self.proj is not None
def forward(self, x: torch.Tensor) -> torch.Tensor: # Performs pooling. if self.pool is not None: x = self.pool(x) # Performs dropout. if self.dropout is not None: x = self.dropout(x) # Performs projection. x = x.permute((0, 2, 3, 4, 1)) x = self.proj(x) x = x.permute((0, 4, 1, 2, 3)) # Performs activation. if self.activation is not None: x = self.activation(x) if self.output_pool is not None: # Performs global averaging. x = self.output_pool(x) x = x.view(x.shape[0], -1) return x