def _make_fuse_layers(self, bn_type, bn_momentum=0.1):
        if self.num_branches == 1:
            return None

        num_branches = self.num_branches
        num_inchannels = self.num_inchannels
        fuse_layers = []
        for i in range(num_branches if self.multi_scale_output else 1):
            fuse_layer = []
            for j in range(num_branches):
                if j > i:
                    fuse_layer.append(
                        nn.Sequential(
                            nn.Conv2d(
                                num_inchannels[j],
                                num_inchannels[i],
                                1, 
                                1, 
                                0, 
                                bias=False
                            ),
                            ModuleHelper.BatchNorm2d(bn_type=bn_type)(num_inchannels[i], momentum=bn_momentum),
                        )
                    )
                elif j == i:
                    fuse_layer.append(None)
                else:
                    conv3x3s = []
                    for k in range(i-j):
                        if k == i - j - 1:
                            num_outchannels_conv3x3 = num_inchannels[i]
                            conv3x3s.append(
                                nn.Sequential(
                                    nn.Conv2d(
                                        num_inchannels[j],
                                        num_outchannels_conv3x3,
                                        3, 2, 1, bias=False
                                    ),
                                    ModuleHelper.BatchNorm2d(bn_type=bn_type)(num_outchannels_conv3x3, momentum=bn_momentum)
                                )
                            )
                        else:
                            num_outchannels_conv3x3 = num_inchannels[j]
                            conv3x3s.append(
                                nn.Sequential(
                                    nn.Conv2d(
                                        num_inchannels[j],
                                        num_outchannels_conv3x3,
                                        3, 2, 1, bias=False
                                    ),
                                    ModuleHelper.BatchNorm2d(bn_type=bn_type)(num_outchannels_conv3x3, momentum=bn_momentum),
                                    nn.ReLU(inplace=False)
                                )
                            )
                    fuse_layer.append(nn.Sequential(*conv3x3s))
            fuse_layers.append(nn.ModuleList(fuse_layer))

        return nn.ModuleList(fuse_layers)
 def __init__(self, inplanes, planes, stride=1, downsample=None, bn_type=None, bn_momentum=0.1):
     super(BasicBlock, self).__init__()
     self.conv1 = conv3x3(inplanes, planes, stride)
     self.bn1 = ModuleHelper.BatchNorm2d(bn_type=bn_type)(planes, momentum=bn_momentum)
     self.relu = nn.ReLU(inplace=False)
     self.relu_in = nn.ReLU(inplace=True)
     self.conv2 = conv3x3(planes, planes)
     self.bn2 = ModuleHelper.BatchNorm2d(bn_type=bn_type)(planes, momentum=bn_momentum)
     self.downsample = downsample
     self.stride = stride
 def __init__(self, inplanes, planes, stride=1, downsample=None, bn_type=None, bn_momentum=0.1):
     super(Bottleneck, self).__init__()
     self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
     self.bn1 = ModuleHelper.BatchNorm2d(bn_type=bn_type)(planes, momentum=bn_momentum)
     self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                            padding=1, bias=False)
     self.bn2 = ModuleHelper.BatchNorm2d(bn_type=bn_type)(planes, momentum=bn_momentum)
     self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
     self.bn3 = ModuleHelper.BatchNorm2d(bn_type=bn_type)(planes * 4, momentum=bn_momentum)
     self.relu = nn.ReLU(inplace=False)
     self.relu_in = nn.ReLU(inplace=True)
     self.downsample = downsample
     self.stride = stride
    def _make_head(self, pre_stage_channels, bn_type, bn_momentum):
        head_block = Bottleneck
        head_channels = [32, 64, 128, 256]

        Log.info("pre_stage_channels: {}".format(pre_stage_channels))
        Log.info("head_channels: {}".format(head_channels))

        # Increasing the #channels on each resolution 
        # from C, 2C, 4C, 8C to 128, 256, 512, 1024
        incre_modules = []
        for i, channels  in enumerate(pre_stage_channels):
            incre_module = self._make_layer(head_block,
                                            channels,
                                            head_channels[i],
                                            1,
                                            bn_type=bn_type, 
                                            bn_momentum=bn_momentum
                                            )
            incre_modules.append(incre_module)
        incre_modules = nn.ModuleList(incre_modules)
            
        # downsampling modules
        downsamp_modules = []
        for i in range(len(pre_stage_channels)-1):
            in_channels = head_channels[i] * head_block.expansion
            out_channels = head_channels[i+1] * head_block.expansion

            downsamp_module = nn.Sequential(
                nn.Conv2d(in_channels=in_channels,
                          out_channels=out_channels,
                          kernel_size=3,
                          stride=2,
                          padding=1),
                ModuleHelper.BatchNorm2d(bn_type=bn_type)(out_channels, momentum=bn_momentum),
                nn.ReLU(inplace=False)
            )
            downsamp_modules.append(downsamp_module)
        downsamp_modules = nn.ModuleList(downsamp_modules)

        final_layer = nn.Sequential(
            nn.Conv2d(
                in_channels=head_channels[3] * head_block.expansion,
                out_channels=2048,
                kernel_size=1,
                stride=1,
                padding=0
            ),
            ModuleHelper.BatchNorm2d(bn_type=bn_type)(2048, momentum=bn_momentum),
            nn.ReLU(inplace=False)
        )
        return incre_modules, downsamp_modules, final_layer
    def _make_transition_layer(
            self, num_channels_pre_layer, num_channels_cur_layer, bn_type, bn_momentum):
        num_branches_cur = len(num_channels_cur_layer)
        num_branches_pre = len(num_channels_pre_layer)

        transition_layers = []
        for i in range(num_branches_cur):
            if i < num_branches_pre:
                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
                    transition_layers.append(
                        nn.Sequential(
                            nn.Conv2d(
                                num_channels_pre_layer[i],
                                num_channels_cur_layer[i],
                                3, 
                                1, 
                                1, 
                                bias=False
                            ),
                            ModuleHelper.BatchNorm2d(bn_type=bn_type)(num_channels_cur_layer[i], momentum=bn_momentum),
                            nn.ReLU(inplace=False)
                        )
                    )
                else:
                    transition_layers.append(None)
            else:
                conv3x3s = []
                for j in range(i+1-num_branches_pre):
                    inchannels = num_channels_pre_layer[-1]
                    outchannels = num_channels_cur_layer[i] \
                        if j == i-num_branches_pre else inchannels
                    conv3x3s.append(
                        nn.Sequential(
                            nn.Conv2d(
                                inchannels, 
                                outchannels, 
                                3, 
                                2, 
                                1, 
                                bias=False
                            ),
                            ModuleHelper.BatchNorm2d(bn_type=bn_type)(outchannels, momentum=bn_momentum),
                            nn.ReLU(inplace=False)
                        )
                    )
                transition_layers.append(nn.Sequential(*conv3x3s))

        return nn.ModuleList(transition_layers)
    def __init__(self,
                 in_channels,
                 key_channels,
                 out_channels,
                 scale=1,
                 dropout=0.1,
                 use_gt=False,
                 use_bg=False,
                 use_oc=True,
                 fetch_attention=False,
                 bn_type=None):
        super(SpatialOCR_Module, self).__init__()
        self.use_gt = use_gt
        self.use_bg = use_bg
        self.use_oc = use_oc
        self.fetch_attention = fetch_attention
        self.object_context_block = ObjectAttentionBlock2D(
            in_channels, key_channels, scale, use_gt, use_bg, fetch_attention,
            bn_type)
        if self.use_bg:
            if self.use_oc:
                _in_channels = 3 * in_channels
            else:
                _in_channels = 2 * in_channels
        else:
            _in_channels = 2 * in_channels

        self.conv_bn_dropout = nn.Sequential(
            nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0),
            ModuleHelper.BNReLU(out_channels, bn_type=bn_type),
            nn.Dropout2d(dropout))
 def __init__(self,
              in_channels,
              key_channels,
              scale=1,
              use_gt=False,
              use_bg=False,
              fetch_attention=False,
              bn_type=None):
     super(_ObjectAttentionBlock, self).__init__()
     self.scale = scale
     self.in_channels = in_channels
     self.key_channels = key_channels
     self.use_gt = use_gt
     self.use_bg = use_bg
     self.fetch_attention = fetch_attention
     self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
     self.f_pixel = nn.Sequential(
         nn.Conv2d(in_channels=self.in_channels,
                   out_channels=self.key_channels,
                   kernel_size=1,
                   stride=1,
                   padding=0),
         ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
         nn.Conv2d(in_channels=self.key_channels,
                   out_channels=self.key_channels,
                   kernel_size=1,
                   stride=1,
                   padding=0),
         ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
     )
     self.f_object = nn.Sequential(
         nn.Conv2d(in_channels=self.in_channels,
                   out_channels=self.key_channels,
                   kernel_size=1,
                   stride=1,
                   padding=0),
         ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
         nn.Conv2d(in_channels=self.key_channels,
                   out_channels=self.key_channels,
                   kernel_size=1,
                   stride=1,
                   padding=0),
         ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
     )
     self.f_down = nn.Sequential(
         nn.Conv2d(in_channels=self.in_channels,
                   out_channels=self.key_channels,
                   kernel_size=1,
                   stride=1,
                   padding=0),
         ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
     )
     self.f_up = nn.Sequential(
         nn.Conv2d(in_channels=self.key_channels,
                   out_channels=self.in_channels,
                   kernel_size=1,
                   stride=1,
                   padding=0),
         ModuleHelper.BNReLU(self.in_channels, bn_type=bn_type),
     )
    def __init__(self, cfg, bn_type='torchbn', **kwargs):
        super(HighResolutionNext, self).__init__()
        # stem net
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
                               bias=False)
        self.bn1 = ModuleHelper.BatchNorm2d(bn_type=bn_type)(64)
        self.relu = nn.ReLU(relu_inplace)

        self.stage1_cfg = cfg['STAGE1']
        num_channels = self.stage1_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage1_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))]
        self.transition0 = self._make_transition_layer([64], num_channels, bn_type=bn_type)
        self.stage1, pre_stage_channels = self._make_stage(
            self.stage1_cfg, num_channels, bn_type=bn_type)

        self.stage2_cfg = cfg['STAGE2']
        num_channels = self.stage2_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage2_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))]
        self.transition1 = self._make_transition_layer(
            pre_stage_channels, num_channels, bn_type=bn_type)
        self.stage2, pre_stage_channels = self._make_stage(
            self.stage2_cfg, num_channels, bn_type=bn_type)

        self.stage3_cfg = cfg['STAGE3']
        num_channels = self.stage3_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage3_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))]
        self.transition2 = self._make_transition_layer(
            pre_stage_channels, num_channels, bn_type=bn_type)
        self.stage3, pre_stage_channels = self._make_stage(
            self.stage3_cfg, num_channels, bn_type=bn_type)

        self.stage4_cfg = cfg['STAGE4']
        num_channels = self.stage4_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage4_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))]
        self.transition3 = self._make_transition_layer(
            pre_stage_channels, num_channels, bn_type=bn_type)
        self.stage4, pre_stage_channels = self._make_stage(
            self.stage4_cfg, num_channels, multi_scale_output=True, bn_type=bn_type)
    def _make_layer(self, block, inplanes, planes, blocks, stride=1, bn_type=None, bn_momentum=0.1):
        downsample = None
        if stride != 1 or inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(
                    inplanes, planes * block.expansion,
                    kernel_size=1, stride=stride, bias=False
                ),
                ModuleHelper.BatchNorm2d(bn_type=bn_type)(planes * block.expansion, momentum=bn_momentum)
            )

        layers = []
        layers.append(block(inplanes, planes, stride, downsample, bn_type=bn_type, bn_momentum=bn_momentum))

        inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(inplanes, planes, bn_type=bn_type, bn_momentum=bn_momentum))

        return nn.Sequential(*layers)
    def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
                         stride=1, bn_type=None, bn_momentum=0.1):
        downsample = None
        if stride != 1 or \
           self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.num_inchannels[branch_index],
                    num_channels[branch_index] * block.expansion,
                    kernel_size=1, stride=stride, bias=False
                ),
                ModuleHelper.BatchNorm2d(bn_type=bn_type)(
                    num_channels[branch_index] * block.expansion, 
                    momentum=bn_momentum
                ),
            )

        layers = []
        layers.append(
            block(
                self.num_inchannels[branch_index],
                num_channels[branch_index],
                stride,
                downsample,
                bn_type=bn_type,
                bn_momentum=bn_momentum
            )
        )
        self.num_inchannels[branch_index] = \
            num_channels[branch_index] * block.expansion
        for i in range(1, num_blocks[branch_index]):
            layers.append(
                block(
                    self.num_inchannels[branch_index],
                    num_channels[branch_index], 
                    bn_type=bn_type,
                    bn_momentum=bn_momentum
                )
            )

        return nn.Sequential(*layers)
 def __init__(self,
              features,
              hidden_features=256,
              out_features=512,
              dilations=(12, 24, 36),
              num_classes=19,
              bn_type=None,
              dropout=0.1):
     super(SpatialOCR_ASP_Module, self).__init__()
     from lib.models.modules.spatial_ocr_block import SpatialOCR_Context
     self.context = nn.Sequential(
         nn.Conv2d(features,
                   hidden_features,
                   kernel_size=3,
                   padding=1,
                   dilation=1,
                   bias=True),
         ModuleHelper.BNReLU(hidden_features, bn_type=bn_type),
         SpatialOCR_Context(in_channels=hidden_features,
                            key_channels=hidden_features // 2,
                            scale=1,
                            bn_type=bn_type),
     )
     self.conv2 = nn.Sequential(
         nn.Conv2d(features,
                   hidden_features,
                   kernel_size=1,
                   padding=0,
                   dilation=1,
                   bias=True),
         ModuleHelper.BNReLU(hidden_features, bn_type=bn_type),
     )
     self.conv3 = nn.Sequential(
         nn.Conv2d(features,
                   hidden_features,
                   kernel_size=3,
                   padding=dilations[0],
                   dilation=dilations[0],
                   bias=True),
         ModuleHelper.BNReLU(hidden_features, bn_type=bn_type),
     )
     self.conv4 = nn.Sequential(
         nn.Conv2d(features,
                   hidden_features,
                   kernel_size=3,
                   padding=dilations[1],
                   dilation=dilations[1],
                   bias=True),
         ModuleHelper.BNReLU(hidden_features, bn_type=bn_type),
     )
     self.conv5 = nn.Sequential(
         nn.Conv2d(features,
                   hidden_features,
                   kernel_size=3,
                   padding=dilations[2],
                   dilation=dilations[2],
                   bias=True),
         ModuleHelper.BNReLU(hidden_features, bn_type=bn_type),
     )
     self.conv_bn_dropout = nn.Sequential(
         nn.Conv2d(hidden_features * 5,
                   out_features,
                   kernel_size=1,
                   padding=0,
                   dilation=1,
                   bias=True),
         ModuleHelper.BNReLU(out_features, bn_type=bn_type),
         nn.Dropout2d(dropout))
     self.object_head = SpatialGather_Module(num_classes)
            out = self._cat_each(feat1, feat2, feat3, feat4, feat5)
        else:
            raise RuntimeError('unknown input type')

        output = self.conv_bn_dropout(out)
        return output


if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    probs = torch.randn((1, 19, 128, 128)).cuda()
    feats = torch.randn((1, 2048, 128, 128)).cuda()

    conv_3x3 = nn.Sequential(
        nn.Conv2d(2048, 512, kernel_size=3, stride=1, padding=1),
        ModuleHelper.BNReLU(512, bn_type='inplace_abn'),
    )

    ocp_gather_infer = SpatialGather_Module(19)
    ocp_distr_infer = SpatialOCR_Module(in_channels=512,
                                        key_channels=256,
                                        out_channels=512,
                                        scale=1,
                                        dropout=0,
                                        bn_type='inplace_abn')

    ocp_gather_infer.eval()
    ocp_distr_infer.eval()
    conv_3x3.eval()
    ocp_gather_infer.cuda()
    ocp_distr_infer.cuda()
    def __init__(self, cfg, bn_type, bn_momentum, **kwargs):
        self.inplanes = 64
        super(HighResolutionNet, self).__init__()

        if os.environ.get('full_res_stem'):
            Log.info("using full-resolution stem with stride=1")
            stem_stride = 1
            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=stem_stride, padding=1,
                                bias=False)
            self.bn1 = ModuleHelper.BatchNorm2d(bn_type=bn_type)(64, momentum=bn_momentum)
            self.relu = nn.ReLU(inplace=False)
            self.layer1 = self._make_layer(Bottleneck, 64, 64, 4, bn_type=bn_type, bn_momentum=bn_momentum)        
        else:
            stem_stride = 2
            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=stem_stride, padding=1,
                                bias=False)
            self.bn1 = ModuleHelper.BatchNorm2d(bn_type=bn_type)(64, momentum=bn_momentum)
            self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=stem_stride, padding=1,
                                bias=False)
            self.bn2 = ModuleHelper.BatchNorm2d(bn_type=bn_type)(64, momentum=bn_momentum)
            self.relu = nn.ReLU(inplace=False)
            self.layer1 = self._make_layer(Bottleneck, 64, 64, 4, bn_type=bn_type, bn_momentum=bn_momentum)

        self.stage2_cfg = cfg['STAGE2']
        num_channels = self.stage2_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage2_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))
        ]

        self.transition1 = self._make_transition_layer([256], num_channels, bn_type=bn_type, bn_momentum=bn_momentum)
        
        self.stage2, pre_stage_channels = self._make_stage(
            self.stage2_cfg, num_channels, bn_type=bn_type, bn_momentum=bn_momentum)

        self.stage3_cfg = cfg['STAGE3']
        num_channels = self.stage3_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage3_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))
        ]
        self.transition2 = self._make_transition_layer(
            pre_stage_channels, num_channels, bn_type=bn_type, bn_momentum=bn_momentum)
        self.stage3, pre_stage_channels = self._make_stage(
            self.stage3_cfg, num_channels, bn_type=bn_type, bn_momentum=bn_momentum)

        self.stage4_cfg = cfg['STAGE4']
        num_channels = self.stage4_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage4_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))
        ]
        self.transition3 = self._make_transition_layer(
            pre_stage_channels, num_channels, bn_type=bn_type, bn_momentum=bn_momentum)

        self.stage4, pre_stage_channels = self._make_stage(
            self.stage4_cfg, num_channels, multi_scale_output=True, bn_type=bn_type, bn_momentum=bn_momentum)

        if os.environ.get('keep_imagenet_head'):
            self.incre_modules, self.downsamp_modules, \
                self.final_layer = self._make_head(pre_stage_channels, bn_type=bn_type, bn_momentum=bn_momentum)