def __init__(self, num_classes, trunk='wrn38', criterion=None):
        super(MscaleV3Plus, self).__init__()
        self.criterion = criterion
        self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk)
        self.aspp, aspp_out_ch = get_aspp(high_level_ch,
                                          bottleneck_ch=256,
                                          output_stride=8)
        self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False)

        # Semantic segmentation prediction head
        self.final = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False))

        # Scale-attention prediction head
        scale_in_ch = 2 * (256 + 48)

        self.scale_attn = nn.Sequential(
            nn.Conv2d(scale_in_ch, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 1, kernel_size=1, bias=False), nn.Sigmoid())

        if cfg.OPTIONS.INIT_DECODER:
            initialize_weights(self.bot_fine)
            initialize_weights(self.bot_aspp)
            initialize_weights(self.scale_attn)
            initialize_weights(self.final)
        else:
            initialize_weights(self.final)
    def _make_transition_layer(
            self, num_channels_pre_layer, num_channels_cur_layer):
        num_branches_cur = len(num_channels_cur_layer)
        num_branches_pre = len(num_channels_pre_layer)

        transition_layers = []
        for i in range(num_branches_cur):
            if i < num_branches_pre:
                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
                    transition_layers.append(nn.Sequential(
                        nn.Conv2d(num_channels_pre_layer[i],
                                  num_channels_cur_layer[i],
                                  3,
                                  1,
                                  1,
                                  bias=False),
                        Norm2d(
                            num_channels_cur_layer[i], momentum=BN_MOMENTUM),
                        nn.ReLU(inplace=relu_inplace)))
                else:
                    transition_layers.append(None)
            else:
                conv3x3s = []
                for j in range(i+1-num_branches_pre):
                    inchannels = num_channels_pre_layer[-1]
                    outchannels = num_channels_cur_layer[i] \
                        if j == i-num_branches_pre else inchannels
                    conv3x3s.append(nn.Sequential(
                        nn.Conv2d(
                            inchannels, outchannels, 3, 2, 1, bias=False),
                        Norm2d(outchannels, momentum=BN_MOMENTUM),
                        nn.ReLU(inplace=relu_inplace)))
                transition_layers.append(nn.Sequential(*conv3x3s))

        return nn.ModuleList(transition_layers)
Esempio n. 3
0
    def __init__(self, num_classes, trunk='wrn38', criterion=None,
                 use_dpc=False, init_all=False):
        super(DeepV3Plus, self).__init__()
        self.criterion = criterion
        self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk)
        self.aspp, aspp_out_ch = get_aspp(high_level_ch,
                                          bottleneck_ch=256,
                                          output_stride=8,
                                          dpc=use_dpc)
        self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False)
        self.final = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False))

        if init_all:
            initialize_weights(self.aspp)
            initialize_weights(self.bot_aspp)
            initialize_weights(self.bot_fine)
            initialize_weights(self.final)
        else:
            initialize_weights(self.final)
    def __init__(self, in_dim, reduction_dim=256, output_stride=16, rates=(6, 12, 18)):
        super(_AtrousSpatialPyramidPoolingModule, self).__init__()

        # Check if we are using distributed BN and use the nn from encoding.nn
        # library rather than using standard pytorch.nn

        if output_stride == 8:
            rates = [2 * r for r in rates]
        elif output_stride == 16:
            pass
        else:
            raise 'output stride of {} not supported'.format(output_stride)

        self.features = []
        # 1x1
        self.features.append(
            nn.Sequential(nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
                          Norm2d(reduction_dim), nn.ReLU(inplace=True)))
        # other rates
        for r in rates:
            self.features.append(nn.Sequential(
                nn.Conv2d(in_dim, reduction_dim, kernel_size=3,
                          dilation=r, padding=r, bias=False),
                Norm2d(reduction_dim),
                nn.ReLU(inplace=True)
            ))
        self.features = torch.nn.ModuleList(self.features)

        # img level features
        self.img_pooling = nn.AdaptiveAvgPool2d(1)
        self.img_conv = nn.Sequential(
            nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
            Norm2d(reduction_dim), nn.ReLU(inplace=True))
Esempio n. 5
0
def make_attn_head(in_ch, out_ch):
    bot_ch = cfg.MODEL.SEGATTN_BOT_CH
    if cfg.MODEL.MSCALE_OLDARCH:
        return old_make_attn_head(in_ch, bot_ch, out_ch)

    od = OrderedDict([('conv0',
                       nn.Conv2d(in_ch,
                                 bot_ch,
                                 kernel_size=3,
                                 padding=1,
                                 bias=False)), ('bn0', Norm2d(bot_ch)),
                      ('re0', nn.ReLU(inplace=True))])

    if cfg.MODEL.MSCALE_INNER_3x3:
        od['conv1'] = nn.Conv2d(bot_ch,
                                bot_ch,
                                kernel_size=3,
                                padding=1,
                                bias=False)
        od['bn1'] = Norm2d(bot_ch)
        od['re1'] = nn.ReLU(inplace=True)

    if cfg.MODEL.MSCALE_DROPOUT:
        od['drop'] = nn.Dropout(0.5)

    od['conv2'] = nn.Conv2d(bot_ch, out_ch, kernel_size=1, bias=False)
    od['sig'] = nn.Sigmoid()

    attn_head = nn.Sequential(od)
    # init_attn(attn_head)
    return attn_head
Esempio n. 6
0
    def __init__(self, in_dim, reduction_dim=256, output_stride=16,
                 rates=(6, 12, 18)):
        super(AtrousSpatialPyramidPoolingModule, self).__init__()

        if output_stride == 8:
            rates = [2 * r for r in rates]
        elif output_stride == 16:
            pass
        else:
            raise 'output stride of {} not supported'.format(output_stride)

        self.features = []
        # 1x1
        self.features.append(
            nn.Sequential(nn.Conv2d(in_dim, reduction_dim, kernel_size=1,
                                    bias=False),
                          Norm2d(reduction_dim), nn.ReLU(inplace=True)))
        # other rates
        for r in rates:
            self.features.append(nn.Sequential(
                nn.Conv2d(in_dim, reduction_dim, kernel_size=3,
                          dilation=r, padding=r, bias=False),
                Norm2d(reduction_dim),
                nn.ReLU(inplace=True)
            ))
        self.features = nn.ModuleList(self.features)

        # img level features
        self.img_pooling = nn.AdaptiveAvgPool2d(1)
        self.img_conv = nn.Sequential(
            nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
            Norm2d(reduction_dim), nn.ReLU(inplace=True))
Esempio n. 7
0
def make_seg_head(in_ch, out_ch):
    bot_ch = cfg.MODEL.SEGATTN_BOT_CH
    return nn.Sequential(
        nn.Conv2d(in_ch, bot_ch, kernel_size=3, padding=1, bias=False),
        Norm2d(bot_ch), nn.ReLU(inplace=True),
        nn.Conv2d(bot_ch, bot_ch, kernel_size=3, padding=1, bias=False),
        Norm2d(bot_ch), nn.ReLU(inplace=True),
        nn.Conv2d(bot_ch, out_ch, kernel_size=1, bias=False))
Esempio n. 8
0
    def __init__(self, num_classes, trunk=None, criterion=None):
        
        super(GSCNN, self).__init__()
        self.criterion = criterion
        self.num_classes = num_classes

        wide_resnet = wider_resnet38_a2(classes=1000, dilation=True)
        wide_resnet = torch.nn.DataParallel(wide_resnet)
        
        wide_resnet = wide_resnet.module
        self.mod1 = wide_resnet.mod1
        self.mod2 = wide_resnet.mod2
        self.mod3 = wide_resnet.mod3
        self.mod4 = wide_resnet.mod4
        self.mod5 = wide_resnet.mod5
        self.mod6 = wide_resnet.mod6
        self.mod7 = wide_resnet.mod7
        self.pool2 = wide_resnet.pool2
        self.pool3 = wide_resnet.pool3
        self.interpolate = F.interpolate
        del wide_resnet

        self.dsn1 = nn.Conv2d(64, 1, 1)
        self.dsn3 = nn.Conv2d(256, 1, 1)
        self.dsn4 = nn.Conv2d(512, 1, 1)
        self.dsn7 = nn.Conv2d(4096, 1, 1)

        self.res1 = Resnet.BasicBlock(64, 64, stride=1, downsample=None)
        self.d1 = nn.Conv2d(64, 32, 1)
        self.res2 = Resnet.BasicBlock(32, 32, stride=1, downsample=None)
        self.d2 = nn.Conv2d(32, 16, 1)
        self.res3 = Resnet.BasicBlock(16, 16, stride=1, downsample=None)
        self.d3 = nn.Conv2d(16, 8, 1)
        self.fuse = nn.Conv2d(8, 1, kernel_size=1, padding=0, bias=False)

        self.cw = nn.Conv2d(2, 1, kernel_size=1, padding=0, bias=False)

        self.gate1 = gsc.GatedSpatialConv2d(32, 32)
        self.gate2 = gsc.GatedSpatialConv2d(16, 16)
        self.gate3 = gsc.GatedSpatialConv2d(8, 8)
         
        self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256,
                                                       output_stride=8)

        self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(1280 + 256, 256, kernel_size=1, bias=False)

        self.final_seg = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False))

        self.sigmoid = nn.Sigmoid()
        initialize_weights(self.final_seg)
Esempio n. 9
0
 def __init__(self, inplanes, planes, stride=1, downsample=None):
     super(BasicBlock, self).__init__()
     self.conv1 = conv3x3(inplanes, planes, stride)
     self.bn1 = Norm2d(planes, momentum=BN_MOMENTUM)
     self.relu = nn.ReLU(inplace=relu_inplace)
     self.conv2 = conv3x3(planes, planes)
     self.bn2 = Norm2d(planes, momentum=BN_MOMENTUM)
     self.downsample = downsample
     self.stride = stride
Esempio n. 10
0
    def _make_fuse_layers(self):
        if self.num_branches == 1:
            return None

        num_branches = self.num_branches
        num_inchannels = self.num_inchannels
        fuse_layers = []
        for i in range(num_branches if self.multi_scale_output else 1):
            fuse_layer = []
            for j in range(num_branches):
                if j > i:
                    fuse_layer.append(
                        nn.Sequential(
                            nn.Conv2d(num_inchannels[j],
                                      num_inchannels[i],
                                      1,
                                      1,
                                      0,
                                      bias=False),
                            Norm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
                elif j == i:
                    fuse_layer.append(None)
                else:
                    conv3x3s = []
                    for k in range(i - j):
                        if k == i - j - 1:
                            num_outchannels_conv3x3 = num_inchannels[i]
                            conv3x3s.append(
                                nn.Sequential(
                                    nn.Conv2d(num_inchannels[j],
                                              num_outchannels_conv3x3,
                                              3,
                                              2,
                                              1,
                                              bias=False),
                                    Norm2d(num_outchannels_conv3x3,
                                           momentum=BN_MOMENTUM)))
                        else:
                            num_outchannels_conv3x3 = num_inchannels[j]
                            conv3x3s.append(
                                nn.Sequential(
                                    nn.Conv2d(num_inchannels[j],
                                              num_outchannels_conv3x3,
                                              3,
                                              2,
                                              1,
                                              bias=False),
                                    Norm2d(num_outchannels_conv3x3,
                                           momentum=BN_MOMENTUM),
                                    nn.ReLU(inplace=relu_inplace)))
                    fuse_layer.append(nn.Sequential(*conv3x3s))
            fuse_layers.append(nn.ModuleList(fuse_layer))

        return nn.ModuleList(fuse_layers)
Esempio n. 11
0
def old_make_attn_head(in_ch, bot_ch, out_ch):
    attn = nn.Sequential(
        nn.Conv2d(in_ch, bot_ch, kernel_size=3, padding=1, bias=False),
        Norm2d(bot_ch), nn.ReLU(inplace=True),
        nn.Conv2d(bot_ch, bot_ch, kernel_size=3, padding=1, bias=False),
        Norm2d(bot_ch), nn.ReLU(inplace=True),
        nn.Conv2d(bot_ch, out_ch, kernel_size=out_ch, bias=False),
        nn.Sigmoid())

    init_attn(attn)
    return attn
Esempio n. 12
0
 def __init__(self, inplanes, planes, stride=1, downsample=None):
     super(Bottleneck, self).__init__()
     self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
     self.bn1 = Norm2d(planes, momentum=BN_MOMENTUM)
     self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                            padding=1, bias=False)
     self.bn2 = Norm2d(planes, momentum=BN_MOMENTUM)
     self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                            bias=False)
     self.bn3 = Norm2d(planes * self.expansion, momentum=BN_MOMENTUM)
     self.relu = nn.ReLU(inplace=relu_inplace)
     self.downsample = downsample
     self.stride = stride
Esempio n. 13
0
    def __init__(self, **kwargs):
        extra = cfg.MODEL.OCR_EXTRA
        super(HighResolutionNet, self).__init__()

        # stem net
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
                               bias=False)
        self.bn1 = Norm2d(64, momentum=BN_MOMENTUM)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
                               bias=False)
        self.bn2 = Norm2d(64, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=relu_inplace)

        self.stage1_cfg = extra['STAGE1']
        num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
        block = blocks_dict[self.stage1_cfg['BLOCK']]
        num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
        self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
        stage1_out_channel = block.expansion*num_channels

        self.stage2_cfg = extra['STAGE2']
        num_channels = self.stage2_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage2_cfg['BLOCK']]
        num_channels = [num_channels[i] * block.expansion
                        for i in range(len(num_channels))]
        self.transition1 = self._make_transition_layer(
            [stage1_out_channel], num_channels)
        self.stage2, pre_stage_channels = self._make_stage(
            self.stage2_cfg, num_channels)

        self.stage3_cfg = extra['STAGE3']
        num_channels = self.stage3_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage3_cfg['BLOCK']]
        num_channels = [num_channels[i] * block.expansion
                        for i in range(len(num_channels))]
        self.transition2 = self._make_transition_layer(
            pre_stage_channels, num_channels)
        self.stage3, pre_stage_channels = self._make_stage(
            self.stage3_cfg, num_channels)

        self.stage4_cfg = extra['STAGE4']
        num_channels = self.stage4_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage4_cfg['BLOCK']]
        num_channels = [num_channels[i] * block.expansion
                        for i in range(len(num_channels))]
        self.transition3 = self._make_transition_layer(
            pre_stage_channels, num_channels)
        self.stage4, pre_stage_channels = self._make_stage(
            self.stage4_cfg, num_channels, multi_scale_output=True)

        self.high_level_ch = np.int(np.sum(pre_stage_channels))
Esempio n. 14
0
    def __init__(self,
                 num_classes,
                 input_size,
                 trunk='WideResnet38',
                 criterion=None):

        super(DeepWV3Plus, self).__init__()
        self.criterion = criterion
        logging.info("Trunk: %s", trunk)
        wide_resnet = wider_resnet38_a2(classes=1000, dilation=True)
        # TODO: Should this be even here ?
        wide_resnet = torch.nn.DataParallel(wide_resnet)
        try:
            checkpoint = torch.load('weights/ResNet/wider_resnet38.pth.tar',
                                    map_location='cpu')
            wide_resnet.load_state_dict(checkpoint['state_dict'])
            del checkpoint
        except:
            print(
                "=====================Could not load imagenet weights======================="
            )

        wide_resnet = wide_resnet.module

        self.mod1 = wide_resnet.mod1
        self.mod2 = wide_resnet.mod2
        self.mod3 = wide_resnet.mod3
        self.mod4 = wide_resnet.mod4
        self.mod5 = wide_resnet.mod5
        self.mod6 = wide_resnet.mod6
        self.mod7 = wide_resnet.mod7
        self.pool2 = wide_resnet.pool2
        self.pool3 = wide_resnet.pool3
        del wide_resnet

        self.aspp = _AtrousSpatialPyramidPoolingModule(4096,
                                                       256,
                                                       output_stride=8)

        self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False)

        self.final = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False))

        initialize_weights(self.final)
Esempio n. 15
0
    def __init__(self, num_classes, trunk='WideResnet38', criterion=None):

        super(DeepWV3Plus, self).__init__()
        self.criterion = criterion
        logging.info("Trunk: %s", trunk)

        wide_resnet = wider_resnet38_a2(classes=1000, dilation=True)
        wide_resnet = torch.nn.DataParallel(wide_resnet)
        if criterion is not None:
            try:
                checkpoint = torch.load(
                    './pretrained_models/wider_resnet38.pth.tar',
                    map_location='cpu')
                wide_resnet.load_state_dict(checkpoint['state_dict'])
                del checkpoint
            except:
                print(
                    "Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models/wider_resnet38.pth.tar."
                )
                raise RuntimeError(
                    "=====================Could not load ImageNet weights of WideResNet38 network.======================="
                )
        wide_resnet = wide_resnet.module

        self.mod1 = wide_resnet.mod1
        self.mod2 = wide_resnet.mod2
        self.mod3 = wide_resnet.mod3
        self.mod4 = wide_resnet.mod4
        self.mod5 = wide_resnet.mod5
        self.mod6 = wide_resnet.mod6
        self.mod7 = wide_resnet.mod7
        self.pool2 = wide_resnet.pool2
        self.pool3 = wide_resnet.pool3
        del wide_resnet

        self.aspp = _AtrousSpatialPyramidPoolingModule(4096,
                                                       256,
                                                       output_stride=8)

        self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False)

        self.final = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False))

        initialize_weights(self.final)
Esempio n. 16
0
    def _make_one_branch(self,
                         branch_index,
                         block,
                         num_blocks,
                         num_channels,
                         stride=1):
        downsample = None
        if stride != 1 or \
           self.num_inchannels[branch_index] != (num_channels[branch_index] *
                                                 block.expansion):
            downsample = nn.Sequential(
                nn.Conv2d(self.num_inchannels[branch_index],
                          num_channels[branch_index] * block.expansion,
                          kernel_size=1,
                          stride=stride,
                          bias=False),
                Norm2d(num_channels[branch_index] * block.expansion,
                       momentum=BN_MOMENTUM),
            )

        layers = []
        layers.append(
            block(self.num_inchannels[branch_index],
                  num_channels[branch_index], stride, downsample))
        self.num_inchannels[branch_index] = \
            num_channels[branch_index] * block.expansion
        for i in range(1, num_blocks[branch_index]):
            layers.append(
                block(self.num_inchannels[branch_index],
                      num_channels[branch_index]))

        return nn.Sequential(*layers)
Esempio n. 17
0
    def __init__(self, num_classes, trunk='wrn38', criterion=None,
                 use_dpc=False, fuse_aspp=False, attn_2b=False, attn_cls=False):
        super(MscaleV3Plus, self).__init__()
        self.criterion = criterion
        self.fuse_aspp = fuse_aspp
        self.attn_2b = attn_2b
        self.attn_cls = attn_cls
        self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk)
        self.aspp, aspp_out_ch = get_aspp(high_level_ch,
                                          bottleneck_ch=256,
                                          output_stride=8,
                                          dpc=use_dpc)
        self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False)

        # Semantic segmentation prediction head
        bot_ch = cfg.MODEL.SEGATTN_BOT_CH
        self.final = nn.Sequential(
            nn.Conv2d(256 + 48, bot_ch, kernel_size=3, padding=1, bias=False),
            Norm2d(bot_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(bot_ch, bot_ch, kernel_size=3, padding=1, bias=False),
            Norm2d(bot_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(bot_ch, num_classes, kernel_size=1, bias=False))

        # Scale-attention prediction head
        if self.attn_2b:
            attn_ch = 2
        else:
            if self.attn_cls:
                attn_ch = num_classes
            else:
                attn_ch = 1

        scale_in_ch = 256 + 48

        self.scale_attn = make_attn_head(in_ch=scale_in_ch,
                                         out_ch=attn_ch)

        if cfg.OPTIONS.INIT_DECODER:
            initialize_weights(self.bot_fine)
            initialize_weights(self.bot_aspp)
            initialize_weights(self.scale_attn)
            initialize_weights(self.final)
        else:
            initialize_weights(self.final)
    def __init__(self, num_classes, trunk='wrn38', criterion=None,
                 use_dpc=False, fuse_aspp=False, attn_2b=False):
        super(MscaleV3Plus, self).__init__()
        self.criterion = criterion
        self.fuse_aspp = fuse_aspp
        self.attn_2b = attn_2b
        self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk)
        self.aspp, aspp_out_ch = get_aspp(high_level_ch,
                                          bottleneck_ch=256,
                                          output_stride=8,
                                          dpc=use_dpc,
                                          img_norm = False)
        self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False)
        #self.asnb = ASNB(low_in_channels = 48, high_in_channels=256, out_channels=256, key_channels=64, value_channels=256, dropout=0., sizes=([1]), norm_type='batchnorm',attn_scale=0.25)
        self.adnb = ADNB(d_model=256, nhead=8, num_encoder_layers=2, dim_feedforward=256, dropout=0.5, activation="relu", num_feature_levels=1, enc_n_points=4)
        # Semantic segmentation prediction head
        bot_ch = cfg.MODEL.SEGATTN_BOT_CH
        self.final = nn.Sequential(
            nn.Conv2d(256 + 48, bot_ch, kernel_size=3, padding=1, bias=False),
            Norm2d(bot_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(bot_ch, bot_ch, kernel_size=3, padding=1, bias=False),
            Norm2d(bot_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(bot_ch, num_classes, kernel_size=1, bias=False))

        # Scale-attention prediction head
        if self.attn_2b:
            attn_ch = 2
        else:
            attn_ch = 1

        scale_in_ch = 256 + 48

        self.scale_attn = make_attn_head(in_ch=scale_in_ch,
                                         out_ch=attn_ch)

        if cfg.OPTIONS.INIT_DECODER:
            initialize_weights(self.bot_fine)
            initialize_weights(self.bot_aspp)
            initialize_weights(self.scale_attn)
            initialize_weights(self.final)
        else:
            initialize_weights(self.final)
Esempio n. 19
0
 def __init__(self, in_dim, reduction_dim=256, output_stride=16,
              rates=(6, 12, 18)):
     super(ASPP_edge, self).__init__(in_dim=in_dim,
                                     reduction_dim=reduction_dim,
                                     output_stride=output_stride,
                                     rates=rates)
     self.edge_conv = nn.Sequential(
         nn.Conv2d(1, reduction_dim, kernel_size=1, bias=False),
         Norm2d(reduction_dim), nn.ReLU(inplace=True))
    def __init__(self,
                 num_classes,
                 trunk='wrn38',
                 criterion=None,
                 use_dpc=False,
                 init_all=False):
        super(DeepV3PlusATTN, self).__init__()
        self.criterion = criterion
        self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk)
        #self.aspp, aspp_out_ch = get_aspp(high_level_ch,
        #                                  bottleneck_ch=256,
        #                                  output_stride=8,
        #                                  dpc=use_dpc)
        #self.attn = APNB(in_channels=high_level_ch, out_channels=high_level_ch, key_channels=256, value_channels=256, dropout=0.5, sizes=([1]), norm_type='batchnorm', psp_size=(1, 3, 6, 8))
        self.attn = AFNB(low_in_channels=2048,
                         high_in_channels=4096,
                         out_channels=256,
                         key_channels=64,
                         value_channels=256,
                         dropout=0.8,
                         sizes=([1]),
                         norm_type='batchnorm',
                         psp_size=(1, 3, 6, 8))
        self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(256, 256, kernel_size=1, bias=False)
        self.final = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False))

        if init_all:
            initialize_weights(self.attn)
            initialize_weights(self.bot_aspp)
            initialize_weights(self.bot_fine)
            initialize_weights(self.final)
        else:
            initialize_weights(self.final)
Esempio n. 21
0
    def _make_layer(self, block, inplanes, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                Norm2d(planes * block.expansion, momentum=BN_MOMENTUM),
            )

        layers = []
        layers.append(block(inplanes, planes, stride, downsample))
        inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(inplanes, planes))

        return nn.Sequential(*layers)
    def __init__(self, num_classes, trunk='WideResnet38', criterion=None,ngf=64, input_nc=1, output_nc=2):

        super(DeepWV3Plus, self).__init__()
        self.criterion = criterion
        self.MSEcriterion= torch.nn.MSELoss()
        logging.info("Trunk: %s", trunk)
        wide_resnet = wider_resnet38_a2(classes=1000, dilation=True)
        wide_resnet = torch.nn.DataParallel(wide_resnet)
        try:
            checkpoint = torch.load('/srv/beegfs02/scratch/language_vision/data/Sound_Event_Prediction/semantic-segmentation-master/pretrained_models/wider_resnet38.pth.tar', map_location='cpu')
            wide_resnet.load_state_dict(checkpoint['state_dict'])
            del checkpoint
        except:
            print("=====================Could not load ImageNet weights=======================")
            print("Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models.")

        #audio_unet = AudioNet_multitask(ngf=64,input_nc=2)
        #Acheckpoint = torch.load('/srv/beegfs02/scratch/language_vision/data/Sound_Event_Prediction/audio/audioSynthesis/checkpoints/synBi2Bi_16_25/3_audio.pth', map_location='cpu')
        #pretrained_dict = Acheckpoint
        #model_dict = audio_unet.state_dict()
        #pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and k!='audionet_upconvlayer1.0.weight' and k!='audionet_upconvlayer5.0.weight' and k!='audionet_upconvlayer5.0.bias' and k!='conv1x1.0.weight' and k!='conv1x1.0.bias' and k!='conv1x1.1.weight' and k!='conv1x1.1.bias' and k!='conv1x1.1.running_mean' and k!='conv1x1.1.running_var'}
        #model_dict.update(pretrained_dict) 
        #audio_unet.load_state_dict(model_dict)

        self.audionet_convlayer1 = unet_conv(input_nc, ngf)
        self.audionet_convlayer2 = unet_conv(ngf, ngf * 2)
        self.audionet_convlayer3 = unet_conv(ngf * 2, ngf * 4)
        self.audionet_convlayer4 = unet_conv(ngf * 4, ngf * 8)
        self.audionet_convlayer5 = unet_conv(ngf * 8, ngf * 8)
        self.audionet_upconvlayer1 = unet_upconv(1024, ngf * 8) #1296 (audio-visual feature) = 784 (visual feature) + 512 (audio feature)
        self.audionet_upconvlayer2 = unet_upconv(ngf * 8, ngf *4)
        self.audionet_upconvlayer3 = unet_upconv(ngf * 4, ngf * 2)
        self.audionet_upconvlayer4 = unet_upconv(ngf * 2, ngf)
        self.audionet_upconvlayer5 = unet_upconv(ngf  , output_nc, True) #outermost layer use a sigmoid to bound the mask
        self.conv1x1 = create_conv(4096, 2, 1, 0)

        wide_resnet = wide_resnet.module
        #self.unet= audio_unet
        #print(wide_resnet)
        '''
        self.mod1 = wide_resnet.mod1
        self.mod2 = wide_resnet.mod2
        self.mod3 = wide_resnet.mod3
        self.mod4 = wide_resnet.mod4
        self.mod5 = wide_resnet.mod5
        self.mod6 = wide_resnet.mod6
        self.mod7 = wide_resnet.mod7
        self.pool2 = wide_resnet.pool2
        self.pool3 = wide_resnet.pool3
        '''
        del wide_resnet
        
        self.aspp = _AtrousSpatialPyramidPoolingModule(512, 64,
                                                       output_stride=8)
        self.depthaspp = _AtrousSpatialPyramidPoolingModule(512,64,
                                                       output_stride=8)

        self.bot_aud1 = nn.Conv2d(512, 256, kernel_size=1, bias=False)
        self.bot_multiaud = nn.Conv2d(512, 512, kernel_size=1, bias=False)
        self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(320, 256, kernel_size=1, bias=False)
        self.bot_depthaspp = nn.Conv2d(320, 128, kernel_size=1, bias=False)

        self.final = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False))
        self.final_depth = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
            Norm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
            Norm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 1, kernel_size=1, bias=False))

        initialize_weights(self.final);initialize_weights(self.bot_aud1);initialize_weights(self.bot_multiaud);
    def __init__(self,
                 num_classes,
                 trunk='wrn38',
                 criterion=None,
                 use_dpc=False,
                 fuse_aspp=False,
                 attn_2b=False,
                 bn_head=False):
        super(ASDV3P, self).__init__()
        self.criterion = criterion
        self.fuse_aspp = fuse_aspp
        self.attn_2b = attn_2b
        self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk)
        self.aspp, aspp_out_ch = get_aspp(high_level_ch,
                                          bottleneck_ch=256,
                                          output_stride=8,
                                          dpc=use_dpc)
        self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False)

        # Semantic segmentation prediction head
        self.final = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False))

        # Scale-attention prediction head
        assert cfg.MODEL.N_SCALES is not None
        self.scales = sorted(cfg.MODEL.N_SCALES)

        num_scales = len(self.scales)
        if cfg.MODEL.ATTNSCALE_BN_HEAD or bn_head:
            self.scale_attn = nn.Sequential(
                nn.Conv2d(num_scales * (256 + 48),
                          256,
                          kernel_size=3,
                          padding=1,
                          bias=False), Norm2d(256), nn.ReLU(inplace=True),
                nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
                Norm2d(256), nn.ReLU(inplace=True),
                nn.Conv2d(256, num_scales, kernel_size=1, bias=False))
        else:
            self.scale_attn = nn.Sequential(
                nn.Conv2d(num_scales * (256 + 48),
                          512,
                          kernel_size=3,
                          padding=1,
                          bias=False), nn.ReLU(inplace=True),
                nn.Conv2d(512,
                          num_scales,
                          kernel_size=1,
                          padding=1,
                          bias=False))

        if cfg.OPTIONS.INIT_DECODER:
            initialize_weights(self.bot_fine)
            initialize_weights(self.bot_aspp)
            initialize_weights(self.scale_attn)
            initialize_weights(self.final)
        else:
            initialize_weights(self.final)
Esempio n. 24
0
    def __init__(self, num_classes, trunk=None, criterion=None):

        super(GSCNN, self).__init__()
        self.criterion = criterion
        self.num_classes = num_classes

        wide_resnet = wider_resnet38_a2(classes=1000, dilation=True)
        wide_resnet = torch.nn.DataParallel(wide_resnet)

        try:
            checkpoint = torch.load(
                './network/pretrained_models/wider_resnet38.pth.tar',
                map_location='cpu')
            wide_resnet.load_state_dict(checkpoint['state_dict'])
            del checkpoint
        except:
            print(
                "Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models/wider_resnet38.pth.tar."
            )
            raise RuntimeError(
                "=====================Could not load ImageNet weights of WideResNet38 network.======================="
            )

        wide_resnet = wide_resnet.module
        self.mod1 = wide_resnet.mod1
        self.mod2 = wide_resnet.mod2
        self.mod3 = wide_resnet.mod3
        self.mod4 = wide_resnet.mod4
        self.mod5 = wide_resnet.mod5
        self.mod6 = wide_resnet.mod6
        self.mod7 = wide_resnet.mod7
        self.pool2 = wide_resnet.pool2
        self.pool3 = wide_resnet.pool3
        self.interpolate = F.interpolate
        del wide_resnet

        self.dsn1 = nn.Conv2d(64, 1, 1)
        self.dsn3 = nn.Conv2d(256, 1, 1)
        self.dsn4 = nn.Conv2d(512, 1, 1)
        self.dsn7 = nn.Conv2d(4096, 1, 1)

        self.res1 = Resnet.BasicBlock(64, 64, stride=1, downsample=None)
        self.d1 = nn.Conv2d(64, 32, 1)
        self.res2 = Resnet.BasicBlock(32, 32, stride=1, downsample=None)
        self.d2 = nn.Conv2d(32, 16, 1)
        self.res3 = Resnet.BasicBlock(16, 16, stride=1, downsample=None)
        self.d3 = nn.Conv2d(16, 8, 1)
        self.fuse = nn.Conv2d(8, 1, kernel_size=1, padding=0, bias=False)

        self.cw = nn.Conv2d(2, 1, kernel_size=1, padding=0, bias=False)

        self.gate1 = gsc.GatedSpatialConv2d(32, 32)
        self.gate2 = gsc.GatedSpatialConv2d(16, 16)
        self.gate3 = gsc.GatedSpatialConv2d(8, 8)

        self.aspp = _AtrousSpatialPyramidPoolingModule(4096,
                                                       256,
                                                       output_stride=8)

        self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(1280 + 256, 256, kernel_size=1, bias=False)

        self.final_seg = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False))

        self.sigmoid = nn.Sigmoid()
        initialize_weights(self.final_seg)
Esempio n. 25
0
def BNReLU(ch):
    return nn.Sequential(Norm2d(ch), nn.ReLU())
Esempio n. 26
0
    def __init__(self, num_classes, trunk='resnet-101', criterion=None, criterion_aux=None,
                variant='D', skip='m1', skip_num=48, args=None):
        super(DeepV3Plus, self).__init__()
        self.criterion = criterion
        self.criterion_aux = criterion_aux
        self.variant = variant
        self.args = args
        self.trunk = trunk

        if trunk == 'shufflenetv2':
            channel_1st = 3
            channel_2nd = 24
            channel_3rd = 116
            channel_4th = 232
            prev_final_channel = 464
            final_channel = 1024
            resnet = Shufflenet.shufflenet_v2_x1_0(pretrained=True, iw=self.args.wt_layer)

            class Layer0(nn.Module):
                def __init__(self, iw):
                    super(Layer0, self).__init__()
                    self.layer = nn.Sequential(resnet.conv1, resnet.maxpool)
                    self.instance_norm_layer = resnet.instance_norm_layer1
                    self.iw = iw

                def forward(self, x_tuple):
                    if len(x_tuple) == 2:
                        w_arr = x_tuple[1]
                        x = x_tuple[0]
                    else:
                        print("error in shufflnet layer 0 forward path")
                        return

                    x = self.layer[0][0](x)
                    if self.iw >= 1:
                        if self.iw == 1 or self.iw == 2:
                            x, w = self.instance_norm_layer(x)
                            w_arr.append(w)
                        else:
                            x = self.instance_norm_layer(x)
                    else:
                        x = self.layer[0][1](x)

                    x = self.layer[0][2](x)
                    x = self.layer[1](x)

                    return [x, w_arr]

            class Layer4(nn.Module):
                def __init__(self, iw):
                    super(Layer4, self).__init__()
                    self.layer = resnet.conv5
                    self.instance_norm_layer = resnet.instance_norm_layer2
                    self.iw = iw

                def forward(self, x_tuple):
                    if len(x_tuple) == 2:
                        w_arr = x_tuple[1]
                        x = x_tuple[0]
                    else:
                        print("error in shufflnet layer 4 forward path")
                        return

                    x = self.layer[0](x)
                    if self.iw >= 1:
                        if self.iw == 1 or self.iw == 2:
                            x, w = self.instance_norm_layer(x)
                            w_arr.append(w)
                        else:
                            x = self.instance_norm_layer(x)
                    else:
                        x = self.layer[1](x)
                    x = self.layer[2](x)

                    return [x, w_arr]


            self.layer0 = Layer0(iw=self.args.wt_layer[2])
            self.layer1 = resnet.stage2
            self.layer2 = resnet.stage3
            self.layer3 = resnet.stage4
            self.layer4 = Layer4(iw=self.args.wt_layer[6])

            if self.variant == 'D':
                for n, m in self.layer2.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride==(2,2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride==(2,2):
                        m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif self.variant == 'D16':
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride==(2,2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
            else:
                # raise 'unknown deepv3 variant: {}'.format(self.variant)
                print("Not using Dilation ")

        elif trunk == 'mnasnet_05' or trunk == 'mnasnet_10':

            if trunk == 'mnasnet_05':
                resnet = models.mnasnet0_5(pretrained=True)
                channel_1st = 3
                channel_2nd = 16
                channel_3rd = 24
                channel_4th = 48
                prev_final_channel = 160
                final_channel = 1280

                print("# of layers", len(resnet.layers))
                self.layer0 = nn.Sequential(resnet.layers[0],resnet.layers[1],resnet.layers[2],
                                            resnet.layers[3],resnet.layers[4],resnet.layers[5],resnet.layers[6],resnet.layers[7])   # 16
                self.layer1 = nn.Sequential(resnet.layers[8], resnet.layers[9]) # 24, 40
                self.layer2 = nn.Sequential(resnet.layers[10], resnet.layers[11])   # 48, 96
                self.layer3 = nn.Sequential(resnet.layers[12], resnet.layers[13]) # 160, 320
                self.layer4 = nn.Sequential(resnet.layers[14], resnet.layers[15], resnet.layers[16])  # 1280
            else:
                resnet = models.mnasnet1_0(pretrained=True)
                channel_1st = 3
                channel_2nd = 16
                channel_3rd = 40
                channel_4th = 96
                prev_final_channel = 320
                final_channel = 1280

                print("# of layers", len(resnet.layers))
                self.layer0 = nn.Sequential(resnet.layers[0],resnet.layers[1],resnet.layers[2],
                                            resnet.layers[3],resnet.layers[4],resnet.layers[5],resnet.layers[6],resnet.layers[7])   # 16
                self.layer1 = nn.Sequential(resnet.layers[8], resnet.layers[9]) # 24, 40
                self.layer2 = nn.Sequential(resnet.layers[10], resnet.layers[11])   # 48, 96
                self.layer3 = nn.Sequential(resnet.layers[12], resnet.layers[13]) # 160, 320
                self.layer4 = nn.Sequential(resnet.layers[14], resnet.layers[15], resnet.layers[16])  # 1280

            if self.variant == 'D':
                for n, m in self.layer2.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride==(2,2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride==(2,2):
                        m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif self.variant == 'D16':
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride==(2,2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
            else:
                # raise 'unknown deepv3 variant: {}'.format(self.variant)
                print("Not using Dilation ")
        elif trunk == 'mobilenetv2':
            channel_1st = 3
            channel_2nd = 16
            channel_3rd = 32
            channel_4th = 64

            # prev_final_channel = 160
            prev_final_channel = 320

            final_channel = 1280
            resnet = Mobilenet.mobilenet_v2(pretrained=True,
                    iw=self.args.wt_layer)
            self.layer0 = nn.Sequential(resnet.features[0],
                                        resnet.features[1])
            self.layer1 = nn.Sequential(resnet.features[2], resnet.features[3],
                                        resnet.features[4], resnet.features[5], resnet.features[6])
            self.layer2 = nn.Sequential(resnet.features[7], resnet.features[8], resnet.features[9], resnet.features[10])

            # self.layer3 = nn.Sequential(resnet.features[11], resnet.features[12], resnet.features[13], resnet.features[14], resnet.features[15], resnet.features[16])
            # self.layer4 = nn.Sequential(resnet.features[17], resnet.features[18])

            self.layer3 = nn.Sequential(resnet.features[11], resnet.features[12], resnet.features[13],
                                        resnet.features[14], resnet.features[15], resnet.features[16],
                                        resnet.features[17])
            self.layer4 = nn.Sequential(resnet.features[18])

            if self.variant == 'D':
                for n, m in self.layer2.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride==(2,2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride==(2,2):
                        m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif self.variant == 'D16':
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride==(2,2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
            else:
                # raise 'unknown deepv3 variant: {}'.format(self.variant)
                print("Not using Dilation ")
        else:
            channel_1st = 3
            channel_2nd = 64
            channel_3rd = 256
            channel_4th = 512
            prev_final_channel = 1024
            final_channel = 2048

            if trunk == 'resnet-18':
                channel_1st = 3
                channel_2nd = 64
                channel_3rd = 64
                channel_4th = 128
                prev_final_channel = 256
                final_channel = 512
                resnet = Resnet.resnet18(wt_layer=self.args.wt_layer)
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
            elif trunk == 'resnet-50':
                resnet = Resnet.resnet50(wt_layer=self.args.wt_layer)
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
            elif trunk == 'resnet-101': # three 3 X 3
                resnet = Resnet.resnet101(pretrained=True, wt_layer=self.args.wt_layer)
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu1,
                                            resnet.conv2, resnet.bn2, resnet.relu2,
                                            resnet.conv3, resnet.bn3, resnet.relu3, resnet.maxpool)
            elif trunk == 'resnet-152':
                resnet = Resnet.resnet152()
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
            elif trunk == 'resnext-50':
                resnet = models.resnext50_32x4d(pretrained=True)
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
            elif trunk == 'resnext-101':
                resnet = models.resnext101_32x8d(pretrained=True)
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
            elif trunk == 'wide_resnet-50':
                resnet = models.wide_resnet50_2(pretrained=True)
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
            elif trunk == 'wide_resnet-101':
                resnet = models.wide_resnet101_2(pretrained=True)
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
            else:
                raise ValueError("Not a valid network arch")

            self.layer0 = resnet.layer0
            self.layer1, self.layer2, self.layer3, self.layer4 = \
                resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

            if self.variant == 'D':
                for n, m in self.layer3.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
                for n, m in self.layer4.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
            elif self.variant == 'D4':
                for n, m in self.layer2.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
                for n, m in self.layer3.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
                for n, m in self.layer4.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (8, 8), (8, 8), (1, 1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
            elif self.variant == 'D16':
                for n, m in self.layer4.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
            else:
                # raise 'unknown deepv3 variant: {}'.format(self.variant)
                print("Not using Dilation ")

        if self.variant == 'D':
            os = 8
        elif self.variant == 'D4':
            os = 4
        elif self.variant == 'D16':
            os = 16
        else:
            os = 32

        self.output_stride = os
        self.aspp = _AtrousSpatialPyramidPoolingModule(final_channel, 256,
                                                    output_stride=os)

        self.bot_fine = nn.Sequential(
            nn.Conv2d(channel_3rd, 48, kernel_size=1, bias=False),
            Norm2d(48),
            nn.ReLU(inplace=True))

        self.bot_aspp = nn.Sequential(
            nn.Conv2d(1280, 256, kernel_size=1, bias=False),
            Norm2d(256),
            nn.ReLU(inplace=True))

        self.final1 = nn.Sequential(
            nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256),
            nn.ReLU(inplace=True))

        self.final2 = nn.Sequential(
            nn.Conv2d(256, num_classes, kernel_size=1, bias=True))

        self.dsn = nn.Sequential(
            nn.Conv2d(prev_final_channel, 512, kernel_size=3, stride=1, padding=1),
            Norm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),
            nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
        )
        initialize_weights(self.dsn)

        initialize_weights(self.aspp)
        initialize_weights(self.bot_aspp)
        initialize_weights(self.bot_fine)
        initialize_weights(self.final1)
        initialize_weights(self.final2)

        # Setting the flags
        self.eps = 1e-5
        self.whitening = False

        if trunk == 'resnet-101':
            self.three_input_layer = True
            in_channel_list = [64, 64, 128, 256, 512, 1024, 2048]   # 8128, 32640, 130816
            out_channel_list = [32, 32, 64, 128, 256,  512, 1024]
        elif trunk == 'resnet-18':
            self.three_input_layer = False
            in_channel_list = [0, 0, 64, 64, 128, 256, 512]   # 8128, 32640, 130816
            out_channel_list = [0, 0, 32, 32, 64,  128, 256]
        elif trunk == 'shufflenetv2':
            self.three_input_layer = False
            in_channel_list = [0, 0, 24, 116, 232, 464, 1024]
        elif trunk == 'mobilenetv2':
            self.three_input_layer = False
            in_channel_list = [0, 0, 16, 32, 64, 320, 1280]
        else: # ResNet-50
            self.three_input_layer = False
            in_channel_list = [0, 0, 64, 256, 512, 1024, 2048]   # 8128, 32640, 130816
            out_channel_list = [0, 0, 32, 128, 256,  512, 1024]

        self.cov_matrix_layer = []
        self.cov_type = []
        for i in range(len(self.args.wt_layer)):
            if self.args.wt_layer[i] > 0:
                self.whitening = True
                if self.args.wt_layer[i] == 1:
                    self.cov_matrix_layer.append(CovMatrix_IRW(dim=in_channel_list[i], relax_denom=self.args.relax_denom))
                    self.cov_type.append(self.args.wt_layer[i])
                elif self.args.wt_layer[i] == 2:
                    self.cov_matrix_layer.append(CovMatrix_ISW(dim=in_channel_list[i], relax_denom=self.args.relax_denom, clusters=self.args.clusters))
                    self.cov_type.append(self.args.wt_layer[i])
Esempio n. 27
0
    def __init__(self,
                 num_classes,
                 trunk='resnet-101',
                 criterion=None,
                 criterion_aux=None,
                 variant='D',
                 skip='m1',
                 skip_num=48,
                 args=None):
        super(DeepV3PlusHANet, self).__init__()
        self.criterion = criterion
        self.criterion_aux = criterion_aux
        self.variant = variant
        self.args = args
        self.num_attention_layer = 0
        self.trunk = trunk

        for i in range(5):
            if args.hanet[i] > 0:
                self.num_attention_layer += 1

        print("#### HANet layers", self.num_attention_layer)

        if trunk == 'shufflenetv2':
            channel_1st = 3
            channel_2nd = 24
            channel_3rd = 116
            channel_4th = 232
            prev_final_channel = 464
            final_channel = 1024
            resnet = models.shufflenet_v2_x1_0(pretrained=True)
            self.layer0 = nn.Sequential(resnet.conv1, resnet.maxpool)
            self.layer1 = resnet.stage2
            self.layer2 = resnet.stage3
            self.layer3 = resnet.stage4
            self.layer4 = resnet.conv5

            if self.variant == 'D':
                for n, m in self.layer2.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1,
                                                                           1)
            elif self.variant == 'D16':
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
            else:
                # raise 'unknown deepv3 variant: {}'.format(self.variant)
                print("Not using Dilation ")
        elif trunk == 'mnasnet_05' or trunk == 'mnasnet_10':

            if trunk == 'mnasnet_05':
                resnet = models.mnasnet0_5(pretrained=True)
                channel_1st = 3
                channel_2nd = 16
                channel_3rd = 24
                channel_4th = 48
                prev_final_channel = 160
                final_channel = 1280

                print("# of layers", len(resnet.layers))
                self.layer0 = nn.Sequential(resnet.layers[0], resnet.layers[1],
                                            resnet.layers[2], resnet.layers[3],
                                            resnet.layers[4], resnet.layers[5],
                                            resnet.layers[6],
                                            resnet.layers[7])  # 16
                self.layer1 = nn.Sequential(resnet.layers[8],
                                            resnet.layers[9])  # 24, 40
                self.layer2 = nn.Sequential(resnet.layers[10],
                                            resnet.layers[11])  # 48, 96
                self.layer3 = nn.Sequential(resnet.layers[12],
                                            resnet.layers[13])  # 160, 320
                self.layer4 = nn.Sequential(resnet.layers[14],
                                            resnet.layers[15],
                                            resnet.layers[16])  # 1280
            else:
                resnet = models.mnasnet1_0(pretrained=True)
                channel_1st = 3
                channel_2nd = 16
                channel_3rd = 40
                channel_4th = 96
                prev_final_channel = 320
                final_channel = 1280

                print("# of layers", len(resnet.layers))
                self.layer0 = nn.Sequential(resnet.layers[0], resnet.layers[1],
                                            resnet.layers[2], resnet.layers[3],
                                            resnet.layers[4], resnet.layers[5],
                                            resnet.layers[6],
                                            resnet.layers[7])  # 16
                self.layer1 = nn.Sequential(resnet.layers[8],
                                            resnet.layers[9])  # 24, 40
                self.layer2 = nn.Sequential(resnet.layers[10],
                                            resnet.layers[11])  # 48, 96
                self.layer3 = nn.Sequential(resnet.layers[12],
                                            resnet.layers[13])  # 160, 320
                self.layer4 = nn.Sequential(resnet.layers[14],
                                            resnet.layers[15],
                                            resnet.layers[16])  # 1280

            if self.variant == 'D':
                for n, m in self.layer2.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1,
                                                                           1)
            elif self.variant == 'D16':
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
            else:
                # raise 'unknown deepv3 variant: {}'.format(self.variant)
                print("Not using Dilation ")
        elif trunk == 'mobilenetv2':
            channel_1st = 3
            channel_2nd = 16
            channel_3rd = 32
            channel_4th = 64

            # prev_final_channel = 160
            prev_final_channel = 320

            final_channel = 1280
            resnet = models.mobilenet_v2(pretrained=True)
            self.layer0 = nn.Sequential(resnet.features[0], resnet.features[1])
            self.layer1 = nn.Sequential(resnet.features[2], resnet.features[3],
                                        resnet.features[4], resnet.features[5],
                                        resnet.features[6])
            self.layer2 = nn.Sequential(resnet.features[7], resnet.features[8],
                                        resnet.features[9],
                                        resnet.features[10])

            # self.layer3 = nn.Sequential(resnet.features[11], resnet.features[12], resnet.features[13], resnet.features[14], resnet.features[15], resnet.features[16])
            # self.layer4 = nn.Sequential(resnet.features[17], resnet.features[18])

            self.layer3 = nn.Sequential(
                resnet.features[11], resnet.features[12], resnet.features[13],
                resnet.features[14], resnet.features[15], resnet.features[16],
                resnet.features[17])
            self.layer4 = nn.Sequential(resnet.features[18])

            if self.variant == 'D':
                for n, m in self.layer2.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1,
                                                                           1)
            elif self.variant == 'D16':
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
            else:
                # raise 'unknown deepv3 variant: {}'.format(self.variant)
                print("Not using Dilation ")
        else:
            channel_1st = 3
            channel_2nd = 64
            channel_3rd = 256
            channel_4th = 512
            prev_final_channel = 1024
            final_channel = 2048

            if trunk == 'resnet-50':
                resnet = Resnet.resnet50()
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                              resnet.relu, resnet.maxpool)
            elif trunk == 'resnet-101':  # three 3 X 3
                resnet = Resnet.resnet101()
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                              resnet.relu1, resnet.conv2,
                                              resnet.bn2, resnet.relu2,
                                              resnet.conv3, resnet.bn3,
                                              resnet.relu3, resnet.maxpool)
            elif trunk == 'resnet-152':
                resnet = Resnet.resnet152()
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                              resnet.relu, resnet.maxpool)
            elif trunk == 'resnext-50':
                resnet = models.resnext50_32x4d(pretrained=True)
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                              resnet.relu, resnet.maxpool)
            elif trunk == 'resnext-101':
                resnet = models.resnext101_32x8d(pretrained=True)
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                              resnet.relu, resnet.maxpool)
            elif trunk == 'wide_resnet-50':
                resnet = models.wide_resnet50_2(pretrained=True)
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                              resnet.relu, resnet.maxpool)
            elif trunk == 'wide_resnet-101':
                resnet = models.wide_resnet101_2(pretrained=True)
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                              resnet.relu, resnet.maxpool)
            else:
                raise ValueError("Not a valid network arch")

            self.layer0 = resnet.layer0
            self.layer1, self.layer2, self.layer3, self.layer4 = \
                resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

            if self.variant == 'D':
                for n, m in self.layer3.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
                for n, m in self.layer4.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1,
                                                                           1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
            elif self.variant == 'D4':
                for n, m in self.layer2.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
                for n, m in self.layer3.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1,
                                                                           1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
                for n, m in self.layer4.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (8, 8), (8, 8), (1,
                                                                           1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
            elif self.variant == 'D16':
                for n, m in self.layer4.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
            else:
                # raise 'unknown deepv3 variant: {}'.format(self.variant)
                print("Not using Dilation ")

        if self.variant == 'D':
            os = 8
        elif self.variant == 'D4':
            os = 4
        elif self.variant == 'D16':
            os = 16
        else:
            os = 32

        self.aspp = _AtrousSpatialPyramidPoolingModule(final_channel,
                                                       256,
                                                       output_stride=os)

        self.bot_fine = nn.Sequential(
            nn.Conv2d(channel_3rd, 48, kernel_size=1, bias=False), Norm2d(48),
            nn.ReLU(inplace=True))

        self.bot_aspp = nn.Sequential(
            nn.Conv2d(1280, 256, kernel_size=1, bias=False), Norm2d(256),
            nn.ReLU(inplace=True))

        self.final1 = nn.Sequential(
            nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True))

        self.final2 = nn.Sequential(
            nn.Conv2d(256, num_classes, kernel_size=1, bias=True))

        if self.args.aux_loss is True:
            self.dsn = nn.Sequential(
                nn.Conv2d(prev_final_channel,
                          512,
                          kernel_size=3,
                          stride=1,
                          padding=1), Norm2d(512), nn.ReLU(inplace=True),
                nn.Dropout2d(0.1),
                nn.Conv2d(512,
                          num_classes,
                          kernel_size=1,
                          stride=1,
                          padding=0,
                          bias=True))
            initialize_weights(self.dsn)

        if self.args.hanet[0] == 1:
            self.hanet0 = HANet_Conv(prev_final_channel,
                                     final_channel,
                                     self.args.hanet_set[0],
                                     self.args.hanet_set[1],
                                     self.args.hanet_set[2],
                                     self.args.hanet_pos[0],
                                     self.args.hanet_pos[1],
                                     pos_rfactor=self.args.pos_rfactor,
                                     pooling=self.args.pooling,
                                     dropout_prob=self.args.dropout,
                                     pos_noise=self.args.pos_noise)
            initialize_weights(self.hanet0)

        if self.args.hanet[1] == 1:
            self.hanet1 = HANet_Conv(final_channel,
                                     1280,
                                     self.args.hanet_set[0],
                                     self.args.hanet_set[1],
                                     self.args.hanet_set[2],
                                     self.args.hanet_pos[0],
                                     self.args.hanet_pos[1],
                                     pos_rfactor=self.args.pos_rfactor,
                                     pooling=self.args.pooling,
                                     dropout_prob=self.args.dropout,
                                     pos_noise=self.args.pos_noise)
            initialize_weights(self.hanet1)

        if self.args.hanet[2] == 1:
            self.hanet2 = HANet_Conv(1280,
                                     256,
                                     self.args.hanet_set[0],
                                     self.args.hanet_set[1],
                                     self.args.hanet_set[2],
                                     self.args.hanet_pos[0],
                                     self.args.hanet_pos[1],
                                     pos_rfactor=self.args.pos_rfactor,
                                     pooling=self.args.pooling,
                                     dropout_prob=self.args.dropout,
                                     pos_noise=self.args.pos_noise)
            initialize_weights(self.hanet2)

        if self.args.hanet[3] == 1:
            self.hanet3 = HANet_Conv(304,
                                     256,
                                     self.args.hanet_set[0],
                                     self.args.hanet_set[1],
                                     self.args.hanet_set[2],
                                     self.args.hanet_pos[0],
                                     self.args.hanet_pos[1],
                                     pos_rfactor=self.args.pos_rfactor,
                                     pooling=self.args.pooling,
                                     dropout_prob=self.args.dropout,
                                     pos_noise=self.args.pos_noise)
            initialize_weights(self.hanet3)

        if self.args.hanet[4] == 1:
            self.hanet4 = HANet_Conv(256,
                                     num_classes,
                                     self.args.hanet_set[0],
                                     self.args.hanet_set[1],
                                     self.args.hanet_set[2],
                                     self.args.hanet_pos[0],
                                     self.args.hanet_pos[1],
                                     pos_rfactor=self.args.pos_rfactor,
                                     pooling='max',
                                     dropout_prob=self.args.dropout,
                                     pos_noise=self.args.pos_noise)
            initialize_weights(self.hanet4)

        initialize_weights(self.aspp)
        initialize_weights(self.bot_aspp)
        initialize_weights(self.bot_fine)
        initialize_weights(self.final1)
        initialize_weights(self.final2)
Esempio n. 28
0
    def __init__(self,
                 in_channel,
                 out_channel,
                 kernel_size=3,
                 r_factor=64,
                 layer=3,
                 pos_injection=2,
                 is_encoding=1,
                 pos_rfactor=8,
                 pooling='mean',
                 dropout_prob=0.0,
                 pos_noise=0.0):
        super(HANet_Conv, self).__init__()

        self.pooling = pooling
        self.pos_injection = pos_injection
        self.layer = layer
        self.dropout_prob = dropout_prob
        self.sigmoid = nn.Sigmoid()

        if r_factor > 0:
            mid_1_channel = math.ceil(in_channel / r_factor)
        elif r_factor < 0:
            r_factor = r_factor * -1
            mid_1_channel = in_channel * r_factor

        if self.dropout_prob > 0:
            self.dropout = nn.Dropout2d(self.dropout_prob)

        self.attention_first = nn.Sequential(
            nn.Conv1d(in_channels=in_channel,
                      out_channels=mid_1_channel,
                      kernel_size=1,
                      stride=1,
                      padding=0,
                      bias=False), Norm2d(mid_1_channel),
            nn.ReLU(inplace=True))

        if layer == 2:
            self.attention_second = nn.Sequential(
                nn.Conv1d(in_channels=mid_1_channel,
                          out_channels=out_channel,
                          kernel_size=kernel_size,
                          stride=1,
                          padding=kernel_size // 2,
                          bias=True))
        elif layer == 3:
            mid_2_channel = (mid_1_channel * 2)
            self.attention_second = nn.Sequential(
                nn.Conv1d(in_channels=mid_1_channel,
                          out_channels=mid_2_channel,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          bias=True), Norm2d(mid_2_channel),
                nn.ReLU(inplace=True))
            self.attention_third = nn.Sequential(
                nn.Conv1d(in_channels=mid_2_channel,
                          out_channels=out_channel,
                          kernel_size=kernel_size,
                          stride=1,
                          padding=kernel_size // 2,
                          bias=True))

        if self.pooling == 'mean':
            #print("##### average pooling")
            self.rowpool = nn.AdaptiveAvgPool2d((128 // pos_rfactor, 1))
        else:
            #print("##### max pooling")
            self.rowpool = nn.AdaptiveMaxPool2d((128 // pos_rfactor, 1))

        if pos_rfactor > 0:
            if is_encoding == 0:
                if self.pos_injection == 1:
                    self.pos_emb1d_1st = PosEmbedding1D(pos_rfactor,
                                                        dim=in_channel,
                                                        pos_noise=pos_noise)
                elif self.pos_injection == 2:
                    self.pos_emb1d_2nd = PosEmbedding1D(pos_rfactor,
                                                        dim=mid_1_channel,
                                                        pos_noise=pos_noise)
            elif is_encoding == 1:
                if self.pos_injection == 1:
                    self.pos_emb1d_1st = PosEncoding1D(pos_rfactor,
                                                       dim=in_channel,
                                                       pos_noise=pos_noise)
                elif self.pos_injection == 2:
                    self.pos_emb1d_2nd = PosEncoding1D(pos_rfactor,
                                                       dim=mid_1_channel,
                                                       pos_noise=pos_noise)
            else:
                print("Not supported position encoding")
                exit()
Esempio n. 29
0
    def __init__(self, trunk='WideResnet38', criterion=None, criterion2=None):

        super(DeepWV3Plus, self).__init__()
        self.criterion = criterion
        self.criterion2 = criterion2
        logging.info("Trunk: %s", trunk)

        tasks = ['semantic', 'traversability']
        wide_resnet = wider_resnet38_a2(classes=1000,
                                        dilation=True,
                                        tasks=tasks)
        wide_resnet = torch.nn.DataParallel(wide_resnet)
        if criterion is not None:
            try:
                checkpoint = torch.load(
                    './pretrained_models/wider_resnet38.pth.tar',
                    map_location='cpu')
                #                wide_resnet.load_state_dict(checkpoint['state_dict'])

                net_state_dict = wide_resnet.state_dict()
                loaded_dict = checkpoint['state_dict']
                new_loaded_dict = {}
                for k in net_state_dict:
                    if k in loaded_dict and net_state_dict[k].size(
                    ) == loaded_dict[k].size():
                        new_loaded_dict[k] = loaded_dict[k]
                    else:
                        logging.info("Skipped loading parameter %s", k)
                net_state_dict.update(new_loaded_dict)
                wide_resnet.load_state_dict(net_state_dict)

                del checkpoint
            except:
                print(
                    "Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models/wider_resnet38.pth.tar."
                )
                raise RuntimeError(
                    "=====================Could not load ImageNet weights of WideResNet38 network.======================="
                )
        wide_resnet = wide_resnet.module

        self.task_weights = torch.nn.Parameter(
            torch.zeros(2, requires_grad=True))

        self.mod1 = wide_resnet.mod1
        self.mod2 = wide_resnet.mod2
        self.mod3 = wide_resnet.mod3
        self.mod4 = wide_resnet.mod4
        self.mod5 = wide_resnet.mod5
        self.mod6 = wide_resnet.mod6
        self.mod7 = wide_resnet.mod7
        self.pool2 = wide_resnet.pool2
        self.pool3 = wide_resnet.pool3
        del wide_resnet

        self.aspp = _AtrousSpatialPyramidPoolingModule(4096,
                                                       256,
                                                       output_stride=8)

        self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False)

        self.final = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 19, kernel_size=1, bias=False))

        self.aspp2 = _AtrousSpatialPyramidPoolingModule(4096,
                                                        256,
                                                        output_stride=8)

        self.bot_fine2 = nn.Conv2d(128, 48, kernel_size=1, bias=False)
        self.bot_aspp2 = nn.Conv2d(1280, 256, kernel_size=1, bias=False)

        self.final2 = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 2, kernel_size=1, bias=False))
    def __init__(self, num_classes, trunk='seresnext-50', criterion=None, variant='D',
                 skip='m1', skip_num=48):
        super(DeepV3Plus, self).__init__()
        self.criterion = criterion
        self.variant = variant
        self.skip = skip
        self.skip_num = skip_num

        if trunk == 'seresnext-50':
            resnet = SEresnext.se_resnext50_32x4d()
        elif trunk == 'seresnext-101':
            resnet = SEresnext.se_resnext101_32x4d()
        elif trunk == 'resnet-50':
            resnet = Resnet.resnet50()
            resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
        elif trunk == 'resnet-101':
            resnet = Resnet.resnet101()
            resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
        else:
            raise ValueError("Not a valid network arch")

        self.layer0 = resnet.layer0
        self.layer1, self.layer2, self.layer3, self.layer4 = \
            resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        if self.variant == 'D':
            for n, m in self.layer3.named_modules():
                if 'conv2' in n:
                    m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
                elif 'downsample.0' in n:
                    m.stride = (1, 1)
            for n, m in self.layer4.named_modules():
                if 'conv2' in n:
                    m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
                elif 'downsample.0' in n:
                    m.stride = (1, 1)
        elif self.variant == 'D16':
            for n, m in self.layer4.named_modules():
                if 'conv2' in n:
                    m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
                elif 'downsample.0' in n:
                    m.stride = (1, 1)
        else:
            # raise 'unknown deepv3 variant: {}'.format(self.variant)
            print("Not using Dilation ")

        self.aspp = _AtrousSpatialPyramidPoolingModule(2048, 256,
                                                       output_stride=8)

        if self.skip == 'm1':
            self.bot_fine = nn.Conv2d(256, self.skip_num, kernel_size=1, bias=False)
        elif self.skip == 'm2':
            self.bot_fine = nn.Conv2d(512, self.skip_num, kernel_size=1, bias=False)
        else:
            raise Exception('Not a valid skip')

        self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False)

        self.final = nn.Sequential(
            nn.Conv2d(256 + self.skip_num, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False))

        initialize_weights(self.aspp)
        initialize_weights(self.bot_aspp)
        initialize_weights(self.bot_fine)
        initialize_weights(self.final)