Exemple #1
0
    def __init__(self,
                 embed_dims,
                 feedforward_channels,
                 act_cfg=dict(type='GELU'),
                 ffn_drop=0.,
                 init_cfg=None):
        super(MixFFN, self).__init__(init_cfg=init_cfg)

        self.embed_dims = embed_dims
        self.feedforward_channels = feedforward_channels
        self.act_cfg = act_cfg

        self.fc1 = Conv2d(in_channels=embed_dims,
                          out_channels=feedforward_channels,
                          kernel_size=1)
        self.dwconv = Conv2d(in_channels=feedforward_channels,
                             out_channels=feedforward_channels,
                             kernel_size=3,
                             stride=1,
                             padding=1,
                             bias=True,
                             groups=feedforward_channels)
        self.act = build_activation_layer(act_cfg)
        self.fc2 = Conv2d(in_channels=feedforward_channels,
                          out_channels=embed_dims,
                          kernel_size=1)
        self.drop = nn.Dropout(ffn_drop)
Exemple #2
0
    def __init__(self, embed_dims, act_cfg=dict(type='GELU'), init_cfg=None):
        super(SpatialAttention, self).__init__(init_cfg=init_cfg)

        self.proj_1 = Conv2d(in_channels=embed_dims,
                             out_channels=embed_dims,
                             kernel_size=1)
        self.activation = build_activation_layer(act_cfg)
        self.spatial_gating_unit = LKA(embed_dims)
        self.proj_2 = Conv2d(in_channels=embed_dims,
                             out_channels=embed_dims,
                             kernel_size=1)
Exemple #3
0
    def __init__(self,
                 num_convs=2,
                 in_channels=256,
                 conv_kernel_size=3,
                 conv_out_channels=256,
                 polygon_size=None,
                 conv_cfg=None,
                 norm_cfg=None):
        super(VertexHead, self).__init__()
        self.num_convs = num_convs
        self.in_channels = in_channels
        self.conv_kernel_size = conv_kernel_size
        self.conv_out_channels = conv_out_channels

        self.convs = nn.ModuleList()
        for i in range(self.num_convs):
            in_channels = (self.in_channels
                           if i == 0 else self.conv_out_channels)
            padding = (self.conv_kernel_size - 1) // 2
            self.convs.append(
                ConvModule(in_channels,
                           self.conv_out_channels,
                           self.conv_kernel_size,
                           padding=padding,
                           conv_cfg=conv_cfg,
                           norm_cfg=norm_cfg))
        self.conv_logits = Conv2d(self.conv_out_channels, 1, 1)
        self.polygon_size = polygon_size
Exemple #4
0
    def __init__(self,
                 embed_dims,
                 num_heads,
                 norm_cfg=dict(type='LN'),
                 qkv_bias=True,
                 sr_ratio=1,
                 **kwargs):
        super(GlobalSubsampledAttention,
              self).__init__(embed_dims, num_heads, **kwargs)

        self.qkv_bias = qkv_bias
        self.q = nn.Linear(self.input_dims, embed_dims, bias=qkv_bias)
        self.kv = nn.Linear(self.input_dims, embed_dims * 2, bias=qkv_bias)

        # remove self.qkv, here split into self.q, self.kv
        delattr(self, 'qkv')

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            # use a conv as the spatial-reduction operation, the kernel_size
            # and stride in conv are equal to the sr_ratio.
            self.sr = Conv2d(in_channels=embed_dims,
                             out_channels=embed_dims,
                             kernel_size=sr_ratio,
                             stride=sr_ratio)
            # The ret[0] of build_norm_layer is norm name.
            self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
Exemple #5
0
    def __init__(self,
                 embed_dims,
                 num_heads,
                 attn_drop=0.,
                 proj_drop=0.,
                 dropout_layer=None,
                 init_cfg=None,
                 batch_first=True,
                 qkv_bias=False,
                 norm_cfg=dict(type='LN'),
                 sr_ratio=1):
        super().__init__(embed_dims,
                         num_heads,
                         attn_drop,
                         proj_drop,
                         dropout_layer=dropout_layer,
                         init_cfg=init_cfg,
                         batch_first=batch_first,
                         bias=qkv_bias)

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = Conv2d(in_channels=embed_dims,
                             out_channels=embed_dims,
                             kernel_size=sr_ratio,
                             stride=sr_ratio)
            # The ret[0] of build_norm_layer is norm name.
            self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
Exemple #6
0
    def __init__(self,
                 embed_dims,
                 feedforward_channels,
                 act_cfg=dict(type='GELU'),
                 ffn_drop=0.,
                 dropout_layer=None,
                 use_conv=False,
                 init_cfg=None):
        super(MixFFN, self).__init__(init_cfg=init_cfg)

        self.embed_dims = embed_dims
        self.feedforward_channels = feedforward_channels
        self.act_cfg = act_cfg
        activate = build_activation_layer(act_cfg)

        in_channels = embed_dims
        fc1 = Conv2d(
            in_channels=in_channels,
            out_channels=feedforward_channels,
            kernel_size=1,
            stride=1,
            bias=True)
        if use_conv:
            # 3x3 depth wise conv to provide positional encode information
            dw_conv = Conv2d(
                in_channels=feedforward_channels,
                out_channels=feedforward_channels,
                kernel_size=3,
                stride=1,
                padding=(3 - 1) // 2,
                bias=True,
                groups=feedforward_channels)
        fc2 = Conv2d(
            in_channels=feedforward_channels,
            out_channels=in_channels,
            kernel_size=1,
            stride=1,
            bias=True)
        drop = nn.Dropout(ffn_drop)
        layers = [fc1, activate, drop, fc2, drop]
        if use_conv:
            layers.insert(1, dw_conv)
        self.layers = Sequential(*layers)
        self.dropout_layer = build_dropout(
            dropout_layer) if dropout_layer else torch.nn.Identity()
Exemple #7
0
    def __init__(self,
                 num_convs=4,
                 num_fcs=2,
                 roi_feat_size=14,
                 in_channels=256,
                 conv_out_channels=256,
                 fc_out_channels=1024,
                 num_classes=80,
                 loss_iou=dict(type='MSELoss', loss_weight=0.5),
                 init_cfg=[
                     dict(type='Kaiming', override=dict(name='convs')),
                     dict(type='Caffe2Xavier', override=dict(name='fcs')),
                     dict(
                         type='Normal',
                         std=0.01,
                         override=dict(name='fc_mask_iou'))
                 ]):
        super(MaskIoUHead, self).__init__(init_cfg)
        self.in_channels = in_channels
        self.conv_out_channels = conv_out_channels
        self.fc_out_channels = fc_out_channels
        self.num_classes = num_classes
        self.fp16_enabled = False

        self.convs = nn.ModuleList()
        for i in range(num_convs):
            if i == 0:
                # concatenation of mask feature and mask prediction
                in_channels = self.in_channels + 1
            else:
                in_channels = self.conv_out_channels
            stride = 2 if i == num_convs - 1 else 1
            self.convs.append(
                Conv2d(
                    in_channels,
                    self.conv_out_channels,
                    3,
                    stride=stride,
                    padding=1))

        roi_feat_size = _pair(roi_feat_size)
        pooled_area = (roi_feat_size[0] // 2) * (roi_feat_size[1] // 2)
        self.fcs = nn.ModuleList()
        for i in range(num_fcs):
            in_channels = (
                self.conv_out_channels *
                pooled_area if i == 0 else self.fc_out_channels)
            self.fcs.append(Linear(in_channels, self.fc_out_channels))

        self.fc_mask_iou = Linear(self.fc_out_channels, self.num_classes)
        self.relu = nn.ReLU()
        self.max_pool = MaxPool2d(2, 2)
        self.loss_iou = build_loss(loss_iou)
Exemple #8
0
    def __init__(self, embed_dims, init_cfg=None):
        super(LKA, self).__init__(init_cfg=init_cfg)

        # a spatial local convolution (depth-wise convolution)
        self.DW_conv = Conv2d(in_channels=embed_dims,
                              out_channels=embed_dims,
                              kernel_size=5,
                              padding=2,
                              groups=embed_dims)

        # a spatial long-range convolution (depth-wise dilation convolution)
        self.DW_D_conv = Conv2d(in_channels=embed_dims,
                                out_channels=embed_dims,
                                kernel_size=7,
                                stride=1,
                                padding=9,
                                groups=embed_dims,
                                dilation=3)

        self.conv1 = Conv2d(in_channels=embed_dims,
                            out_channels=embed_dims,
                            kernel_size=1)
Exemple #9
0
 def _init_layers(self):
     """Initialize layers of the transformer head."""
     self.input_proj = Conv2d(
         self.in_channels, self.embed_dims, kernel_size=1)
     self.fc_cls = Linear(self.embed_dims, self.cls_out_channels)
     self.reg_ffn = FFN(
         self.embed_dims,
         self.embed_dims,
         self.num_fcs,
         self.act_cfg,
         dropout=0.0,
         add_residual=False)
     self.fc_reg = Linear(self.embed_dims, 4)
     self.query_embedding = nn.Embedding(self.num_query, self.embed_dims)
 def __init__(self,
              img_size=224,
              patch_size=16,
              in_channels=3,
              embed_dim=768):
     super(PatchEmbed, self).__init__()
     if isinstance(img_size, int):
         self.img_size = (img_size, img_size)
     elif isinstance(img_size, tuple):
         self.img_size = img_size
     else:
         raise TypeError('img_size must be type of int or tuple')
     h, w = self.img_size
     self.patch_size = (patch_size, patch_size)
     self.num_patches = (h // patch_size) * (w // patch_size)
     self.proj = Conv2d(
         in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
Exemple #11
0
    def __init__(self,
                 in_channels,
                 feat_channels,
                 out_channels,
                 norm_cfg=dict(type='GN', num_groups=32),
                 act_cfg=dict(type='ReLU'),
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.in_channels = in_channels
        self.num_inputs = len(in_channels)
        self.lateral_convs = ModuleList()
        self.output_convs = ModuleList()
        self.use_bias = norm_cfg is None
        for i in range(0, self.num_inputs - 1):
            lateral_conv = ConvModule(in_channels[i],
                                      feat_channels,
                                      kernel_size=1,
                                      bias=self.use_bias,
                                      norm_cfg=norm_cfg,
                                      act_cfg=None)
            output_conv = ConvModule(feat_channels,
                                     feat_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1,
                                     bias=self.use_bias,
                                     norm_cfg=norm_cfg,
                                     act_cfg=act_cfg)
            self.lateral_convs.append(lateral_conv)
            self.output_convs.append(output_conv)

        self.last_feat_conv = ConvModule(in_channels[-1],
                                         feat_channels,
                                         kernel_size=3,
                                         padding=1,
                                         stride=1,
                                         bias=self.use_bias,
                                         norm_cfg=norm_cfg,
                                         act_cfg=act_cfg)
        self.mask_feature = Conv2d(feat_channels,
                                   out_channels,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1)
Exemple #12
0
    def __init__(self,
                 embed_dims,
                 num_heads,
                 attn_drop=0.,
                 proj_drop=0.,
                 dropout_layer=None,
                 batch_first=True,
                 qkv_bias=True,
                 norm_cfg=dict(type='LN'),
                 sr_ratio=1,
                 init_cfg=None):
        super().__init__(
            embed_dims,
            num_heads,
            attn_drop,
            proj_drop,
            batch_first=batch_first,
            dropout_layer=dropout_layer,
            bias=qkv_bias,
            init_cfg=init_cfg)

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = Conv2d(
                in_channels=embed_dims,
                out_channels=embed_dims,
                kernel_size=sr_ratio,
                stride=sr_ratio)
            # The ret[0] of build_norm_layer is norm name.
            self.norm = build_norm_layer(norm_cfg, embed_dims)[1]

        # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
        from mmdet import mmcv_version, digit_version
        if mmcv_version < digit_version('1.3.17'):
            warnings.warn('The legacy version of forward function in'
                          'SpatialReductionAttention is deprecated in'
                          'mmcv>=1.3.17 and will no longer support in the'
                          'future. Please upgrade your mmcv.')
            self.forward = self.legacy_forward
Exemple #13
0
    def __init__(self,
                 in_channels,
                 feat_channels,
                 out_channels,
                 norm_cfg=dict(type='GN', num_groups=32),
                 act_cfg=dict(type='ReLU'),
                 encoder=None,
                 positional_encoding=dict(type='SinePositionalEncoding',
                                          num_feats=128,
                                          normalize=True),
                 init_cfg=None):
        super(TransformerEncoderPixelDecoder, self).__init__(in_channels,
                                                             feat_channels,
                                                             out_channels,
                                                             norm_cfg,
                                                             act_cfg,
                                                             init_cfg=init_cfg)
        self.last_feat_conv = None

        self.encoder = build_transformer_layer_sequence(encoder)
        self.encoder_embed_dims = self.encoder.embed_dims
        assert self.encoder_embed_dims == feat_channels, 'embed_dims({}) of ' \
            'tranformer encoder must equal to feat_channels({})'.format(
                feat_channels, self.encoder_embed_dims)
        self.positional_encoding = build_positional_encoding(
            positional_encoding)
        self.encoder_in_proj = Conv2d(in_channels[-1],
                                      feat_channels,
                                      kernel_size=1)
        self.encoder_out_proj = ConvModule(feat_channels,
                                           feat_channels,
                                           kernel_size=3,
                                           stride=1,
                                           padding=1,
                                           bias=self.use_bias,
                                           norm_cfg=norm_cfg,
                                           act_cfg=act_cfg)
Exemple #14
0
    def __init__(self,
                 num_convs=2,
                 in_channels=256,
                 conv_kernel_size=3,
                 conv_out_channels=128,
                 hidden_channels=64,
                 num_layers=2,
                 feat_size=7,
                 polygon_size=None,
                 max_time_step=10,
                 use_attention=False,
                 attention_type=1,
                 use_coord=False,
                 coord_type=1,
                 use_bn=False,
                 conv_cfg=None,
                 norm_cfg=None,
                 sample_vertex=None,
                 beam_step=0,
                 use_mask_pred=False,
                 weight_kernel_params=dict(kernel_size=1, type='constant'),
                 loss_type=0,
                 act_test='softmax',
                 with_offset=False,
                 dilation_params=dict(with_dilation=False,
                                      dilations=[3, 3, 3, 3],
                                      num_convs=4),
                 vertex_edge_params=dict(vertex_channels=64,
                                         edge_channels=64,
                                         type=0)):
        super(PolyRnnHead, self).__init__()

        self.num_convs = num_convs
        self.in_channels = in_channels
        self.conv_kernel_size = conv_kernel_size
        self.conv_out_channels = conv_out_channels
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers

        self.max_time_step = max_time_step
        self.use_attention = use_attention
        self.use_coord = use_coord
        self.coord_type = coord_type
        self.use_bn = use_bn
        if polygon_size is None:
            polygon_size = feat_size
        self.feat_size = feat_size
        self.polygon_size = polygon_size
        self.dilation_params = dilation_params
        self.beam_step = beam_step
        self.use_mask_pred = use_mask_pred
        self.weight_kernel_params = weight_kernel_params
        self.loss_type = loss_type
        self.act_test = act_test
        self.vertex_edge_params = vertex_edge_params
        self.with_offset = with_offset
        self.init_kernel()

        self.convs = nn.ModuleList()
        for i in range(self.num_convs):
            in_channels = (self.in_channels
                           if i == 0 else self.conv_out_channels)
            if i == 0 and use_coord:
                in_channels += 2
            if i == 0 and (self.vertex_edge_params['type'] == 1
                           or self.vertex_edge_params['type'] == 3):
                in_channels += (self.vertex_edge_params['vertex_channels'] +
                                self.vertex_edge_params['edge_channels'])
            if i == 0 and (self.vertex_edge_params['type'] == 4
                           or self.vertex_edge_params['type'] == 6):
                in_channels += 2
            padding = (self.conv_kernel_size - 1) // 2
            self.convs.append(
                ConvModule(in_channels,
                           self.conv_out_channels,
                           self.conv_kernel_size,
                           padding=padding,
                           conv_cfg=conv_cfg,
                           norm_cfg=norm_cfg))
        if dilation_params.get('with_dilation', False):
            dilations = dilation_params.get('dilations', None)
            num_convs = dilation_params.get('num_convs', 4)
            if dilations is None:
                dilations = [1 for _ in range(num_convs)]
            else:
                assert len(dilations) == num_convs
            for dilation in dilations:
                self.convs.append(
                    Bottleneck(self.conv_out_channels,
                               self.conv_out_channels // 2,
                               dilation=dilation))

        self.conv_x = nn.ModuleList()
        self.conv_h = nn.ModuleList()
        self.bn_x = nn.ModuleList()
        self.bn_h = nn.ModuleList()
        self.bn_c = nn.ModuleList()

        padding = conv_kernel_size // 2
        for l in range(self.num_layers):
            if l != 0:
                in_channels = self.hidden_channels
            else:
                in_channels = self.conv_out_channels + 3
                if self.use_coord and self.coord_type == 2:
                    in_channels += 2
                if self.vertex_edge_params[
                        'type'] == 2 or self.vertex_edge_params['type'] == 3:
                    in_channels += (
                        self.vertex_edge_params['vertex_channels'] +
                        self.vertex_edge_params['edge_channels'])
                if self.vertex_edge_params[
                        'type'] == 5 or self.vertex_edge_params['type'] == 6:
                    in_channels += 2
            self.conv_x.append(
                Conv2d(in_channels,
                       4 * hidden_channels,
                       kernel_size=conv_kernel_size,
                       padding=padding))
            self.conv_h.append(
                Conv2d(hidden_channels,
                       4 * hidden_channels,
                       kernel_size=conv_kernel_size,
                       padding=padding))

            if self.use_bn:
                self.bn_x.append(
                    nn.ModuleList([
                        nn.BatchNorm2d(4 * hidden_channels)
                        for i in range(max_time_step - 1)
                    ]))
                self.bn_h.append(
                    nn.ModuleList([
                        nn.BatchNorm2d(4 * hidden_channels)
                        for i in range(max_time_step - 1)
                    ]))
                self.bn_c.append(
                    nn.ModuleList([
                        nn.BatchNorm2d(hidden_channels)
                        for i in range(max_time_step - 1)
                    ]))

        if self.use_attention and attention_type != 3:
            self.conv_atten = ConvModule(hidden_channels * num_layers,
                                         1,
                                         kernel_size=1,
                                         act_cfg=None)

        self.fc_out = nn.Linear(self.feat_size**2 * hidden_channels,
                                self.polygon_size**2 + 1)
        if self.with_offset:
            self.fc_offset = nn.Linear(self.feat_size**2 * hidden_channels, 2)
        self.attention = getattr(self, 'attention_%d' % attention_type)
        if attention_type == 3:
            self.conv_atten = ConvModule(hidden_channels * num_layers,
                                         self.conv_out_channels,
                                         kernel_size=1,
                                         act_cfg=None)
            self.atten_hidden = ConvModule(self.conv_out_channels,
                                           1,
                                           kernel_size=1,
                                           act_cfg=None)
        if beam_step != 0:
            self.forward = self.forward_beam
Exemple #15
0
    def __init__(self,
                 in_channels,
                 feat_channels,
                 out_channels,
                 num_things_classes=80,
                 num_stuff_classes=53,
                 num_queries=100,
                 pixel_decoder=None,
                 enforce_decoder_input_project=False,
                 transformer_decoder=None,
                 positional_encoding=None,
                 loss_cls=dict(type='CrossEntropyLoss',
                               bg_cls_weight=0.1,
                               use_sigmoid=False,
                               loss_weight=1.0,
                               class_weight=1.0),
                 loss_mask=dict(type='FocalLoss',
                                use_sigmoid=True,
                                gamma=2.0,
                                alpha=0.25,
                                loss_weight=20.0),
                 loss_dice=dict(type='DiceLoss',
                                use_sigmoid=True,
                                activate=True,
                                naive_dice=True,
                                loss_weight=1.0),
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=None,
                 **kwargs):
        super(AnchorFreeHead, self).__init__(init_cfg)
        self.num_things_classes = num_things_classes
        self.num_stuff_classes = num_stuff_classes
        self.num_classes = self.num_things_classes + self.num_stuff_classes
        self.num_queries = num_queries

        pixel_decoder.update(in_channels=in_channels,
                             feat_channels=feat_channels,
                             out_channels=out_channels)
        self.pixel_decoder = build_plugin_layer(pixel_decoder)[1]
        self.transformer_decoder = build_transformer_layer_sequence(
            transformer_decoder)
        self.decoder_embed_dims = self.transformer_decoder.embed_dims
        pixel_decoder_type = pixel_decoder.get('type')
        if pixel_decoder_type == 'PixelDecoder' and (
                self.decoder_embed_dims != in_channels[-1]
                or enforce_decoder_input_project):
            self.decoder_input_proj = Conv2d(in_channels[-1],
                                             self.decoder_embed_dims,
                                             kernel_size=1)
        else:
            self.decoder_input_proj = nn.Identity()
        self.decoder_pe = build_positional_encoding(positional_encoding)
        self.query_embed = nn.Embedding(self.num_queries, out_channels)

        self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
        self.mask_embed = nn.Sequential(
            nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
            nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
            nn.Linear(feat_channels, out_channels))

        self.test_cfg = test_cfg
        self.train_cfg = train_cfg
        if train_cfg:
            assert 'assigner' in train_cfg, 'assigner should be provided '\
                'when train_cfg is set.'
            assigner = train_cfg['assigner']
            self.assigner = build_assigner(assigner)
            sampler_cfg = dict(type='MaskPseudoSampler')
            self.sampler = build_sampler(sampler_cfg, context=self)

        self.bg_cls_weight = 0
        class_weight = loss_cls.get('class_weight', None)
        if class_weight is not None and (self.__class__ is MaskFormerHead):
            assert isinstance(class_weight, float), 'Expected ' \
                'class_weight to have type float. Found ' \
                f'{type(class_weight)}.'
            # NOTE following the official MaskFormerHead repo, bg_cls_weight
            # means relative classification weight of the VOID class.
            bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight)
            assert isinstance(bg_cls_weight, float), 'Expected ' \
                'bg_cls_weight to have type float. Found ' \
                f'{type(bg_cls_weight)}.'
            class_weight = torch.ones(self.num_classes + 1) * class_weight
            # set VOID class as the last indice
            class_weight[self.num_classes] = bg_cls_weight
            loss_cls.update({'class_weight': class_weight})
            if 'bg_cls_weight' in loss_cls:
                loss_cls.pop('bg_cls_weight')
            self.bg_cls_weight = bg_cls_weight
        self.loss_cls = build_loss(loss_cls)
        self.loss_mask = build_loss(loss_mask)
        self.loss_dice = build_loss(loss_dice)
Exemple #16
0
    def __init__(self,
                 num_convs=4,
                 roi_feat_size=14,
                 in_channels=256,
                 conv_kernel_size=3,
                 conv_out_channels=256,
                 num_classes=80,
                 class_agnostic=False,
                 upsample_cfg=dict(type='deconv', scale_factor=2),
                 conv_cfg=None,
                 norm_cfg=None,
                 conv_to_res=False,
                 loss_mask=dict(type='CrossEntropyLoss',
                                use_mask=True,
                                loss_weight=1.0)):
        super(FCNMaskHead, self).__init__()
        self.upsample_cfg = upsample_cfg.copy()
        if self.upsample_cfg['type'] not in [
                None, 'deconv', 'nearest', 'bilinear', 'carafe'
        ]:
            raise ValueError(
                f'Invalid upsample method {self.upsample_cfg["type"]}, '
                'accepted methods are "deconv", "nearest", "bilinear", '
                '"carafe"')
        self.num_convs = num_convs
        # WARN: roi_feat_size is reserved and not used
        self.roi_feat_size = _pair(roi_feat_size)
        self.in_channels = in_channels
        self.conv_kernel_size = conv_kernel_size
        self.conv_out_channels = conv_out_channels
        self.upsample_method = self.upsample_cfg.get('type')
        self.scale_factor = self.upsample_cfg.pop('scale_factor', None)
        self.num_classes = num_classes
        self.class_agnostic = class_agnostic
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.conv_to_res = conv_to_res
        self.fp16_enabled = False
        self.loss_mask = build_loss(loss_mask)

        if conv_to_res:
            assert conv_kernel_size == 3
            self.num_res_blocks = num_convs // 2
            self.convs = ResLayer(SimplifiedBasicBlock,
                                  in_channels,
                                  self.conv_out_channels,
                                  self.num_res_blocks,
                                  conv_cfg=conv_cfg,
                                  norm_cfg=norm_cfg)
        else:
            self.convs = nn.ModuleList()
            for i in range(self.num_convs):
                in_channels = (self.in_channels
                               if i == 0 else self.conv_out_channels)
                padding = (self.conv_kernel_size - 1) // 2
                self.convs.append(
                    ConvModule(in_channels,
                               self.conv_out_channels,
                               self.conv_kernel_size,
                               padding=padding,
                               conv_cfg=conv_cfg,
                               norm_cfg=norm_cfg))
        upsample_in_channels = (self.conv_out_channels
                                if self.num_convs > 0 else in_channels)
        upsample_cfg_ = self.upsample_cfg.copy()
        if self.upsample_method is None:
            self.upsample = None
        elif self.upsample_method == 'deconv':
            upsample_cfg_.update(in_channels=upsample_in_channels,
                                 out_channels=self.conv_out_channels,
                                 kernel_size=self.scale_factor,
                                 stride=self.scale_factor)
            self.upsample = build_upsample_layer(upsample_cfg_)
        elif self.upsample_method == 'carafe':
            upsample_cfg_.update(channels=upsample_in_channels,
                                 scale_factor=self.scale_factor)
            self.upsample = build_upsample_layer(upsample_cfg_)
        else:
            # suppress warnings
            align_corners = (None
                             if self.upsample_method == 'nearest' else False)
            upsample_cfg_.update(scale_factor=self.scale_factor,
                                 mode=self.upsample_method,
                                 align_corners=align_corners)
            self.upsample = build_upsample_layer(upsample_cfg_)

        out_channels = 1 if self.class_agnostic else self.num_classes
        logits_in_channel = (self.conv_out_channels if self.upsample_method
                             == 'deconv' else upsample_in_channels)
        self.conv_logits = Conv2d(logits_in_channel, out_channels, 1)
        self.relu = nn.ReLU(inplace=True)
        self.debug_imgs = None
    def __init__(self,
                 in_channels,
                 feat_channels,
                 out_channels,
                 num_things_classes=80,
                 num_stuff_classes=53,
                 num_queries=100,
                 pixel_decoder=None,
                 enforce_decoder_input_project=False,
                 transformer_decoder=None,
                 positional_encoding=None,
                 loss_cls=dict(type='CrossEntropyLoss',
                               use_sigmoid=False,
                               loss_weight=1.0,
                               class_weight=[1.0] * 133 + [0.1]),
                 loss_mask=dict(type='FocalLoss',
                                use_sigmoid=True,
                                gamma=2.0,
                                alpha=0.25,
                                loss_weight=20.0),
                 loss_dice=dict(type='DiceLoss',
                                use_sigmoid=True,
                                activate=True,
                                naive_dice=True,
                                loss_weight=1.0),
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=None,
                 **kwargs):
        super(AnchorFreeHead, self).__init__(init_cfg)
        self.num_things_classes = num_things_classes
        self.num_stuff_classes = num_stuff_classes
        self.num_classes = self.num_things_classes + self.num_stuff_classes
        self.num_queries = num_queries

        pixel_decoder.update(in_channels=in_channels,
                             feat_channels=feat_channels,
                             out_channels=out_channels)
        self.pixel_decoder = build_plugin_layer(pixel_decoder)[1]
        self.transformer_decoder = build_transformer_layer_sequence(
            transformer_decoder)
        self.decoder_embed_dims = self.transformer_decoder.embed_dims
        pixel_decoder_type = pixel_decoder.get('type')
        if pixel_decoder_type == 'PixelDecoder' and (
                self.decoder_embed_dims != in_channels[-1]
                or enforce_decoder_input_project):
            self.decoder_input_proj = Conv2d(in_channels[-1],
                                             self.decoder_embed_dims,
                                             kernel_size=1)
        else:
            self.decoder_input_proj = nn.Identity()
        self.decoder_pe = build_positional_encoding(positional_encoding)
        self.query_embed = nn.Embedding(self.num_queries, out_channels)

        self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
        self.mask_embed = nn.Sequential(
            nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
            nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
            nn.Linear(feat_channels, out_channels))

        self.test_cfg = test_cfg
        self.train_cfg = train_cfg
        if train_cfg:
            self.assigner = build_assigner(train_cfg.assigner)
            self.sampler = build_sampler(train_cfg.sampler, context=self)

        self.class_weight = loss_cls.class_weight
        self.loss_cls = build_loss(loss_cls)
        self.loss_mask = build_loss(loss_mask)
        self.loss_dice = build_loss(loss_dice)
    def __init__(self,
                 in_channels,
                 feat_channels,
                 out_channels,
                 num_things_classes=80,
                 num_stuff_classes=53,
                 num_queries=100,
                 num_transformer_feat_level=3,
                 pixel_decoder=None,
                 enforce_decoder_input_project=False,
                 transformer_decoder=None,
                 positional_encoding=None,
                 loss_cls=None,
                 loss_mask=None,
                 loss_dice=None,
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=None,
                 **kwargs):
        super(AnchorFreeHead, self).__init__(init_cfg)
        self.num_things_classes = num_things_classes
        self.num_stuff_classes = num_stuff_classes
        self.num_classes = self.num_things_classes + self.num_stuff_classes
        self.num_queries = num_queries
        self.num_transformer_feat_level = num_transformer_feat_level
        self.num_heads = transformer_decoder.transformerlayers.\
            attn_cfgs.num_heads
        self.num_transformer_decoder_layers = transformer_decoder.num_layers
        assert pixel_decoder.encoder.transformerlayers.\
            attn_cfgs.num_levels == num_transformer_feat_level
        pixel_decoder_ = copy.deepcopy(pixel_decoder)
        pixel_decoder_.update(in_channels=in_channels,
                              feat_channels=feat_channels,
                              out_channels=out_channels)
        self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1]
        self.transformer_decoder = build_transformer_layer_sequence(
            transformer_decoder)
        self.decoder_embed_dims = self.transformer_decoder.embed_dims

        self.decoder_input_projs = ModuleList()
        # from low resolution to high resolution
        for _ in range(num_transformer_feat_level):
            if (self.decoder_embed_dims != feat_channels
                    or enforce_decoder_input_project):
                self.decoder_input_projs.append(
                    Conv2d(feat_channels,
                           self.decoder_embed_dims,
                           kernel_size=1))
            else:
                self.decoder_input_projs.append(nn.Identity())
        self.decoder_positional_encoding = build_positional_encoding(
            positional_encoding)
        self.query_embed = nn.Embedding(self.num_queries, feat_channels)
        self.query_feat = nn.Embedding(self.num_queries, feat_channels)
        # from low resolution to high resolution
        self.level_embed = nn.Embedding(self.num_transformer_feat_level,
                                        feat_channels)

        self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
        self.mask_embed = nn.Sequential(
            nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
            nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
            nn.Linear(feat_channels, out_channels))

        self.test_cfg = test_cfg
        self.train_cfg = train_cfg
        if train_cfg:
            self.assigner = build_assigner(self.train_cfg.assigner)
            self.sampler = build_sampler(self.train_cfg.sampler, context=self)
            self.num_points = self.train_cfg.get('num_points', 12544)
            self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
            self.importance_sample_ratio = self.train_cfg.get(
                'importance_sample_ratio', 0.75)

        self.class_weight = loss_cls.class_weight
        self.loss_cls = build_loss(loss_cls)
        self.loss_mask = build_loss(loss_mask)
        self.loss_dice = build_loss(loss_dice)