def test_concat_ade(separate_eval):
    test_dataset = ADE20KDataset(
        pipeline=[],
        img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'))
    assert len(test_dataset) == 5

    concat_dataset = ConcatDataset([test_dataset, test_dataset],
                                   separate_eval=separate_eval)
    assert len(concat_dataset) == 10
    # Test format_results
    pseudo_results = []
    for _ in range(len(concat_dataset)):
        h, w = (2, 2)
        pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))

    # test format per image
    file_paths = []
    for i in range(len(pseudo_results)):
        file_paths.extend(
            concat_dataset.format_results([pseudo_results[i]],
                                          '.format_ade',
                                          indices=[i]))
    assert len(file_paths) == len(concat_dataset)
    temp = np.array(Image.open(file_paths[0]))
    assert np.allclose(temp, pseudo_results[0] + 1)

    shutil.rmtree('.format_ade')

    # test default argument
    file_paths = concat_dataset.format_results(pseudo_results, '.format_ade')
    assert len(file_paths) == len(concat_dataset)
    temp = np.array(Image.open(file_paths[0]))
    assert np.allclose(temp, pseudo_results[0] + 1)

    shutil.rmtree('.format_ade')
def test_concat_cityscapes(separate_eval):
    cityscape_dataset = CityscapesDataset(
        pipeline=[],
        img_dir=osp.join(osp.dirname(__file__),
                         '../data/pseudo_cityscapes_dataset/leftImg8bit'),
        ann_dir=osp.join(osp.dirname(__file__),
                         '../data/pseudo_cityscapes_dataset/gtFine'))
    assert len(cityscape_dataset) == 1
    with pytest.raises(NotImplementedError):
        _ = ConcatDataset([cityscape_dataset, cityscape_dataset],
                          separate_eval=separate_eval)
    ade_dataset = ADE20KDataset(pipeline=[],
                                img_dir=osp.join(
                                    osp.dirname(__file__),
                                    '../data/pseudo_dataset/imgs'))
    assert len(ade_dataset) == 5
    with pytest.raises(NotImplementedError):
        _ = ConcatDataset([cityscape_dataset, ade_dataset],
                          separate_eval=separate_eval)
def test_dataset_wrapper():
    # CustomDataset.load_annotations = MagicMock()
    # CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
    dataset_a = CustomDataset(img_dir=MagicMock(), pipeline=[])
    len_a = 10
    dataset_a.img_infos = MagicMock()
    dataset_a.img_infos.__len__.return_value = len_a
    dataset_b = CustomDataset(img_dir=MagicMock(), pipeline=[])
    len_b = 20
    dataset_b.img_infos = MagicMock()
    dataset_b.img_infos.__len__.return_value = len_b

    concat_dataset = ConcatDataset([dataset_a, dataset_b])
    assert concat_dataset[5] == 5
    assert concat_dataset[25] == 15
    assert len(concat_dataset) == len(dataset_a) + len(dataset_b)

    repeat_dataset = RepeatDataset(dataset_a, 10)
    assert repeat_dataset[5] == 5
    assert repeat_dataset[15] == 5
    assert repeat_dataset[27] == 7
    assert len(repeat_dataset) == 10 * len(dataset_a)
def test_dataset_wrapper():
    # CustomDataset.load_annotations = MagicMock()
    # CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
    dataset_a = CustomDataset(img_dir=MagicMock(), pipeline=[])
    len_a = 10
    dataset_a.img_infos = MagicMock()
    dataset_a.img_infos.__len__.return_value = len_a
    dataset_b = CustomDataset(img_dir=MagicMock(), pipeline=[])
    len_b = 20
    dataset_b.img_infos = MagicMock()
    dataset_b.img_infos.__len__.return_value = len_b

    concat_dataset = ConcatDataset([dataset_a, dataset_b])
    assert concat_dataset[5] == 5
    assert concat_dataset[25] == 15
    assert len(concat_dataset) == len(dataset_a) + len(dataset_b)

    repeat_dataset = RepeatDataset(dataset_a, 10)
    assert repeat_dataset[5] == 5
    assert repeat_dataset[15] == 5
    assert repeat_dataset[27] == 7
    assert len(repeat_dataset) == 10 * len(dataset_a)

    img_scale = (60, 60)
    pipeline = [
        dict(type='RandomMosaic', prob=1, img_scale=img_scale),
        dict(type='RandomFlip', prob=0.5),
        dict(type='Resize', img_scale=img_scale, keep_ratio=False),
    ]

    CustomDataset.load_annotations = MagicMock()
    results = []
    for _ in range(2):
        height = np.random.randint(10, 30)
        weight = np.random.randint(10, 30)
        img = np.ones((height, weight, 3))
        gt_semantic_seg = np.random.randint(5, size=(height, weight))
        results.append(dict(gt_semantic_seg=gt_semantic_seg, img=img))

    classes = ['0', '1', '2', '3', '4']
    palette = [(0, 0, 0), (1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)]
    CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: results[idx])
    dataset_a = CustomDataset(img_dir=MagicMock(),
                              pipeline=[],
                              test_mode=True,
                              classes=classes,
                              palette=palette)
    len_a = 2
    dataset_a.img_infos = MagicMock()
    dataset_a.img_infos.__len__.return_value = len_a

    multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)
    assert len(multi_image_mix_dataset) == len(dataset_a)

    for idx in range(len_a):
        results_ = multi_image_mix_dataset[idx]

    # test skip_type_keys
    multi_image_mix_dataset = MultiImageMixDataset(
        dataset_a, pipeline, skip_type_keys=('RandomFlip'))
    for idx in range(len_a):
        results_ = multi_image_mix_dataset[idx]
        assert results_['img'].shape == (img_scale[0], img_scale[1], 3)

    skip_type_keys = ('RandomFlip', 'Resize')
    multi_image_mix_dataset.update_skip_type_keys(skip_type_keys)
    for idx in range(len_a):
        results_ = multi_image_mix_dataset[idx]
        assert results_['img'].shape[:2] != img_scale

    # test pipeline
    with pytest.raises(TypeError):
        pipeline = [['Resize']]
        multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)