Ejemplo n.º 1
0
def test_transformer_encoder_pixel_decoder():
    base_channels = 64
    pixel_decoder_cfg = ConfigDict(
        dict(type='TransformerEncoderPixelDecoder',
             in_channels=[base_channels * 2**i for i in range(4)],
             feat_channels=base_channels,
             out_channels=base_channels,
             norm_cfg=dict(type='GN', num_groups=32),
             act_cfg=dict(type='ReLU'),
             encoder=dict(
                 type='DetrTransformerEncoder',
                 num_layers=6,
                 transformerlayers=dict(
                     type='BaseTransformerLayer',
                     attn_cfgs=dict(type='MultiheadAttention',
                                    embed_dims=base_channels,
                                    num_heads=8,
                                    attn_drop=0.1,
                                    proj_drop=0.1,
                                    dropout_layer=None,
                                    batch_first=False),
                     ffn_cfgs=dict(embed_dims=base_channels,
                                   feedforward_channels=base_channels * 8,
                                   num_fcs=2,
                                   act_cfg=dict(type='ReLU', inplace=True),
                                   ffn_drop=0.1,
                                   dropout_layer=None,
                                   add_identity=True),
                     operation_order=('self_attn', 'norm', 'ffn', 'norm'),
                     norm_cfg=dict(type='LN'),
                     init_cfg=None,
                     batch_first=False),
                 init_cfg=None),
             positional_encoding=dict(type='SinePositionalEncoding',
                                      num_feats=base_channels // 2,
                                      normalize=True)))
    self = build_plugin_layer(pixel_decoder_cfg)[1]
    img_metas = [{
        'batch_input_shape': (128, 160),
        'img_shape': (120, 160, 3),
    }, {
        'batch_input_shape': (128, 160),
        'img_shape': (125, 160, 3),
    }]
    feats = [
        torch.rand((2, base_channels * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i)))
        for i in range(4)
    ]
    mask_feature, memory = self(feats, img_metas)

    assert memory.shape[-2:] == feats[-1].shape[-2:]
    assert mask_feature.shape == feats[0].shape
Ejemplo n.º 2
0
def test_msdeformattn_pixel_decoder():
    base_channels = 64
    pixel_decoder_cfg = ConfigDict(
        dict(type='MSDeformAttnPixelDecoder',
             in_channels=[base_channels * 2**i for i in range(4)],
             strides=[4, 8, 16, 32],
             feat_channels=base_channels,
             out_channels=base_channels,
             num_outs=3,
             norm_cfg=dict(type='GN', num_groups=32),
             act_cfg=dict(type='ReLU'),
             encoder=dict(
                 type='DetrTransformerEncoder',
                 num_layers=6,
                 transformerlayers=dict(
                     type='BaseTransformerLayer',
                     attn_cfgs=dict(type='MultiScaleDeformableAttention',
                                    embed_dims=base_channels,
                                    num_heads=8,
                                    num_levels=3,
                                    num_points=4,
                                    im2col_step=64,
                                    dropout=0.0,
                                    batch_first=False,
                                    norm_cfg=None,
                                    init_cfg=None),
                     ffn_cfgs=dict(type='FFN',
                                   embed_dims=base_channels,
                                   feedforward_channels=base_channels * 4,
                                   num_fcs=2,
                                   ffn_drop=0.0,
                                   act_cfg=dict(type='ReLU', inplace=True)),
                     operation_order=('self_attn', 'norm', 'ffn', 'norm')),
                 init_cfg=None),
             positional_encoding=dict(type='SinePositionalEncoding',
                                      num_feats=base_channels // 2,
                                      normalize=True),
             init_cfg=None), )
    self = build_plugin_layer(pixel_decoder_cfg)[1]
    feats = [
        torch.rand((2, base_channels * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i)))
        for i in range(4)
    ]
    mask_feature, multi_scale_features = self(feats)

    assert mask_feature.shape == feats[0].shape
    assert len(multi_scale_features) == 3
    multi_scale_features = multi_scale_features[::-1]
    for i in range(3):
        assert multi_scale_features[i].shape[-2:] == feats[i + 1].shape[-2:]
Ejemplo n.º 3
0
def test_pixel_decoder():
    base_channels = 64
    pixel_decoder_cfg = ConfigDict(
        dict(type='PixelDecoder',
             in_channels=[base_channels * 2**i for i in range(4)],
             feat_channels=base_channels,
             out_channels=base_channels,
             norm_cfg=dict(type='GN', num_groups=32),
             act_cfg=dict(type='ReLU')))
    self = build_plugin_layer(pixel_decoder_cfg)[1]
    img_metas = [{}, {}]
    feats = [
        torch.rand((2, base_channels * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i)))
        for i in range(4)
    ]
    mask_feature, memory = self(feats, img_metas)

    assert (memory == feats[-1]).all()
    assert mask_feature.shape == feats[0].shape
Ejemplo n.º 4
0
 def make_block_plugins(self, in_channels, plugins):
     """make plugins for block.
     Args:
         in_channels (int): Input channels of plugin.
         plugins (list[dict]): List of plugins cfg to build.
     Returns:
         list[str]: List of the names of plugin.
     """
     assert isinstance(plugins, list)
     plugin_names = []
     for plugin in plugins:
         plugin = plugin.copy()
         name, layer = build_plugin_layer(plugin,
                                          in_channels=in_channels,
                                          postfix=plugin.pop('postfix', ''))
         assert not hasattr(self, name), f'duplicate plugin {name}'
         self.add_module(name, layer)
         plugin_names.append(name)
     return plugin_names
Ejemplo n.º 5
0
    def __init__(self,
                 in_channels,
                 feat_channels,
                 out_channels,
                 num_things_classes=80,
                 num_stuff_classes=53,
                 num_queries=100,
                 pixel_decoder=None,
                 enforce_decoder_input_project=False,
                 transformer_decoder=None,
                 positional_encoding=None,
                 loss_cls=dict(type='CrossEntropyLoss',
                               bg_cls_weight=0.1,
                               use_sigmoid=False,
                               loss_weight=1.0,
                               class_weight=1.0),
                 loss_mask=dict(type='FocalLoss',
                                use_sigmoid=True,
                                gamma=2.0,
                                alpha=0.25,
                                loss_weight=20.0),
                 loss_dice=dict(type='DiceLoss',
                                use_sigmoid=True,
                                activate=True,
                                naive_dice=True,
                                loss_weight=1.0),
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=None,
                 **kwargs):
        super(AnchorFreeHead, self).__init__(init_cfg)
        self.num_things_classes = num_things_classes
        self.num_stuff_classes = num_stuff_classes
        self.num_classes = self.num_things_classes + self.num_stuff_classes
        self.num_queries = num_queries

        pixel_decoder.update(in_channels=in_channels,
                             feat_channels=feat_channels,
                             out_channels=out_channels)
        self.pixel_decoder = build_plugin_layer(pixel_decoder)[1]
        self.transformer_decoder = build_transformer_layer_sequence(
            transformer_decoder)
        self.decoder_embed_dims = self.transformer_decoder.embed_dims
        pixel_decoder_type = pixel_decoder.get('type')
        if pixel_decoder_type == 'PixelDecoder' and (
                self.decoder_embed_dims != in_channels[-1]
                or enforce_decoder_input_project):
            self.decoder_input_proj = Conv2d(in_channels[-1],
                                             self.decoder_embed_dims,
                                             kernel_size=1)
        else:
            self.decoder_input_proj = nn.Identity()
        self.decoder_pe = build_positional_encoding(positional_encoding)
        self.query_embed = nn.Embedding(self.num_queries, out_channels)

        self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
        self.mask_embed = nn.Sequential(
            nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
            nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
            nn.Linear(feat_channels, out_channels))

        self.test_cfg = test_cfg
        self.train_cfg = train_cfg
        if train_cfg:
            assert 'assigner' in train_cfg, 'assigner should be provided '\
                'when train_cfg is set.'
            assigner = train_cfg['assigner']
            self.assigner = build_assigner(assigner)
            sampler_cfg = dict(type='MaskPseudoSampler')
            self.sampler = build_sampler(sampler_cfg, context=self)

        self.bg_cls_weight = 0
        class_weight = loss_cls.get('class_weight', None)
        if class_weight is not None and (self.__class__ is MaskFormerHead):
            assert isinstance(class_weight, float), 'Expected ' \
                'class_weight to have type float. Found ' \
                f'{type(class_weight)}.'
            # NOTE following the official MaskFormerHead repo, bg_cls_weight
            # means relative classification weight of the VOID class.
            bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight)
            assert isinstance(bg_cls_weight, float), 'Expected ' \
                'bg_cls_weight to have type float. Found ' \
                f'{type(bg_cls_weight)}.'
            class_weight = torch.ones(self.num_classes + 1) * class_weight
            # set VOID class as the last indice
            class_weight[self.num_classes] = bg_cls_weight
            loss_cls.update({'class_weight': class_weight})
            if 'bg_cls_weight' in loss_cls:
                loss_cls.pop('bg_cls_weight')
            self.bg_cls_weight = bg_cls_weight
        self.loss_cls = build_loss(loss_cls)
        self.loss_mask = build_loss(loss_mask)
        self.loss_dice = build_loss(loss_dice)
Ejemplo n.º 6
0
    def __init__(self,
                 in_channels,
                 feat_channels,
                 out_channels,
                 num_things_classes=80,
                 num_stuff_classes=53,
                 num_queries=100,
                 num_transformer_feat_level=3,
                 pixel_decoder=None,
                 enforce_decoder_input_project=False,
                 transformer_decoder=None,
                 positional_encoding=None,
                 loss_cls=None,
                 loss_mask=None,
                 loss_dice=None,
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=None,
                 **kwargs):
        super(AnchorFreeHead, self).__init__(init_cfg)
        self.num_things_classes = num_things_classes
        self.num_stuff_classes = num_stuff_classes
        self.num_classes = self.num_things_classes + self.num_stuff_classes
        self.num_queries = num_queries
        self.num_transformer_feat_level = num_transformer_feat_level
        self.num_heads = transformer_decoder.transformerlayers.\
            attn_cfgs.num_heads
        self.num_transformer_decoder_layers = transformer_decoder.num_layers
        assert pixel_decoder.encoder.transformerlayers.\
            attn_cfgs.num_levels == num_transformer_feat_level
        pixel_decoder_ = copy.deepcopy(pixel_decoder)
        pixel_decoder_.update(in_channels=in_channels,
                              feat_channels=feat_channels,
                              out_channels=out_channels)
        self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1]
        self.transformer_decoder = build_transformer_layer_sequence(
            transformer_decoder)
        self.decoder_embed_dims = self.transformer_decoder.embed_dims

        self.decoder_input_projs = ModuleList()
        # from low resolution to high resolution
        for _ in range(num_transformer_feat_level):
            if (self.decoder_embed_dims != feat_channels
                    or enforce_decoder_input_project):
                self.decoder_input_projs.append(
                    Conv2d(feat_channels,
                           self.decoder_embed_dims,
                           kernel_size=1))
            else:
                self.decoder_input_projs.append(nn.Identity())
        self.decoder_positional_encoding = build_positional_encoding(
            positional_encoding)
        self.query_embed = nn.Embedding(self.num_queries, feat_channels)
        self.query_feat = nn.Embedding(self.num_queries, feat_channels)
        # from low resolution to high resolution
        self.level_embed = nn.Embedding(self.num_transformer_feat_level,
                                        feat_channels)

        self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
        self.mask_embed = nn.Sequential(
            nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
            nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
            nn.Linear(feat_channels, out_channels))

        self.test_cfg = test_cfg
        self.train_cfg = train_cfg
        if train_cfg:
            self.assigner = build_assigner(self.train_cfg.assigner)
            self.sampler = build_sampler(self.train_cfg.sampler, context=self)
            self.num_points = self.train_cfg.get('num_points', 12544)
            self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
            self.importance_sample_ratio = self.train_cfg.get(
                'importance_sample_ratio', 0.75)

        self.class_weight = loss_cls.class_weight
        self.loss_cls = build_loss(loss_cls)
        self.loss_mask = build_loss(loss_mask)
        self.loss_dice = build_loss(loss_dice)
Ejemplo n.º 7
0
    def _make_stage_plugins(self, plugins, stage_idx):
        """Make plugins for ResNet ``stage_idx`` th stage.

        Currently we support inserting ``nn.Maxpooling``,
        ``mmcv.cnn.Convmodule``into the backbone. Originally designed
        for ResNet31-like architectures.

        Examples:
            >>> plugins=[
            ...     dict(cfg=dict(type="Maxpooling", arg=(2,2)),
            ...          stages=(True, True, False, False),
            ...          position='before_stage'),
            ...     dict(cfg=dict(type="Maxpooling", arg=(2,1)),
            ...          stages=(False, False, True, Flase),
            ...          position='before_stage'),
            ...     dict(cfg=dict(
            ...              type='ConvModule',
            ...              kernel_size=3,
            ...              stride=1,
            ...              padding=1,
            ...              norm_cfg=dict(type='BN'),
            ...              act_cfg=dict(type='ReLU')),
            ...          stages=(True, True, True, True),
            ...          position='after_stage')]

        Suppose ``stage_idx=1``, the structure of stage would be:

        .. code-block:: none

            Maxpooling -> A set of Basicblocks -> ConvModule

        Args:
            plugins (list[dict]): List of plugins cfg to build.
            stage_idx (int): Index of stage to build

        Returns:
            list[dict]: Plugins for current stage
        """
        in_channels = self.arch_channels[stage_idx]
        self.plugin_ahead_names.append([])
        self.plugin_after_names.append([])
        for plugin in plugins:
            plugin = plugin.copy()
            stages = plugin.pop('stages', None)
            position = plugin.pop('position', None)
            assert stages is None or len(stages) == self.num_stages
            if stages[stage_idx]:
                if position == 'before_stage':
                    name, layer = build_plugin_layer(
                        plugin['cfg'],
                        f'_before_stage_{stage_idx+1}',
                        in_channels=in_channels,
                        out_channels=in_channels)
                    self.plugin_ahead_names[stage_idx].append(name)
                    self.add_module(name, layer)
                elif position == 'after_stage':
                    name, layer = build_plugin_layer(
                        plugin['cfg'],
                        f'_after_stage_{stage_idx+1}',
                        in_channels=in_channels,
                        out_channels=in_channels)
                    self.plugin_after_names[stage_idx].append(name)
                    self.add_module(name, layer)
                else:
                    raise ValueError('uncorrect plugin position')
Ejemplo n.º 8
0
    def __init__(self,
                 in_channels,
                 feat_channels,
                 out_channels,
                 num_things_classes=80,
                 num_stuff_classes=53,
                 num_queries=100,
                 pixel_decoder=None,
                 enforce_decoder_input_project=False,
                 transformer_decoder=None,
                 positional_encoding=None,
                 loss_cls=dict(type='CrossEntropyLoss',
                               use_sigmoid=False,
                               loss_weight=1.0,
                               class_weight=[1.0] * 133 + [0.1]),
                 loss_mask=dict(type='FocalLoss',
                                use_sigmoid=True,
                                gamma=2.0,
                                alpha=0.25,
                                loss_weight=20.0),
                 loss_dice=dict(type='DiceLoss',
                                use_sigmoid=True,
                                activate=True,
                                naive_dice=True,
                                loss_weight=1.0),
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=None,
                 **kwargs):
        super(AnchorFreeHead, self).__init__(init_cfg)
        self.num_things_classes = num_things_classes
        self.num_stuff_classes = num_stuff_classes
        self.num_classes = self.num_things_classes + self.num_stuff_classes
        self.num_queries = num_queries

        pixel_decoder.update(in_channels=in_channels,
                             feat_channels=feat_channels,
                             out_channels=out_channels)
        self.pixel_decoder = build_plugin_layer(pixel_decoder)[1]
        self.transformer_decoder = build_transformer_layer_sequence(
            transformer_decoder)
        self.decoder_embed_dims = self.transformer_decoder.embed_dims
        pixel_decoder_type = pixel_decoder.get('type')
        if pixel_decoder_type == 'PixelDecoder' and (
                self.decoder_embed_dims != in_channels[-1]
                or enforce_decoder_input_project):
            self.decoder_input_proj = Conv2d(in_channels[-1],
                                             self.decoder_embed_dims,
                                             kernel_size=1)
        else:
            self.decoder_input_proj = nn.Identity()
        self.decoder_pe = build_positional_encoding(positional_encoding)
        self.query_embed = nn.Embedding(self.num_queries, out_channels)

        self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
        self.mask_embed = nn.Sequential(
            nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
            nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
            nn.Linear(feat_channels, out_channels))

        self.test_cfg = test_cfg
        self.train_cfg = train_cfg
        if train_cfg:
            self.assigner = build_assigner(train_cfg.assigner)
            self.sampler = build_sampler(train_cfg.sampler, context=self)

        self.class_weight = loss_cls.class_weight
        self.loss_cls = build_loss(loss_cls)
        self.loss_mask = build_loss(loss_mask)
        self.loss_dice = build_loss(loss_dice)