def test_custom_dataset():
    img_norm_cfg = dict(mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True)
    crop_size = (512, 1024)
    train_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(type='LoadAnnotations'),
        dict(type='Resize', img_scale=(128, 256), ratio_range=(0.5, 2.0)),
        dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
        dict(type='RandomFlip', prob=0.5),
        dict(type='PhotoMetricDistortion'),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
        dict(type='DefaultFormatBundle'),
        dict(type='Collect', keys=['img', 'gt_semantic_seg']),
    ]
    test_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(
            type='MultiScaleFlipAug',
            img_scale=(128, 256),
            # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
            flip=False,
            transforms=[
                dict(type='Resize', keep_ratio=True),
                dict(type='RandomFlip'),
                dict(type='Normalize', **img_norm_cfg),
                dict(type='ImageToTensor', keys=['img']),
                dict(type='Collect', keys=['img']),
            ])
    ]

    # with img_dir and ann_dir
    train_dataset = CustomDataset(train_pipeline,
                                  data_root=osp.join(osp.dirname(__file__),
                                                     '../data/pseudo_dataset'),
                                  img_dir='imgs/',
                                  ann_dir='gts/',
                                  img_suffix='img.jpg',
                                  seg_map_suffix='gt.png')
    assert len(train_dataset) == 5

    # with img_dir, ann_dir, split
    train_dataset = CustomDataset(train_pipeline,
                                  data_root=osp.join(osp.dirname(__file__),
                                                     '../data/pseudo_dataset'),
                                  img_dir='imgs/',
                                  ann_dir='gts/',
                                  img_suffix='img.jpg',
                                  seg_map_suffix='gt.png',
                                  split='splits/train.txt')
    assert len(train_dataset) == 4

    # no data_root
    train_dataset = CustomDataset(
        train_pipeline,
        img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
        ann_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts'),
        img_suffix='img.jpg',
        seg_map_suffix='gt.png')
    assert len(train_dataset) == 5

    # with data_root but img_dir/ann_dir are abs path
    train_dataset = CustomDataset(train_pipeline,
                                  data_root=osp.join(osp.dirname(__file__),
                                                     '../data/pseudo_dataset'),
                                  img_dir=osp.abspath(
                                      osp.join(osp.dirname(__file__),
                                               '../data/pseudo_dataset/imgs')),
                                  ann_dir=osp.abspath(
                                      osp.join(osp.dirname(__file__),
                                               '../data/pseudo_dataset/gts')),
                                  img_suffix='img.jpg',
                                  seg_map_suffix='gt.png')
    assert len(train_dataset) == 5

    # test_mode=True
    test_dataset = CustomDataset(test_pipeline,
                                 img_dir=osp.join(
                                     osp.dirname(__file__),
                                     '../data/pseudo_dataset/imgs'),
                                 img_suffix='img.jpg',
                                 test_mode=True,
                                 classes=('pseudo_class', ))
    assert len(test_dataset) == 5

    # training data get
    train_data = train_dataset[0]
    assert isinstance(train_data, dict)

    # test data get
    test_data = test_dataset[0]
    assert isinstance(test_data, dict)

    # get gt seg map
    gt_seg_maps = train_dataset.get_gt_seg_maps(efficient_test=True)
    assert isinstance(gt_seg_maps, Generator)
    gt_seg_maps = list(gt_seg_maps)
    assert len(gt_seg_maps) == 5

    # format_results not implemented
    with pytest.raises(NotImplementedError):
        test_dataset.format_results([], '')

    pseudo_results = []
    for gt_seg_map in gt_seg_maps:
        h, w = gt_seg_map.shape
        pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))

    # test past evaluation without CLASSES
    with pytest.raises(TypeError):
        eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])

    with pytest.raises(TypeError):
        eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')

    with pytest.raises(TypeError):
        eval_results = train_dataset.evaluate(pseudo_results,
                                              metric=['mDice', 'mIoU'])

    # test past evaluation with CLASSES
    train_dataset.CLASSES = tuple(['a'] * 7)
    eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU')
    assert isinstance(eval_results, dict)
    assert 'mIoU' in eval_results
    assert 'mAcc' in eval_results
    assert 'aAcc' in eval_results

    eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
    assert isinstance(eval_results, dict)
    assert 'mDice' in eval_results
    assert 'mAcc' in eval_results
    assert 'aAcc' in eval_results

    eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore')
    assert isinstance(eval_results, dict)
    assert 'mRecall' in eval_results
    assert 'mPrecision' in eval_results
    assert 'mFscore' in eval_results
    assert 'aAcc' in eval_results

    eval_results = train_dataset.evaluate(pseudo_results,
                                          metric=['mIoU', 'mDice', 'mFscore'])
    assert isinstance(eval_results, dict)
    assert 'mIoU' in eval_results
    assert 'mDice' in eval_results
    assert 'mAcc' in eval_results
    assert 'aAcc' in eval_results
    assert 'mFscore' in eval_results
    assert 'mPrecision' in eval_results
    assert 'mRecall' in eval_results

    assert not np.isnan(eval_results['mIoU'])
    assert not np.isnan(eval_results['mDice'])
    assert not np.isnan(eval_results['mAcc'])
    assert not np.isnan(eval_results['aAcc'])
    assert not np.isnan(eval_results['mFscore'])
    assert not np.isnan(eval_results['mPrecision'])
    assert not np.isnan(eval_results['mRecall'])

    # test evaluation with pre-eval and the dataset.CLASSES is necessary
    train_dataset.CLASSES = tuple(['a'] * 7)
    pseudo_results = []
    for idx in range(len(train_dataset)):
        h, w = gt_seg_maps[idx].shape
        pseudo_result = np.random.randint(low=0, high=7, size=(h, w))
        pseudo_results.extend(train_dataset.pre_eval(pseudo_result, idx))
    eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
    assert isinstance(eval_results, dict)
    assert 'mIoU' in eval_results
    assert 'mAcc' in eval_results
    assert 'aAcc' in eval_results

    eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
    assert isinstance(eval_results, dict)
    assert 'mDice' in eval_results
    assert 'mAcc' in eval_results
    assert 'aAcc' in eval_results

    eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore')
    assert isinstance(eval_results, dict)
    assert 'mRecall' in eval_results
    assert 'mPrecision' in eval_results
    assert 'mFscore' in eval_results
    assert 'aAcc' in eval_results

    eval_results = train_dataset.evaluate(pseudo_results,
                                          metric=['mIoU', 'mDice', 'mFscore'])
    assert isinstance(eval_results, dict)
    assert 'mIoU' in eval_results
    assert 'mDice' in eval_results
    assert 'mAcc' in eval_results
    assert 'aAcc' in eval_results
    assert 'mFscore' in eval_results
    assert 'mPrecision' in eval_results
    assert 'mRecall' in eval_results

    assert not np.isnan(eval_results['mIoU'])
    assert not np.isnan(eval_results['mDice'])
    assert not np.isnan(eval_results['mAcc'])
    assert not np.isnan(eval_results['aAcc'])
    assert not np.isnan(eval_results['mFscore'])
    assert not np.isnan(eval_results['mPrecision'])
    assert not np.isnan(eval_results['mRecall'])
def test_custom_dataset():
    img_norm_cfg = dict(mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True)
    crop_size = (512, 1024)
    train_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(type='LoadAnnotations'),
        dict(type='Resize', img_scale=(128, 256), ratio_range=(0.5, 2.0)),
        dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
        dict(type='RandomFlip', prob=0.5),
        dict(type='PhotoMetricDistortion'),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
        dict(type='DefaultFormatBundle'),
        dict(type='Collect', keys=['img', 'gt_semantic_seg']),
    ]
    test_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(
            type='MultiScaleFlipAug',
            img_scale=(128, 256),
            # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
            flip=False,
            transforms=[
                dict(type='Resize', keep_ratio=True),
                dict(type='RandomFlip'),
                dict(type='Normalize', **img_norm_cfg),
                dict(type='ImageToTensor', keys=['img']),
                dict(type='Collect', keys=['img']),
            ])
    ]

    # with img_dir and ann_dir
    train_dataset = CustomDataset(train_pipeline,
                                  data_root=osp.join(osp.dirname(__file__),
                                                     '../data/pseudo_dataset'),
                                  img_dir='imgs/',
                                  ann_dir='gts/',
                                  img_suffix='img.jpg',
                                  seg_map_suffix='gt.png')
    assert len(train_dataset) == 5

    # with img_dir, ann_dir, split
    train_dataset = CustomDataset(train_pipeline,
                                  data_root=osp.join(osp.dirname(__file__),
                                                     '../data/pseudo_dataset'),
                                  img_dir='imgs/',
                                  ann_dir='gts/',
                                  img_suffix='img.jpg',
                                  seg_map_suffix='gt.png',
                                  split='splits/train.txt')
    assert len(train_dataset) == 4

    # no data_root
    train_dataset = CustomDataset(
        train_pipeline,
        img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
        ann_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts'),
        img_suffix='img.jpg',
        seg_map_suffix='gt.png')
    assert len(train_dataset) == 5

    # with data_root but img_dir/ann_dir are abs path
    train_dataset = CustomDataset(train_pipeline,
                                  data_root=osp.join(osp.dirname(__file__),
                                                     '../data/pseudo_dataset'),
                                  img_dir=osp.abspath(
                                      osp.join(osp.dirname(__file__),
                                               '../data/pseudo_dataset/imgs')),
                                  ann_dir=osp.abspath(
                                      osp.join(osp.dirname(__file__),
                                               '../data/pseudo_dataset/gts')),
                                  img_suffix='img.jpg',
                                  seg_map_suffix='gt.png')
    assert len(train_dataset) == 5

    # test_mode=True
    test_dataset = CustomDataset(test_pipeline,
                                 img_dir=osp.join(
                                     osp.dirname(__file__),
                                     '../data/pseudo_dataset/imgs'),
                                 img_suffix='img.jpg',
                                 test_mode=True)
    assert len(test_dataset) == 5

    # training data get
    train_data = train_dataset[0]
    assert isinstance(train_data, dict)

    # test data get
    test_data = test_dataset[0]
    assert isinstance(test_data, dict)

    # get gt seg map
    gt_seg_maps = train_dataset.get_gt_seg_maps()
    assert len(gt_seg_maps) == 5

    # evaluation
    pseudo_results = []
    for gt_seg_map in gt_seg_maps:
        h, w = gt_seg_map.shape
        pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
    eval_results = train_dataset.evaluate(pseudo_results)
    assert isinstance(eval_results, dict)
    assert 'mIoU' in eval_results
    assert 'mAcc' in eval_results
    assert 'aAcc' in eval_results

    # evaluation with CLASSES
    train_dataset.CLASSES = tuple(['a'] * 7)
    eval_results = train_dataset.evaluate(pseudo_results)
    assert isinstance(eval_results, dict)
    assert 'mIoU' in eval_results
    assert 'mAcc' in eval_results
    assert 'aAcc' in eval_results