Ejemplo n.º 1
0
def test_vid4_dataset():
    root_path = Path(__file__).parent / 'data'

    txt_content = ('calendar 1 (320,480,3)\ncity 2 (320,480,3)\n')
    mocked_open_function = mock_open(read_data=txt_content)

    with patch('builtins.open', mocked_open_function):
        vid4_dataset = SRVid4Dataset(lq_folder=root_path / 'lq',
                                     gt_folder=root_path / 'gt',
                                     ann_file='fake_ann_file',
                                     num_input_frames=5,
                                     pipeline=[],
                                     scale=4,
                                     test_mode=False,
                                     metric_average_mode='clip',
                                     filename_tmpl='{:08d}')

        assert vid4_dataset.data_infos == [
            dict(lq_path=str(root_path / 'lq'),
                 gt_path=str(root_path / 'gt'),
                 key='calendar/00000000',
                 num_input_frames=5,
                 max_frame_num=1),
            dict(lq_path=str(root_path / 'lq'),
                 gt_path=str(root_path / 'gt'),
                 key='city/00000000',
                 num_input_frames=5,
                 max_frame_num=2),
            dict(lq_path=str(root_path / 'lq'),
                 gt_path=str(root_path / 'gt'),
                 key='city/00000001',
                 num_input_frames=5,
                 max_frame_num=2),
        ]

        with pytest.raises(AssertionError):
            # num_input_frames should be odd numbers
            SRVid4Dataset(lq_folder=root_path,
                          gt_folder=root_path,
                          ann_file='fake_ann_file',
                          num_input_frames=6,
                          pipeline=[],
                          scale=4,
                          test_mode=False)

        with pytest.raises(ValueError):
            # metric_average_mode can only be either 'clip' or 'all'
            SRVid4Dataset(lq_folder=root_path,
                          gt_folder=root_path,
                          ann_file='fake_ann_file',
                          num_input_frames=6,
                          pipeline=[],
                          scale=4,
                          metric_average_mode='abc',
                          test_mode=False)
Ejemplo n.º 2
0
def test_vid4_dataset():
    root_path = Path(__file__).parent.parent.parent / 'data'

    txt_content = ('calendar 1 (320,480,3)\ncity 2 (320,480,3)\n')
    mocked_open_function = mock_open(read_data=txt_content)

    with patch('builtins.open', mocked_open_function):
        vid4_dataset = SRVid4Dataset(
            lq_folder=root_path / 'lq',
            gt_folder=root_path / 'gt',
            ann_file='fake_ann_file',
            num_input_frames=5,
            pipeline=[],
            scale=4,
            test_mode=False,
            metric_average_mode='clip',
            filename_tmpl='{:08d}')

        assert vid4_dataset.data_infos == [
            dict(
                lq_path=str(root_path / 'lq'),
                gt_path=str(root_path / 'gt'),
                key='calendar/00000000',
                num_input_frames=5,
                max_frame_num=1),
            dict(
                lq_path=str(root_path / 'lq'),
                gt_path=str(root_path / 'gt'),
                key='city/00000000',
                num_input_frames=5,
                max_frame_num=2),
            dict(
                lq_path=str(root_path / 'lq'),
                gt_path=str(root_path / 'gt'),
                key='city/00000001',
                num_input_frames=5,
                max_frame_num=2),
        ]

        # test evaluate function ('clip' mode)
        results = [{
            'eval_result': {
                'PSNR': 21,
                'SSIM': 0.75
            }
        }, {
            'eval_result': {
                'PSNR': 22,
                'SSIM': 0.8
            }
        }, {
            'eval_result': {
                'PSNR': 24,
                'SSIM': 0.9
            }
        }]
        eval_results = vid4_dataset.evaluate(results)
        np.testing.assert_almost_equal(eval_results['PSNR'], 22)
        np.testing.assert_almost_equal(eval_results['SSIM'], 0.8)

        # test evaluate function ('all' mode)
        vid4_dataset = SRVid4Dataset(
            lq_folder=root_path / 'lq',
            gt_folder=root_path / 'gt',
            ann_file='fake_ann_file',
            num_input_frames=5,
            pipeline=[],
            scale=4,
            test_mode=False,
            metric_average_mode='all',
            filename_tmpl='{:08d}')
        eval_results = vid4_dataset.evaluate(results)
        np.testing.assert_almost_equal(eval_results['PSNR'], 22.3333333)
        np.testing.assert_almost_equal(eval_results['SSIM'], 0.81666666)

        with pytest.raises(AssertionError):
            # num_input_frames should be odd numbers
            SRVid4Dataset(
                lq_folder=root_path,
                gt_folder=root_path,
                ann_file='fake_ann_file',
                num_input_frames=6,
                pipeline=[],
                scale=4,
                test_mode=False)

        with pytest.raises(ValueError):
            # metric_average_mode can only be either 'folder' or 'all'
            SRVid4Dataset(
                lq_folder=root_path,
                gt_folder=root_path,
                ann_file='fake_ann_file',
                num_input_frames=5,
                pipeline=[],
                scale=4,
                metric_average_mode='abc',
                test_mode=False)

        with pytest.raises(TypeError):
            # results must be a list
            vid4_dataset.evaluate(results=5)
        with pytest.raises(AssertionError):
            # The length of results should be equal to the dataset len
            vid4_dataset.evaluate(results=[results[0]])