def __init__(self, num_classes, trunk='wrn38', criterion=None): super(MscaleV3Plus, self).__init__() self.criterion = criterion self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) self.aspp, aspp_out_ch = get_aspp(high_level_ch, bottleneck_ch=256, output_stride=8) self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) # Semantic segmentation prediction head self.final = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) # Scale-attention prediction head scale_in_ch = 2 * (256 + 48) self.scale_attn = nn.Sequential( nn.Conv2d(scale_in_ch, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 1, kernel_size=1, bias=False), nn.Sigmoid()) if cfg.OPTIONS.INIT_DECODER: initialize_weights(self.bot_fine) initialize_weights(self.bot_aspp) initialize_weights(self.scale_attn) initialize_weights(self.final) else: initialize_weights(self.final)
def _make_transition_layer( self, num_channels_pre_layer, num_channels_cur_layer): num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) transition_layers = [] for i in range(num_branches_cur): if i < num_branches_pre: if num_channels_cur_layer[i] != num_channels_pre_layer[i]: transition_layers.append(nn.Sequential( nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), Norm2d( num_channels_cur_layer[i], momentum=BN_MOMENTUM), nn.ReLU(inplace=relu_inplace))) else: transition_layers.append(None) else: conv3x3s = [] for j in range(i+1-num_branches_pre): inchannels = num_channels_pre_layer[-1] outchannels = num_channels_cur_layer[i] \ if j == i-num_branches_pre else inchannels conv3x3s.append(nn.Sequential( nn.Conv2d( inchannels, outchannels, 3, 2, 1, bias=False), Norm2d(outchannels, momentum=BN_MOMENTUM), nn.ReLU(inplace=relu_inplace))) transition_layers.append(nn.Sequential(*conv3x3s)) return nn.ModuleList(transition_layers)
def __init__(self, num_classes, trunk='wrn38', criterion=None, use_dpc=False, init_all=False): super(DeepV3Plus, self).__init__() self.criterion = criterion self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) self.aspp, aspp_out_ch = get_aspp(high_level_ch, bottleneck_ch=256, output_stride=8, dpc=use_dpc) self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) self.final = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) if init_all: initialize_weights(self.aspp) initialize_weights(self.bot_aspp) initialize_weights(self.bot_fine) initialize_weights(self.final) else: initialize_weights(self.final)
def __init__(self, in_dim, reduction_dim=256, output_stride=16, rates=(6, 12, 18)): super(_AtrousSpatialPyramidPoolingModule, self).__init__() # Check if we are using distributed BN and use the nn from encoding.nn # library rather than using standard pytorch.nn if output_stride == 8: rates = [2 * r for r in rates] elif output_stride == 16: pass else: raise 'output stride of {} not supported'.format(output_stride) self.features = [] # 1x1 self.features.append( nn.Sequential(nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), Norm2d(reduction_dim), nn.ReLU(inplace=True))) # other rates for r in rates: self.features.append(nn.Sequential( nn.Conv2d(in_dim, reduction_dim, kernel_size=3, dilation=r, padding=r, bias=False), Norm2d(reduction_dim), nn.ReLU(inplace=True) )) self.features = torch.nn.ModuleList(self.features) # img level features self.img_pooling = nn.AdaptiveAvgPool2d(1) self.img_conv = nn.Sequential( nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), Norm2d(reduction_dim), nn.ReLU(inplace=True))
def make_attn_head(in_ch, out_ch): bot_ch = cfg.MODEL.SEGATTN_BOT_CH if cfg.MODEL.MSCALE_OLDARCH: return old_make_attn_head(in_ch, bot_ch, out_ch) od = OrderedDict([('conv0', nn.Conv2d(in_ch, bot_ch, kernel_size=3, padding=1, bias=False)), ('bn0', Norm2d(bot_ch)), ('re0', nn.ReLU(inplace=True))]) if cfg.MODEL.MSCALE_INNER_3x3: od['conv1'] = nn.Conv2d(bot_ch, bot_ch, kernel_size=3, padding=1, bias=False) od['bn1'] = Norm2d(bot_ch) od['re1'] = nn.ReLU(inplace=True) if cfg.MODEL.MSCALE_DROPOUT: od['drop'] = nn.Dropout(0.5) od['conv2'] = nn.Conv2d(bot_ch, out_ch, kernel_size=1, bias=False) od['sig'] = nn.Sigmoid() attn_head = nn.Sequential(od) # init_attn(attn_head) return attn_head
def __init__(self, in_dim, reduction_dim=256, output_stride=16, rates=(6, 12, 18)): super(AtrousSpatialPyramidPoolingModule, self).__init__() if output_stride == 8: rates = [2 * r for r in rates] elif output_stride == 16: pass else: raise 'output stride of {} not supported'.format(output_stride) self.features = [] # 1x1 self.features.append( nn.Sequential(nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), Norm2d(reduction_dim), nn.ReLU(inplace=True))) # other rates for r in rates: self.features.append(nn.Sequential( nn.Conv2d(in_dim, reduction_dim, kernel_size=3, dilation=r, padding=r, bias=False), Norm2d(reduction_dim), nn.ReLU(inplace=True) )) self.features = nn.ModuleList(self.features) # img level features self.img_pooling = nn.AdaptiveAvgPool2d(1) self.img_conv = nn.Sequential( nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), Norm2d(reduction_dim), nn.ReLU(inplace=True))
def make_seg_head(in_ch, out_ch): bot_ch = cfg.MODEL.SEGATTN_BOT_CH return nn.Sequential( nn.Conv2d(in_ch, bot_ch, kernel_size=3, padding=1, bias=False), Norm2d(bot_ch), nn.ReLU(inplace=True), nn.Conv2d(bot_ch, bot_ch, kernel_size=3, padding=1, bias=False), Norm2d(bot_ch), nn.ReLU(inplace=True), nn.Conv2d(bot_ch, out_ch, kernel_size=1, bias=False))
def __init__(self, num_classes, trunk=None, criterion=None): super(GSCNN, self).__init__() self.criterion = criterion self.num_classes = num_classes wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) wide_resnet = torch.nn.DataParallel(wide_resnet) wide_resnet = wide_resnet.module self.mod1 = wide_resnet.mod1 self.mod2 = wide_resnet.mod2 self.mod3 = wide_resnet.mod3 self.mod4 = wide_resnet.mod4 self.mod5 = wide_resnet.mod5 self.mod6 = wide_resnet.mod6 self.mod7 = wide_resnet.mod7 self.pool2 = wide_resnet.pool2 self.pool3 = wide_resnet.pool3 self.interpolate = F.interpolate del wide_resnet self.dsn1 = nn.Conv2d(64, 1, 1) self.dsn3 = nn.Conv2d(256, 1, 1) self.dsn4 = nn.Conv2d(512, 1, 1) self.dsn7 = nn.Conv2d(4096, 1, 1) self.res1 = Resnet.BasicBlock(64, 64, stride=1, downsample=None) self.d1 = nn.Conv2d(64, 32, 1) self.res2 = Resnet.BasicBlock(32, 32, stride=1, downsample=None) self.d2 = nn.Conv2d(32, 16, 1) self.res3 = Resnet.BasicBlock(16, 16, stride=1, downsample=None) self.d3 = nn.Conv2d(16, 8, 1) self.fuse = nn.Conv2d(8, 1, kernel_size=1, padding=0, bias=False) self.cw = nn.Conv2d(2, 1, kernel_size=1, padding=0, bias=False) self.gate1 = gsc.GatedSpatialConv2d(32, 32) self.gate2 = gsc.GatedSpatialConv2d(16, 16) self.gate3 = gsc.GatedSpatialConv2d(8, 8) self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, output_stride=8) self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(1280 + 256, 256, kernel_size=1, bias=False) self.final_seg = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) self.sigmoid = nn.Sigmoid() initialize_weights(self.final_seg)
def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = Norm2d(planes, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=relu_inplace) self.conv2 = conv3x3(planes, planes) self.bn2 = Norm2d(planes, momentum=BN_MOMENTUM) self.downsample = downsample self.stride = stride
def _make_fuse_layers(self): if self.num_branches == 1: return None num_branches = self.num_branches num_inchannels = self.num_inchannels fuse_layers = [] for i in range(num_branches if self.multi_scale_output else 1): fuse_layer = [] for j in range(num_branches): if j > i: fuse_layer.append( nn.Sequential( nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), Norm2d(num_inchannels[i], momentum=BN_MOMENTUM))) elif j == i: fuse_layer.append(None) else: conv3x3s = [] for k in range(i - j): if k == i - j - 1: num_outchannels_conv3x3 = num_inchannels[i] conv3x3s.append( nn.Sequential( nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), Norm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM))) else: num_outchannels_conv3x3 = num_inchannels[j] conv3x3s.append( nn.Sequential( nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), Norm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM), nn.ReLU(inplace=relu_inplace))) fuse_layer.append(nn.Sequential(*conv3x3s)) fuse_layers.append(nn.ModuleList(fuse_layer)) return nn.ModuleList(fuse_layers)
def old_make_attn_head(in_ch, bot_ch, out_ch): attn = nn.Sequential( nn.Conv2d(in_ch, bot_ch, kernel_size=3, padding=1, bias=False), Norm2d(bot_ch), nn.ReLU(inplace=True), nn.Conv2d(bot_ch, bot_ch, kernel_size=3, padding=1, bias=False), Norm2d(bot_ch), nn.ReLU(inplace=True), nn.Conv2d(bot_ch, out_ch, kernel_size=out_ch, bias=False), nn.Sigmoid()) init_attn(attn) return attn
def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = Norm2d(planes, momentum=BN_MOMENTUM) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = Norm2d(planes, momentum=BN_MOMENTUM) self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = Norm2d(planes * self.expansion, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=relu_inplace) self.downsample = downsample self.stride = stride
def __init__(self, **kwargs): extra = cfg.MODEL.OCR_EXTRA super(HighResolutionNet, self).__init__() # stem net self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = Norm2d(64, momentum=BN_MOMENTUM) self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) self.bn2 = Norm2d(64, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=relu_inplace) self.stage1_cfg = extra['STAGE1'] num_channels = self.stage1_cfg['NUM_CHANNELS'][0] block = blocks_dict[self.stage1_cfg['BLOCK']] num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) stage1_out_channel = block.expansion*num_channels self.stage2_cfg = extra['STAGE2'] num_channels = self.stage2_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage2_cfg['BLOCK']] num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] self.transition1 = self._make_transition_layer( [stage1_out_channel], num_channels) self.stage2, pre_stage_channels = self._make_stage( self.stage2_cfg, num_channels) self.stage3_cfg = extra['STAGE3'] num_channels = self.stage3_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage3_cfg['BLOCK']] num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] self.transition2 = self._make_transition_layer( pre_stage_channels, num_channels) self.stage3, pre_stage_channels = self._make_stage( self.stage3_cfg, num_channels) self.stage4_cfg = extra['STAGE4'] num_channels = self.stage4_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage4_cfg['BLOCK']] num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] self.transition3 = self._make_transition_layer( pre_stage_channels, num_channels) self.stage4, pre_stage_channels = self._make_stage( self.stage4_cfg, num_channels, multi_scale_output=True) self.high_level_ch = np.int(np.sum(pre_stage_channels))
def __init__(self, num_classes, input_size, trunk='WideResnet38', criterion=None): super(DeepWV3Plus, self).__init__() self.criterion = criterion logging.info("Trunk: %s", trunk) wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) # TODO: Should this be even here ? wide_resnet = torch.nn.DataParallel(wide_resnet) try: checkpoint = torch.load('weights/ResNet/wider_resnet38.pth.tar', map_location='cpu') wide_resnet.load_state_dict(checkpoint['state_dict']) del checkpoint except: print( "=====================Could not load imagenet weights=======================" ) wide_resnet = wide_resnet.module self.mod1 = wide_resnet.mod1 self.mod2 = wide_resnet.mod2 self.mod3 = wide_resnet.mod3 self.mod4 = wide_resnet.mod4 self.mod5 = wide_resnet.mod5 self.mod6 = wide_resnet.mod6 self.mod7 = wide_resnet.mod7 self.pool2 = wide_resnet.pool2 self.pool3 = wide_resnet.pool3 del wide_resnet self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, output_stride=8) self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False) self.final = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) initialize_weights(self.final)
def __init__(self, num_classes, trunk='WideResnet38', criterion=None): super(DeepWV3Plus, self).__init__() self.criterion = criterion logging.info("Trunk: %s", trunk) wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) wide_resnet = torch.nn.DataParallel(wide_resnet) if criterion is not None: try: checkpoint = torch.load( './pretrained_models/wider_resnet38.pth.tar', map_location='cpu') wide_resnet.load_state_dict(checkpoint['state_dict']) del checkpoint except: print( "Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models/wider_resnet38.pth.tar." ) raise RuntimeError( "=====================Could not load ImageNet weights of WideResNet38 network.=======================" ) wide_resnet = wide_resnet.module self.mod1 = wide_resnet.mod1 self.mod2 = wide_resnet.mod2 self.mod3 = wide_resnet.mod3 self.mod4 = wide_resnet.mod4 self.mod5 = wide_resnet.mod5 self.mod6 = wide_resnet.mod6 self.mod7 = wide_resnet.mod7 self.pool2 = wide_resnet.pool2 self.pool3 = wide_resnet.pool3 del wide_resnet self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, output_stride=8) self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False) self.final = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) initialize_weights(self.final)
def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): downsample = None if stride != 1 or \ self.num_inchannels[branch_index] != (num_channels[branch_index] * block.expansion): downsample = nn.Sequential( nn.Conv2d(self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion, kernel_size=1, stride=stride, bias=False), Norm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM), ) layers = [] layers.append( block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)) self.num_inchannels[branch_index] = \ num_channels[branch_index] * block.expansion for i in range(1, num_blocks[branch_index]): layers.append( block(self.num_inchannels[branch_index], num_channels[branch_index])) return nn.Sequential(*layers)
def __init__(self, num_classes, trunk='wrn38', criterion=None, use_dpc=False, fuse_aspp=False, attn_2b=False, attn_cls=False): super(MscaleV3Plus, self).__init__() self.criterion = criterion self.fuse_aspp = fuse_aspp self.attn_2b = attn_2b self.attn_cls = attn_cls self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) self.aspp, aspp_out_ch = get_aspp(high_level_ch, bottleneck_ch=256, output_stride=8, dpc=use_dpc) self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) # Semantic segmentation prediction head bot_ch = cfg.MODEL.SEGATTN_BOT_CH self.final = nn.Sequential( nn.Conv2d(256 + 48, bot_ch, kernel_size=3, padding=1, bias=False), Norm2d(bot_ch), nn.ReLU(inplace=True), nn.Conv2d(bot_ch, bot_ch, kernel_size=3, padding=1, bias=False), Norm2d(bot_ch), nn.ReLU(inplace=True), nn.Conv2d(bot_ch, num_classes, kernel_size=1, bias=False)) # Scale-attention prediction head if self.attn_2b: attn_ch = 2 else: if self.attn_cls: attn_ch = num_classes else: attn_ch = 1 scale_in_ch = 256 + 48 self.scale_attn = make_attn_head(in_ch=scale_in_ch, out_ch=attn_ch) if cfg.OPTIONS.INIT_DECODER: initialize_weights(self.bot_fine) initialize_weights(self.bot_aspp) initialize_weights(self.scale_attn) initialize_weights(self.final) else: initialize_weights(self.final)
def __init__(self, num_classes, trunk='wrn38', criterion=None, use_dpc=False, fuse_aspp=False, attn_2b=False): super(MscaleV3Plus, self).__init__() self.criterion = criterion self.fuse_aspp = fuse_aspp self.attn_2b = attn_2b self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) self.aspp, aspp_out_ch = get_aspp(high_level_ch, bottleneck_ch=256, output_stride=8, dpc=use_dpc, img_norm = False) self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) #self.asnb = ASNB(low_in_channels = 48, high_in_channels=256, out_channels=256, key_channels=64, value_channels=256, dropout=0., sizes=([1]), norm_type='batchnorm',attn_scale=0.25) self.adnb = ADNB(d_model=256, nhead=8, num_encoder_layers=2, dim_feedforward=256, dropout=0.5, activation="relu", num_feature_levels=1, enc_n_points=4) # Semantic segmentation prediction head bot_ch = cfg.MODEL.SEGATTN_BOT_CH self.final = nn.Sequential( nn.Conv2d(256 + 48, bot_ch, kernel_size=3, padding=1, bias=False), Norm2d(bot_ch), nn.ReLU(inplace=True), nn.Conv2d(bot_ch, bot_ch, kernel_size=3, padding=1, bias=False), Norm2d(bot_ch), nn.ReLU(inplace=True), nn.Conv2d(bot_ch, num_classes, kernel_size=1, bias=False)) # Scale-attention prediction head if self.attn_2b: attn_ch = 2 else: attn_ch = 1 scale_in_ch = 256 + 48 self.scale_attn = make_attn_head(in_ch=scale_in_ch, out_ch=attn_ch) if cfg.OPTIONS.INIT_DECODER: initialize_weights(self.bot_fine) initialize_weights(self.bot_aspp) initialize_weights(self.scale_attn) initialize_weights(self.final) else: initialize_weights(self.final)
def __init__(self, in_dim, reduction_dim=256, output_stride=16, rates=(6, 12, 18)): super(ASPP_edge, self).__init__(in_dim=in_dim, reduction_dim=reduction_dim, output_stride=output_stride, rates=rates) self.edge_conv = nn.Sequential( nn.Conv2d(1, reduction_dim, kernel_size=1, bias=False), Norm2d(reduction_dim), nn.ReLU(inplace=True))
def __init__(self, num_classes, trunk='wrn38', criterion=None, use_dpc=False, init_all=False): super(DeepV3PlusATTN, self).__init__() self.criterion = criterion self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) #self.aspp, aspp_out_ch = get_aspp(high_level_ch, # bottleneck_ch=256, # output_stride=8, # dpc=use_dpc) #self.attn = APNB(in_channels=high_level_ch, out_channels=high_level_ch, key_channels=256, value_channels=256, dropout=0.5, sizes=([1]), norm_type='batchnorm', psp_size=(1, 3, 6, 8)) self.attn = AFNB(low_in_channels=2048, high_in_channels=4096, out_channels=256, key_channels=64, value_channels=256, dropout=0.8, sizes=([1]), norm_type='batchnorm', psp_size=(1, 3, 6, 8)) self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(256, 256, kernel_size=1, bias=False) self.final = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) if init_all: initialize_weights(self.attn) initialize_weights(self.bot_aspp) initialize_weights(self.bot_fine) initialize_weights(self.final) else: initialize_weights(self.final)
def _make_layer(self, block, inplanes, planes, blocks, stride=1): downsample = None if stride != 1 or inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), Norm2d(planes * block.expansion, momentum=BN_MOMENTUM), ) layers = [] layers.append(block(inplanes, planes, stride, downsample)) inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(inplanes, planes)) return nn.Sequential(*layers)
def __init__(self, num_classes, trunk='WideResnet38', criterion=None,ngf=64, input_nc=1, output_nc=2): super(DeepWV3Plus, self).__init__() self.criterion = criterion self.MSEcriterion= torch.nn.MSELoss() logging.info("Trunk: %s", trunk) wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) wide_resnet = torch.nn.DataParallel(wide_resnet) try: checkpoint = torch.load('/srv/beegfs02/scratch/language_vision/data/Sound_Event_Prediction/semantic-segmentation-master/pretrained_models/wider_resnet38.pth.tar', map_location='cpu') wide_resnet.load_state_dict(checkpoint['state_dict']) del checkpoint except: print("=====================Could not load ImageNet weights=======================") print("Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models.") #audio_unet = AudioNet_multitask(ngf=64,input_nc=2) #Acheckpoint = torch.load('/srv/beegfs02/scratch/language_vision/data/Sound_Event_Prediction/audio/audioSynthesis/checkpoints/synBi2Bi_16_25/3_audio.pth', map_location='cpu') #pretrained_dict = Acheckpoint #model_dict = audio_unet.state_dict() #pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and k!='audionet_upconvlayer1.0.weight' and k!='audionet_upconvlayer5.0.weight' and k!='audionet_upconvlayer5.0.bias' and k!='conv1x1.0.weight' and k!='conv1x1.0.bias' and k!='conv1x1.1.weight' and k!='conv1x1.1.bias' and k!='conv1x1.1.running_mean' and k!='conv1x1.1.running_var'} #model_dict.update(pretrained_dict) #audio_unet.load_state_dict(model_dict) self.audionet_convlayer1 = unet_conv(input_nc, ngf) self.audionet_convlayer2 = unet_conv(ngf, ngf * 2) self.audionet_convlayer3 = unet_conv(ngf * 2, ngf * 4) self.audionet_convlayer4 = unet_conv(ngf * 4, ngf * 8) self.audionet_convlayer5 = unet_conv(ngf * 8, ngf * 8) self.audionet_upconvlayer1 = unet_upconv(1024, ngf * 8) #1296 (audio-visual feature) = 784 (visual feature) + 512 (audio feature) self.audionet_upconvlayer2 = unet_upconv(ngf * 8, ngf *4) self.audionet_upconvlayer3 = unet_upconv(ngf * 4, ngf * 2) self.audionet_upconvlayer4 = unet_upconv(ngf * 2, ngf) self.audionet_upconvlayer5 = unet_upconv(ngf , output_nc, True) #outermost layer use a sigmoid to bound the mask self.conv1x1 = create_conv(4096, 2, 1, 0) wide_resnet = wide_resnet.module #self.unet= audio_unet #print(wide_resnet) ''' self.mod1 = wide_resnet.mod1 self.mod2 = wide_resnet.mod2 self.mod3 = wide_resnet.mod3 self.mod4 = wide_resnet.mod4 self.mod5 = wide_resnet.mod5 self.mod6 = wide_resnet.mod6 self.mod7 = wide_resnet.mod7 self.pool2 = wide_resnet.pool2 self.pool3 = wide_resnet.pool3 ''' del wide_resnet self.aspp = _AtrousSpatialPyramidPoolingModule(512, 64, output_stride=8) self.depthaspp = _AtrousSpatialPyramidPoolingModule(512,64, output_stride=8) self.bot_aud1 = nn.Conv2d(512, 256, kernel_size=1, bias=False) self.bot_multiaud = nn.Conv2d(512, 512, kernel_size=1, bias=False) self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(320, 256, kernel_size=1, bias=False) self.bot_depthaspp = nn.Conv2d(320, 128, kernel_size=1, bias=False) self.final = nn.Sequential( nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) self.final_depth = nn.Sequential( nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), Norm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), Norm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 1, kernel_size=1, bias=False)) initialize_weights(self.final);initialize_weights(self.bot_aud1);initialize_weights(self.bot_multiaud);
def __init__(self, num_classes, trunk='wrn38', criterion=None, use_dpc=False, fuse_aspp=False, attn_2b=False, bn_head=False): super(ASDV3P, self).__init__() self.criterion = criterion self.fuse_aspp = fuse_aspp self.attn_2b = attn_2b self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) self.aspp, aspp_out_ch = get_aspp(high_level_ch, bottleneck_ch=256, output_stride=8, dpc=use_dpc) self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) # Semantic segmentation prediction head self.final = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) # Scale-attention prediction head assert cfg.MODEL.N_SCALES is not None self.scales = sorted(cfg.MODEL.N_SCALES) num_scales = len(self.scales) if cfg.MODEL.ATTNSCALE_BN_HEAD or bn_head: self.scale_attn = nn.Sequential( nn.Conv2d(num_scales * (256 + 48), 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_scales, kernel_size=1, bias=False)) else: self.scale_attn = nn.Sequential( nn.Conv2d(num_scales * (256 + 48), 512, kernel_size=3, padding=1, bias=False), nn.ReLU(inplace=True), nn.Conv2d(512, num_scales, kernel_size=1, padding=1, bias=False)) if cfg.OPTIONS.INIT_DECODER: initialize_weights(self.bot_fine) initialize_weights(self.bot_aspp) initialize_weights(self.scale_attn) initialize_weights(self.final) else: initialize_weights(self.final)
def __init__(self, num_classes, trunk=None, criterion=None): super(GSCNN, self).__init__() self.criterion = criterion self.num_classes = num_classes wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) wide_resnet = torch.nn.DataParallel(wide_resnet) try: checkpoint = torch.load( './network/pretrained_models/wider_resnet38.pth.tar', map_location='cpu') wide_resnet.load_state_dict(checkpoint['state_dict']) del checkpoint except: print( "Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models/wider_resnet38.pth.tar." ) raise RuntimeError( "=====================Could not load ImageNet weights of WideResNet38 network.=======================" ) wide_resnet = wide_resnet.module self.mod1 = wide_resnet.mod1 self.mod2 = wide_resnet.mod2 self.mod3 = wide_resnet.mod3 self.mod4 = wide_resnet.mod4 self.mod5 = wide_resnet.mod5 self.mod6 = wide_resnet.mod6 self.mod7 = wide_resnet.mod7 self.pool2 = wide_resnet.pool2 self.pool3 = wide_resnet.pool3 self.interpolate = F.interpolate del wide_resnet self.dsn1 = nn.Conv2d(64, 1, 1) self.dsn3 = nn.Conv2d(256, 1, 1) self.dsn4 = nn.Conv2d(512, 1, 1) self.dsn7 = nn.Conv2d(4096, 1, 1) self.res1 = Resnet.BasicBlock(64, 64, stride=1, downsample=None) self.d1 = nn.Conv2d(64, 32, 1) self.res2 = Resnet.BasicBlock(32, 32, stride=1, downsample=None) self.d2 = nn.Conv2d(32, 16, 1) self.res3 = Resnet.BasicBlock(16, 16, stride=1, downsample=None) self.d3 = nn.Conv2d(16, 8, 1) self.fuse = nn.Conv2d(8, 1, kernel_size=1, padding=0, bias=False) self.cw = nn.Conv2d(2, 1, kernel_size=1, padding=0, bias=False) self.gate1 = gsc.GatedSpatialConv2d(32, 32) self.gate2 = gsc.GatedSpatialConv2d(16, 16) self.gate3 = gsc.GatedSpatialConv2d(8, 8) self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, output_stride=8) self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(1280 + 256, 256, kernel_size=1, bias=False) self.final_seg = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) self.sigmoid = nn.Sigmoid() initialize_weights(self.final_seg)
def BNReLU(ch): return nn.Sequential(Norm2d(ch), nn.ReLU())
def __init__(self, num_classes, trunk='resnet-101', criterion=None, criterion_aux=None, variant='D', skip='m1', skip_num=48, args=None): super(DeepV3Plus, self).__init__() self.criterion = criterion self.criterion_aux = criterion_aux self.variant = variant self.args = args self.trunk = trunk if trunk == 'shufflenetv2': channel_1st = 3 channel_2nd = 24 channel_3rd = 116 channel_4th = 232 prev_final_channel = 464 final_channel = 1024 resnet = Shufflenet.shufflenet_v2_x1_0(pretrained=True, iw=self.args.wt_layer) class Layer0(nn.Module): def __init__(self, iw): super(Layer0, self).__init__() self.layer = nn.Sequential(resnet.conv1, resnet.maxpool) self.instance_norm_layer = resnet.instance_norm_layer1 self.iw = iw def forward(self, x_tuple): if len(x_tuple) == 2: w_arr = x_tuple[1] x = x_tuple[0] else: print("error in shufflnet layer 0 forward path") return x = self.layer[0][0](x) if self.iw >= 1: if self.iw == 1 or self.iw == 2: x, w = self.instance_norm_layer(x) w_arr.append(w) else: x = self.instance_norm_layer(x) else: x = self.layer[0][1](x) x = self.layer[0][2](x) x = self.layer[1](x) return [x, w_arr] class Layer4(nn.Module): def __init__(self, iw): super(Layer4, self).__init__() self.layer = resnet.conv5 self.instance_norm_layer = resnet.instance_norm_layer2 self.iw = iw def forward(self, x_tuple): if len(x_tuple) == 2: w_arr = x_tuple[1] x = x_tuple[0] else: print("error in shufflnet layer 4 forward path") return x = self.layer[0](x) if self.iw >= 1: if self.iw == 1 or self.iw == 2: x, w = self.instance_norm_layer(x) w_arr.append(w) else: x = self.instance_norm_layer(x) else: x = self.layer[1](x) x = self.layer[2](x) return [x, w_arr] self.layer0 = Layer0(iw=self.args.wt_layer[2]) self.layer1 = resnet.stage2 self.layer2 = resnet.stage3 self.layer3 = resnet.stage4 self.layer4 = Layer4(iw=self.args.wt_layer[6]) if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") elif trunk == 'mnasnet_05' or trunk == 'mnasnet_10': if trunk == 'mnasnet_05': resnet = models.mnasnet0_5(pretrained=True) channel_1st = 3 channel_2nd = 16 channel_3rd = 24 channel_4th = 48 prev_final_channel = 160 final_channel = 1280 print("# of layers", len(resnet.layers)) self.layer0 = nn.Sequential(resnet.layers[0],resnet.layers[1],resnet.layers[2], resnet.layers[3],resnet.layers[4],resnet.layers[5],resnet.layers[6],resnet.layers[7]) # 16 self.layer1 = nn.Sequential(resnet.layers[8], resnet.layers[9]) # 24, 40 self.layer2 = nn.Sequential(resnet.layers[10], resnet.layers[11]) # 48, 96 self.layer3 = nn.Sequential(resnet.layers[12], resnet.layers[13]) # 160, 320 self.layer4 = nn.Sequential(resnet.layers[14], resnet.layers[15], resnet.layers[16]) # 1280 else: resnet = models.mnasnet1_0(pretrained=True) channel_1st = 3 channel_2nd = 16 channel_3rd = 40 channel_4th = 96 prev_final_channel = 320 final_channel = 1280 print("# of layers", len(resnet.layers)) self.layer0 = nn.Sequential(resnet.layers[0],resnet.layers[1],resnet.layers[2], resnet.layers[3],resnet.layers[4],resnet.layers[5],resnet.layers[6],resnet.layers[7]) # 16 self.layer1 = nn.Sequential(resnet.layers[8], resnet.layers[9]) # 24, 40 self.layer2 = nn.Sequential(resnet.layers[10], resnet.layers[11]) # 48, 96 self.layer3 = nn.Sequential(resnet.layers[12], resnet.layers[13]) # 160, 320 self.layer4 = nn.Sequential(resnet.layers[14], resnet.layers[15], resnet.layers[16]) # 1280 if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") elif trunk == 'mobilenetv2': channel_1st = 3 channel_2nd = 16 channel_3rd = 32 channel_4th = 64 # prev_final_channel = 160 prev_final_channel = 320 final_channel = 1280 resnet = Mobilenet.mobilenet_v2(pretrained=True, iw=self.args.wt_layer) self.layer0 = nn.Sequential(resnet.features[0], resnet.features[1]) self.layer1 = nn.Sequential(resnet.features[2], resnet.features[3], resnet.features[4], resnet.features[5], resnet.features[6]) self.layer2 = nn.Sequential(resnet.features[7], resnet.features[8], resnet.features[9], resnet.features[10]) # self.layer3 = nn.Sequential(resnet.features[11], resnet.features[12], resnet.features[13], resnet.features[14], resnet.features[15], resnet.features[16]) # self.layer4 = nn.Sequential(resnet.features[17], resnet.features[18]) self.layer3 = nn.Sequential(resnet.features[11], resnet.features[12], resnet.features[13], resnet.features[14], resnet.features[15], resnet.features[16], resnet.features[17]) self.layer4 = nn.Sequential(resnet.features[18]) if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") else: channel_1st = 3 channel_2nd = 64 channel_3rd = 256 channel_4th = 512 prev_final_channel = 1024 final_channel = 2048 if trunk == 'resnet-18': channel_1st = 3 channel_2nd = 64 channel_3rd = 64 channel_4th = 128 prev_final_channel = 256 final_channel = 512 resnet = Resnet.resnet18(wt_layer=self.args.wt_layer) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnet-50': resnet = Resnet.resnet50(wt_layer=self.args.wt_layer) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnet-101': # three 3 X 3 resnet = Resnet.resnet101(pretrained=True, wt_layer=self.args.wt_layer) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu1, resnet.conv2, resnet.bn2, resnet.relu2, resnet.conv3, resnet.bn3, resnet.relu3, resnet.maxpool) elif trunk == 'resnet-152': resnet = Resnet.resnet152() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnext-50': resnet = models.resnext50_32x4d(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnext-101': resnet = models.resnext101_32x8d(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'wide_resnet-50': resnet = models.wide_resnet50_2(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'wide_resnet-101': resnet = models.wide_resnet101_2(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) else: raise ValueError("Not a valid network arch") self.layer0 = resnet.layer0 self.layer1, self.layer2, self.layer3, self.layer4 = \ resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 if self.variant == 'D': for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D4': for n, m in self.layer2.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (8, 8), (8, 8), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D16': for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") if self.variant == 'D': os = 8 elif self.variant == 'D4': os = 4 elif self.variant == 'D16': os = 16 else: os = 32 self.output_stride = os self.aspp = _AtrousSpatialPyramidPoolingModule(final_channel, 256, output_stride=os) self.bot_fine = nn.Sequential( nn.Conv2d(channel_3rd, 48, kernel_size=1, bias=False), Norm2d(48), nn.ReLU(inplace=True)) self.bot_aspp = nn.Sequential( nn.Conv2d(1280, 256, kernel_size=1, bias=False), Norm2d(256), nn.ReLU(inplace=True)) self.final1 = nn.Sequential( nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True)) self.final2 = nn.Sequential( nn.Conv2d(256, num_classes, kernel_size=1, bias=True)) self.dsn = nn.Sequential( nn.Conv2d(prev_final_channel, 512, kernel_size=3, stride=1, padding=1), Norm2d(512), nn.ReLU(inplace=True), nn.Dropout2d(0.1), nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True) ) initialize_weights(self.dsn) initialize_weights(self.aspp) initialize_weights(self.bot_aspp) initialize_weights(self.bot_fine) initialize_weights(self.final1) initialize_weights(self.final2) # Setting the flags self.eps = 1e-5 self.whitening = False if trunk == 'resnet-101': self.three_input_layer = True in_channel_list = [64, 64, 128, 256, 512, 1024, 2048] # 8128, 32640, 130816 out_channel_list = [32, 32, 64, 128, 256, 512, 1024] elif trunk == 'resnet-18': self.three_input_layer = False in_channel_list = [0, 0, 64, 64, 128, 256, 512] # 8128, 32640, 130816 out_channel_list = [0, 0, 32, 32, 64, 128, 256] elif trunk == 'shufflenetv2': self.three_input_layer = False in_channel_list = [0, 0, 24, 116, 232, 464, 1024] elif trunk == 'mobilenetv2': self.three_input_layer = False in_channel_list = [0, 0, 16, 32, 64, 320, 1280] else: # ResNet-50 self.three_input_layer = False in_channel_list = [0, 0, 64, 256, 512, 1024, 2048] # 8128, 32640, 130816 out_channel_list = [0, 0, 32, 128, 256, 512, 1024] self.cov_matrix_layer = [] self.cov_type = [] for i in range(len(self.args.wt_layer)): if self.args.wt_layer[i] > 0: self.whitening = True if self.args.wt_layer[i] == 1: self.cov_matrix_layer.append(CovMatrix_IRW(dim=in_channel_list[i], relax_denom=self.args.relax_denom)) self.cov_type.append(self.args.wt_layer[i]) elif self.args.wt_layer[i] == 2: self.cov_matrix_layer.append(CovMatrix_ISW(dim=in_channel_list[i], relax_denom=self.args.relax_denom, clusters=self.args.clusters)) self.cov_type.append(self.args.wt_layer[i])
def __init__(self, num_classes, trunk='resnet-101', criterion=None, criterion_aux=None, variant='D', skip='m1', skip_num=48, args=None): super(DeepV3PlusHANet, self).__init__() self.criterion = criterion self.criterion_aux = criterion_aux self.variant = variant self.args = args self.num_attention_layer = 0 self.trunk = trunk for i in range(5): if args.hanet[i] > 0: self.num_attention_layer += 1 print("#### HANet layers", self.num_attention_layer) if trunk == 'shufflenetv2': channel_1st = 3 channel_2nd = 24 channel_3rd = 116 channel_4th = 232 prev_final_channel = 464 final_channel = 1024 resnet = models.shufflenet_v2_x1_0(pretrained=True) self.layer0 = nn.Sequential(resnet.conv1, resnet.maxpool) self.layer1 = resnet.stage2 self.layer2 = resnet.stage3 self.layer3 = resnet.stage4 self.layer4 = resnet.conv5 if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") elif trunk == 'mnasnet_05' or trunk == 'mnasnet_10': if trunk == 'mnasnet_05': resnet = models.mnasnet0_5(pretrained=True) channel_1st = 3 channel_2nd = 16 channel_3rd = 24 channel_4th = 48 prev_final_channel = 160 final_channel = 1280 print("# of layers", len(resnet.layers)) self.layer0 = nn.Sequential(resnet.layers[0], resnet.layers[1], resnet.layers[2], resnet.layers[3], resnet.layers[4], resnet.layers[5], resnet.layers[6], resnet.layers[7]) # 16 self.layer1 = nn.Sequential(resnet.layers[8], resnet.layers[9]) # 24, 40 self.layer2 = nn.Sequential(resnet.layers[10], resnet.layers[11]) # 48, 96 self.layer3 = nn.Sequential(resnet.layers[12], resnet.layers[13]) # 160, 320 self.layer4 = nn.Sequential(resnet.layers[14], resnet.layers[15], resnet.layers[16]) # 1280 else: resnet = models.mnasnet1_0(pretrained=True) channel_1st = 3 channel_2nd = 16 channel_3rd = 40 channel_4th = 96 prev_final_channel = 320 final_channel = 1280 print("# of layers", len(resnet.layers)) self.layer0 = nn.Sequential(resnet.layers[0], resnet.layers[1], resnet.layers[2], resnet.layers[3], resnet.layers[4], resnet.layers[5], resnet.layers[6], resnet.layers[7]) # 16 self.layer1 = nn.Sequential(resnet.layers[8], resnet.layers[9]) # 24, 40 self.layer2 = nn.Sequential(resnet.layers[10], resnet.layers[11]) # 48, 96 self.layer3 = nn.Sequential(resnet.layers[12], resnet.layers[13]) # 160, 320 self.layer4 = nn.Sequential(resnet.layers[14], resnet.layers[15], resnet.layers[16]) # 1280 if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") elif trunk == 'mobilenetv2': channel_1st = 3 channel_2nd = 16 channel_3rd = 32 channel_4th = 64 # prev_final_channel = 160 prev_final_channel = 320 final_channel = 1280 resnet = models.mobilenet_v2(pretrained=True) self.layer0 = nn.Sequential(resnet.features[0], resnet.features[1]) self.layer1 = nn.Sequential(resnet.features[2], resnet.features[3], resnet.features[4], resnet.features[5], resnet.features[6]) self.layer2 = nn.Sequential(resnet.features[7], resnet.features[8], resnet.features[9], resnet.features[10]) # self.layer3 = nn.Sequential(resnet.features[11], resnet.features[12], resnet.features[13], resnet.features[14], resnet.features[15], resnet.features[16]) # self.layer4 = nn.Sequential(resnet.features[17], resnet.features[18]) self.layer3 = nn.Sequential( resnet.features[11], resnet.features[12], resnet.features[13], resnet.features[14], resnet.features[15], resnet.features[16], resnet.features[17]) self.layer4 = nn.Sequential(resnet.features[18]) if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") else: channel_1st = 3 channel_2nd = 64 channel_3rd = 256 channel_4th = 512 prev_final_channel = 1024 final_channel = 2048 if trunk == 'resnet-50': resnet = Resnet.resnet50() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnet-101': # three 3 X 3 resnet = Resnet.resnet101() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu1, resnet.conv2, resnet.bn2, resnet.relu2, resnet.conv3, resnet.bn3, resnet.relu3, resnet.maxpool) elif trunk == 'resnet-152': resnet = Resnet.resnet152() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnext-50': resnet = models.resnext50_32x4d(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnext-101': resnet = models.resnext101_32x8d(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'wide_resnet-50': resnet = models.wide_resnet50_2(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'wide_resnet-101': resnet = models.wide_resnet101_2(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) else: raise ValueError("Not a valid network arch") self.layer0 = resnet.layer0 self.layer1, self.layer2, self.layer3, self.layer4 = \ resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 if self.variant == 'D': for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D4': for n, m in self.layer2.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (8, 8), (8, 8), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D16': for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") if self.variant == 'D': os = 8 elif self.variant == 'D4': os = 4 elif self.variant == 'D16': os = 16 else: os = 32 self.aspp = _AtrousSpatialPyramidPoolingModule(final_channel, 256, output_stride=os) self.bot_fine = nn.Sequential( nn.Conv2d(channel_3rd, 48, kernel_size=1, bias=False), Norm2d(48), nn.ReLU(inplace=True)) self.bot_aspp = nn.Sequential( nn.Conv2d(1280, 256, kernel_size=1, bias=False), Norm2d(256), nn.ReLU(inplace=True)) self.final1 = nn.Sequential( nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True)) self.final2 = nn.Sequential( nn.Conv2d(256, num_classes, kernel_size=1, bias=True)) if self.args.aux_loss is True: self.dsn = nn.Sequential( nn.Conv2d(prev_final_channel, 512, kernel_size=3, stride=1, padding=1), Norm2d(512), nn.ReLU(inplace=True), nn.Dropout2d(0.1), nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)) initialize_weights(self.dsn) if self.args.hanet[0] == 1: self.hanet0 = HANet_Conv(prev_final_channel, final_channel, self.args.hanet_set[0], self.args.hanet_set[1], self.args.hanet_set[2], self.args.hanet_pos[0], self.args.hanet_pos[1], pos_rfactor=self.args.pos_rfactor, pooling=self.args.pooling, dropout_prob=self.args.dropout, pos_noise=self.args.pos_noise) initialize_weights(self.hanet0) if self.args.hanet[1] == 1: self.hanet1 = HANet_Conv(final_channel, 1280, self.args.hanet_set[0], self.args.hanet_set[1], self.args.hanet_set[2], self.args.hanet_pos[0], self.args.hanet_pos[1], pos_rfactor=self.args.pos_rfactor, pooling=self.args.pooling, dropout_prob=self.args.dropout, pos_noise=self.args.pos_noise) initialize_weights(self.hanet1) if self.args.hanet[2] == 1: self.hanet2 = HANet_Conv(1280, 256, self.args.hanet_set[0], self.args.hanet_set[1], self.args.hanet_set[2], self.args.hanet_pos[0], self.args.hanet_pos[1], pos_rfactor=self.args.pos_rfactor, pooling=self.args.pooling, dropout_prob=self.args.dropout, pos_noise=self.args.pos_noise) initialize_weights(self.hanet2) if self.args.hanet[3] == 1: self.hanet3 = HANet_Conv(304, 256, self.args.hanet_set[0], self.args.hanet_set[1], self.args.hanet_set[2], self.args.hanet_pos[0], self.args.hanet_pos[1], pos_rfactor=self.args.pos_rfactor, pooling=self.args.pooling, dropout_prob=self.args.dropout, pos_noise=self.args.pos_noise) initialize_weights(self.hanet3) if self.args.hanet[4] == 1: self.hanet4 = HANet_Conv(256, num_classes, self.args.hanet_set[0], self.args.hanet_set[1], self.args.hanet_set[2], self.args.hanet_pos[0], self.args.hanet_pos[1], pos_rfactor=self.args.pos_rfactor, pooling='max', dropout_prob=self.args.dropout, pos_noise=self.args.pos_noise) initialize_weights(self.hanet4) initialize_weights(self.aspp) initialize_weights(self.bot_aspp) initialize_weights(self.bot_fine) initialize_weights(self.final1) initialize_weights(self.final2)
def __init__(self, in_channel, out_channel, kernel_size=3, r_factor=64, layer=3, pos_injection=2, is_encoding=1, pos_rfactor=8, pooling='mean', dropout_prob=0.0, pos_noise=0.0): super(HANet_Conv, self).__init__() self.pooling = pooling self.pos_injection = pos_injection self.layer = layer self.dropout_prob = dropout_prob self.sigmoid = nn.Sigmoid() if r_factor > 0: mid_1_channel = math.ceil(in_channel / r_factor) elif r_factor < 0: r_factor = r_factor * -1 mid_1_channel = in_channel * r_factor if self.dropout_prob > 0: self.dropout = nn.Dropout2d(self.dropout_prob) self.attention_first = nn.Sequential( nn.Conv1d(in_channels=in_channel, out_channels=mid_1_channel, kernel_size=1, stride=1, padding=0, bias=False), Norm2d(mid_1_channel), nn.ReLU(inplace=True)) if layer == 2: self.attention_second = nn.Sequential( nn.Conv1d(in_channels=mid_1_channel, out_channels=out_channel, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, bias=True)) elif layer == 3: mid_2_channel = (mid_1_channel * 2) self.attention_second = nn.Sequential( nn.Conv1d(in_channels=mid_1_channel, out_channels=mid_2_channel, kernel_size=3, stride=1, padding=1, bias=True), Norm2d(mid_2_channel), nn.ReLU(inplace=True)) self.attention_third = nn.Sequential( nn.Conv1d(in_channels=mid_2_channel, out_channels=out_channel, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, bias=True)) if self.pooling == 'mean': #print("##### average pooling") self.rowpool = nn.AdaptiveAvgPool2d((128 // pos_rfactor, 1)) else: #print("##### max pooling") self.rowpool = nn.AdaptiveMaxPool2d((128 // pos_rfactor, 1)) if pos_rfactor > 0: if is_encoding == 0: if self.pos_injection == 1: self.pos_emb1d_1st = PosEmbedding1D(pos_rfactor, dim=in_channel, pos_noise=pos_noise) elif self.pos_injection == 2: self.pos_emb1d_2nd = PosEmbedding1D(pos_rfactor, dim=mid_1_channel, pos_noise=pos_noise) elif is_encoding == 1: if self.pos_injection == 1: self.pos_emb1d_1st = PosEncoding1D(pos_rfactor, dim=in_channel, pos_noise=pos_noise) elif self.pos_injection == 2: self.pos_emb1d_2nd = PosEncoding1D(pos_rfactor, dim=mid_1_channel, pos_noise=pos_noise) else: print("Not supported position encoding") exit()
def __init__(self, trunk='WideResnet38', criterion=None, criterion2=None): super(DeepWV3Plus, self).__init__() self.criterion = criterion self.criterion2 = criterion2 logging.info("Trunk: %s", trunk) tasks = ['semantic', 'traversability'] wide_resnet = wider_resnet38_a2(classes=1000, dilation=True, tasks=tasks) wide_resnet = torch.nn.DataParallel(wide_resnet) if criterion is not None: try: checkpoint = torch.load( './pretrained_models/wider_resnet38.pth.tar', map_location='cpu') # wide_resnet.load_state_dict(checkpoint['state_dict']) net_state_dict = wide_resnet.state_dict() loaded_dict = checkpoint['state_dict'] new_loaded_dict = {} for k in net_state_dict: if k in loaded_dict and net_state_dict[k].size( ) == loaded_dict[k].size(): new_loaded_dict[k] = loaded_dict[k] else: logging.info("Skipped loading parameter %s", k) net_state_dict.update(new_loaded_dict) wide_resnet.load_state_dict(net_state_dict) del checkpoint except: print( "Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models/wider_resnet38.pth.tar." ) raise RuntimeError( "=====================Could not load ImageNet weights of WideResNet38 network.=======================" ) wide_resnet = wide_resnet.module self.task_weights = torch.nn.Parameter( torch.zeros(2, requires_grad=True)) self.mod1 = wide_resnet.mod1 self.mod2 = wide_resnet.mod2 self.mod3 = wide_resnet.mod3 self.mod4 = wide_resnet.mod4 self.mod5 = wide_resnet.mod5 self.mod6 = wide_resnet.mod6 self.mod7 = wide_resnet.mod7 self.pool2 = wide_resnet.pool2 self.pool3 = wide_resnet.pool3 del wide_resnet self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, output_stride=8) self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False) self.final = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 19, kernel_size=1, bias=False)) self.aspp2 = _AtrousSpatialPyramidPoolingModule(4096, 256, output_stride=8) self.bot_fine2 = nn.Conv2d(128, 48, kernel_size=1, bias=False) self.bot_aspp2 = nn.Conv2d(1280, 256, kernel_size=1, bias=False) self.final2 = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 2, kernel_size=1, bias=False))
def __init__(self, num_classes, trunk='seresnext-50', criterion=None, variant='D', skip='m1', skip_num=48): super(DeepV3Plus, self).__init__() self.criterion = criterion self.variant = variant self.skip = skip self.skip_num = skip_num if trunk == 'seresnext-50': resnet = SEresnext.se_resnext50_32x4d() elif trunk == 'seresnext-101': resnet = SEresnext.se_resnext101_32x4d() elif trunk == 'resnet-50': resnet = Resnet.resnet50() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnet-101': resnet = Resnet.resnet101() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) else: raise ValueError("Not a valid network arch") self.layer0 = resnet.layer0 self.layer1, self.layer2, self.layer3, self.layer4 = \ resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 if self.variant == 'D': for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D16': for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") self.aspp = _AtrousSpatialPyramidPoolingModule(2048, 256, output_stride=8) if self.skip == 'm1': self.bot_fine = nn.Conv2d(256, self.skip_num, kernel_size=1, bias=False) elif self.skip == 'm2': self.bot_fine = nn.Conv2d(512, self.skip_num, kernel_size=1, bias=False) else: raise Exception('Not a valid skip') self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False) self.final = nn.Sequential( nn.Conv2d(256 + self.skip_num, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) initialize_weights(self.aspp) initialize_weights(self.bot_aspp) initialize_weights(self.bot_fine) initialize_weights(self.final)