Shortcuts

Source code for pytorchvideo.data.encoded_video_torchvision

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import logging
from typing import BinaryIO, Dict, Optional

import numpy as np
import torch

from .utils import pts_to_secs, secs_to_pts, thwc_to_cthw
from .video import Video


logger = logging.getLogger(__name__)


[docs]class EncodedVideoTorchVision(Video): """ Accessing clips from an encoded video using Torchvision video reading API (torch.ops.video_reader.read_video_from_memory) as the decoding backend. """ """ av_seek_frame is imprecise so seek to a timestamp earlier by a margin The unit of margin is second """ SEEK_FRAME_MARGIN = 0.25 def __init__( self, file: BinaryIO, video_name: Optional[str] = None, decode_audio: bool = True, ) -> None: self._video_tensor = torch.tensor( np.frombuffer(file.getvalue(), dtype=np.uint8) ) self._video_name = video_name self._decode_audio = decode_audio ( self._video, self._video_time_base, self._video_start_pts, video_duration, self._audio, self._audio_time_base, self._audio_start_pts, audio_duration, ) = self._torch_vision_decode_video() # Take the largest duration of either video or duration stream. if audio_duration is not None and video_duration is not None: self._duration = max( pts_to_secs( video_duration, self._video_time_base, self._video_start_pts ), pts_to_secs( audio_duration, self._audio_time_base, self._audio_start_pts ), ) elif video_duration is not None: self._duration = pts_to_secs( video_duration, self._video_time_base, self._video_start_pts ) elif audio_duration is not None: self._duration = pts_to_secs( audio_duration, self._audio_time_base, self._audio_start_pts ) @property def name(self) -> Optional[str]: """ Returns: name: the name of the stored video if set. """ return self._video_name @property def duration(self) -> float: """ Returns: duration: the video's duration/end-time in seconds. """ return self._duration def close(self): pass
[docs] def get_clip( self, start_sec: float, end_sec: float ) -> Dict[str, Optional[torch.Tensor]]: """ Retrieves frames from the encoded video at the specified start and end times in seconds (the video always starts at 0 seconds). Args: start_sec (float): the clip start time in seconds end_sec (float): the clip end time in seconds Returns: clip_data: A dictionary mapping the entries at "video" and "audio" to a tensors. "video": A tensor of the clip's RGB frames with shape: (channel, time, height, width). The frames are of type torch.float32 and in the range [0 - 255]. "audio": A tensor of the clip's audio samples with shape: (samples). The samples are of type torch.float32 and in the range [0 - 255]. Returns None if no video or audio found within time range. """ video_frames = None if self._video is not None: video_start_pts = secs_to_pts( start_sec, self._video_time_base, self._video_start_pts ) video_end_pts = secs_to_pts( end_sec, self._video_time_base, self._video_start_pts ) video_frames = [ f for f, pts in self._video if pts >= video_start_pts and pts <= video_end_pts ] audio_samples = None if self._decode_audio and self._audio: audio_start_pts = secs_to_pts( start_sec, self._audio_time_base, self._audio_start_pts ) audio_end_pts = secs_to_pts( end_sec, self._audio_time_base, self._audio_start_pts ) audio_samples = [ f for f, pts in self._audio if pts >= audio_start_pts and pts <= audio_end_pts ] audio_samples = torch.cat(audio_samples, axis=0) audio_samples = audio_samples.to(torch.float32) if video_frames is None or len(video_frames) == 0: logger.warning( f"No video found within {start_sec} and {end_sec} seconds. " f"Video starts at time 0 and ends at {self.duration}." ) video_frames = None if video_frames is not None: video_frames = thwc_to_cthw(torch.stack(video_frames)).to(torch.float32) return { "video": video_frames, "audio": audio_samples, }
def _torch_vision_decode_video( self, start_pts: int = 0, end_pts: int = -1 ) -> float: """ Decode the video in the PTS range [start_pts, end_pts] """ video_and_pts = None audio_and_pts = None width, height, min_dimension, max_dimension = 0, 0, 0, 0 video_start_pts, video_end_pts = start_pts, end_pts video_timebase_num, video_timebase_den = 0, 1 samples, channels = 0, 0 audio_start_pts, audio_end_pts = start_pts, end_pts audio_timebase_num, audio_timebase_den = 0, 1 try: tv_result = torch.ops.video_reader.read_video_from_memory( self._video_tensor, self.SEEK_FRAME_MARGIN, # Set getPtsOnly=0, i.e., read full video rather than just header 0, # Read video stream 1, width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, # Read audio stream self._decode_audio, samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) except Exception as e: logger.warning(f"Failed to decode video of name {self.video_name}. {e}") raise e ( vframes, vframes_pts, vtimebase, _, vduration, aframes, aframe_pts, atimebase, _, aduration, ) = tv_result if vduration < 0: # No header information to infer video duration video_duration = float(vframes_pts[-1]) else: video_duration = float(vduration) video_and_pts = list(zip(vframes, vframes_pts)) video_start_pts = int(vframes_pts[0]) video_time_base = float(vtimebase[0] / vtimebase[1]) audio_and_pts = None audio_time_base = None audio_start_pts = None audio_duration = None if self._decode_audio: if aduration < 0: # No header information to infer audio duration audio_duration = float(aframe_pts[-1]) else: audio_duration = float(aduration) audio_and_pts = list(zip(aframes, aframe_pts)) audio_start_pts = int(aframe_pts[0]) audio_time_base = float(atimebase[0] / atimebase[1]) return ( video_and_pts, video_time_base, video_start_pts, video_duration, audio_and_pts, audio_time_base, audio_start_pts, audio_duration, )