def test_custom_dataset_custom_palette():
    dataset = CustomDataset(pipeline=[],
                            img_dir=MagicMock(),
                            split=MagicMock(),
                            classes=('bus', 'car'),
                            palette=[[100, 100, 100], [200, 200, 200]],
                            test_mode=True)
    assert tuple(dataset.PALETTE) == tuple([[100, 100, 100], [200, 200, 200]])
def test_custom_dataset_random_palette_is_generated():
    dataset = CustomDataset(pipeline=[],
                            img_dir=MagicMock(),
                            split=MagicMock(),
                            classes=('bus', 'car'),
                            test_mode=True)
    assert len(dataset.PALETTE) == 2
    for class_color in dataset.PALETTE:
        assert len(class_color) == 3
        assert all(x >= 0 and x <= 255 for x in class_color)
def test_dataset_wrapper():
    # CustomDataset.load_annotations = MagicMock()
    # CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
    dataset_a = CustomDataset(img_dir=MagicMock(), pipeline=[])
    len_a = 10
    dataset_a.img_infos = MagicMock()
    dataset_a.img_infos.__len__.return_value = len_a
    dataset_b = CustomDataset(img_dir=MagicMock(), pipeline=[])
    len_b = 20
    dataset_b.img_infos = MagicMock()
    dataset_b.img_infos.__len__.return_value = len_b

    concat_dataset = ConcatDataset([dataset_a, dataset_b])
    assert concat_dataset[5] == 5
    assert concat_dataset[25] == 15
    assert len(concat_dataset) == len(dataset_a) + len(dataset_b)

    repeat_dataset = RepeatDataset(dataset_a, 10)
    assert repeat_dataset[5] == 5
    assert repeat_dataset[15] == 5
    assert repeat_dataset[27] == 7
    assert len(repeat_dataset) == 10 * len(dataset_a)
def test_dataset_wrapper():
    # CustomDataset.load_annotations = MagicMock()
    # CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
    dataset_a = CustomDataset(img_dir=MagicMock(), pipeline=[])
    len_a = 10
    dataset_a.img_infos = MagicMock()
    dataset_a.img_infos.__len__.return_value = len_a
    dataset_b = CustomDataset(img_dir=MagicMock(), pipeline=[])
    len_b = 20
    dataset_b.img_infos = MagicMock()
    dataset_b.img_infos.__len__.return_value = len_b

    concat_dataset = ConcatDataset([dataset_a, dataset_b])
    assert concat_dataset[5] == 5
    assert concat_dataset[25] == 15
    assert len(concat_dataset) == len(dataset_a) + len(dataset_b)

    repeat_dataset = RepeatDataset(dataset_a, 10)
    assert repeat_dataset[5] == 5
    assert repeat_dataset[15] == 5
    assert repeat_dataset[27] == 7
    assert len(repeat_dataset) == 10 * len(dataset_a)

    img_scale = (60, 60)
    pipeline = [
        dict(type='RandomMosaic', prob=1, img_scale=img_scale),
        dict(type='RandomFlip', prob=0.5),
        dict(type='Resize', img_scale=img_scale, keep_ratio=False),
    ]

    CustomDataset.load_annotations = MagicMock()
    results = []
    for _ in range(2):
        height = np.random.randint(10, 30)
        weight = np.random.randint(10, 30)
        img = np.ones((height, weight, 3))
        gt_semantic_seg = np.random.randint(5, size=(height, weight))
        results.append(dict(gt_semantic_seg=gt_semantic_seg, img=img))

    classes = ['0', '1', '2', '3', '4']
    palette = [(0, 0, 0), (1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)]
    CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: results[idx])
    dataset_a = CustomDataset(img_dir=MagicMock(),
                              pipeline=[],
                              test_mode=True,
                              classes=classes,
                              palette=palette)
    len_a = 2
    dataset_a.img_infos = MagicMock()
    dataset_a.img_infos.__len__.return_value = len_a

    multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)
    assert len(multi_image_mix_dataset) == len(dataset_a)

    for idx in range(len_a):
        results_ = multi_image_mix_dataset[idx]

    # test skip_type_keys
    multi_image_mix_dataset = MultiImageMixDataset(
        dataset_a, pipeline, skip_type_keys=('RandomFlip'))
    for idx in range(len_a):
        results_ = multi_image_mix_dataset[idx]
        assert results_['img'].shape == (img_scale[0], img_scale[1], 3)

    skip_type_keys = ('RandomFlip', 'Resize')
    multi_image_mix_dataset.update_skip_type_keys(skip_type_keys)
    for idx in range(len_a):
        results_ = multi_image_mix_dataset[idx]
        assert results_['img'].shape[:2] != img_scale

    # test pipeline
    with pytest.raises(TypeError):
        pipeline = [['Resize']]
        multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)
def test_custom_dataset():
    img_norm_cfg = dict(mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True)
    crop_size = (512, 1024)
    train_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(type='LoadAnnotations'),
        dict(type='Resize', img_scale=(128, 256), ratio_range=(0.5, 2.0)),
        dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
        dict(type='RandomFlip', prob=0.5),
        dict(type='PhotoMetricDistortion'),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
        dict(type='DefaultFormatBundle'),
        dict(type='Collect', keys=['img', 'gt_semantic_seg']),
    ]
    test_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(
            type='MultiScaleFlipAug',
            img_scale=(128, 256),
            # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
            flip=False,
            transforms=[
                dict(type='Resize', keep_ratio=True),
                dict(type='RandomFlip'),
                dict(type='Normalize', **img_norm_cfg),
                dict(type='ImageToTensor', keys=['img']),
                dict(type='Collect', keys=['img']),
            ])
    ]

    # with img_dir and ann_dir
    train_dataset = CustomDataset(train_pipeline,
                                  data_root=osp.join(osp.dirname(__file__),
                                                     '../data/pseudo_dataset'),
                                  img_dir='imgs/',
                                  ann_dir='gts/',
                                  img_suffix='img.jpg',
                                  seg_map_suffix='gt.png')
    assert len(train_dataset) == 5

    # with img_dir, ann_dir, split
    train_dataset = CustomDataset(train_pipeline,
                                  data_root=osp.join(osp.dirname(__file__),
                                                     '../data/pseudo_dataset'),
                                  img_dir='imgs/',
                                  ann_dir='gts/',
                                  img_suffix='img.jpg',
                                  seg_map_suffix='gt.png',
                                  split='splits/train.txt')
    assert len(train_dataset) == 4

    # no data_root
    train_dataset = CustomDataset(
        train_pipeline,
        img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
        ann_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts'),
        img_suffix='img.jpg',
        seg_map_suffix='gt.png')
    assert len(train_dataset) == 5

    # with data_root but img_dir/ann_dir are abs path
    train_dataset = CustomDataset(train_pipeline,
                                  data_root=osp.join(osp.dirname(__file__),
                                                     '../data/pseudo_dataset'),
                                  img_dir=osp.abspath(
                                      osp.join(osp.dirname(__file__),
                                               '../data/pseudo_dataset/imgs')),
                                  ann_dir=osp.abspath(
                                      osp.join(osp.dirname(__file__),
                                               '../data/pseudo_dataset/gts')),
                                  img_suffix='img.jpg',
                                  seg_map_suffix='gt.png')
    assert len(train_dataset) == 5

    # test_mode=True
    test_dataset = CustomDataset(test_pipeline,
                                 img_dir=osp.join(
                                     osp.dirname(__file__),
                                     '../data/pseudo_dataset/imgs'),
                                 img_suffix='img.jpg',
                                 test_mode=True,
                                 classes=('pseudo_class', ))
    assert len(test_dataset) == 5

    # training data get
    train_data = train_dataset[0]
    assert isinstance(train_data, dict)

    # test data get
    test_data = test_dataset[0]
    assert isinstance(test_data, dict)

    # get gt seg map
    gt_seg_maps = train_dataset.get_gt_seg_maps(efficient_test=True)
    assert isinstance(gt_seg_maps, Generator)
    gt_seg_maps = list(gt_seg_maps)
    assert len(gt_seg_maps) == 5

    # format_results not implemented
    with pytest.raises(NotImplementedError):
        test_dataset.format_results([], '')

    pseudo_results = []
    for gt_seg_map in gt_seg_maps:
        h, w = gt_seg_map.shape
        pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))

    # test past evaluation without CLASSES
    with pytest.raises(TypeError):
        eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])

    with pytest.raises(TypeError):
        eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')

    with pytest.raises(TypeError):
        eval_results = train_dataset.evaluate(pseudo_results,
                                              metric=['mDice', 'mIoU'])

    # test past evaluation with CLASSES
    train_dataset.CLASSES = tuple(['a'] * 7)
    eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU')
    assert isinstance(eval_results, dict)
    assert 'mIoU' in eval_results
    assert 'mAcc' in eval_results
    assert 'aAcc' in eval_results

    eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
    assert isinstance(eval_results, dict)
    assert 'mDice' in eval_results
    assert 'mAcc' in eval_results
    assert 'aAcc' in eval_results

    eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore')
    assert isinstance(eval_results, dict)
    assert 'mRecall' in eval_results
    assert 'mPrecision' in eval_results
    assert 'mFscore' in eval_results
    assert 'aAcc' in eval_results

    eval_results = train_dataset.evaluate(pseudo_results,
                                          metric=['mIoU', 'mDice', 'mFscore'])
    assert isinstance(eval_results, dict)
    assert 'mIoU' in eval_results
    assert 'mDice' in eval_results
    assert 'mAcc' in eval_results
    assert 'aAcc' in eval_results
    assert 'mFscore' in eval_results
    assert 'mPrecision' in eval_results
    assert 'mRecall' in eval_results

    assert not np.isnan(eval_results['mIoU'])
    assert not np.isnan(eval_results['mDice'])
    assert not np.isnan(eval_results['mAcc'])
    assert not np.isnan(eval_results['aAcc'])
    assert not np.isnan(eval_results['mFscore'])
    assert not np.isnan(eval_results['mPrecision'])
    assert not np.isnan(eval_results['mRecall'])

    # test evaluation with pre-eval and the dataset.CLASSES is necessary
    train_dataset.CLASSES = tuple(['a'] * 7)
    pseudo_results = []
    for idx in range(len(train_dataset)):
        h, w = gt_seg_maps[idx].shape
        pseudo_result = np.random.randint(low=0, high=7, size=(h, w))
        pseudo_results.extend(train_dataset.pre_eval(pseudo_result, idx))
    eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
    assert isinstance(eval_results, dict)
    assert 'mIoU' in eval_results
    assert 'mAcc' in eval_results
    assert 'aAcc' in eval_results

    eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
    assert isinstance(eval_results, dict)
    assert 'mDice' in eval_results
    assert 'mAcc' in eval_results
    assert 'aAcc' in eval_results

    eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore')
    assert isinstance(eval_results, dict)
    assert 'mRecall' in eval_results
    assert 'mPrecision' in eval_results
    assert 'mFscore' in eval_results
    assert 'aAcc' in eval_results

    eval_results = train_dataset.evaluate(pseudo_results,
                                          metric=['mIoU', 'mDice', 'mFscore'])
    assert isinstance(eval_results, dict)
    assert 'mIoU' in eval_results
    assert 'mDice' in eval_results
    assert 'mAcc' in eval_results
    assert 'aAcc' in eval_results
    assert 'mFscore' in eval_results
    assert 'mPrecision' in eval_results
    assert 'mRecall' in eval_results

    assert not np.isnan(eval_results['mIoU'])
    assert not np.isnan(eval_results['mDice'])
    assert not np.isnan(eval_results['mAcc'])
    assert not np.isnan(eval_results['aAcc'])
    assert not np.isnan(eval_results['mFscore'])
    assert not np.isnan(eval_results['mPrecision'])
    assert not np.isnan(eval_results['mRecall'])
def test_custom_dataset():
    img_norm_cfg = dict(mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True)
    crop_size = (512, 1024)
    train_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(type='LoadAnnotations'),
        dict(type='Resize', img_scale=(128, 256), ratio_range=(0.5, 2.0)),
        dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
        dict(type='RandomFlip', prob=0.5),
        dict(type='PhotoMetricDistortion'),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
        dict(type='DefaultFormatBundle'),
        dict(type='Collect', keys=['img', 'gt_semantic_seg']),
    ]
    test_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(
            type='MultiScaleFlipAug',
            img_scale=(128, 256),
            # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
            flip=False,
            transforms=[
                dict(type='Resize', keep_ratio=True),
                dict(type='RandomFlip'),
                dict(type='Normalize', **img_norm_cfg),
                dict(type='ImageToTensor', keys=['img']),
                dict(type='Collect', keys=['img']),
            ])
    ]

    # with img_dir and ann_dir
    train_dataset = CustomDataset(train_pipeline,
                                  data_root=osp.join(osp.dirname(__file__),
                                                     '../data/pseudo_dataset'),
                                  img_dir='imgs/',
                                  ann_dir='gts/',
                                  img_suffix='img.jpg',
                                  seg_map_suffix='gt.png')
    assert len(train_dataset) == 5

    # with img_dir, ann_dir, split
    train_dataset = CustomDataset(train_pipeline,
                                  data_root=osp.join(osp.dirname(__file__),
                                                     '../data/pseudo_dataset'),
                                  img_dir='imgs/',
                                  ann_dir='gts/',
                                  img_suffix='img.jpg',
                                  seg_map_suffix='gt.png',
                                  split='splits/train.txt')
    assert len(train_dataset) == 4

    # no data_root
    train_dataset = CustomDataset(
        train_pipeline,
        img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
        ann_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts'),
        img_suffix='img.jpg',
        seg_map_suffix='gt.png')
    assert len(train_dataset) == 5

    # with data_root but img_dir/ann_dir are abs path
    train_dataset = CustomDataset(train_pipeline,
                                  data_root=osp.join(osp.dirname(__file__),
                                                     '../data/pseudo_dataset'),
                                  img_dir=osp.abspath(
                                      osp.join(osp.dirname(__file__),
                                               '../data/pseudo_dataset/imgs')),
                                  ann_dir=osp.abspath(
                                      osp.join(osp.dirname(__file__),
                                               '../data/pseudo_dataset/gts')),
                                  img_suffix='img.jpg',
                                  seg_map_suffix='gt.png')
    assert len(train_dataset) == 5

    # test_mode=True
    test_dataset = CustomDataset(test_pipeline,
                                 img_dir=osp.join(
                                     osp.dirname(__file__),
                                     '../data/pseudo_dataset/imgs'),
                                 img_suffix='img.jpg',
                                 test_mode=True)
    assert len(test_dataset) == 5

    # training data get
    train_data = train_dataset[0]
    assert isinstance(train_data, dict)

    # test data get
    test_data = test_dataset[0]
    assert isinstance(test_data, dict)

    # get gt seg map
    gt_seg_maps = train_dataset.get_gt_seg_maps()
    assert len(gt_seg_maps) == 5

    # evaluation
    pseudo_results = []
    for gt_seg_map in gt_seg_maps:
        h, w = gt_seg_map.shape
        pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
    eval_results = train_dataset.evaluate(pseudo_results)
    assert isinstance(eval_results, dict)
    assert 'mIoU' in eval_results
    assert 'mAcc' in eval_results
    assert 'aAcc' in eval_results

    # evaluation with CLASSES
    train_dataset.CLASSES = tuple(['a'] * 7)
    eval_results = train_dataset.evaluate(pseudo_results)
    assert isinstance(eval_results, dict)
    assert 'mIoU' in eval_results
    assert 'mAcc' in eval_results
    assert 'aAcc' in eval_results