def __init__(self, inplanes, planes, stride=1, downsample=None):
     super(Resnet_BasicBlock, self).__init__()
     self.conv1 = ws.Conv2dws(inplanes, planes, kernel_size=3, stride=1, padding=1)
     self.bn1 = sn.SwitchNorm2d(planes)
     self.conv2 = ws.Conv2dws(planes, planes, kernel_size=3, stride=1, padding=1)
     self.bn2 = sn.SwitchNorm2d(planes)
     self.relu = nn.ReLU(inplace=True)
     self.downsample = downsample
     self.stride = stride
Beispiel #2
0
    def _make_fuse_layers(self):
        if self.num_branches == 1:
            return None

        num_branches = self.num_branches
        num_inchannels = self.num_inchannels
        fuse_layers = []
        for i in range(num_branches if self.multi_scale_output else 1):
            fuse_layer = []
            for j in range(num_branches):
                if j > i:
                    fuse_layer.append(
                        nn.Sequential(
                            ws.Conv2dws(num_inchannels[j],
                                        num_inchannels[i],
                                        1,
                                        1,
                                        0,
                                        bias=False),
                            SwitchNorm2d(num_inchannels[i])))
                elif j == i:
                    fuse_layer.append(None)
                else:
                    conv3x3s = []
                    for k in range(i - j):
                        if k == i - j - 1:
                            num_outchannels_conv3x3 = num_inchannels[i]
                            conv3x3s.append(
                                nn.Sequential(
                                    ws.Conv2dws(num_inchannels[j],
                                                num_outchannels_conv3x3,
                                                3,
                                                2,
                                                1,
                                                bias=False),
                                    SwitchNorm2d(num_outchannels_conv3x3)))
                        else:
                            num_outchannels_conv3x3 = num_inchannels[j]
                            conv3x3s.append(
                                nn.Sequential(
                                    ws.Conv2dws(num_inchannels[j],
                                                num_outchannels_conv3x3,
                                                3,
                                                2,
                                                1,
                                                bias=False),
                                    SwitchNorm2d(num_outchannels_conv3x3),
                                    nn.ReLU(inplace=True)))
                    fuse_layer.append(nn.Sequential(*conv3x3s))
            fuse_layers.append(nn.ModuleList(fuse_layer))

        return nn.ModuleList(fuse_layers)
Beispiel #3
0
    def _make_one_branch(self,
                         branch_index,
                         block,
                         num_blocks,
                         num_channels,
                         stride=1):
        downsample = None
        if stride != 1 or \
           self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
            downsample = nn.Sequential(
                ws.Conv2dws(self.num_inchannels[branch_index],
                            num_channels[branch_index] * block.expansion,
                            kernel_size=1,
                            stride=stride,
                            bias=False),
                SwitchNorm2d(num_channels[branch_index] * block.expansion),
            )

        layers = []
        layers.append(
            block(self.num_inchannels[branch_index],
                  num_channels[branch_index], stride, downsample))
        self.num_inchannels[branch_index] = \
            num_channels[branch_index] * block.expansion
        for i in range(1, num_blocks[branch_index]):
            layers.append(
                block(self.num_inchannels[branch_index],
                      num_channels[branch_index]))

        return nn.Sequential(*layers)
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=False):

        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(GatedSpatialConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, 
                                                 False, _pair(0), groups, bias, 'zeros')

        self.gate_conv = nn.Sequential(
            sn.SwitchNorm2d(in_channels+1),
            ws.Conv2dws(in_channels+1, in_channels+1, 1),
            nn.ReLU(), 
            ws.Conv2dws(in_channels+1, 1, 1),
            sn.SwitchNorm2d(1),
            nn.Sigmoid()
        )
Beispiel #5
0
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return ws.Conv2dws(in_planes,
                       out_planes,
                       kernel_size=3,
                       stride=stride,
                       padding=1,
                       bias=False)
Beispiel #6
0
 def __init__(self, inplanes, planes, stride=1, downsample=None):
     super(Bottleneck, self).__init__()
     self.conv1 = ws.Conv2dws(inplanes, planes, kernel_size=1, bias=False)
     self.sn1 = SwitchNorm2d(planes)
     self.conv2 = ws.Conv2dws(planes,
                              planes,
                              kernel_size=3,
                              stride=stride,
                              padding=1,
                              bias=False)
     self.sn2 = SwitchNorm2d(planes)
     self.conv3 = ws.Conv2dws(planes,
                              planes * self.expansion,
                              kernel_size=1,
                              bias=False)
     self.sn3 = SwitchNorm2d(planes * self.expansion)
     self.relu = nn.ReLU(inplace=True)
     self.downsample = downsample
     self.stride = stride
 def __init__(self, inplanes, planes, kernel_size, padding, dilation,
              BatchNorm):
     super(ASPP, self).__init__()
     self.atrous_conv = ws.Conv2dws(inplanes,
                                    planes,
                                    kernel_size=kernel_size,
                                    stride=1,
                                    padding=padding,
                                    dilation=dilation,
                                    bias=False)
     self.bn = BatchNorm(planes)
     self.relu = nn.ReLU()
     self.init_weight()
Beispiel #8
0
    def _make_transition_layer(self, num_channels_pre_layer,
                               num_channels_cur_layer):
        num_branches_cur = len(num_channels_cur_layer)
        num_branches_pre = len(num_channels_pre_layer)

        transition_layers = []
        for i in range(num_branches_cur):
            if i < num_branches_pre:
                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
                    transition_layers.append(
                        nn.Sequential(
                            ws.Conv2dws(num_channels_pre_layer[i],
                                        num_channels_cur_layer[i],
                                        3,
                                        1,
                                        1,
                                        bias=False),
                            SwitchNorm2d(num_channels_cur_layer[i]),
                            nn.ReLU(inplace=True)))
                else:
                    transition_layers.append(None)
            else:
                conv3x3s = []
                for j in range(i + 1 - num_branches_pre):
                    inchannels = num_channels_pre_layer[-1]
                    outchannels = num_channels_cur_layer[i] \
                        if j == i-num_branches_pre else inchannels
                    conv3x3s.append(
                        nn.Sequential(
                            ws.Conv2dws(inchannels,
                                        outchannels,
                                        3,
                                        2,
                                        1,
                                        bias=False), SwitchNorm2d(outchannels),
                            nn.ReLU(inplace=True)))
                transition_layers.append(nn.Sequential(*conv3x3s))

        return nn.ModuleList(transition_layers)
Beispiel #9
0
def ConvSnRelu(in_channels,
               out_channels,
               using_movavg=True,
               using_bn=True,
               last_gamma=False,
               kernel_size=(3, 3),
               stride=1,
               padding=1):
    conv = ws.Conv2dws(in_channels, out_channels, kernel_size, stride, padding)
    norm = sn.SwitchNorm2d(out_channels,
                           using_movavg=using_movavg,
                           using_bn=using_bn,
                           last_gamma=last_gamma)
    relu = nn.ReLU()
    return nn.Sequential(conv, norm, relu)
Beispiel #10
0
    def __init__(self,
                 in_channels,
                 key_channels,
                 out_channels,
                 scale=1,
                 dropout=0.1):
        super(SpatialOCR_Module, self).__init__()
        self.object_context_block = ObjectAttentionBlock2D(
            in_channels, key_channels, scale)
        _in_channels = 2 * in_channels

        self.conv_bn_dropout = nn.Sequential(
            ws.Conv2dws(_in_channels,
                        out_channels,
                        kernel_size=1,
                        padding=0,
                        bias=False), SwitchNorm2d(out_channels), nn.ReLU(),
            nn.Dropout2d(dropout))
Beispiel #11
0
 def __init__(self, in_channels, key_channels, scale=1):
     super(_ObjectAttentionBlock, self).__init__()
     self.scale = scale
     self.in_channels = in_channels
     self.key_channels = key_channels
     self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
     self.f_pixel = nn.Sequential(
         ws.Conv2dws(in_channels=self.in_channels,
                     out_channels=self.key_channels,
                     kernel_size=1,
                     stride=1,
                     padding=0,
                     bias=False), SwitchNorm2d(self.key_channels),
         nn.ReLU(),
         ws.Conv2dws(in_channels=self.key_channels,
                     out_channels=self.key_channels,
                     kernel_size=1,
                     stride=1,
                     padding=0,
                     bias=False), SwitchNorm2d(self.key_channels),
         nn.ReLU())
     self.f_object = nn.Sequential(
         ws.Conv2dws(in_channels=self.in_channels,
                     out_channels=self.key_channels,
                     kernel_size=1,
                     stride=1,
                     padding=0,
                     bias=False), SwitchNorm2d(self.key_channels),
         nn.ReLU(),
         ws.Conv2dws(in_channels=self.key_channels,
                     out_channels=self.key_channels,
                     kernel_size=1,
                     stride=1,
                     padding=0,
                     bias=False), SwitchNorm2d(self.key_channels),
         nn.ReLU())
     self.f_down = nn.Sequential(
         ws.Conv2dws(in_channels=self.in_channels,
                     out_channels=self.key_channels,
                     kernel_size=1,
                     stride=1,
                     padding=0,
                     bias=False), SwitchNorm2d(self.key_channels),
         nn.ReLU())
     self.f_up = nn.Sequential(
         ws.Conv2dws(in_channels=self.key_channels,
                     out_channels=self.in_channels,
                     kernel_size=1,
                     stride=1,
                     padding=0,
                     bias=False), SwitchNorm2d(self.in_channels), nn.ReLU())
Beispiel #12
0
    def _make_layer(self, block, inplanes, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                ws.Conv2dws(inplanes,
                            planes * block.expansion,
                            kernel_size=1,
                            stride=stride,
                            bias=False),
                SwitchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(inplanes, planes, stride, downsample))
        inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(inplanes, planes))

        return nn.Sequential(*layers)
Beispiel #13
0
    def __init__(self, n_channels, n_filters, n_class, using_movavg, using_bn):

        super(GSCNN, self).__init__()

        self.maxPool = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)

        # down1
        self.convd1_1 = ConvSnRelu(n_channels, n_filters, using_movavg,
                                   using_bn)
        self.convd1_2 = ConvSnRelu(n_filters, n_filters, using_movavg,
                                   using_bn)

        # down2
        self.convd2_1 = ConvSnRelu(n_filters, n_filters * 2, using_movavg,
                                   using_bn)
        self.convd2_2 = ConvSnRelu(n_filters * 2, n_filters * 2, using_movavg,
                                   using_bn)

        # down3
        self.convd3_1 = ConvSnRelu(n_filters * 2, n_filters * 4, using_movavg,
                                   using_bn)
        self.convd3_2 = ConvSnRelu(n_filters * 4, n_filters * 4, using_movavg,
                                   using_bn)

        # down4
        self.convd4_1 = ConvSnRelu(n_filters * 4, n_filters * 8, using_movavg,
                                   using_bn)
        self.convd4_2 = ConvSnRelu(n_filters * 8, n_filters * 8, using_movavg,
                                   using_bn)

        ## center
        self.conv5_1 = ConvSnRelu(n_filters * 8, n_filters * 16, using_movavg,
                                  using_bn)
        self.conv5_2 = ConvSnRelu(n_filters * 16, n_filters * 16, using_movavg,
                                  using_bn)

        # down6
        #self.convd6_1 = ConvSnRelu(n_filters*16, n_filters*32, using_movavg, using_bn)
        #self.convd6_2 = ConvSnRelu(n_filters*32, n_filters*32, using_movavg, using_bn)

        ## center
        #self.conv7_1 = ConvSnRelu(n_filters*32, n_filters*64, using_movavg, using_bn)
        #self.conv7_2 = ConvSnRelu(n_filters*64, n_filters*64, using_movavg, using_bn)

        self.dsn1 = ws.Conv2dws(64, 1, 1)
        self.dsn3 = ws.Conv2dws(256, 1, 1)
        self.dsn4 = ws.Conv2dws(512, 1, 1)
        self.dsn5 = ws.Conv2dws(1024, 1, 1)

        self.res1 = Resnet_BasicBlock(64, 64, stride=1, downsample=None)
        self.d1 = ws.Conv2dws(64, 32, 1)
        self.res2 = Resnet_BasicBlock(32, 32, stride=1, downsample=None)
        self.d2 = ws.Conv2dws(32, 16, 1)
        self.res3 = Resnet_BasicBlock(16, 16, stride=1, downsample=None)
        self.d3 = ws.Conv2dws(16, 8, 1)
        self.fuse = ws.Conv2dws(8, 1, kernel_size=1, padding=0, bias=False)
        self.cw = ws.Conv2dws(2, 1, kernel_size=1, padding=0, bias=False)

        self.gate1 = GatedSpatialConv2d(32, 32)
        self.gate2 = GatedSpatialConv2d(16, 16)
        self.gate3 = GatedSpatialConv2d(8, 8)

        ## ------------------aspp------------------
        self.aspp1 = ASPP(
            1024, 256, 1, padding=0, dilation=1,
            BatchNorm=nn.BatchNorm2d)  # 0 1    1024  nn.BatchNorm2d
        self.aspp2 = ASPP(1024,
                          256,
                          3,
                          padding=6,
                          dilation=6,
                          BatchNorm=nn.BatchNorm2d)  # 2
        self.aspp3 = ASPP(1024,
                          256,
                          3,
                          padding=12,
                          dilation=12,
                          BatchNorm=nn.BatchNorm2d)  # 6
        self.aspp4 = ASPP(1024,
                          256,
                          3,
                          padding=18,
                          dilation=18,
                          BatchNorm=nn.BatchNorm2d)  # 12

        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(1024, 256, 1, stride=1, bias=False), nn.BatchNorm2d(256),
            nn.ReLU())
        self.edge_conv = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(1, 256, 1, stride=1, bias=False), nn.BatchNorm2d(256),
            nn.ReLU())

        self.bot_aspp = ws.Conv2dws(1280 + 256, 256, kernel_size=1, bias=False)
        # ConvSnRelu(1536, 256, using_movavg=1, using_bn=1, kernel_size=1, padding=0)
        self.bot_fine = ws.Conv2dws(128, 48, kernel_size=1, bias=False)
        self.final_seg = nn.Sequential(
            ws.Conv2dws(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            sn.SwitchNorm2d(256), nn.ReLU(inplace=True),
            ws.Conv2dws(256, 256, kernel_size=3, padding=1, bias=False),
            sn.SwitchNorm2d(256), nn.ReLU(inplace=True),
            ws.Conv2dws(256, n_class, kernel_size=1, bias=False))

        self.sigmoid = nn.Sigmoid()
    def __init__(self, n_channels, n_filters, n_class, using_movavg, using_bn):
        super(UNet3p_res_edge2_aspp_SNws, self).__init__()

        ## ------------------encoder------------------
        self.maxPool = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
        self.maxPool4 = nn.MaxPool2d(4, 4, ceil_mode=True)
        self.maxPool8 = nn.MaxPool2d(8, 8, ceil_mode=True)
        self.upSample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=False)
        self.upSample4 = nn.Upsample(scale_factor=4,
                                     mode='bilinear',
                                     align_corners=False)
        self.upSample8 = nn.Upsample(scale_factor=8,
                                     mode='bilinear',
                                     align_corners=False)
        self.upSample16 = nn.Upsample(scale_factor=16,
                                      mode='bilinear',
                                      align_corners=False)

        # down1
        self.convd1_1 = ConvSnRelu(n_channels, n_filters, using_movavg,
                                   using_bn)
        #self.convd1_2 = ConvSnRelu(n_filters, n_filters, using_movavg, using_bn)
        self.convd1_2 = Resnet_BasicBlock(n_filters,
                                          n_filters,
                                          stride=1,
                                          downsample=None)

        # down2
        self.convd2_1 = ConvSnRelu(n_filters, n_filters * 2, using_movavg,
                                   using_bn)
        #self.convd2_2 = ConvSnRelu(n_filters*2, n_filters*2, using_movavg, using_bn)
        self.convd2_2 = Resnet_BasicBlock(n_filters * 2,
                                          n_filters * 2,
                                          stride=1,
                                          downsample=None)

        # down3
        self.convd3_1 = ConvSnRelu(n_filters * 2, n_filters * 4, using_movavg,
                                   using_bn)
        #self.convd3_2 = ConvSnRelu(n_filters*4, n_filters*4, using_movavg, using_bn)
        self.convd3_2 = Resnet_BasicBlock(n_filters * 4,
                                          n_filters * 4,
                                          stride=1,
                                          downsample=None)

        # down4
        self.convd4_1 = ConvSnRelu(n_filters * 4, n_filters * 8, using_movavg,
                                   using_bn)
        #self.convd4_2 = ConvSnRelu(n_filters*8, n_filters*8, using_movavg, using_bn)
        self.convd4_2 = Resnet_BasicBlock(n_filters * 8,
                                          n_filters * 8,
                                          stride=1,
                                          downsample=None)

        ## center
        self.conv5_1 = ConvSnRelu(n_filters * 8, n_filters * 16, using_movavg,
                                  using_bn)
        #self.conv5_2 = ConvSnRelu(n_filters*16, n_filters*16, using_movavg, using_bn)
        self.conv5_2 = Resnet_BasicBlock(n_filters * 16,
                                         n_filters * 16,
                                         stride=1,
                                         downsample=None)

        ## ------------------edge------------------.
        self.dsn3 = ws.Conv2dws(256, 1, 1)
        self.dsn4 = ws.Conv2dws(512, 1, 1)
        self.dsn5 = ws.Conv2dws(1024, 1, 1)

        self.d1 = ws.Conv2dws(64, 32, 1)
        self.res2 = Resnet_BasicBlock(32, 32, stride=1, downsample=None)
        self.d2 = ws.Conv2dws(32, 16, 1)
        self.res3 = Resnet_BasicBlock(16, 16, stride=1, downsample=None)
        self.d3 = ws.Conv2dws(16, 8, 1)
        self.fuse = ws.Conv2dws(8, 1, kernel_size=1, padding=0, bias=False)
        self.cw = ws.Conv2dws(2, 1, kernel_size=1, padding=0, bias=False)

        self.gate1 = GatedSpatialConv2d(32, 32)
        self.gate2 = GatedSpatialConv2d(16, 16)
        self.gate3 = GatedSpatialConv2d(8, 8)
        self.edge_conv = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(1, 64, 1, stride=1, bias=False), nn.BatchNorm2d(64),
            nn.ReLU())
        self.sigmoid = nn.Sigmoid()

        ## ------------------aspp------------------

        self.aspp1 = ASPP(
            320, 64, 1, padding=0, dilation=1,
            BatchNorm=sn.SwitchNorm2d)  # 0 1    1024  nn.BatchNorm2d
        self.aspp2 = ASPP(320,
                          64,
                          3,
                          padding=2,
                          dilation=2,
                          BatchNorm=sn.SwitchNorm2d)  # 2
        self.aspp3 = ASPP(320,
                          64,
                          3,
                          padding=3,
                          dilation=3,
                          BatchNorm=sn.SwitchNorm2d)  # 6
        self.aspp4 = ASPP(320,
                          64,
                          3,
                          padding=6,
                          dilation=6,
                          BatchNorm=sn.SwitchNorm2d)  # 12
        self.aspp5 = ASPP(320,
                          64,
                          3,
                          padding=12,
                          dilation=12,
                          BatchNorm=sn.SwitchNorm2d)  # 18
        #self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
        #                                     nn.Conv2d(1024,256,1,stride=1,bias=False),
        #                                     nn.BatchNorm2d(256),
        #                                     nn.ReLU())

        #self.conva = nn.Sequential(nn.Conv2d(1280, n_filters, 1, bias=False),
        #                           nn.BatchNorm2d(n_filters),
        #                           nn.ReLU(),
        #                           nn.Conv2d(n_filters, n_filters, 3, bias=False)),
        #                           nn.BatchNorm2d(n_filters),
        #                           nn.ReLU())
        self.conva = ConvSnRelu(320 + 64,
                                n_filters,
                                using_movavg,
                                using_bn,
                                kernel_size=1,
                                padding=0)
        #                           ConvSnRelu(n_filters, n_filters, using_movavg, using_bn))

        #self.conv1 = nn.Conv2d(1280, 512, 1, bias=False)
        #self.bn1 = nn.BatchNorm2d(512)
        #self.bn1 = sn.SwitchNorm2d(1024)
        #self.relu = nn.ReLU()
        #self.dropout = nn.Dropout(0.5)

        ## ------------------unet3+_decoder------------------

        self.Cat_channels = n_filters
        self.Fusion_channels = 5 * n_filters  # 5 * n_filters

        # up
        self.convd1_cat = ConvSnRelu(n_filters, self.Cat_channels,
                                     using_movavg, using_bn)
        self.convd2_cat = ConvSnRelu(n_filters * 2, self.Cat_channels,
                                     using_movavg, using_bn)
        self.convd3_cat = ConvSnRelu(n_filters * 4, self.Cat_channels,
                                     using_movavg, using_bn)
        self.convd4_cat = ConvSnRelu(n_filters * 8, self.Cat_channels,
                                     using_movavg, using_bn)
        self.conv5_cat = ConvSnRelu(n_filters * 16, self.Cat_channels,
                                    using_movavg, using_bn)
        self.convu_cat = ConvSnRelu(self.Fusion_channels, self.Cat_channels,
                                    using_movavg, using_bn)
        self.conv_fusion = ConvSnRelu(self.Fusion_channels,
                                      self.Fusion_channels, using_movavg,
                                      using_bn)

        self.output_seg_map = nn.Conv2d(n_filters,
                                        n_class,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)
Beispiel #15
0
    def __init__(self, n_channels, n_filters, num_classes, config, **kwargs):
        extra = config['EXTRA']
        super(HighResolutionNet_OCR_SNws, self).__init__()
        ALIGN_CORNERS = True

        # stem net
        self.conv1 = ws.Conv2dws(n_channels,
                                 n_filters,
                                 kernel_size=3,
                                 stride=2,
                                 padding=1,
                                 bias=False)
        self.sn1 = SwitchNorm2d(n_filters)
        self.conv2 = ws.Conv2dws(n_filters,
                                 n_filters,
                                 kernel_size=3,
                                 stride=2,
                                 padding=1,
                                 bias=False)
        self.sn2 = SwitchNorm2d(n_filters)
        self.relu = nn.ReLU(inplace=True)

        self.stage1_cfg = extra['STAGE1']
        # num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
        num_channels = self.stage1_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage1_cfg['BLOCK']]
        # num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
        num_blocks = self.stage1_cfg['NUM_BLOCKS']
        self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
        stage1_out_channel = block.expansion * num_channels

        self.stage2_cfg = extra['STAGE2']
        num_channels = self.stage2_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage2_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))
        ]
        self.transition1 = self._make_transition_layer([stage1_out_channel],
                                                       num_channels)
        self.stage2, pre_stage_channels = self._make_stage(
            self.stage2_cfg, num_channels)

        self.stage3_cfg = extra['STAGE3']
        num_channels = self.stage3_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage3_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))
        ]
        self.transition2 = self._make_transition_layer(pre_stage_channels,
                                                       num_channels)
        self.stage3, pre_stage_channels = self._make_stage(
            self.stage3_cfg, num_channels)

        self.stage4_cfg = extra['STAGE4']
        num_channels = self.stage4_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage4_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))
        ]
        self.transition3 = self._make_transition_layer(pre_stage_channels,
                                                       num_channels)
        self.stage4, pre_stage_channels = self._make_stage(
            self.stage4_cfg, num_channels, multi_scale_output=True)

        last_inp_channels = np.int(np.sum(pre_stage_channels))
        ocr_mid_channels = 512
        ocr_key_channels = 256

        self.conv3x3_ocr = nn.Sequential(
            ws.Conv2dws(last_inp_channels,
                        ocr_mid_channels,
                        kernel_size=3,
                        stride=1,
                        padding=1), SwitchNorm2d(ocr_mid_channels),
            nn.ReLU(inplace=True))
        self.ocr_gather_head = SpatialGather_Module(num_classes)

        self.ocr_distri_head = SpatialOCR_Module(
            in_channels=ocr_mid_channels,
            key_channels=ocr_key_channels,
            out_channels=ocr_mid_channels,
            scale=1,
            dropout=0.05,
        )
        self.cls_head = ws.Conv2dws(ocr_mid_channels,
                                    num_classes,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0,
                                    bias=True)

        self.aux_head = nn.Sequential(
            ws.Conv2dws(last_inp_channels,
                        last_inp_channels,
                        kernel_size=1,
                        stride=1,
                        padding=0), SwitchNorm2d(last_inp_channels),
            nn.ReLU(inplace=True),
            ws.Conv2dws(last_inp_channels,
                        num_classes,
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        bias=True))