示例#1
0
文件: test_io.py 项目: xx5988/vision
    def test_read_packed_b_frames_divx_file(self):
        with get_tmp_dir() as temp_dir:
            name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi"
            f_name = os.path.join(temp_dir, name)
            url = "https://download.pytorch.org/vision_tests/io/" + name
            try:
                utils.download_url(url, temp_dir)
                if _video_backend == "pyav":
                    pts, fps = io.read_video_timestamps(f_name)
                else:
                    pts, _, info = io._read_video_timestamps_from_file(f_name)
                    fps = info["video_fps"]

                self.assertEqual(pts, sorted(pts))
                self.assertEqual(fps, 30)
            except URLError:
                msg = "could not download test file '{}'".format(url)
                warnings.warn(msg, RuntimeWarning)
                raise unittest.SkipTest(msg)
示例#2
0
文件: test_io.py 项目: xx5988/vision
    def test_read_timestamps(self):
        with temp_video(10, 300, 300, 5) as (f_name, data):
            if _video_backend == "pyav":
                pts, _ = io.read_video_timestamps(f_name)
            else:
                pts, _, _ = io._read_video_timestamps_from_file(f_name)
            # note: not all formats/codecs provide accurate information for computing the
            # timestamps. For the format that we use here, this information is available,
            # so we use it as a baseline
            container = av.open(f_name)
            stream = container.streams[0]
            pts_step = int(
                round(float(1 / (stream.average_rate * stream.time_base))))
            num_frames = int(
                round(
                    float(stream.average_rate * stream.time_base *
                          stream.duration)))
            expected_pts = [i * pts_step for i in range(num_frames)]

            self.assertEqual(pts, expected_pts)
示例#3
0
文件: test_io.py 项目: xx5988/vision
    def test_read_partial_video(self):
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            if _video_backend == "pyav":
                pts, _ = io.read_video_timestamps(f_name)
            else:
                pts, _, _ = io._read_video_timestamps_from_file(f_name)
            for start in range(5):
                for l in range(1, 4):
                    lv, _, _ = _read_video(f_name, pts[start],
                                           pts[start + l - 1])
                    s_data = data[start:(start + l)]
                    self.assertEqual(len(lv), l)
                    self.assertTrue(s_data.equal(lv))

            if _video_backend == "pyav":
                # for "video_reader" backend, we don't decode the closest early frame
                # when the given start pts is not matching any frame pts
                lv, _, _ = _read_video(f_name, pts[4] + 1, pts[7])
                self.assertEqual(len(lv), 4)
                self.assertTrue(data[4:8].equal(lv))
示例#4
0
文件: test_io.py 项目: xx5988/vision
    def test_read_partial_video_bframes(self):
        # do not use lossless encoding, to test the presence of B-frames
        options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
        with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
            if _video_backend == "pyav":
                pts, _ = io.read_video_timestamps(f_name)
            else:
                pts, _, _ = io._read_video_timestamps_from_file(f_name)
            for start in range(0, 80, 20):
                for l in range(1, 4):
                    lv, _, _ = _read_video(f_name, pts[start],
                                           pts[start + l - 1])
                    s_data = data[start:(start + l)]
                    self.assertEqual(len(lv), l)
                    self.assertTrue((s_data.float() -
                                     lv.float()).abs().max() < self.TOLERANCE)

            lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
            self.assertEqual(len(lv), 4)
            self.assertTrue(
                (data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)
示例#5
0
 def __getitem__(self, idx):
     if self._backend == "pyav":
         return read_video_timestamps(self.x[idx])
     else:
         return _read_video_timestamps_from_file(self.x[idx])