def test_load_img(self):
        results = dict(img_prefix=self.data_prefix,
                       img_info=dict(filename='color.jpg'))
        transform = LoadImageFromFile()
        results = transform(copy.deepcopy(results))
        assert results['filename'] == osp.join(self.data_prefix, 'color.jpg')
        assert results['ori_filename'] == 'color.jpg'
        assert results['img'].shape == (288, 512, 3)
        assert results['img'].dtype == np.uint8
        assert results['img_shape'] == (288, 512, 3)
        assert results['ori_shape'] == (288, 512, 3)
        assert results['pad_shape'] == (288, 512, 3)
        assert results['scale_factor'] == 1.0
        np.testing.assert_equal(results['img_norm_cfg']['mean'],
                                np.zeros(3, dtype=np.float32))
        assert repr(transform) == transform.__class__.__name__ + \
            "(to_float32=False,color_type='color',imdecode_backend='cv2')"

        # no img_prefix
        results = dict(img_prefix=None,
                       img_info=dict(filename='tests/data/color.jpg'))
        transform = LoadImageFromFile()
        results = transform(copy.deepcopy(results))
        assert results['filename'] == 'tests/data/color.jpg'
        assert results['ori_filename'] == 'tests/data/color.jpg'
        assert results['img'].shape == (288, 512, 3)

        # to_float32
        transform = LoadImageFromFile(to_float32=True)
        results = transform(copy.deepcopy(results))
        assert results['img'].dtype == np.float32

        # gray image
        results = dict(img_prefix=self.data_prefix,
                       img_info=dict(filename='gray.jpg'))
        transform = LoadImageFromFile()
        results = transform(copy.deepcopy(results))
        assert results['img'].shape == (288, 512, 3)
        assert results['img'].dtype == np.uint8

        transform = LoadImageFromFile(color_type='unchanged')
        results = transform(copy.deepcopy(results))
        assert results['img'].shape == (288, 512)
        assert results['img'].dtype == np.uint8
        np.testing.assert_equal(results['img_norm_cfg']['mean'],
                                np.zeros(1, dtype=np.float32))
    def test_load_seg_custom_classes(self):

        test_img = np.random.rand(10, 10)
        test_gt = np.zeros_like(test_img)
        test_gt[2:4, 2:4] = 1
        test_gt[2:4, 6:8] = 2
        test_gt[6:8, 2:4] = 3
        test_gt[6:8, 6:8] = 4

        tmp_dir = tempfile.TemporaryDirectory()
        img_path = osp.join(tmp_dir.name, 'img.jpg')
        gt_path = osp.join(tmp_dir.name, 'gt.png')

        mmcv.imwrite(test_img, img_path)
        mmcv.imwrite(test_gt, gt_path)

        # test only train with label with id 3
        results = dict(img_info=dict(filename=img_path),
                       ann_info=dict(seg_map=gt_path),
                       label_map={
                           0: 0,
                           1: 0,
                           2: 0,
                           3: 1,
                           4: 0
                       },
                       seg_fields=[])

        load_imgs = LoadImageFromFile()
        results = load_imgs(copy.deepcopy(results))

        load_anns = LoadAnnotations()
        results = load_anns(copy.deepcopy(results))

        gt_array = results['gt_semantic_seg']

        true_mask = np.zeros_like(gt_array)
        true_mask[6:8, 2:4] = 1

        assert results['seg_fields'] == ['gt_semantic_seg']
        assert gt_array.shape == (10, 10)
        assert gt_array.dtype == np.uint8
        np.testing.assert_array_equal(gt_array, true_mask)

        # test only train with label with id 4 and 3
        results = dict(img_info=dict(filename=img_path),
                       ann_info=dict(seg_map=gt_path),
                       label_map={
                           0: 0,
                           1: 0,
                           2: 0,
                           3: 2,
                           4: 1
                       },
                       seg_fields=[])

        load_imgs = LoadImageFromFile()
        results = load_imgs(copy.deepcopy(results))

        load_anns = LoadAnnotations()
        results = load_anns(copy.deepcopy(results))

        gt_array = results['gt_semantic_seg']

        true_mask = np.zeros_like(gt_array)
        true_mask[6:8, 2:4] = 2
        true_mask[6:8, 6:8] = 1

        assert results['seg_fields'] == ['gt_semantic_seg']
        assert gt_array.shape == (10, 10)
        assert gt_array.dtype == np.uint8
        np.testing.assert_array_equal(gt_array, true_mask)

        # test no custom classes
        results = dict(img_info=dict(filename=img_path),
                       ann_info=dict(seg_map=gt_path),
                       seg_fields=[])

        load_imgs = LoadImageFromFile()
        results = load_imgs(copy.deepcopy(results))

        load_anns = LoadAnnotations()
        results = load_anns(copy.deepcopy(results))

        gt_array = results['gt_semantic_seg']

        assert results['seg_fields'] == ['gt_semantic_seg']
        assert gt_array.shape == (10, 10)
        assert gt_array.dtype == np.uint8
        np.testing.assert_array_equal(gt_array, test_gt)

        tmp_dir.cleanup()