Example #1
0
    def __init__(self,
                 root,
                 extensions=None,
                 transform=None,
                 target_transform=None,
                 is_valid_file=None):
        
        super(VideoDataset, self).__init__(root,
                                           transform=transform,
                                           target_transform=target_transform)

        videos, video_timestamps, offsets, fps = _make_dataset(
            self.root, extensions, is_valid_file)
        
        if len(videos) == 0:
            msg = 'Found 0 videos in folder: {}\n'.format(self.root)
            if extensions is not None:
                msg += 'Supported extensions are: {}'.format(
                    ','.join(extensions))
            raise RuntimeError(msg)

        self.extensions = extensions

        backend = torchvision.get_video_backend()
        self.video_loaders = \
            [VideoLoader(video, timestamps, backend=backend) for video, timestamps in zip(videos, video_timestamps)]

        self.videos = videos
        self.video_timestamps = video_timestamps
        # offsets[i] indicates the index of the first frame of the i-th video.
        # e.g. for two videos of length 10 and 20, the offsets will be [0, 10].
        self.offsets = offsets
        self.fps = fps
Example #2
0
    def __init__(self,
                 video_paths,
                 clip_length_in_frames=16,
                 frames_between_clips=1,
                 frame_rate=None,
                 _precomputed_metadata=None,
                 num_workers=0,
                 _video_width=0,
                 _video_height=0,
                 _video_min_dimension=0,
                 _audio_samples=0):
        from torchvision import get_video_backend

        self.video_paths = video_paths
        self.num_workers = num_workers
        self._backend = get_video_backend()
        self._video_width = _video_width
        self._video_height = _video_height
        self._video_min_dimension = _video_min_dimension
        self._audio_samples = _audio_samples

        if _precomputed_metadata is None:
            self._compute_frame_pts()
        else:
            self._init_from_metadata(_precomputed_metadata)
        self.compute_clips(clip_length_in_frames, frames_between_clips,
                           frame_rate)
Example #3
0
    def test_read_partial_video_pts_unit_sec(self, start, offset):
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name, pts_unit='sec')

            lv, _, _ = io.read_video(f_name,
                                     pts[start],
                                     pts[start + offset - 1],
                                     pts_unit='sec')
            s_data = data[start:(start + offset)]
            assert len(lv) == offset
            assert_equal(s_data, lv)

            with av.open(f_name) as container:
                stream = container.streams[0]
                lv, _, _ = io.read_video(f_name,
                                         int(pts[4] *
                                             (1.0 / stream.time_base) + 1) *
                                         stream.time_base,
                                         pts[7],
                                         pts_unit='sec')
            if get_video_backend() == "pyav":
                # for "video_reader" backend, we don't decode the closest early frame
                # when the given start pts is not matching any frame pts
                assert len(lv) == 4
                assert_equal(data[4:8], lv)
Example #4
0
def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options=None):
    if lossless:
        if video_codec is not None:
            raise ValueError("video_codec can't be specified together with lossless")
        if options is not None:
            raise ValueError("options can't be specified together with lossless")
        video_codec = 'libx264rgb'
        options = {'crf': '0'}

    if video_codec is None:
        if get_video_backend() == "pyav":
            video_codec = 'libx264'
        else:
            # when video_codec is not set, we assume it is libx264rgb which accepts
            # RGB pixel formats as input instead of YUV
            video_codec = 'libx264rgb'
    if options is None:
        options = {}

    data = _create_video_frames(num_frames, height, width)
    with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
        f.close()
        io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options)
        yield f.name, data
    os.unlink(f.name)
Example #5
0
    def test_read_partial_video_pts_unit_sec(self):
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name, pts_unit='sec')

            for start in range(5):
                for l in range(1, 4):
                    lv, _, _ = io.read_video(f_name,
                                             pts[start],
                                             pts[start + l - 1],
                                             pts_unit='sec')
                    s_data = data[start:(start + l)]
                    self.assertEqual(len(lv), l)
                    self.assertTrue(s_data.equal(lv))

            container = av.open(f_name)
            stream = container.streams[0]
            lv, _, _ = io.read_video(
                f_name,
                int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base,
                pts[7],
                pts_unit='sec')
            if get_video_backend() == "pyav":
                # for "video_reader" backend, we don't decode the closest early frame
                # when the given start pts is not matching any frame pts
                self.assertEqual(len(lv), 4)
                self.assertTrue(data[4:8].equal(lv))
            container.close()
Example #6
0
    def __init__(
        self,
        dataset: Any,
        split: str,
        batchsize_per_replica: int,
        shuffle: bool,
        transform: Callable,
        num_samples: Optional[int],
        clips_per_video: int,
    ):
        """The constructor method of ClassyVideoDataset.

        Args:
            dataset: the underlying video dataset from either TorchVision or other
                source. It should have an attribute *video_clips* of type
                `torchvision.datasets.video_utils.VideoClips <https://github.com/
                pytorch/vision/blob/master/torchvision/datasets/
                video_utils.py#L46/>`_
            split: dataset split. Must be either "train" or "test"
            batchsize_per_replica: batch size per model replica
            shuffle: If true, shuffle video clips.
            transform: callable function to transform video clip sample from
                ClassyVideoDataset
            num_samples: If provided, return at most `num_samples` video clips
            clips_per_video: The number of clips sampled from each video

        """
        super(ClassyVideoDataset, self).__init__(
            dataset, batchsize_per_replica, shuffle, transform, num_samples
        )
        # Assignments:
        self.clips_per_video = clips_per_video
        self.split = split
        self.video_backend = get_video_backend()
Example #7
0
    def test_video_clips(self):
        _backend = get_video_backend()
        with get_list_of_videos(num_videos=3) as video_list:
            video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
            self.assertEqual(video_clips.num_clips(), 1 + 2 + 3)
            for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0),
                                                (2, 1), (2, 2)]):
                video_idx, clip_idx = video_clips.get_clip_location(i)
                self.assertEqual(video_idx, v_idx)
                self.assertEqual(clip_idx, c_idx)

            video_clips = VideoClips(video_list, 6, 6, _backend=_backend)
            self.assertEqual(video_clips.num_clips(), 0 + 1 + 2)
            for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]):
                video_idx, clip_idx = video_clips.get_clip_location(i)
                self.assertEqual(video_idx, v_idx)
                self.assertEqual(clip_idx, c_idx)

            video_clips = VideoClips(video_list, 6, 1, _backend=_backend)
            self.assertEqual(video_clips.num_clips(),
                             0 + (10 - 6 + 1) + (15 - 6 + 1))
            for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0),
                                    (6, 2, 1)]:
                video_idx, clip_idx = video_clips.get_clip_location(i)
                self.assertEqual(video_idx, v_idx)
                self.assertEqual(clip_idx, c_idx)
def read_video_timestamps(filename, pts_unit="pts"):
    """
    List the video frames timestamps.

    Note that the function decodes the whole video frame-by-frame.

    Parameters
    ----------
    filename : str
        path to the video file
    pts_unit : str, optional
        unit in which timestamp values will be returned either 'pts' or 'sec'. Defaults to 'pts'.

    Returns
    -------
    pts : List[int] if pts_unit = 'pts'
        List[Fraction] if pts_unit = 'sec'
        presentation timestamps for each one of the frames in the video.
    video_fps : int
        the frame rate for the video

    """
    from torchvision import get_video_backend

    if get_video_backend() != "pyav":
        return _video_opt._read_video_timestamps(filename, pts_unit)

    _check_av_available()

    video_frames = []
    video_fps = None

    try:
        container = av.open(filename, metadata_errors="ignore")
    except av.AVError:
        # TODO add a warning
        pass
    else:
        if container.streams.video:
            video_stream = container.streams.video[0]
            video_time_base = video_stream.time_base
            if _can_read_timestamps_from_packets(container):
                # fast path
                video_frames = [
                    x for x in container.demux(video=0) if x.pts is not None
                ]
            else:
                video_frames = _read_from_stream(container, 0, float("inf"),
                                                 pts_unit, video_stream,
                                                 {"video": 0})
            video_fps = float(video_stream.average_rate)
        container.close()

    pts = [x.pts for x in video_frames]

    if pts_unit == "sec":
        pts = [x * video_time_base for x in pts]

    return pts, video_fps
Example #9
0
    def get_clip(self, idx):
        """
        Gets a subclip from a list of videos.

        Arguments:
            idx (int): index of the subclip. Must be between 0 and num_clips().

        Returns:
            video (Tensor)
            audio (Tensor)
            info (Dict)
            video_idx (int): index of the video in `video_paths`
        """
        if idx >= self.num_clips():
            raise IndexError("Index {} out of range "
                             "({} number of clips)".format(
                                 idx, self.num_clips()))
        video_path = self.video_paths[idx]
        clip_pts = self.clips[idx]

        from torchvision import get_video_backend

        backend = get_video_backend()

        if backend == "pyav":
            # check for invalid options
            if self._video_width != 0:
                raise ValueError(
                    "pyav backend doesn't support _video_width != 0")
            if self._video_height != 0:
                raise ValueError(
                    "pyav backend doesn't support _video_height != 0")
            if self._video_min_dimension != 0:
                raise ValueError(
                    "pyav backend doesn't support _video_min_dimension != 0")
            if self._video_max_dimension != 0:
                raise ValueError(
                    "pyav backend doesn't support _video_max_dimension != 0")
            if self._audio_samples != 0:
                raise ValueError(
                    "pyav backend doesn't support _audio_samples != 0")

        if backend == "pyav":
            assert len(clip_pts) > 0
            start_pts = clip_pts[0].item()
            end_pts = clip_pts[-1].item()
            video, audio, info = read_video(video_path, start_pts, end_pts)
        else:
            raise NotImplementedError(f"backend {backend} is not implemented.")

        resampling_idx = self.resampling_idxs[idx]
        if isinstance(resampling_idx, torch.Tensor):
            resampling_idx = resampling_idx - resampling_idx[0]
        video = video[resampling_idx]
        info["video_fps"] = self.frame_rate
        assert len(video) == self.num_frames, "{} x {}".format(
            video.shape, self.num_frames)
        return video, audio, info
Example #10
0
def read_video_timestamps(
        filename: str,
        pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]:
    """
    List the video frames timestamps.

    Note that the function decodes the whole video frame-by-frame.

    Parameters
    ----------
    filename : str
        path to the video file
    pts_unit : str, optional
        unit in which timestamp values will be returned either 'pts' or 'sec'. Defaults to 'pts'.

    Returns
    -------
    pts : List[int] if pts_unit = 'pts'
        List[Fraction] if pts_unit = 'sec'
        presentation timestamps for each one of the frames in the video.
    video_fps : float, optional
        the frame rate for the video

    """
    from torchvision import get_video_backend

    if get_video_backend() != "pyav":
        return _video_opt._read_video_timestamps(filename, pts_unit)

    _check_av_available()

    video_fps = None
    pts = []

    try:
        with av.open(filename, metadata_errors="ignore") as container:
            if container.streams.video:
                video_stream = container.streams.video[0]
                video_time_base = video_stream.time_base
                try:
                    pts = _decode_video_timestamps(container)
                except av.AVError:
                    warnings.warn(
                        f"Failed decoding frames for file {filename}")
                video_fps = float(video_stream.average_rate)
    except av.AVError:
        # TODO add a warning
        pass

    pts.sort()

    if pts_unit == "sec":
        pts = [x * video_time_base for x in pts]

    return pts, video_fps
Example #11
0
 def test_video_sampler(self):
     _backend = get_video_backend()
     with get_list_of_videos(num_videos=3, sizes=[25, 25,
                                                  25]) as video_list:
         video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
         sampler = RandomClipSampler(video_clips, 3)  # noqa: F821
         self.assertEqual(len(sampler), 3 * 3)
         indices = torch.tensor(list(iter(sampler)))
         videos = indices // 5
         v_idxs, count = torch.unique(videos, return_counts=True)
         self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2])))
         self.assertTrue(count.equal(torch.tensor([3, 3, 3])))
Example #12
0
def read_video_timestamps(
        filename: str,
        pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]:
    """
    List the video frames timestamps.

    Note that the function decodes the whole video frame-by-frame.

    Args:
        filename (str): path to the video file
        pts_unit (str, optional): unit in which timestamp values will be returned
            either 'pts' or 'sec'. Defaults to 'pts'.

    Returns:
        pts (List[int] if pts_unit = 'pts', List[Fraction] if pts_unit = 'sec'):
            presentation timestamps for each one of the frames in the video.
        video_fps (float, optional): the frame rate for the video

    """
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(read_video_timestamps)
    from torchvision import get_video_backend

    if get_video_backend() != "pyav":
        return _video_opt._read_video_timestamps(filename, pts_unit)

    _check_av_available()

    video_fps = None
    pts = []

    try:
        with av.open(filename, metadata_errors="ignore") as container:
            if container.streams.video:
                video_stream = container.streams.video[0]
                video_time_base = video_stream.time_base
                try:
                    pts = _decode_video_timestamps(container)
                except av.AVError:
                    warnings.warn(
                        f"Failed decoding frames for file {filename}")
                video_fps = float(video_stream.average_rate)
    except av.AVError as e:
        msg = f"Failed to open container for {filename}; Caught error: {e}"
        warnings.warn(msg, RuntimeWarning)

    pts.sort()

    if pts_unit == "sec":
        pts = [x * video_time_base for x in pts]

    return pts, video_fps
Example #13
0
    def test_read_partial_video(self, start, offset):
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name)

            lv, _, _ = io.read_video(f_name, pts[start],
                                     pts[start + offset - 1])
            s_data = data[start:(start + offset)]
            assert len(lv) == offset
            assert_equal(s_data, lv)

            if get_video_backend() == "pyav":
                # for "video_reader" backend, we don't decode the closest early frame
                # when the given start pts is not matching any frame pts
                lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
                assert len(lv) == 4
                assert_equal(data[4:8], lv)
Example #14
0
 def test_video_clips_custom_fps(self):
     _backend = get_video_backend()
     with get_list_of_videos(num_videos=3,
                             sizes=[12, 12, 12],
                             fps=[3, 4, 6]) as video_list:
         num_frames = 4
         for fps in [1, 3, 4, 10]:
             video_clips = VideoClips(video_list,
                                      num_frames,
                                      num_frames,
                                      fps,
                                      _backend=_backend)
             for i in range(video_clips.num_clips()):
                 video, audio, info, video_idx = video_clips.get_clip(i)
                 self.assertEqual(video.shape[0], num_frames)
                 self.assertEqual(info["video_fps"], fps)
Example #15
0
    def test_read_partial_video(self):
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name)
            for start in range(5):
                for l in range(1, 4):
                    lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
                    s_data = data[start:(start + l)]
                    self.assertEqual(len(lv), l)
                    self.assertTrue(s_data.equal(lv))

            if get_video_backend() == "pyav":
                # for "video_reader" backend, we don't decode the closest early frame
                # when the given start pts is not matching any frame pts
                lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
                self.assertEqual(len(lv), 4)
                self.assertTrue(data[4:8].equal(lv))
Example #16
0
 def test_video_sampler_unequal(self):
     _backend = get_video_backend()
     with get_list_of_videos(num_videos=3, sizes=[10, 25,
                                                  25]) as video_list:
         video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
         sampler = RandomClipSampler(video_clips, 3)  # noqa: F821
         self.assertEqual(len(sampler), 2 + 3 + 3)
         indices = list(iter(sampler))
         self.assertIn(0, indices)
         self.assertIn(1, indices)
         # remove elements of the first video, to simplify testing
         indices.remove(0)
         indices.remove(1)
         indices = torch.tensor(indices) - 2
         videos = indices // 5
         v_idxs, count = torch.unique(videos, return_counts=True)
         self.assertTrue(v_idxs.equal(torch.tensor([0, 1])))
         self.assertTrue(count.equal(torch.tensor([3, 3])))
Example #17
0
    def test_read_partial_video_bframes(self, start, offset):
        # do not use lossless encoding, to test the presence of B-frames
        options = {"bframes": "16", "keyint": "10", "min-keyint": "4"}
        with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name)

            lv, _, _ = io.read_video(f_name, pts[start],
                                     pts[start + offset - 1])
            s_data = data[start:(start + offset)]
            assert len(lv) == offset
            assert_equal(s_data, lv, rtol=0.0, atol=self.TOLERANCE)

            lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
            # TODO fix this
            if get_video_backend() == "pyav":
                assert len(lv) == 4
                assert_equal(data[4:8], lv, rtol=0.0, atol=self.TOLERANCE)
            else:
                assert len(lv) == 3
                assert_equal(data[5:8], lv, rtol=0.0, atol=self.TOLERANCE)
Example #18
0
    def test_read_partial_video_bframes(self):
        # do not use lossless encoding, to test the presence of B-frames
        options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
        with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name)
            for start in range(0, 80, 20):
                for offset in range(1, 4):
                    lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1])
                    s_data = data[start:(start + offset)]
                    self.assertEqual(len(lv), offset)
                    assert_equal(s_data, lv, rtol=0.0, atol=self.TOLERANCE)

            lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
            # TODO fix this
            if get_video_backend() == 'pyav':
                self.assertEqual(len(lv), 4)
                assert_equal(data[4:8], lv, rtol=0.0, atol=self.TOLERANCE)
            else:
                self.assertEqual(len(lv), 3)
                assert_equal(data[5:8], lv, rtol=0.0, atol=self.TOLERANCE)
Example #19
0
    def test_read_partial_video_bframes(self):
        # do not use lossless encoding, to test the presence of B-frames
        options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
        with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name)
            for start in range(0, 80, 20):
                for l in range(1, 4):
                    lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
                    s_data = data[start:(start + l)]
                    self.assertEqual(len(lv), l)
                    self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)

            lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
            # TODO fix this
            if get_video_backend() == 'pyav':
                self.assertEqual(len(lv), 4)
                self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)
            else:
                self.assertEqual(len(lv), 3)
                self.assertTrue((data[5:8].float() - lv.float()).abs().max() < self.TOLERANCE)
Example #20
0
 def test_read_video_partially_corrupted_file(self):
     with temp_video(5, 4, 4, 5, lossless=True) as (f_name, data):
         with open(f_name, 'r+b') as f:
             size = os.path.getsize(f_name)
             bytes_to_overwrite = size // 10
             # seek to the middle of the file
             f.seek(5 * bytes_to_overwrite)
             # corrupt 10% of the file from the middle
             f.write(b'\xff' * bytes_to_overwrite)
         # this exercises the container.decode assertion check
         video, audio, info = io.read_video(f.name, pts_unit='sec')
         # check that size is not equal to 5, but 3
         # TODO fix this
         if get_video_backend() == 'pyav':
             self.assertEqual(len(video), 3)
         else:
             self.assertEqual(len(video), 4)
         # but the valid decoded content is still correct
         self.assertTrue(video[:3].equal(data[:3]))
         # and the last few frames are wrong
         self.assertFalse(video.equal(data))
Example #21
0
 def test_read_video_partially_corrupted_file(self):
     with temp_video(5, 4, 4, 5, lossless=True) as (f_name, data):
         with open(f_name, "r+b") as f:
             size = os.path.getsize(f_name)
             bytes_to_overwrite = size // 10
             # seek to the middle of the file
             f.seek(5 * bytes_to_overwrite)
             # corrupt 10% of the file from the middle
             f.write(b"\xff" * bytes_to_overwrite)
         # this exercises the container.decode assertion check
         video, audio, info = io.read_video(f.name, pts_unit="sec")
         # check that size is not equal to 5, but 3
         # TODO fix this
         if get_video_backend() == "pyav":
             assert len(video) == 3
         else:
             assert len(video) == 4
         # but the valid decoded content is still correct
         assert_equal(video[:3], data[:3])
         # and the last few frames are wrong
         with pytest.raises(AssertionError):
             assert_equal(video, data)
Example #22
0
            video_codec = 'libx264'
        else:
            # when video_codec is not set, we assume it is libx264rgb which accepts
            # RGB pixel formats as input instead of YUV
            video_codec = 'libx264rgb'
    if options is None:
        options = {}

    data = _create_video_frames(num_frames, height, width)
    with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
        f.close()
        io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options)
        yield f.name, data
    os.unlink(f.name)

@unittest.skipIf(get_video_backend() != "pyav" and not io._HAS_VIDEO_OPT,
                 "video_reader backend not available")
@unittest.skipIf(av is None, "PyAV unavailable")
class Tester(unittest.TestCase):
    # compression adds artifacts, thus we add a tolerance of
    # 6 in 0-255 range
    TOLERANCE = 6

    def test_write_read_video(self):
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            lv, _, info = io.read_video(f_name)
            self.assertTrue(data.equal(lv))
            self.assertEqual(info["video_fps"], 5)

    @unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
    def test_probe_video_from_file(self):
Example #23
0
def _make_dataset(directory,
                  extensions=None,
                  is_valid_file=None,
                  pts_unit='sec'):
    """Returns a list of all video files, timestamps, and offsets.

    Args:
        directory:
            Root directory path (should not contain subdirectories).
        extensions:
            Tuple of valid extensions.
        is_valid_file:
            Used to find valid files.
        pts_unit:
            Unit of the timestamps.

    Returns:
        A list of video files, timestamps, frame offsets, and fps.

    """

    # use filename to find valid files
    if extensions is not None:
        def _is_valid_file(filename):
            return filename.lower().endswith(extensions)

    # overwrite function to find valid files
    if is_valid_file is not None:
        _is_valid_file = is_valid_file

    # find all instances (no subdirectories)
    instances = []
    for fname in os.listdir(directory):

        # skip invalid files
        if not _is_valid_file(fname):
            continue

        # keep track of valid files
        path = os.path.join(directory, fname)
        instances.append(path)

    # get timestamps
    timestamps, fpss = [], []
    for instance in instances:

        if AV_AVAILABLE and torchvision.get_video_backend() == 'pyav':
            # This is a hacky solution to estimate the timestamps.
            # When using the video_reader this approach fails because the 
            # estimated timestamps are not correct.
            with av.open(instance) as av_video:
                stream = av_video.streams.video[0]
                duration = stream.duration * stream.time_base
                fps = stream.base_rate
                n_frames = int(int(duration) * fps)

            timestamps.append([Fraction(i, fps) for i in range(n_frames)])
            fpss.append(fps)
        else:
            ts, fps = io.read_video_timestamps(instance, pts_unit=pts_unit)
            timestamps.append(ts)
            fpss.append(fps)


    # get frame offsets
    offsets = [len(ts) for ts in timestamps]
    offsets = [0] + offsets[:-1]
    for i in range(1, len(offsets)):
        offsets[i] = offsets[i-1] + offsets[i] # cumsum

    return instances, timestamps, offsets, fpss
Example #24
0
def read_video(
    filename: str,
    start_pts: int = 0,
    end_pts: Optional[float] = None,
    pts_unit: str = "pts"
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
    """
    Reads a video from a file, returning both the video frames as well as
    the audio frames

    Parameters
    ----------
    filename : str
        path to the video file
    start_pts : int if pts_unit = 'pts', optional
        float / Fraction if pts_unit = 'sec', optional
        the start presentation time of the video
    end_pts : int if pts_unit = 'pts', optional
        float / Fraction if pts_unit = 'sec', optional
        the end presentation time
    pts_unit : str, optional
        unit in which start_pts and end_pts values will be interpreted, either 'pts' or 'sec'. Defaults to 'pts'.

    Returns
    -------
    vframes : Tensor[T, H, W, C]
        the `T` video frames
    aframes : Tensor[K, L]
        the audio frames, where `K` is the number of channels and `L` is the
        number of points
    info : Dict
        metadata for the video and audio. Can contain the fields video_fps (float)
        and audio_fps (int)
    """

    from torchvision import get_video_backend

    if get_video_backend() != "pyav":
        return _video_opt._read_video(filename, start_pts, end_pts, pts_unit)

    _check_av_available()

    if end_pts is None:
        end_pts = float("inf")

    if end_pts < start_pts:
        raise ValueError("end_pts should be larger than start_pts, got "
                         "start_pts={} and end_pts={}".format(
                             start_pts, end_pts))

    info = {}
    video_frames = []
    audio_frames = []

    try:
        with av.open(filename, metadata_errors="ignore") as container:
            if container.streams.video:
                video_frames = _read_from_stream(
                    container,
                    start_pts,
                    end_pts,
                    pts_unit,
                    container.streams.video[0],
                    {"video": 0},
                )
                video_fps = container.streams.video[0].average_rate
                # guard against potentially corrupted files
                if video_fps is not None:
                    info["video_fps"] = float(video_fps)

            if container.streams.audio:
                audio_frames = _read_from_stream(
                    container,
                    start_pts,
                    end_pts,
                    pts_unit,
                    container.streams.audio[0],
                    {"audio": 0},
                )
                info["audio_fps"] = container.streams.audio[0].rate

    except av.AVError:
        # TODO raise a warning?
        pass

    vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
    aframes_list = [frame.to_ndarray() for frame in audio_frames]

    if vframes_list:
        vframes = torch.as_tensor(np.stack(vframes_list))
    else:
        vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)

    if aframes_list:
        aframes = np.concatenate(aframes_list, 1)
        aframes = torch.as_tensor(aframes)
        aframes = _align_audio_frames(aframes, audio_frames, start_pts,
                                      end_pts)
    else:
        aframes = torch.empty((1, 0), dtype=torch.float32)

    return vframes, aframes, info
Example #25
0
    if options is None:
        options = {}

    data = _create_video_frames(num_frames, height, width)
    with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
        f.close()
        io.write_video(f.name,
                       data,
                       fps=fps,
                       video_codec=video_codec,
                       options=options)
        yield f.name, data
    os.unlink(f.name)


@pytest.mark.skipif(get_video_backend() != "pyav" and not io._HAS_VIDEO_OPT,
                    reason="video_reader backend not available")
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
class TestVideo:
    # compression adds artifacts, thus we add a tolerance of
    # 6 in 0-255 range
    TOLERANCE = 6

    def test_write_read_video(self):
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            lv, _, info = io.read_video(f_name)
            assert_equal(data, lv)
            assert info["video_fps"] == 5

    @pytest.mark.skipif(not io._HAS_VIDEO_OPT,
                        reason="video_reader backend is not chosen")
Example #26
0
from . import utils
from .utils import LightLogger
from .train_utils import train_one_epoch, evaluate

import os
import time
import datetime
from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter

print(args)
print("torch version: ", torch.__version__)
print("torchvision version: ", torchvision.__version__)
print("torchvision video backend: ", torchvision.get_video_backend())

device = torch.device(args.device)
num_epoch = args.num_epoch
print_freq = args.print_freq

torch.backends.cudnn.benchmark = True


print("Creating model")
model = get_model(args.model)
model.to(device)
count_gpu = torch.cuda.device_count()
if count_gpu > 1:
    if args.disable_dist:
        print(f"{count_gpu} GPUs detected, but distribute is disabled")
Example #27
0
def read_video(
    filename: str,
    start_pts: Union[float, Fraction] = 0,
    end_pts: Optional[Union[float, Fraction]] = None,
    pts_unit: str = "pts",
    output_format: str = "THWC",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
    """
    Reads a video from a file, returning both the video frames as well as
    the audio frames

    Args:
        filename (str): path to the video file
        start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
            The start presentation time of the video
        end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
            The end presentation time
        pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
            either 'pts' or 'sec'. Defaults to 'pts'.
        output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".

    Returns:
        vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
        aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
        info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
    """
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(read_video)

    output_format = output_format.upper()
    if output_format not in ("THWC", "TCHW"):
        raise ValueError(
            f"output_format should be either 'THWC' or 'TCHW', got {output_format}."
        )

    from torchvision import get_video_backend

    if not os.path.exists(filename):
        raise RuntimeError(f"File not found: {filename}")

    if get_video_backend() != "pyav":
        return _video_opt._read_video(filename, start_pts, end_pts, pts_unit)

    _check_av_available()

    if end_pts is None:
        end_pts = float("inf")

    if end_pts < start_pts:
        raise ValueError(
            f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
        )

    info = {}
    video_frames = []
    audio_frames = []
    audio_timebase = _video_opt.default_timebase

    try:
        with av.open(filename, metadata_errors="ignore") as container:
            if container.streams.audio:
                audio_timebase = container.streams.audio[0].time_base
            if container.streams.video:
                video_frames = _read_from_stream(
                    container,
                    start_pts,
                    end_pts,
                    pts_unit,
                    container.streams.video[0],
                    {"video": 0},
                )
                video_fps = container.streams.video[0].average_rate
                # guard against potentially corrupted files
                if video_fps is not None:
                    info["video_fps"] = float(video_fps)

            if container.streams.audio:
                audio_frames = _read_from_stream(
                    container,
                    start_pts,
                    end_pts,
                    pts_unit,
                    container.streams.audio[0],
                    {"audio": 0},
                )
                info["audio_fps"] = container.streams.audio[0].rate

    except av.AVError:
        # TODO raise a warning?
        pass

    vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
    aframes_list = [frame.to_ndarray() for frame in audio_frames]

    if vframes_list:
        vframes = torch.as_tensor(np.stack(vframes_list), dtype=torch.uint8)
    else:
        vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)

    if aframes_list:
        aframes = np.concatenate(aframes_list, 1)
        aframes = torch.as_tensor(aframes, dtype=torch.float32)
        if pts_unit == "sec":
            start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
            if end_pts != float("inf"):
                end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
        aframes = _align_audio_frames(aframes, audio_frames, start_pts,
                                      end_pts)
    else:
        aframes = torch.empty((1, 0), dtype=torch.float32)

    if output_format == "TCHW":
        # [T,H,W,C] --> [T,C,H,W]
        vframes = vframes.permute(0, 3, 1, 2)

    return vframes, aframes, info
Example #28
0
    def get_clip(self, idx):
        """
        Gets a subclip from a list of videos.

        Args:
            idx (int): index of the subclip. Must be between 0 and num_clips().

        Returns:
            video (Tensor)
            audio (Tensor)
            info (Dict)
            video_idx (int): index of the video in `video_paths`
        """
        if idx >= self.num_clips():
            raise IndexError("Index {} out of range "
                             "({} number of clips)".format(
                                 idx, self.num_clips()))
        video_idx, clip_idx = self.get_clip_location(idx)
        video_path = self.video_paths[video_idx]
        clip_pts = self.clips[video_idx][clip_idx]

        from torchvision import get_video_backend

        backend = get_video_backend()

        if backend == "pyav":
            # check for invalid options
            if self._video_width != 0:
                raise ValueError(
                    "pyav backend doesn't support _video_width != 0")
            if self._video_height != 0:
                raise ValueError(
                    "pyav backend doesn't support _video_height != 0")
            if self._video_min_dimension != 0:
                raise ValueError(
                    "pyav backend doesn't support _video_min_dimension != 0")
            if self._video_max_dimension != 0:
                raise ValueError(
                    "pyav backend doesn't support _video_max_dimension != 0")
            if self._audio_samples != 0:
                raise ValueError(
                    "pyav backend doesn't support _audio_samples != 0")

        if backend == "pyav":
            start_pts = clip_pts[0].item()
            end_pts = clip_pts[-1].item()
            video, audio, info = read_video(video_path, start_pts, end_pts)
        else:
            info = _probe_video_from_file(video_path)
            video_fps = info.video_fps
            audio_fps = None

            video_start_pts = clip_pts[0].item()
            video_end_pts = clip_pts[-1].item()

            audio_start_pts, audio_end_pts = 0, -1
            audio_timebase = Fraction(0, 1)
            video_timebase = Fraction(info.video_timebase.numerator,
                                      info.video_timebase.denominator)
            if info.has_audio:
                audio_timebase = Fraction(info.audio_timebase.numerator,
                                          info.audio_timebase.denominator)
                audio_start_pts = pts_convert(video_start_pts, video_timebase,
                                              audio_timebase, math.floor)
                audio_end_pts = pts_convert(video_end_pts, video_timebase,
                                            audio_timebase, math.ceil)
                audio_fps = info.audio_sample_rate
            video, audio, info = _read_video_from_file(
                video_path,
                video_width=self._video_width,
                video_height=self._video_height,
                video_min_dimension=self._video_min_dimension,
                video_max_dimension=self._video_max_dimension,
                video_pts_range=(video_start_pts, video_end_pts),
                video_timebase=video_timebase,
                audio_samples=self._audio_samples,
                audio_channels=self._audio_channels,
                audio_pts_range=(audio_start_pts, audio_end_pts),
                audio_timebase=audio_timebase,
            )

            info = {"video_fps": video_fps}
            if audio_fps is not None:
                info["audio_fps"] = audio_fps

        if self.frame_rate is not None:
            resampling_idx = self.resampling_idxs[video_idx][clip_idx]
            if isinstance(resampling_idx, torch.Tensor):
                resampling_idx = resampling_idx - resampling_idx[0]
            video = video[resampling_idx]
            info["video_fps"] = self.frame_rate
        assert len(video) == self.num_frames, "{} x {}".format(
            video.shape, self.num_frames)
        return video, audio, info, video_idx
Example #29
0
'''
Here we assume video have been resized to value appearing in T.Resize.
'''

import torchvision
import video_yyz.transforms as T
from torchvision import get_video_backend

video_backend = get_video_backend()

normalize = T.Normalize(mean=[0.43216, 0.394666, 0.37645],
                        std=[0.22803, 0.22145, 0.216989])

transform_train = torchvision.transforms.Compose([
    T.ToFloatTensorInZeroOne(),
    T.Resize((128, 228)),
    T.RandomHorizontalFlip(), normalize,
    T.RandomCrop((112, 112))
])

transform_test = torchvision.transforms.Compose([
    T.ToFloatTensorInZeroOne(),
    T.Resize((128, 228)), normalize,
    T.CenterCrop((112, 112))
])
Example #30
0
def _make_dataset(directory,
                  extensions=None,
                  is_valid_file=None,
                  pts_unit='sec'):
    """Returns a list of all video files, timestamps, and offsets.

    Args:
        directory:
            Root directory path (should not contain subdirectories).
        extensions:
            Tuple of valid extensions.
        is_valid_file:
            Used to find valid files.
        pts_unit:
            Unit of the timestamps.

    Returns:
        A list of video files, timestamps, frame offsets, and fps.

    """

    if extensions is None:
        if is_valid_file is None:
            ValueError('Both extensions and is_valid_file cannot be None')
        else:
            _is_valid_file = is_valid_file
    else:

        def is_valid_file_extension(filepath):
            return filepath.lower().endswith(extensions)

        if is_valid_file is None:
            _is_valid_file = is_valid_file_extension
        else:

            def _is_valid_file(filepath):
                return is_valid_file_extension(filepath) and is_valid_file(
                    filepath)

    # find all video instances (no subdirectories)
    video_instances = []

    def on_error(error):
        raise error

    for root, _, files in os.walk(directory, onerror=on_error):

        for fname in files:
            # skip invalid files
            if not _is_valid_file(os.path.join(root, fname)):
                continue

            # keep track of valid files
            path = os.path.join(root, fname)
            video_instances.append(path)

    # get timestamps
    timestamps, fpss = [], []
    for instance in video_instances[:]:  # video_instances[:] creates a copy

        if AV_AVAILABLE and torchvision.get_video_backend() == 'pyav':
            # This is a hacky solution to estimate the timestamps.
            # When using the video_reader this approach fails because the
            # estimated timestamps are not correct.
            with av.open(instance) as av_video:
                stream = av_video.streams.video[0]

                # check if we can extract the video duration
                if not stream.duration:
                    print(
                        f'Video {instance} has no timestamp and will be skipped...'
                    )
                    video_instances.remove(
                        instance)  # remove from original list (not copy)
                    continue  # skip this broken video

                duration = stream.duration * stream.time_base
                fps = stream.base_rate
                n_frames = int(int(duration) * fps)

            timestamps.append([Fraction(i, fps) for i in range(n_frames)])
            fpss.append(fps)
        else:
            ts, fps = io.read_video_timestamps(instance, pts_unit=pts_unit)
            timestamps.append(ts)
            fpss.append(fps)

    # get frame offsets
    offsets = [len(ts) for ts in timestamps]
    offsets = [0] + offsets[:-1]
    for i in range(1, len(offsets)):
        offsets[i] = offsets[i - 1] + offsets[i]  # cumsum

    return video_instances, timestamps, offsets, fpss