コード例 #1
0
def construct_toy_data(poly2mask=True):
    img = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.uint8)
    img = np.stack([img, img, img], axis=-1)
    results = dict()
    # image
    results['img'] = img
    results['img_shape'] = img.shape
    results['img_fields'] = ['img']
    # bboxes
    results['bbox_fields'] = ['gt_bboxes', 'gt_bboxes_ignore']
    results['gt_bboxes'] = np.array([[0., 0., 2., 1.]], dtype=np.float32)
    results['gt_bboxes_ignore'] = np.array([[2., 0., 3., 1.]],
                                           dtype=np.float32)
    # labels
    results['gt_labels'] = np.array([1], dtype=np.int64)
    # masks
    results['mask_fields'] = ['gt_masks']
    if poly2mask:
        gt_masks = np.array([[0, 1, 1, 0], [0, 1, 0, 0]],
                            dtype=np.uint8)[None, :, :]
        results['gt_masks'] = BitmapMasks(gt_masks, 2, 4)
    else:
        raw_masks = [[np.array([1, 0, 2, 0, 2, 1, 1, 1], dtype=np.float)]]
        results['gt_masks'] = PolygonMasks(raw_masks, 2, 4)
    # segmentations
    results['seg_fields'] = ['gt_semantic_seg']
    results['gt_semantic_seg'] = img[..., 0]
    return results
コード例 #2
0
ファイル: utils.py プロジェクト: shinya7y/UniverseNet
def create_full_masks(gt_bboxes, img_w, img_h):
    xmin, ymin = gt_bboxes[:, 0:1], gt_bboxes[:, 1:2]
    xmax, ymax = gt_bboxes[:, 2:3], gt_bboxes[:, 3:4]
    gt_masks = np.zeros((len(gt_bboxes), img_h, img_w), dtype=np.uint8)
    for i in range(len(gt_bboxes)):
        gt_masks[i, int(ymin[i]):int(ymax[i]), int(xmin[i]):int(xmax[i])] = 1
    gt_masks = BitmapMasks(gt_masks, img_h, img_w)
    return gt_masks
コード例 #3
0
def _load_masks(results, poly2mask=True):
    h, w = results['img_info']['height'], results['img_info']['width']
    gt_masks = results['ann_info']['masks']
    if poly2mask:
        gt_masks = BitmapMasks([_poly2mask(mask, h, w) for mask in gt_masks],
                               h, w)
    else:
        gt_masks = PolygonMasks(
            [_process_polygons(polygons) for polygons in gt_masks], h, w)
    results['gt_masks'] = gt_masks
    results['mask_fields'] = ['gt_masks']
コード例 #4
0
ファイル: test_loading.py プロジェクト: OceanPang/mmdetection
def test_filter_annotations(target, kwargs):
    filter_ann = FilterAnnotations(**kwargs)
    bboxes = np.array([[2., 10., 4., 14.], [2., 10., 2.1, 10.1]])
    raw_masks = np.zeros((2, 24, 24))
    raw_masks[0, 10:14, 2:4] = 1
    bitmap_masks = BitmapMasks(raw_masks, 24, 24)
    results = dict(gt_bboxes=bboxes, gt_masks=bitmap_masks)
    results = filter_ann(results)
    if results is not None:
        results = results['gt_bboxes'].shape[0]
    assert results == target
コード例 #5
0
def test_maskformer_head_loss():
    """Tests head loss when truth is empty and non-empty."""
    base_channels = 64
    # batch_input_shape = (128, 160)
    img_metas = [{
        'batch_input_shape': (128, 160),
        'pad_shape': (128, 160, 3),
        'img_shape': (126, 160, 3),
        'ori_shape': (63, 80, 3)
    }, {
        'batch_input_shape': (128, 160),
        'pad_shape': (128, 160, 3),
        'img_shape': (120, 160, 3),
        'ori_shape': (60, 80, 3)
    }]
    feats = [
        torch.rand((2, 64 * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i)))
        for i in range(4)
    ]
    num_things_classes = 80
    num_stuff_classes = 53
    num_classes = num_things_classes + num_stuff_classes
    config = ConfigDict(
        dict(
            type='MaskFormerHead',
            in_channels=[base_channels * 2**i for i in range(4)],
            feat_channels=base_channels,
            out_channels=base_channels,
            num_things_classes=num_things_classes,
            num_stuff_classes=num_stuff_classes,
            num_queries=100,
            pixel_decoder=dict(
                type='TransformerEncoderPixelDecoder',
                norm_cfg=dict(type='GN', num_groups=32),
                act_cfg=dict(type='ReLU'),
                encoder=dict(
                    type='DetrTransformerEncoder',
                    num_layers=6,
                    transformerlayers=dict(
                        type='BaseTransformerLayer',
                        attn_cfgs=dict(type='MultiheadAttention',
                                       embed_dims=base_channels,
                                       num_heads=8,
                                       attn_drop=0.1,
                                       proj_drop=0.1,
                                       dropout_layer=None,
                                       batch_first=False),
                        ffn_cfgs=dict(embed_dims=base_channels,
                                      feedforward_channels=base_channels * 8,
                                      num_fcs=2,
                                      act_cfg=dict(type='ReLU', inplace=True),
                                      ffn_drop=0.1,
                                      dropout_layer=None,
                                      add_identity=True),
                        operation_order=('self_attn', 'norm', 'ffn', 'norm'),
                        norm_cfg=dict(type='LN'),
                        init_cfg=None,
                        batch_first=False),
                    init_cfg=None),
                positional_encoding=dict(type='SinePositionalEncoding',
                                         num_feats=base_channels // 2,
                                         normalize=True)),
            enforce_decoder_input_project=False,
            positional_encoding=dict(type='SinePositionalEncoding',
                                     num_feats=base_channels // 2,
                                     normalize=True),
            transformer_decoder=dict(
                type='DetrTransformerDecoder',
                return_intermediate=True,
                num_layers=6,
                transformerlayers=dict(
                    type='DetrTransformerDecoderLayer',
                    attn_cfgs=dict(type='MultiheadAttention',
                                   embed_dims=base_channels,
                                   num_heads=8,
                                   attn_drop=0.1,
                                   proj_drop=0.1,
                                   dropout_layer=None,
                                   batch_first=False),
                    ffn_cfgs=dict(embed_dims=base_channels,
                                  feedforward_channels=base_channels * 8,
                                  num_fcs=2,
                                  act_cfg=dict(type='ReLU', inplace=True),
                                  ffn_drop=0.1,
                                  dropout_layer=None,
                                  add_identity=True),
                    # the following parameter was not used,
                    # just make current api happy
                    feedforward_channels=base_channels * 8,
                    operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
                                     'ffn', 'norm')),
                init_cfg=None),
            loss_cls=dict(type='CrossEntropyLoss',
                          use_sigmoid=False,
                          loss_weight=1.0,
                          reduction='mean',
                          class_weight=[1.0] * num_classes + [0.1]),
            loss_mask=dict(type='FocalLoss',
                           use_sigmoid=True,
                           gamma=2.0,
                           alpha=0.25,
                           reduction='mean',
                           loss_weight=20.0),
            loss_dice=dict(type='DiceLoss',
                           use_sigmoid=True,
                           activate=True,
                           reduction='mean',
                           naive_dice=True,
                           eps=1.0,
                           loss_weight=1.0),
            train_cfg=dict(assigner=dict(type='MaskHungarianAssigner',
                                         cls_cost=dict(
                                             type='ClassificationCost',
                                             weight=1.0),
                                         mask_cost=dict(type='FocalLossCost',
                                                        weight=20.0,
                                                        binary_input=True),
                                         dice_cost=dict(type='DiceCost',
                                                        weight=1.0,
                                                        pred_act=True,
                                                        eps=1.0)),
                           sampler=dict(type='MaskPseudoSampler')),
            test_cfg=dict(object_mask_thr=0.8, iou_thr=0.8)))
    self = MaskFormerHead(**config)
    self.init_weights()
    all_cls_scores, all_mask_preds = self.forward(feats, img_metas)
    # Test that empty ground truth encourages the network to predict background
    gt_labels_list = [torch.LongTensor([]), torch.LongTensor([])]
    gt_masks_list = [
        torch.zeros((0, 128, 160)).long(),
        torch.zeros((0, 128, 160)).long()
    ]

    empty_gt_losses = self.loss(all_cls_scores, all_mask_preds, gt_labels_list,
                                gt_masks_list, img_metas)
    # When there is no truth, the cls loss should be nonzero but there should
    # be no mask loss.
    for key, loss in empty_gt_losses.items():
        if 'cls' in key:
            assert loss.item() > 0, 'cls loss should be non-zero'
        elif 'mask' in key:
            assert loss.item(
            ) == 0, 'there should be no mask loss when there are no true mask'
        elif 'dice' in key:
            assert loss.item(
            ) == 0, 'there should be no dice loss when there are no true mask'

    # when truth is non-empty then both cls, mask, dice loss should be nonzero
    # random inputs
    gt_labels_list = [
        torch.tensor([10, 100]).long(),
        torch.tensor([100, 10]).long()
    ]
    mask1 = torch.zeros((2, 128, 160)).long()
    mask1[0, :50] = 1
    mask1[1, 50:] = 1
    mask2 = torch.zeros((2, 128, 160)).long()
    mask2[0, :, :50] = 1
    mask2[1, :, 50:] = 1
    gt_masks_list = [mask1, mask2]
    two_gt_losses = self.loss(all_cls_scores, all_mask_preds, gt_labels_list,
                              gt_masks_list, img_metas)
    for loss in two_gt_losses.values():
        assert loss.item() > 0, 'all loss should be non-zero'

    # test forward_train
    gt_bboxes = None
    gt_labels = [
        torch.tensor([10]).long(),
        torch.tensor([10]).long(),
    ]
    thing_mask1 = np.zeros((1, 128, 160), dtype=np.int32)
    thing_mask1[0, :50] = 1
    thing_mask2 = np.zeros((1, 128, 160), dtype=np.int32)
    thing_mask2[0, :, 50:] = 1
    gt_masks = [
        BitmapMasks(thing_mask1, 128, 160),
        BitmapMasks(thing_mask2, 128, 160),
    ]
    stuff_mask1 = torch.zeros((1, 128, 160)).long()
    stuff_mask1[0, :50] = 10
    stuff_mask1[0, 50:] = 100
    stuff_mask2 = torch.zeros((1, 128, 160)).long()
    stuff_mask2[0, :, 50:] = 10
    stuff_mask2[0, :, :50] = 100
    gt_semantic_seg = [stuff_mask1, stuff_mask2]

    self.forward_train(feats, img_metas, gt_bboxes, gt_labels, gt_masks,
                       gt_semantic_seg)

    # test inference mode
    self.simple_test(feats, img_metas)
コード例 #6
0
def test_shear():
    # test assertion for invalid type of max_shear_magnitude
    with pytest.raises(AssertionError):
        transform = dict(type='Shear', level=1, max_shear_magnitude=(0.5, ))
        build_from_cfg(transform, PIPELINES)

    # test assertion for invalid value of max_shear_magnitude
    with pytest.raises(AssertionError):
        transform = dict(type='Shear', level=2, max_shear_magnitude=1.2)
        build_from_cfg(transform, PIPELINES)

    # test ValueError for invalid type of img_fill_val
    with pytest.raises(ValueError):
        transform = dict(type='Shear', level=2, img_fill_val=[128])
        build_from_cfg(transform, PIPELINES)

    results = construct_toy_data()
    # test case when no shear aug (level=0, direction='horizontal')
    img_fill_val = (104, 116, 124)
    seg_ignore_label = 255
    transform = dict(
        type='Shear',
        level=0,
        prob=1.,
        img_fill_val=img_fill_val,
        seg_ignore_label=seg_ignore_label,
        direction='horizontal')
    shear_module = build_from_cfg(transform, PIPELINES)
    results_wo_shear = shear_module(copy.deepcopy(results))
    check_shear(results, results_wo_shear)

    # test case when no shear aug (level=0, direction='vertical')
    transform = dict(
        type='Shear',
        level=0,
        prob=1.,
        img_fill_val=img_fill_val,
        seg_ignore_label=seg_ignore_label,
        direction='vertical')
    shear_module = build_from_cfg(transform, PIPELINES)
    results_wo_shear = shear_module(copy.deepcopy(results))
    check_shear(results, results_wo_shear)

    # test case when no shear aug (prob<=0)
    transform = dict(
        type='Shear',
        level=10,
        prob=0.,
        img_fill_val=img_fill_val,
        direction='vertical')
    shear_module = build_from_cfg(transform, PIPELINES)
    results_wo_shear = shear_module(copy.deepcopy(results))
    check_shear(results, results_wo_shear)

    # test shear horizontally, magnitude=1
    transform = dict(
        type='Shear',
        level=10,
        prob=1.,
        img_fill_val=img_fill_val,
        direction='horizontal',
        max_shear_magnitude=1.,
        random_negative_prob=0.)
    shear_module = build_from_cfg(transform, PIPELINES)
    results_sheared = shear_module(copy.deepcopy(results))
    results_gt = copy.deepcopy(results)
    img_s = np.array([[1, 2, 3, 4], [0, 5, 6, 7]], dtype=np.uint8)
    img_s = np.stack([img_s, img_s, img_s], axis=-1)
    img_s[1, 0, :] = np.array(img_fill_val)
    results_gt['img'] = img_s
    results_gt['gt_bboxes'] = np.array([[0., 0., 3., 1.]], dtype=np.float32)
    results_gt['gt_bboxes_ignore'] = np.array([[2., 0., 4., 1.]],
                                              dtype=np.float32)
    gt_masks = np.array([[0, 1, 1, 0], [0, 0, 1, 0]],
                        dtype=np.uint8)[None, :, :]
    results_gt['gt_masks'] = BitmapMasks(gt_masks, 2, 4)
    results_gt['gt_semantic_seg'] = np.array(
        [[1, 2, 3, 4], [255, 5, 6, 7]], dtype=results['gt_semantic_seg'].dtype)
    check_shear(results_gt, results_sheared)

    # test PolygonMasks with shear horizontally, magnitude=1
    results = construct_toy_data(poly2mask=False)
    results_sheared = shear_module(copy.deepcopy(results))
    gt_masks = [[np.array([1, 0, 2, 0, 3, 1, 2, 1], dtype=np.float)]]
    results_gt['gt_masks'] = PolygonMasks(gt_masks, 2, 4)
    check_shear(results_gt, results_sheared)

    # test shear vertically, magnitude=-1
    img_fill_val = 128
    results = construct_toy_data()
    transform = dict(
        type='Shear',
        level=10,
        prob=1.,
        img_fill_val=img_fill_val,
        direction='vertical',
        max_shear_magnitude=1.,
        random_negative_prob=1.)
    shear_module = build_from_cfg(transform, PIPELINES)
    results_sheared = shear_module(copy.deepcopy(results))
    results_gt = copy.deepcopy(results)
    img_s = np.array([[1, 6, img_fill_val, img_fill_val],
                      [5, img_fill_val, img_fill_val, img_fill_val]],
                     dtype=np.uint8)
    img_s = np.stack([img_s, img_s, img_s], axis=-1)
    results_gt['img'] = img_s
    results_gt['gt_bboxes'] = np.empty((0, 4), dtype=np.float32)
    results_gt['gt_labels'] = np.empty((0, ), dtype=np.int64)
    results_gt['gt_bboxes_ignore'] = np.empty((0, 4), dtype=np.float32)
    gt_masks = np.array([[0, 1, 0, 0], [0, 0, 0, 0]],
                        dtype=np.uint8)[None, :, :]
    results_gt['gt_masks'] = BitmapMasks(gt_masks, 2, 4)
    results_gt['gt_semantic_seg'] = np.array(
        [[1, 6, 255, 255], [5, 255, 255, 255]],
        dtype=results['gt_semantic_seg'].dtype)
    check_shear(results_gt, results_sheared)

    # test PolygonMasks with shear vertically, magnitude=-1
    results = construct_toy_data(poly2mask=False)
    results_sheared = shear_module(copy.deepcopy(results))
    gt_masks = [[np.array([1, 0, 2, 0, 2, 0, 1, 0], dtype=np.float)]]
    results_gt['gt_masks'] = PolygonMasks(gt_masks, 2, 4)
    check_shear(results_gt, results_sheared)

    results = construct_toy_data()
    # same mask for BitmapMasks and PolygonMasks
    results['gt_masks'] = BitmapMasks(
        np.array([[0, 1, 1, 0], [0, 1, 1, 0]], dtype=np.uint8)[None, :, :], 2,
        4)
    results['gt_bboxes'] = np.array([[1., 0., 2., 1.]], dtype=np.float32)
    results_sheared_bitmap = shear_module(copy.deepcopy(results))
    check_shear(results_sheared_bitmap, results_sheared)

    # test AutoAugment equipped with Shear
    policies = [[dict(type='Shear', level=10, prob=1.)]]
    autoaug = dict(type='AutoAugment', policies=policies)
    autoaug_module = build_from_cfg(autoaug, PIPELINES)
    autoaug_module(copy.deepcopy(results))

    policies = [[
        dict(type='Shear', level=10, prob=1.),
        dict(
            type='Shear',
            level=8,
            img_fill_val=img_fill_val,
            direction='vertical',
            max_shear_magnitude=1.)
    ]]
    autoaug = dict(type='AutoAugment', policies=policies)
    autoaug_module = build_from_cfg(autoaug, PIPELINES)
    autoaug_module(copy.deepcopy(results))
コード例 #7
0
ファイル: test_rotate.py プロジェクト: xvjiarui/mmdetection
def test_rotate():
    # test assertion for invalid type of max_rotate_angle
    with pytest.raises(AssertionError):
        transform = dict(type='Rotate', level=1, max_rotate_angle=(30, ))
        build_from_cfg(transform, PIPELINES)

    # test assertion for invalid type of scale
    with pytest.raises(AssertionError):
        transform = dict(type='Rotate', level=2, scale=(1.2, ))
        build_from_cfg(transform, PIPELINES)

    # test ValueError for invalid type of img_fill_val
    with pytest.raises(ValueError):
        transform = dict(type='Rotate', level=2, img_fill_val=[
            128,
        ])
        build_from_cfg(transform, PIPELINES)

    # test assertion for invalid number of elements in center
    with pytest.raises(AssertionError):
        transform = dict(type='Rotate', level=2, center=(0.5, ))
        build_from_cfg(transform, PIPELINES)

    # test assertion for invalid type of center
    with pytest.raises(AssertionError):
        transform = dict(type='Rotate', level=2, center=[0, 0])
        build_from_cfg(transform, PIPELINES)

    # test case when no rotate aug (level=0)
    results = construct_toy_data()
    img_fill_val = (104, 116, 124)
    seg_ignore_label = 255
    transform = dict(
        type='Rotate',
        level=0,
        prob=1.,
        img_fill_val=img_fill_val,
        seg_ignore_label=seg_ignore_label,
    )
    rotate_module = build_from_cfg(transform, PIPELINES)
    results_wo_rotate = rotate_module(copy.deepcopy(results))
    check_result_same(results, results_wo_rotate)

    # test case when no rotate aug (prob<=0)
    transform = dict(type='Rotate',
                     level=10,
                     prob=0.,
                     img_fill_val=img_fill_val,
                     scale=0.6)
    rotate_module = build_from_cfg(transform, PIPELINES)
    results_wo_rotate = rotate_module(copy.deepcopy(results))
    check_result_same(results, results_wo_rotate)

    # test clockwise rotation with angle 90
    results = construct_toy_data()
    img_fill_val = 128
    transform = dict(
        type='Rotate',
        level=10,
        max_rotate_angle=90,
        img_fill_val=img_fill_val,
        # set random_negative_prob to 0 for clockwise rotation
        random_negative_prob=0.,
        prob=1.)
    rotate_module = build_from_cfg(transform, PIPELINES)
    results_rotated = rotate_module(copy.deepcopy(results))
    img_r = np.array([[img_fill_val, 6, 2, img_fill_val],
                      [img_fill_val, 7, 3, img_fill_val]]).astype(np.uint8)
    img_r = np.stack([img_r, img_r, img_r], axis=-1)
    results_gt = copy.deepcopy(results)
    results_gt['img'] = img_r
    results_gt['gt_bboxes'] = np.array([[1., 0., 2., 1.]], dtype=np.float32)
    results_gt['gt_bboxes_ignore'] = np.empty((0, 4), dtype=np.float32)
    gt_masks = np.array([[0, 1, 1, 0], [0, 0, 1, 0]],
                        dtype=np.uint8)[None, :, :]
    results_gt['gt_masks'] = BitmapMasks(gt_masks, 2, 4)
    results_gt['gt_semantic_seg'] = np.array(
        [[255, 6, 2, 255], [255, 7, 3,
                            255]]).astype(results['gt_semantic_seg'].dtype)
    check_result_same(results_gt, results_rotated)

    # test clockwise rotation with angle 90, PolygonMasks
    results = construct_toy_data(poly2mask=False)
    results_rotated = rotate_module(copy.deepcopy(results))
    gt_masks = [[np.array([2, 0, 2, 1, 1, 1, 1, 0], dtype=np.float)]]
    results_gt['gt_masks'] = PolygonMasks(gt_masks, 2, 4)
    check_result_same(results_gt, results_rotated)

    # test counter-clockwise roatation with angle 90,
    # and specify the ratation center
    img_fill_val = (104, 116, 124)
    transform = dict(
        type='Rotate',
        level=10,
        max_rotate_angle=90,
        center=(0, 0),
        img_fill_val=img_fill_val,
        # set random_negative_prob to 1 for counter-clockwise rotation
        random_negative_prob=1.,
        prob=1.)
    results = construct_toy_data()
    rotate_module = build_from_cfg(transform, PIPELINES)
    results_rotated = rotate_module(copy.deepcopy(results))
    results_gt = copy.deepcopy(results)
    h, w = results['img'].shape[:2]
    img_r = np.stack([
        np.ones((h, w)) * img_fill_val[0],
        np.ones((h, w)) * img_fill_val[1],
        np.ones((h, w)) * img_fill_val[2]
    ],
                     axis=-1).astype(np.uint8)
    img_r[0, 0, :] = 1
    img_r[0, 1, :] = 5
    results_gt['img'] = img_r
    results_gt['gt_bboxes'] = np.empty((0, 4), dtype=np.float32)
    results_gt['gt_bboxes_ignore'] = np.empty((0, 4), dtype=np.float32)
    results_gt['gt_labels'] = np.empty((0, ), dtype=np.int64)
    gt_masks = np.empty((0, h, w), dtype=np.uint8)
    results_gt['gt_masks'] = BitmapMasks(gt_masks, h, w)
    gt_seg = (np.ones((h, w)) * 255).astype(results['gt_semantic_seg'].dtype)
    gt_seg[0, 0], gt_seg[0, 1] = 1, 5
    results_gt['gt_semantic_seg'] = gt_seg
    check_result_same(results_gt, results_rotated)

    transform = dict(type='Rotate',
                     level=10,
                     max_rotate_angle=90,
                     center=(0),
                     img_fill_val=img_fill_val,
                     random_negative_prob=1.,
                     prob=1.)
    rotate_module = build_from_cfg(transform, PIPELINES)
    results_rotated = rotate_module(copy.deepcopy(results))
    check_result_same(results_gt, results_rotated)

    # test counter-clockwise roatation with angle 90,
    # and specify the ratation center, PolygonMasks
    results = construct_toy_data(poly2mask=False)
    results_rotated = rotate_module(copy.deepcopy(results))
    gt_masks = [[np.array([0, 0, 0, 0, 1, 0, 1, 0], dtype=np.float)]]
    results_gt['gt_masks'] = PolygonMasks(gt_masks, 2, 4)
    check_result_same(results_gt, results_rotated)

    # test AutoAugment equipped with Rotate
    policies = [[dict(type='Rotate', level=10, prob=1.)]]
    autoaug = dict(type='AutoAugment', policies=policies)
    autoaug_module = build_from_cfg(autoaug, PIPELINES)
    autoaug_module(copy.deepcopy(results))

    policies = [[
        dict(type='Rotate', level=10, prob=1.),
        dict(type='Rotate',
             level=8,
             max_rotate_angle=90,
             center=(0),
             img_fill_val=img_fill_val)
    ]]
    autoaug = dict(type='AutoAugment', policies=policies)
    autoaug_module = build_from_cfg(autoaug, PIPELINES)
    autoaug_module(copy.deepcopy(results))
コード例 #8
0
def test_mask2former_head_loss(num_stuff_classes, label_num):
    """Tests head loss when truth is empty and non-empty.

    Tests head loss as Panoptic Segmentation and Instance Segmentation. Tests
    forward_train and simple_test with masks and None as gt_semantic_seg
    """
    self = _init_model(num_stuff_classes)
    img_metas = [{
        'batch_input_shape': (128, 160),
        'pad_shape': (128, 160, 3),
        'img_shape': (126, 160, 3),
        'ori_shape': (63, 80, 3)
    }, {
        'batch_input_shape': (128, 160),
        'pad_shape': (128, 160, 3),
        'img_shape': (120, 160, 3),
        'ori_shape': (60, 80, 3)
    }]
    feats = [
        torch.rand((2, 64 * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i)))
        for i in range(4)
    ]
    all_cls_scores, all_mask_preds = self.forward(feats, img_metas)
    # Test that empty ground truth encourages the network to predict background
    gt_labels_list = [torch.LongTensor([]), torch.LongTensor([])]
    gt_masks_list = [
        torch.zeros((0, 128, 160)).long(),
        torch.zeros((0, 128, 160)).long()
    ]

    empty_gt_losses = self.loss(all_cls_scores, all_mask_preds, gt_labels_list,
                                gt_masks_list, img_metas)
    # When there is no truth, the cls loss should be nonzero but there should
    # be no mask loss.
    for key, loss in empty_gt_losses.items():
        if 'cls' in key:
            assert loss.item() > 0, 'cls loss should be non-zero'
        elif 'mask' in key:
            assert loss.item(
            ) == 0, 'there should be no mask loss when there are no true mask'
        elif 'dice' in key:
            assert loss.item(
            ) == 0, 'there should be no dice loss when there are no true mask'

    # when truth is non-empty then both cls, mask, dice loss should be nonzero
    # random inputs
    gt_labels_list = [
        torch.tensor([10, label_num]).long(),
        torch.tensor([label_num, 10]).long()
    ]
    mask1 = torch.zeros((2, 128, 160)).long()
    mask1[0, :50] = 1
    mask1[1, 50:] = 1
    mask2 = torch.zeros((2, 128, 160)).long()
    mask2[0, :, :50] = 1
    mask2[1, :, 50:] = 1
    gt_masks_list = [mask1, mask2]
    two_gt_losses = self.loss(all_cls_scores, all_mask_preds, gt_labels_list,
                              gt_masks_list, img_metas)
    for loss in two_gt_losses.values():
        assert loss.item() > 0, 'all loss should be non-zero'

    # test forward_train
    gt_bboxes = None
    gt_labels = [
        torch.tensor([10]).long(),
        torch.tensor([10]).long(),
    ]
    thing_mask1 = np.zeros((1, 128, 160), dtype=np.int32)
    thing_mask1[0, :50] = 1
    thing_mask2 = np.zeros((1, 128, 160), dtype=np.int32)
    thing_mask2[0, :, 50:] = 1
    gt_masks = [
        BitmapMasks(thing_mask1, 128, 160),
        BitmapMasks(thing_mask2, 128, 160),
    ]
    stuff_mask1 = torch.zeros((1, 128, 160)).long()
    stuff_mask1[0, :50] = 10
    stuff_mask1[0, 50:] = 100
    stuff_mask2 = torch.zeros((1, 128, 160)).long()
    stuff_mask2[0, :, 50:] = 10
    stuff_mask2[0, :, :50] = 100
    gt_semantic_seg = [stuff_mask1, stuff_mask2]

    self.forward_train(feats, img_metas, gt_bboxes, gt_labels, gt_masks,
                       gt_semantic_seg)

    # test when gt_semantic_seg is None
    gt_semantic_seg = None
    self.forward_train(feats, img_metas, gt_bboxes, gt_labels, gt_masks,
                       gt_semantic_seg)

    # test inference mode
    self.simple_test(feats, img_metas)