Source code for pytorchvideo.data.frame_video
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from __future__ import annotations
import logging
import time
from typing import Callable, Dict, List, Optional
import numpy as np
import torch
import torch.utils.data
from iopath.common.file_io import g_pathmgr
from pytorchvideo.data.utils import optional_threaded_foreach
from .utils import thwc_to_cthw
from .video import Video
try:
import cv2
except ImportError:
_HAS_CV2 = False
else:
_HAS_CV2 = True
logger = logging.getLogger(__name__)
[docs]class FrameVideo(Video):
"""
FrameVideo is an abstractions for accessing clips based on their start and end
time for a video where each frame is stored as an image. PathManager is used for
frame image reading, allowing non-local uri's to be used.
"""
[docs] def __init__(
self,
duration: float,
fps: float,
video_frame_to_path_fn: Callable[[int], str] = None,
video_frame_paths: List[str] = None,
multithreaded_io: bool = False,
) -> None:
"""
Args:
duration (float): the duration of the video in seconds.
fps (float): the target fps for the video. This is needed to link the frames
to a second timestamp in the video.
video_frame_to_path_fn (Callable[[int], str]): a function that maps from a frame
index integer to the file path where the frame is located.
video_frame_paths (List[str]): Dictionary of frame paths for each index of a video.
multithreaded_io (bool): controls whether parllelizable io operations are
performed across multiple threads.
"""
if not _HAS_CV2:
raise ImportError(
"opencv2 is required to use FrameVideo. Please "
"install with 'pip install opencv-python'"
)
self._duration = duration
self._fps = fps
assert (video_frame_to_path_fn is None) != (
video_frame_paths is None
), "Only one of video_frame_to_path_fn or video_frame_paths can be provided"
self._video_frame_to_path_fn = video_frame_to_path_fn
self._video_frame_paths = video_frame_paths
self._multithreaded_io = multithreaded_io
[docs] @classmethod
def from_frame_paths(
cls,
video_frame_paths: List[str],
fps: float = 30.0,
multithreaded_io: bool = False,
):
"""
Args:
video_frame_paths (List[str]): a list of paths to each frames in the video.
fps (float): the target fps for the video. This is needed to link the frames
to a second timestamp in the video.
multithreaded_io (bool): controls whether parllelizable io operations are
performed across multiple threads.
"""
assert len(video_frame_paths) != 0, "video_frame_paths is empty"
return cls(
len(video_frame_paths) / fps,
fps,
video_frame_paths=video_frame_paths,
multithreaded_io=multithreaded_io,
)
@property
def duration(self) -> float:
"""
Returns:
duration: the video's duration/end-time in seconds.
"""
return self._duration
def _get_frame_index_for_time(self, time_sec: float) -> int:
return int(np.round(self._fps * time_sec))
[docs] def get_clip(
self,
start_sec: float,
end_sec: float,
frame_filter: Optional[Callable[[List[int]], List[int]]] = None,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Retrieves frames from the stored video at the specified start and end times
in seconds (the video always starts at 0 seconds). Given that PathManager may
be fetching the frames from network storage, to handle transient errors, frame
reading is retried N times.
Args:
start_sec (float): the clip start time in seconds
end_sec (float): the clip end time in seconds
frame_filter (Optional[Callable[List[int], List[int]]]):
function to subsample frames in a clip before loading.
If None, no subsampling is peformed.
Returns:
clip_frames: A tensor of the clip's RGB frames with shape:
(channel, time, height, width). The frames are of type torch.float32 and
in the range [0 - 255]. Raises an exception if unable to load images.
clip_data:
"video": A tensor of the clip's RGB frames with shape:
(channel, time, height, width). The frames are of type torch.float32 and
in the range [0 - 255]. Raises an exception if unable to load images.
"frame_indices": A list of indices for each frame relative to all frames in the
video.
Returns None if no frames are found.
"""
if start_sec < 0 or start_sec > self._duration:
logger.warning(
f"No frames found within {start_sec} and {end_sec} seconds. Video starts"
f"at time 0 and ends at {self._duration}."
)
return None
end_sec = min(end_sec, self._duration)
start_frame_index = self._get_frame_index_for_time(start_sec)
end_frame_index = self._get_frame_index_for_time(end_sec)
frame_indices = list(range(start_frame_index, end_frame_index))
# Frame filter function to allow for subsampling before loading
if frame_filter:
frame_indices = frame_filter(frame_indices)
clip_paths = [self._video_frame_to_path(i) for i in frame_indices]
clip_frames = _load_images_with_retries(
clip_paths, multithreaded=self._multithreaded_io
)
clip_frames = thwc_to_cthw(clip_frames).to(torch.float32)
return {"video": clip_frames, "frame_indices": frame_indices}
def _video_frame_to_path(self, frame_index: int) -> str:
if self._video_frame_to_path_fn:
return self._video_frame_to_path_fn(frame_index)
elif self._video_frame_paths:
return self._video_frame_paths[frame_index]
else:
raise Exception(
"One of _video_frame_to_path_fn or _video_frame_paths must be set"
)
def _load_images_with_retries(
image_paths: List[str], num_retries: int = 10, multithreaded: bool = True
) -> torch.Tensor:
"""
Loads the given image paths using PathManager, decodes them as RGB images and
returns them as a stacked tensors.
Args:
image_paths (List[str]): a list of paths to images.
num_retries (int): number of times to retry image reading to handle transient error.
multithreaded (bool): if images are fetched via multiple threads in parallel.
Returns:
A tensor of the clip's RGB frames with shape:
(time, height, width, channel). The frames are of type torch.uint8 and
in the range [0 - 255]. Raises an exception if unable to load images.
"""
imgs = [None for i in image_paths]
def fetch_image(image_index: int, image_path: str) -> None:
for i in range(num_retries):
with g_pathmgr.open(image_path, "rb") as f:
img_str = np.frombuffer(f.read(), np.uint8)
img_bgr = cv2.imdecode(img_str, flags=cv2.IMREAD_COLOR)
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
if img_rgb is not None:
imgs[image_index] = img_rgb
return
else:
logging.warning(f"Reading attempt {i}/{num_retries} failed.")
time.sleep(1e-6)
optional_threaded_foreach(fetch_image, enumerate(image_paths), multithreaded)
if any((img is None for img in imgs)):
raise Exception("Failed to load images from {}".format(image_paths))
return torch.as_tensor(np.stack(imgs))