def __init__(self, num_classes, trunk='wrn38', criterion=None): super(MscaleV3Plus, self).__init__() self.criterion = criterion self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) self.aspp, aspp_out_ch = get_aspp(high_level_ch, bottleneck_ch=256, output_stride=8) self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) # Semantic segmentation prediction head self.final = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) # Scale-attention prediction head scale_in_ch = 2 * (256 + 48) self.scale_attn = nn.Sequential( nn.Conv2d(scale_in_ch, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 1, kernel_size=1, bias=False), nn.Sigmoid()) if cfg.OPTIONS.INIT_DECODER: initialize_weights(self.bot_fine) initialize_weights(self.bot_aspp) initialize_weights(self.scale_attn) initialize_weights(self.final) else: initialize_weights(self.final)
def __init__(self, num_classes, trunk='resnet-50', criterion=None, use_dpc=False, init_all=False, output_stride=8): super(DeepV3ATTN, self).__init__() self.criterion = criterion self.backbone, _s2_ch, _s4_ch, high_level_ch = \ get_trunk(trunk, output_stride=output_stride) #self.aspp, aspp_out_ch = get_aspp(high_level_ch, # bottleneck_ch=256, # output_stride=output_stride, # dpc=use_dpc) #self.attn = APNB(in_channels=high_level_ch, out_channels=high_level_ch, key_channels=256, value_channels=256, dropout=0.5, sizes=([1]), norm_type='batchnorm', psp_size=(1,3,6,8)) self.attn = AFNB(low_in_channels=2048, high_in_channels=4096, out_channels=2048, key_channels=1024, value_channels=2048, dropout=0.5, sizes=([1]), norm_type='batchnorm', psp_size=(1, 3, 6, 8)) self.final = make_seg_head(in_ch=high_level_ch, out_ch=num_classes) initialize_weights(self.attn) initialize_weights(self.final)
def __init__(self, num_classes, trunk='hrnetv2', criterion=None): super(Basic, self).__init__() self.criterion = criterion self.backbone, _, _, high_level_ch = get_trunk(trunk_name=trunk, output_stride=8) self.seg_head = make_seg_head(in_ch=high_level_ch, out_ch=num_classes) initialize_weights(self.seg_head)
def __init__(self, num_classes, trunk='wrn38', criterion=None, use_dpc=False, init_all=False): super(DeepV3Plus, self).__init__() self.criterion = criterion self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) self.aspp, aspp_out_ch = get_aspp(high_level_ch, bottleneck_ch=256, output_stride=8, dpc=use_dpc) self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) self.final = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) if init_all: initialize_weights(self.aspp) initialize_weights(self.bot_aspp) initialize_weights(self.bot_fine) initialize_weights(self.final) else: initialize_weights(self.final)
def __init__(self, num_classes, trunk='wrn38', criterion=None, fuse_aspp=False, attn_2b=False): super(MscaleDeeper, self).__init__() self.criterion = criterion self.fuse_aspp = fuse_aspp self.attn_2b = attn_2b self.backbone, s2_ch, s4_ch, high_level_ch = get_trunk( trunk_name=trunk, output_stride=8) self.aspp, aspp_out_ch = get_aspp(high_level_ch, bottleneck_ch=256, output_stride=8) self.convs2 = nn.Conv2d(s2_ch, 32, kernel_size=1, bias=False) self.convs4 = nn.Conv2d(s4_ch, 64, kernel_size=1, bias=False) self.conv_up1 = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) self.conv_up2 = ConvBnRelu(256 + 64, 256, kernel_size=5, padding=2) self.conv_up3 = ConvBnRelu(256 + 32, 256, kernel_size=5, padding=2) self.conv_up5 = nn.Conv2d(256, num_classes, kernel_size=1, bias=False) # Scale-attention prediction head if self.attn_2b: attn_ch = 2 else: attn_ch = 1 self.scale_attn = make_attn_head(in_ch=256, out_ch=attn_ch) if cfg.OPTIONS.INIT_DECODER: initialize_weights(self.convs2, self.convs4, self.conv_up1, self.conv_up2, self.conv_up3, self.conv_up5, self.scale_attn)
def __init__(self, num_classes, trunk=None, criterion=None): super(GSCNN, self).__init__() self.criterion = criterion self.num_classes = num_classes wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) wide_resnet = torch.nn.DataParallel(wide_resnet) wide_resnet = wide_resnet.module self.mod1 = wide_resnet.mod1 self.mod2 = wide_resnet.mod2 self.mod3 = wide_resnet.mod3 self.mod4 = wide_resnet.mod4 self.mod5 = wide_resnet.mod5 self.mod6 = wide_resnet.mod6 self.mod7 = wide_resnet.mod7 self.pool2 = wide_resnet.pool2 self.pool3 = wide_resnet.pool3 self.interpolate = F.interpolate del wide_resnet self.dsn1 = nn.Conv2d(64, 1, 1) self.dsn3 = nn.Conv2d(256, 1, 1) self.dsn4 = nn.Conv2d(512, 1, 1) self.dsn7 = nn.Conv2d(4096, 1, 1) self.res1 = Resnet.BasicBlock(64, 64, stride=1, downsample=None) self.d1 = nn.Conv2d(64, 32, 1) self.res2 = Resnet.BasicBlock(32, 32, stride=1, downsample=None) self.d2 = nn.Conv2d(32, 16, 1) self.res3 = Resnet.BasicBlock(16, 16, stride=1, downsample=None) self.d3 = nn.Conv2d(16, 8, 1) self.fuse = nn.Conv2d(8, 1, kernel_size=1, padding=0, bias=False) self.cw = nn.Conv2d(2, 1, kernel_size=1, padding=0, bias=False) self.gate1 = gsc.GatedSpatialConv2d(32, 32) self.gate2 = gsc.GatedSpatialConv2d(16, 16) self.gate3 = gsc.GatedSpatialConv2d(8, 8) self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, output_stride=8) self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(1280 + 256, 256, kernel_size=1, bias=False) self.final_seg = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) self.sigmoid = nn.Sigmoid() initialize_weights(self.final_seg)
def __init__(self, num_classes, trunk='hrnetv2', criterion=None): super(ASPP, self).__init__() self.criterion = criterion self.backbone, _, _, high_level_ch = get_trunk(trunk) self.aspp, aspp_out_ch = get_aspp(high_level_ch, bottleneck_ch=cfg.MODEL.ASPP_BOT_CH, output_stride=8) self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) self.final = make_seg_head(in_ch=256, out_ch=num_classes) initialize_weights(self.final, self.bot_aspp, self.aspp)
def __init__(self, num_classes, trunk='WideResnet38', criterion=None): super(DeepWV3Plus, self).__init__() self.criterion = criterion logging.info("Trunk: %s", trunk) wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) wide_resnet = torch.nn.DataParallel(wide_resnet) if criterion is not None: try: checkpoint = torch.load( './pretrained_models/wider_resnet38.pth.tar', map_location='cpu') wide_resnet.load_state_dict(checkpoint['state_dict']) del checkpoint except: print( "Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models/wider_resnet38.pth.tar." ) raise RuntimeError( "=====================Could not load ImageNet weights of WideResNet38 network.=======================" ) wide_resnet = wide_resnet.module self.mod1 = wide_resnet.mod1 self.mod2 = wide_resnet.mod2 self.mod3 = wide_resnet.mod3 self.mod4 = wide_resnet.mod4 self.mod5 = wide_resnet.mod5 self.mod6 = wide_resnet.mod6 self.mod7 = wide_resnet.mod7 self.pool2 = wide_resnet.pool2 self.pool3 = wide_resnet.pool3 del wide_resnet self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, output_stride=8) self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False) self.final = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) initialize_weights(self.final)
def __init__(self, num_classes, input_size, trunk='WideResnet38', criterion=None): super(DeepWV3Plus, self).__init__() self.criterion = criterion logging.info("Trunk: %s", trunk) wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) # TODO: Should this be even here ? wide_resnet = torch.nn.DataParallel(wide_resnet) try: checkpoint = torch.load('weights/ResNet/wider_resnet38.pth.tar', map_location='cpu') wide_resnet.load_state_dict(checkpoint['state_dict']) del checkpoint except: print( "=====================Could not load imagenet weights=======================" ) wide_resnet = wide_resnet.module self.mod1 = wide_resnet.mod1 self.mod2 = wide_resnet.mod2 self.mod3 = wide_resnet.mod3 self.mod4 = wide_resnet.mod4 self.mod5 = wide_resnet.mod5 self.mod6 = wide_resnet.mod6 self.mod7 = wide_resnet.mod7 self.pool2 = wide_resnet.pool2 self.pool3 = wide_resnet.pool3 del wide_resnet self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, output_stride=8) self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False) self.final = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) initialize_weights(self.final)
def __init__(self, num_classes, trunk='resnet-50', criterion=None, use_dpc=False, init_all=False, output_stride=8): super(DeepV3, self).__init__() self.criterion = criterion self.backbone, _s2_ch, _s4_ch, high_level_ch = \ get_trunk(trunk, output_stride=output_stride) self.aspp, aspp_out_ch = get_aspp(high_level_ch, bottleneck_ch=256, output_stride=output_stride, dpc=use_dpc) self.final = make_seg_head(in_ch=aspp_out_ch, out_ch=num_classes) initialize_weights(self.aspp) initialize_weights(self.final)
def __init__(self, num_classes, trunk='wrn38', criterion=None, use_dpc=False, fuse_aspp=False, attn_2b=False, attn_cls=False): super(MscaleV3Plus, self).__init__() self.criterion = criterion self.fuse_aspp = fuse_aspp self.attn_2b = attn_2b self.attn_cls = attn_cls self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) self.aspp, aspp_out_ch = get_aspp(high_level_ch, bottleneck_ch=256, output_stride=8, dpc=use_dpc) self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) # Semantic segmentation prediction head bot_ch = cfg.MODEL.SEGATTN_BOT_CH self.final = nn.Sequential( nn.Conv2d(256 + 48, bot_ch, kernel_size=3, padding=1, bias=False), Norm2d(bot_ch), nn.ReLU(inplace=True), nn.Conv2d(bot_ch, bot_ch, kernel_size=3, padding=1, bias=False), Norm2d(bot_ch), nn.ReLU(inplace=True), nn.Conv2d(bot_ch, num_classes, kernel_size=1, bias=False)) # Scale-attention prediction head if self.attn_2b: attn_ch = 2 else: if self.attn_cls: attn_ch = num_classes else: attn_ch = 1 scale_in_ch = 256 + 48 self.scale_attn = make_attn_head(in_ch=scale_in_ch, out_ch=attn_ch) if cfg.OPTIONS.INIT_DECODER: initialize_weights(self.bot_fine) initialize_weights(self.bot_aspp) initialize_weights(self.scale_attn) initialize_weights(self.final) else: initialize_weights(self.final)
def __init__(self, high_level_ch): super(OCR_block, self).__init__() ocr_mid_channels = cfg.MODEL.OCR.MID_CHANNELS ocr_key_channels = cfg.MODEL.OCR.KEY_CHANNELS num_classes = cfg.DATASET.NUM_CLASSES self.conv3x3_ocr = nn.Sequential( nn.Conv2d(high_level_ch, ocr_mid_channels, kernel_size=3, stride=1, padding=1), BNReLU(ocr_mid_channels), ) self.ocr_gather_head = SpatialGather_Module(num_classes) self.ocr_distri_head = SpatialOCR_Module( in_channels=ocr_mid_channels, key_channels=ocr_key_channels, out_channels=ocr_mid_channels, scale=1, dropout=0.05, ) self.cls_head = nn.Conv2d(ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True) self.aux_head = nn.Sequential( nn.Conv2d(high_level_ch, high_level_ch, kernel_size=1, stride=1, padding=0), BNReLU(high_level_ch), nn.Conv2d(high_level_ch, num_classes, kernel_size=1, stride=1, padding=0, bias=True)) if cfg.OPTIONS.INIT_DECODER: initialize_weights(self.conv3x3_ocr, self.ocr_gather_head, self.ocr_distri_head, self.cls_head, self.aux_head)
def __init__(self, num_classes, trunk='wrn38', criterion=None, use_dpc=False, fuse_aspp=False, attn_2b=False): super(MscaleV3Plus, self).__init__() self.criterion = criterion self.fuse_aspp = fuse_aspp self.attn_2b = attn_2b self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) self.aspp, aspp_out_ch = get_aspp(high_level_ch, bottleneck_ch=256, output_stride=8, dpc=use_dpc, img_norm = False) self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) #self.asnb = ASNB(low_in_channels = 48, high_in_channels=256, out_channels=256, key_channels=64, value_channels=256, dropout=0., sizes=([1]), norm_type='batchnorm',attn_scale=0.25) self.adnb = ADNB(d_model=256, nhead=8, num_encoder_layers=2, dim_feedforward=256, dropout=0.5, activation="relu", num_feature_levels=1, enc_n_points=4) # Semantic segmentation prediction head bot_ch = cfg.MODEL.SEGATTN_BOT_CH self.final = nn.Sequential( nn.Conv2d(256 + 48, bot_ch, kernel_size=3, padding=1, bias=False), Norm2d(bot_ch), nn.ReLU(inplace=True), nn.Conv2d(bot_ch, bot_ch, kernel_size=3, padding=1, bias=False), Norm2d(bot_ch), nn.ReLU(inplace=True), nn.Conv2d(bot_ch, num_classes, kernel_size=1, bias=False)) # Scale-attention prediction head if self.attn_2b: attn_ch = 2 else: attn_ch = 1 scale_in_ch = 256 + 48 self.scale_attn = make_attn_head(in_ch=scale_in_ch, out_ch=attn_ch) if cfg.OPTIONS.INIT_DECODER: initialize_weights(self.bot_fine) initialize_weights(self.bot_aspp) initialize_weights(self.scale_attn) initialize_weights(self.final) else: initialize_weights(self.final)
def __init__(self, num_classes, trunk='wrn38', criterion=None, use_dpc=False, init_all=False): super(DeepV3PlusATTN, self).__init__() self.criterion = criterion self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) #self.aspp, aspp_out_ch = get_aspp(high_level_ch, # bottleneck_ch=256, # output_stride=8, # dpc=use_dpc) #self.attn = APNB(in_channels=high_level_ch, out_channels=high_level_ch, key_channels=256, value_channels=256, dropout=0.5, sizes=([1]), norm_type='batchnorm', psp_size=(1, 3, 6, 8)) self.attn = AFNB(low_in_channels=2048, high_in_channels=4096, out_channels=256, key_channels=64, value_channels=256, dropout=0.8, sizes=([1]), norm_type='batchnorm', psp_size=(1, 3, 6, 8)) self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(256, 256, kernel_size=1, bias=False) self.final = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) if init_all: initialize_weights(self.attn) initialize_weights(self.bot_aspp) initialize_weights(self.bot_fine) initialize_weights(self.final) else: initialize_weights(self.final)
def __init__(self, num_classes, trunk='resnet-101', criterion=None, criterion_aux=None, variant='D', skip='m1', skip_num=48, args=None): super(DeepV3Plus, self).__init__() self.criterion = criterion self.criterion_aux = criterion_aux self.variant = variant self.args = args self.trunk = trunk if trunk == 'shufflenetv2': channel_1st = 3 channel_2nd = 24 channel_3rd = 116 channel_4th = 232 prev_final_channel = 464 final_channel = 1024 resnet = Shufflenet.shufflenet_v2_x1_0(pretrained=True, iw=self.args.wt_layer) class Layer0(nn.Module): def __init__(self, iw): super(Layer0, self).__init__() self.layer = nn.Sequential(resnet.conv1, resnet.maxpool) self.instance_norm_layer = resnet.instance_norm_layer1 self.iw = iw def forward(self, x_tuple): if len(x_tuple) == 2: w_arr = x_tuple[1] x = x_tuple[0] else: print("error in shufflnet layer 0 forward path") return x = self.layer[0][0](x) if self.iw >= 1: if self.iw == 1 or self.iw == 2: x, w = self.instance_norm_layer(x) w_arr.append(w) else: x = self.instance_norm_layer(x) else: x = self.layer[0][1](x) x = self.layer[0][2](x) x = self.layer[1](x) return [x, w_arr] class Layer4(nn.Module): def __init__(self, iw): super(Layer4, self).__init__() self.layer = resnet.conv5 self.instance_norm_layer = resnet.instance_norm_layer2 self.iw = iw def forward(self, x_tuple): if len(x_tuple) == 2: w_arr = x_tuple[1] x = x_tuple[0] else: print("error in shufflnet layer 4 forward path") return x = self.layer[0](x) if self.iw >= 1: if self.iw == 1 or self.iw == 2: x, w = self.instance_norm_layer(x) w_arr.append(w) else: x = self.instance_norm_layer(x) else: x = self.layer[1](x) x = self.layer[2](x) return [x, w_arr] self.layer0 = Layer0(iw=self.args.wt_layer[2]) self.layer1 = resnet.stage2 self.layer2 = resnet.stage3 self.layer3 = resnet.stage4 self.layer4 = Layer4(iw=self.args.wt_layer[6]) if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") elif trunk == 'mnasnet_05' or trunk == 'mnasnet_10': if trunk == 'mnasnet_05': resnet = models.mnasnet0_5(pretrained=True) channel_1st = 3 channel_2nd = 16 channel_3rd = 24 channel_4th = 48 prev_final_channel = 160 final_channel = 1280 print("# of layers", len(resnet.layers)) self.layer0 = nn.Sequential(resnet.layers[0],resnet.layers[1],resnet.layers[2], resnet.layers[3],resnet.layers[4],resnet.layers[5],resnet.layers[6],resnet.layers[7]) # 16 self.layer1 = nn.Sequential(resnet.layers[8], resnet.layers[9]) # 24, 40 self.layer2 = nn.Sequential(resnet.layers[10], resnet.layers[11]) # 48, 96 self.layer3 = nn.Sequential(resnet.layers[12], resnet.layers[13]) # 160, 320 self.layer4 = nn.Sequential(resnet.layers[14], resnet.layers[15], resnet.layers[16]) # 1280 else: resnet = models.mnasnet1_0(pretrained=True) channel_1st = 3 channel_2nd = 16 channel_3rd = 40 channel_4th = 96 prev_final_channel = 320 final_channel = 1280 print("# of layers", len(resnet.layers)) self.layer0 = nn.Sequential(resnet.layers[0],resnet.layers[1],resnet.layers[2], resnet.layers[3],resnet.layers[4],resnet.layers[5],resnet.layers[6],resnet.layers[7]) # 16 self.layer1 = nn.Sequential(resnet.layers[8], resnet.layers[9]) # 24, 40 self.layer2 = nn.Sequential(resnet.layers[10], resnet.layers[11]) # 48, 96 self.layer3 = nn.Sequential(resnet.layers[12], resnet.layers[13]) # 160, 320 self.layer4 = nn.Sequential(resnet.layers[14], resnet.layers[15], resnet.layers[16]) # 1280 if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") elif trunk == 'mobilenetv2': channel_1st = 3 channel_2nd = 16 channel_3rd = 32 channel_4th = 64 # prev_final_channel = 160 prev_final_channel = 320 final_channel = 1280 resnet = Mobilenet.mobilenet_v2(pretrained=True, iw=self.args.wt_layer) self.layer0 = nn.Sequential(resnet.features[0], resnet.features[1]) self.layer1 = nn.Sequential(resnet.features[2], resnet.features[3], resnet.features[4], resnet.features[5], resnet.features[6]) self.layer2 = nn.Sequential(resnet.features[7], resnet.features[8], resnet.features[9], resnet.features[10]) # self.layer3 = nn.Sequential(resnet.features[11], resnet.features[12], resnet.features[13], resnet.features[14], resnet.features[15], resnet.features[16]) # self.layer4 = nn.Sequential(resnet.features[17], resnet.features[18]) self.layer3 = nn.Sequential(resnet.features[11], resnet.features[12], resnet.features[13], resnet.features[14], resnet.features[15], resnet.features[16], resnet.features[17]) self.layer4 = nn.Sequential(resnet.features[18]) if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") else: channel_1st = 3 channel_2nd = 64 channel_3rd = 256 channel_4th = 512 prev_final_channel = 1024 final_channel = 2048 if trunk == 'resnet-18': channel_1st = 3 channel_2nd = 64 channel_3rd = 64 channel_4th = 128 prev_final_channel = 256 final_channel = 512 resnet = Resnet.resnet18(wt_layer=self.args.wt_layer) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnet-50': resnet = Resnet.resnet50(wt_layer=self.args.wt_layer) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnet-101': # three 3 X 3 resnet = Resnet.resnet101(pretrained=True, wt_layer=self.args.wt_layer) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu1, resnet.conv2, resnet.bn2, resnet.relu2, resnet.conv3, resnet.bn3, resnet.relu3, resnet.maxpool) elif trunk == 'resnet-152': resnet = Resnet.resnet152() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnext-50': resnet = models.resnext50_32x4d(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnext-101': resnet = models.resnext101_32x8d(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'wide_resnet-50': resnet = models.wide_resnet50_2(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'wide_resnet-101': resnet = models.wide_resnet101_2(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) else: raise ValueError("Not a valid network arch") self.layer0 = resnet.layer0 self.layer1, self.layer2, self.layer3, self.layer4 = \ resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 if self.variant == 'D': for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D4': for n, m in self.layer2.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (8, 8), (8, 8), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D16': for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") if self.variant == 'D': os = 8 elif self.variant == 'D4': os = 4 elif self.variant == 'D16': os = 16 else: os = 32 self.output_stride = os self.aspp = _AtrousSpatialPyramidPoolingModule(final_channel, 256, output_stride=os) self.bot_fine = nn.Sequential( nn.Conv2d(channel_3rd, 48, kernel_size=1, bias=False), Norm2d(48), nn.ReLU(inplace=True)) self.bot_aspp = nn.Sequential( nn.Conv2d(1280, 256, kernel_size=1, bias=False), Norm2d(256), nn.ReLU(inplace=True)) self.final1 = nn.Sequential( nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True)) self.final2 = nn.Sequential( nn.Conv2d(256, num_classes, kernel_size=1, bias=True)) self.dsn = nn.Sequential( nn.Conv2d(prev_final_channel, 512, kernel_size=3, stride=1, padding=1), Norm2d(512), nn.ReLU(inplace=True), nn.Dropout2d(0.1), nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True) ) initialize_weights(self.dsn) initialize_weights(self.aspp) initialize_weights(self.bot_aspp) initialize_weights(self.bot_fine) initialize_weights(self.final1) initialize_weights(self.final2) # Setting the flags self.eps = 1e-5 self.whitening = False if trunk == 'resnet-101': self.three_input_layer = True in_channel_list = [64, 64, 128, 256, 512, 1024, 2048] # 8128, 32640, 130816 out_channel_list = [32, 32, 64, 128, 256, 512, 1024] elif trunk == 'resnet-18': self.three_input_layer = False in_channel_list = [0, 0, 64, 64, 128, 256, 512] # 8128, 32640, 130816 out_channel_list = [0, 0, 32, 32, 64, 128, 256] elif trunk == 'shufflenetv2': self.three_input_layer = False in_channel_list = [0, 0, 24, 116, 232, 464, 1024] elif trunk == 'mobilenetv2': self.three_input_layer = False in_channel_list = [0, 0, 16, 32, 64, 320, 1280] else: # ResNet-50 self.three_input_layer = False in_channel_list = [0, 0, 64, 256, 512, 1024, 2048] # 8128, 32640, 130816 out_channel_list = [0, 0, 32, 128, 256, 512, 1024] self.cov_matrix_layer = [] self.cov_type = [] for i in range(len(self.args.wt_layer)): if self.args.wt_layer[i] > 0: self.whitening = True if self.args.wt_layer[i] == 1: self.cov_matrix_layer.append(CovMatrix_IRW(dim=in_channel_list[i], relax_denom=self.args.relax_denom)) self.cov_type.append(self.args.wt_layer[i]) elif self.args.wt_layer[i] == 2: self.cov_matrix_layer.append(CovMatrix_ISW(dim=in_channel_list[i], relax_denom=self.args.relax_denom, clusters=self.args.clusters)) self.cov_type.append(self.args.wt_layer[i])
def __init__(self, num_classes, trunk='WideResnet38', criterion=None,ngf=64, input_nc=1, output_nc=2): super(DeepWV3Plus, self).__init__() self.criterion = criterion self.MSEcriterion= torch.nn.MSELoss() logging.info("Trunk: %s", trunk) wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) wide_resnet = torch.nn.DataParallel(wide_resnet) try: checkpoint = torch.load('/srv/beegfs02/scratch/language_vision/data/Sound_Event_Prediction/semantic-segmentation-master/pretrained_models/wider_resnet38.pth.tar', map_location='cpu') wide_resnet.load_state_dict(checkpoint['state_dict']) del checkpoint except: print("=====================Could not load ImageNet weights=======================") print("Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models.") #audio_unet = AudioNet_multitask(ngf=64,input_nc=2) #Acheckpoint = torch.load('/srv/beegfs02/scratch/language_vision/data/Sound_Event_Prediction/audio/audioSynthesis/checkpoints/synBi2Bi_16_25/3_audio.pth', map_location='cpu') #pretrained_dict = Acheckpoint #model_dict = audio_unet.state_dict() #pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and k!='audionet_upconvlayer1.0.weight' and k!='audionet_upconvlayer5.0.weight' and k!='audionet_upconvlayer5.0.bias' and k!='conv1x1.0.weight' and k!='conv1x1.0.bias' and k!='conv1x1.1.weight' and k!='conv1x1.1.bias' and k!='conv1x1.1.running_mean' and k!='conv1x1.1.running_var'} #model_dict.update(pretrained_dict) #audio_unet.load_state_dict(model_dict) self.audionet_convlayer1 = unet_conv(input_nc, ngf) self.audionet_convlayer2 = unet_conv(ngf, ngf * 2) self.audionet_convlayer3 = unet_conv(ngf * 2, ngf * 4) self.audionet_convlayer4 = unet_conv(ngf * 4, ngf * 8) self.audionet_convlayer5 = unet_conv(ngf * 8, ngf * 8) self.audionet_upconvlayer1 = unet_upconv(1024, ngf * 8) #1296 (audio-visual feature) = 784 (visual feature) + 512 (audio feature) self.audionet_upconvlayer2 = unet_upconv(ngf * 8, ngf *4) self.audionet_upconvlayer3 = unet_upconv(ngf * 4, ngf * 2) self.audionet_upconvlayer4 = unet_upconv(ngf * 2, ngf) self.audionet_upconvlayer5 = unet_upconv(ngf , output_nc, True) #outermost layer use a sigmoid to bound the mask self.conv1x1 = create_conv(4096, 2, 1, 0) wide_resnet = wide_resnet.module #self.unet= audio_unet #print(wide_resnet) ''' self.mod1 = wide_resnet.mod1 self.mod2 = wide_resnet.mod2 self.mod3 = wide_resnet.mod3 self.mod4 = wide_resnet.mod4 self.mod5 = wide_resnet.mod5 self.mod6 = wide_resnet.mod6 self.mod7 = wide_resnet.mod7 self.pool2 = wide_resnet.pool2 self.pool3 = wide_resnet.pool3 ''' del wide_resnet self.aspp = _AtrousSpatialPyramidPoolingModule(512, 64, output_stride=8) self.depthaspp = _AtrousSpatialPyramidPoolingModule(512,64, output_stride=8) self.bot_aud1 = nn.Conv2d(512, 256, kernel_size=1, bias=False) self.bot_multiaud = nn.Conv2d(512, 512, kernel_size=1, bias=False) self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(320, 256, kernel_size=1, bias=False) self.bot_depthaspp = nn.Conv2d(320, 128, kernel_size=1, bias=False) self.final = nn.Sequential( nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) self.final_depth = nn.Sequential( nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), Norm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), Norm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 1, kernel_size=1, bias=False)) initialize_weights(self.final);initialize_weights(self.bot_aud1);initialize_weights(self.bot_multiaud);
def __init__(self, num_classes, trunk='seresnext-50', criterion=None, variant='D', skip='m1', skip_num=48): super(DeepV3Plus, self).__init__() self.criterion = criterion self.variant = variant self.skip = skip self.skip_num = skip_num if trunk == 'seresnext-50': resnet = SEresnext.se_resnext50_32x4d() elif trunk == 'seresnext-101': resnet = SEresnext.se_resnext101_32x4d() elif trunk == 'resnet-50': resnet = Resnet.resnet50() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnet-101': resnet = Resnet.resnet101() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) else: raise ValueError("Not a valid network arch") self.layer0 = resnet.layer0 self.layer1, self.layer2, self.layer3, self.layer4 = \ resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 if self.variant == 'D': for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D16': for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") self.aspp = _AtrousSpatialPyramidPoolingModule(2048, 256, output_stride=8) if self.skip == 'm1': self.bot_fine = nn.Conv2d(256, self.skip_num, kernel_size=1, bias=False) elif self.skip == 'm2': self.bot_fine = nn.Conv2d(512, self.skip_num, kernel_size=1, bias=False) else: raise Exception('Not a valid skip') self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False) self.final = nn.Sequential( nn.Conv2d(256 + self.skip_num, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) initialize_weights(self.aspp) initialize_weights(self.bot_aspp) initialize_weights(self.bot_fine) initialize_weights(self.final)
def __init__(self, num_classes, trunk='wrn38', criterion=None, use_dpc=False, fuse_aspp=False, attn_2b=False, bn_head=False): super(ASDV3P, self).__init__() self.criterion = criterion self.fuse_aspp = fuse_aspp self.attn_2b = attn_2b self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) self.aspp, aspp_out_ch = get_aspp(high_level_ch, bottleneck_ch=256, output_stride=8, dpc=use_dpc) self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) # Semantic segmentation prediction head self.final = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) # Scale-attention prediction head assert cfg.MODEL.N_SCALES is not None self.scales = sorted(cfg.MODEL.N_SCALES) num_scales = len(self.scales) if cfg.MODEL.ATTNSCALE_BN_HEAD or bn_head: self.scale_attn = nn.Sequential( nn.Conv2d(num_scales * (256 + 48), 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_scales, kernel_size=1, bias=False)) else: self.scale_attn = nn.Sequential( nn.Conv2d(num_scales * (256 + 48), 512, kernel_size=3, padding=1, bias=False), nn.ReLU(inplace=True), nn.Conv2d(512, num_scales, kernel_size=1, padding=1, bias=False)) if cfg.OPTIONS.INIT_DECODER: initialize_weights(self.bot_fine) initialize_weights(self.bot_aspp) initialize_weights(self.scale_attn) initialize_weights(self.final) else: initialize_weights(self.final)
def __init__(self, num_classes, trunk='resnet-101', criterion=None, criterion_aux=None, variant='D', skip='m1', skip_num=48, args=None): super(DeepV3PlusHANet, self).__init__() self.criterion = criterion self.criterion_aux = criterion_aux self.variant = variant self.args = args self.num_attention_layer = 0 self.trunk = trunk for i in range(5): if args.hanet[i] > 0: self.num_attention_layer += 1 print("#### HANet layers", self.num_attention_layer) if trunk == 'shufflenetv2': channel_1st = 3 channel_2nd = 24 channel_3rd = 116 channel_4th = 232 prev_final_channel = 464 final_channel = 1024 resnet = models.shufflenet_v2_x1_0(pretrained=True) self.layer0 = nn.Sequential(resnet.conv1, resnet.maxpool) self.layer1 = resnet.stage2 self.layer2 = resnet.stage3 self.layer3 = resnet.stage4 self.layer4 = resnet.conv5 if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") elif trunk == 'mnasnet_05' or trunk == 'mnasnet_10': if trunk == 'mnasnet_05': resnet = models.mnasnet0_5(pretrained=True) channel_1st = 3 channel_2nd = 16 channel_3rd = 24 channel_4th = 48 prev_final_channel = 160 final_channel = 1280 print("# of layers", len(resnet.layers)) self.layer0 = nn.Sequential(resnet.layers[0], resnet.layers[1], resnet.layers[2], resnet.layers[3], resnet.layers[4], resnet.layers[5], resnet.layers[6], resnet.layers[7]) # 16 self.layer1 = nn.Sequential(resnet.layers[8], resnet.layers[9]) # 24, 40 self.layer2 = nn.Sequential(resnet.layers[10], resnet.layers[11]) # 48, 96 self.layer3 = nn.Sequential(resnet.layers[12], resnet.layers[13]) # 160, 320 self.layer4 = nn.Sequential(resnet.layers[14], resnet.layers[15], resnet.layers[16]) # 1280 else: resnet = models.mnasnet1_0(pretrained=True) channel_1st = 3 channel_2nd = 16 channel_3rd = 40 channel_4th = 96 prev_final_channel = 320 final_channel = 1280 print("# of layers", len(resnet.layers)) self.layer0 = nn.Sequential(resnet.layers[0], resnet.layers[1], resnet.layers[2], resnet.layers[3], resnet.layers[4], resnet.layers[5], resnet.layers[6], resnet.layers[7]) # 16 self.layer1 = nn.Sequential(resnet.layers[8], resnet.layers[9]) # 24, 40 self.layer2 = nn.Sequential(resnet.layers[10], resnet.layers[11]) # 48, 96 self.layer3 = nn.Sequential(resnet.layers[12], resnet.layers[13]) # 160, 320 self.layer4 = nn.Sequential(resnet.layers[14], resnet.layers[15], resnet.layers[16]) # 1280 if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") elif trunk == 'mobilenetv2': channel_1st = 3 channel_2nd = 16 channel_3rd = 32 channel_4th = 64 # prev_final_channel = 160 prev_final_channel = 320 final_channel = 1280 resnet = models.mobilenet_v2(pretrained=True) self.layer0 = nn.Sequential(resnet.features[0], resnet.features[1]) self.layer1 = nn.Sequential(resnet.features[2], resnet.features[3], resnet.features[4], resnet.features[5], resnet.features[6]) self.layer2 = nn.Sequential(resnet.features[7], resnet.features[8], resnet.features[9], resnet.features[10]) # self.layer3 = nn.Sequential(resnet.features[11], resnet.features[12], resnet.features[13], resnet.features[14], resnet.features[15], resnet.features[16]) # self.layer4 = nn.Sequential(resnet.features[17], resnet.features[18]) self.layer3 = nn.Sequential( resnet.features[11], resnet.features[12], resnet.features[13], resnet.features[14], resnet.features[15], resnet.features[16], resnet.features[17]) self.layer4 = nn.Sequential(resnet.features[18]) if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") else: channel_1st = 3 channel_2nd = 64 channel_3rd = 256 channel_4th = 512 prev_final_channel = 1024 final_channel = 2048 if trunk == 'resnet-50': resnet = Resnet.resnet50() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnet-101': # three 3 X 3 resnet = Resnet.resnet101() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu1, resnet.conv2, resnet.bn2, resnet.relu2, resnet.conv3, resnet.bn3, resnet.relu3, resnet.maxpool) elif trunk == 'resnet-152': resnet = Resnet.resnet152() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnext-50': resnet = models.resnext50_32x4d(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnext-101': resnet = models.resnext101_32x8d(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'wide_resnet-50': resnet = models.wide_resnet50_2(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'wide_resnet-101': resnet = models.wide_resnet101_2(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) else: raise ValueError("Not a valid network arch") self.layer0 = resnet.layer0 self.layer1, self.layer2, self.layer3, self.layer4 = \ resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 if self.variant == 'D': for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D4': for n, m in self.layer2.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (8, 8), (8, 8), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D16': for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") if self.variant == 'D': os = 8 elif self.variant == 'D4': os = 4 elif self.variant == 'D16': os = 16 else: os = 32 self.aspp = _AtrousSpatialPyramidPoolingModule(final_channel, 256, output_stride=os) self.bot_fine = nn.Sequential( nn.Conv2d(channel_3rd, 48, kernel_size=1, bias=False), Norm2d(48), nn.ReLU(inplace=True)) self.bot_aspp = nn.Sequential( nn.Conv2d(1280, 256, kernel_size=1, bias=False), Norm2d(256), nn.ReLU(inplace=True)) self.final1 = nn.Sequential( nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True)) self.final2 = nn.Sequential( nn.Conv2d(256, num_classes, kernel_size=1, bias=True)) if self.args.aux_loss is True: self.dsn = nn.Sequential( nn.Conv2d(prev_final_channel, 512, kernel_size=3, stride=1, padding=1), Norm2d(512), nn.ReLU(inplace=True), nn.Dropout2d(0.1), nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)) initialize_weights(self.dsn) if self.args.hanet[0] == 1: self.hanet0 = HANet_Conv(prev_final_channel, final_channel, self.args.hanet_set[0], self.args.hanet_set[1], self.args.hanet_set[2], self.args.hanet_pos[0], self.args.hanet_pos[1], pos_rfactor=self.args.pos_rfactor, pooling=self.args.pooling, dropout_prob=self.args.dropout, pos_noise=self.args.pos_noise) initialize_weights(self.hanet0) if self.args.hanet[1] == 1: self.hanet1 = HANet_Conv(final_channel, 1280, self.args.hanet_set[0], self.args.hanet_set[1], self.args.hanet_set[2], self.args.hanet_pos[0], self.args.hanet_pos[1], pos_rfactor=self.args.pos_rfactor, pooling=self.args.pooling, dropout_prob=self.args.dropout, pos_noise=self.args.pos_noise) initialize_weights(self.hanet1) if self.args.hanet[2] == 1: self.hanet2 = HANet_Conv(1280, 256, self.args.hanet_set[0], self.args.hanet_set[1], self.args.hanet_set[2], self.args.hanet_pos[0], self.args.hanet_pos[1], pos_rfactor=self.args.pos_rfactor, pooling=self.args.pooling, dropout_prob=self.args.dropout, pos_noise=self.args.pos_noise) initialize_weights(self.hanet2) if self.args.hanet[3] == 1: self.hanet3 = HANet_Conv(304, 256, self.args.hanet_set[0], self.args.hanet_set[1], self.args.hanet_set[2], self.args.hanet_pos[0], self.args.hanet_pos[1], pos_rfactor=self.args.pos_rfactor, pooling=self.args.pooling, dropout_prob=self.args.dropout, pos_noise=self.args.pos_noise) initialize_weights(self.hanet3) if self.args.hanet[4] == 1: self.hanet4 = HANet_Conv(256, num_classes, self.args.hanet_set[0], self.args.hanet_set[1], self.args.hanet_set[2], self.args.hanet_pos[0], self.args.hanet_pos[1], pos_rfactor=self.args.pos_rfactor, pooling='max', dropout_prob=self.args.dropout, pos_noise=self.args.pos_noise) initialize_weights(self.hanet4) initialize_weights(self.aspp) initialize_weights(self.bot_aspp) initialize_weights(self.bot_fine) initialize_weights(self.final1) initialize_weights(self.final2)
def __init__(self, num_classes, trunk=None, criterion=None): super(GSCNN, self).__init__() self.criterion = criterion self.num_classes = num_classes wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) wide_resnet = torch.nn.DataParallel(wide_resnet) try: checkpoint = torch.load( './network/pretrained_models/wider_resnet38.pth.tar', map_location='cpu') wide_resnet.load_state_dict(checkpoint['state_dict']) del checkpoint except: print( "Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models/wider_resnet38.pth.tar." ) raise RuntimeError( "=====================Could not load ImageNet weights of WideResNet38 network.=======================" ) wide_resnet = wide_resnet.module self.mod1 = wide_resnet.mod1 self.mod2 = wide_resnet.mod2 self.mod3 = wide_resnet.mod3 self.mod4 = wide_resnet.mod4 self.mod5 = wide_resnet.mod5 self.mod6 = wide_resnet.mod6 self.mod7 = wide_resnet.mod7 self.pool2 = wide_resnet.pool2 self.pool3 = wide_resnet.pool3 self.interpolate = F.interpolate del wide_resnet self.dsn1 = nn.Conv2d(64, 1, 1) self.dsn3 = nn.Conv2d(256, 1, 1) self.dsn4 = nn.Conv2d(512, 1, 1) self.dsn7 = nn.Conv2d(4096, 1, 1) self.res1 = Resnet.BasicBlock(64, 64, stride=1, downsample=None) self.d1 = nn.Conv2d(64, 32, 1) self.res2 = Resnet.BasicBlock(32, 32, stride=1, downsample=None) self.d2 = nn.Conv2d(32, 16, 1) self.res3 = Resnet.BasicBlock(16, 16, stride=1, downsample=None) self.d3 = nn.Conv2d(16, 8, 1) self.fuse = nn.Conv2d(8, 1, kernel_size=1, padding=0, bias=False) self.cw = nn.Conv2d(2, 1, kernel_size=1, padding=0, bias=False) self.gate1 = gsc.GatedSpatialConv2d(32, 32) self.gate2 = gsc.GatedSpatialConv2d(16, 16) self.gate3 = gsc.GatedSpatialConv2d(8, 8) self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, output_stride=8) self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(1280 + 256, 256, kernel_size=1, bias=False) self.final_seg = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) self.sigmoid = nn.Sigmoid() initialize_weights(self.final_seg)