def __init__(self, num_classes, trunk=None, criterion=None): super(GSCNN, self).__init__() self.criterion = criterion self.num_classes = num_classes wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) wide_resnet = torch.nn.DataParallel(wide_resnet) wide_resnet = wide_resnet.module self.mod1 = wide_resnet.mod1 self.mod2 = wide_resnet.mod2 self.mod3 = wide_resnet.mod3 self.mod4 = wide_resnet.mod4 self.mod5 = wide_resnet.mod5 self.mod6 = wide_resnet.mod6 self.mod7 = wide_resnet.mod7 self.pool2 = wide_resnet.pool2 self.pool3 = wide_resnet.pool3 self.interpolate = F.interpolate del wide_resnet self.dsn1 = nn.Conv2d(64, 1, 1) self.dsn3 = nn.Conv2d(256, 1, 1) self.dsn4 = nn.Conv2d(512, 1, 1) self.dsn7 = nn.Conv2d(4096, 1, 1) self.res1 = Resnet.BasicBlock(64, 64, stride=1, downsample=None) self.d1 = nn.Conv2d(64, 32, 1) self.res2 = Resnet.BasicBlock(32, 32, stride=1, downsample=None) self.d2 = nn.Conv2d(32, 16, 1) self.res3 = Resnet.BasicBlock(16, 16, stride=1, downsample=None) self.d3 = nn.Conv2d(16, 8, 1) self.fuse = nn.Conv2d(8, 1, kernel_size=1, padding=0, bias=False) self.cw = nn.Conv2d(2, 1, kernel_size=1, padding=0, bias=False) self.gate1 = gsc.GatedSpatialConv2d(32, 32) self.gate2 = gsc.GatedSpatialConv2d(16, 16) self.gate3 = gsc.GatedSpatialConv2d(8, 8) self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, output_stride=8) self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(1280 + 256, 256, kernel_size=1, bias=False) self.final_seg = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) self.sigmoid = nn.Sigmoid() initialize_weights(self.final_seg)
def main(config=None): try: global params global x global y # config should be a dictionary if config: params = _type_convert(config) else: # load params with open('params.json') as f: params = _type_convert(json.load(f)) # add some code to validate dictionary channels = len(params['bands']) x = tf.placeholder(tf.float32, [None, 84, 84, channels]) y = tf.placeholder(tf.float32, [None, params['n_classes']]) net = Resnet.get_network(x, params['block_config'], params['train'], params['global_avg_pool'], params['2016_update']) if params['train']: _train_network(net) else: _use_network(net) except KeyboardInterrupt: print '\nInterrupted' run = 'y' == raw_input('Run eval statistics?(y/[n])') if run: print 'Not implemented yet'
def __init__(self, num_classes, trunk='seresnext-50', criterion=None, variant='D', skip='m1', skip_num=48): super(DeepV3Plus, self).__init__() self.criterion = criterion self.variant = variant self.skip = skip self.skip_num = skip_num if trunk == 'seresnext-50': resnet = SEresnext.se_resnext50_32x4d() elif trunk == 'seresnext-101': resnet = SEresnext.se_resnext101_32x4d() elif trunk == 'resnet-50': resnet = Resnet.resnet50() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnet-101': resnet = Resnet.resnet101() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) else: raise ValueError("Not a valid network arch") self.layer0 = resnet.layer0 self.layer1, self.layer2, self.layer3, self.layer4 = \ resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 if self.variant == 'D': for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D16': for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") self.aspp = _AtrousSpatialPyramidPoolingModule(2048, 256, output_stride=8) if self.skip == 'm1': self.bot_fine = nn.Conv2d(256, self.skip_num, kernel_size=1, bias=False) elif self.skip == 'm2': self.bot_fine = nn.Conv2d(512, self.skip_num, kernel_size=1, bias=False) else: raise Exception('Not a valid skip') self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False) self.final = nn.Sequential( nn.Conv2d(256 + self.skip_num, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) initialize_weights(self.aspp) initialize_weights(self.bot_aspp) initialize_weights(self.bot_fine) initialize_weights(self.final)
from torch.utils.tensorboard import SummaryWriter import torchvision if __name__ == '__main__': opt = gather_options() print_options(opt) device = torch.device('cuda:{}'.format( opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') trainloader, testloader = loadData(opt) dataset_size = len(trainloader) print('#training images = %d' % dataset_size) net = Resnet(opt.input_nc, num_classes=opt.num_classes, norm=opt.norm, nl=opt.nl) net = init_net(net, init_type='normal', gpu_ids=[0]) if opt.continue_train: load_networks(opt, net) criterion = nn.CrossEntropyLoss().to(device) optimizer = torch.optim.SGD(net.parameters(), lr=opt.lr, momentum=0.9) scheduler = get_scheduler(optimizer, opt) iter = 0 running_loss = 0.0 correct = 0.0 total = 0
gamma=0.1, last_epoch=checkpoint['epoch']) print("Loading from epoch:", checkpoint['epoch'], 'schedular:', self.scheduler.last_epoch, 'map:', checkpoint['map'], 'rank1:', checkpoint['rank'][0]) return checkpoint['epoch'], checkpoint['rank'], checkpoint['map'] if __name__ == '__main__': mp.set_start_method(opt.start_method, True) if opt.model_name == 'MGN': model = MGN() loss = Loss_MGN() elif opt.model_name == 'Resnet': model = Resnet() loss = Loss_Resnet() elif opt.model_name == 'CGN': model = CGN() loss = Loss_CGN() elif opt.model_name == 'SN': model = SN() loss = Loss_SN() elif opt.model_name == 'FPN': model = FPN() loss = Loss_FPN() elif opt.model_name == 'AN': model = AN() loss = Loss_AN() if opt.mode == 'train':
def __init__(self, num_classes, trunk='resnet-101', criterion=None, criterion_aux=None, variant='D', skip='m1', skip_num=48, args=None): super(DeepV3PlusHANet, self).__init__() self.criterion = criterion self.criterion_aux = criterion_aux self.variant = variant self.args = args self.num_attention_layer = 0 self.trunk = trunk for i in range(5): if args.hanet[i] > 0: self.num_attention_layer += 1 print("#### HANet layers", self.num_attention_layer) if trunk == 'shufflenetv2': channel_1st = 3 channel_2nd = 24 channel_3rd = 116 channel_4th = 232 prev_final_channel = 464 final_channel = 1024 resnet = models.shufflenet_v2_x1_0(pretrained=True) self.layer0 = nn.Sequential(resnet.conv1, resnet.maxpool) self.layer1 = resnet.stage2 self.layer2 = resnet.stage3 self.layer3 = resnet.stage4 self.layer4 = resnet.conv5 if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") elif trunk == 'mnasnet_05' or trunk == 'mnasnet_10': if trunk == 'mnasnet_05': resnet = models.mnasnet0_5(pretrained=True) channel_1st = 3 channel_2nd = 16 channel_3rd = 24 channel_4th = 48 prev_final_channel = 160 final_channel = 1280 print("# of layers", len(resnet.layers)) self.layer0 = nn.Sequential(resnet.layers[0], resnet.layers[1], resnet.layers[2], resnet.layers[3], resnet.layers[4], resnet.layers[5], resnet.layers[6], resnet.layers[7]) # 16 self.layer1 = nn.Sequential(resnet.layers[8], resnet.layers[9]) # 24, 40 self.layer2 = nn.Sequential(resnet.layers[10], resnet.layers[11]) # 48, 96 self.layer3 = nn.Sequential(resnet.layers[12], resnet.layers[13]) # 160, 320 self.layer4 = nn.Sequential(resnet.layers[14], resnet.layers[15], resnet.layers[16]) # 1280 else: resnet = models.mnasnet1_0(pretrained=True) channel_1st = 3 channel_2nd = 16 channel_3rd = 40 channel_4th = 96 prev_final_channel = 320 final_channel = 1280 print("# of layers", len(resnet.layers)) self.layer0 = nn.Sequential(resnet.layers[0], resnet.layers[1], resnet.layers[2], resnet.layers[3], resnet.layers[4], resnet.layers[5], resnet.layers[6], resnet.layers[7]) # 16 self.layer1 = nn.Sequential(resnet.layers[8], resnet.layers[9]) # 24, 40 self.layer2 = nn.Sequential(resnet.layers[10], resnet.layers[11]) # 48, 96 self.layer3 = nn.Sequential(resnet.layers[12], resnet.layers[13]) # 160, 320 self.layer4 = nn.Sequential(resnet.layers[14], resnet.layers[15], resnet.layers[16]) # 1280 if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") elif trunk == 'mobilenetv2': channel_1st = 3 channel_2nd = 16 channel_3rd = 32 channel_4th = 64 # prev_final_channel = 160 prev_final_channel = 320 final_channel = 1280 resnet = models.mobilenet_v2(pretrained=True) self.layer0 = nn.Sequential(resnet.features[0], resnet.features[1]) self.layer1 = nn.Sequential(resnet.features[2], resnet.features[3], resnet.features[4], resnet.features[5], resnet.features[6]) self.layer2 = nn.Sequential(resnet.features[7], resnet.features[8], resnet.features[9], resnet.features[10]) # self.layer3 = nn.Sequential(resnet.features[11], resnet.features[12], resnet.features[13], resnet.features[14], resnet.features[15], resnet.features[16]) # self.layer4 = nn.Sequential(resnet.features[17], resnet.features[18]) self.layer3 = nn.Sequential( resnet.features[11], resnet.features[12], resnet.features[13], resnet.features[14], resnet.features[15], resnet.features[16], resnet.features[17]) self.layer4 = nn.Sequential(resnet.features[18]) if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride == (2, 2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") else: channel_1st = 3 channel_2nd = 64 channel_3rd = 256 channel_4th = 512 prev_final_channel = 1024 final_channel = 2048 if trunk == 'resnet-50': resnet = Resnet.resnet50() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnet-101': # three 3 X 3 resnet = Resnet.resnet101() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu1, resnet.conv2, resnet.bn2, resnet.relu2, resnet.conv3, resnet.bn3, resnet.relu3, resnet.maxpool) elif trunk == 'resnet-152': resnet = Resnet.resnet152() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnext-50': resnet = models.resnext50_32x4d(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnext-101': resnet = models.resnext101_32x8d(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'wide_resnet-50': resnet = models.wide_resnet50_2(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'wide_resnet-101': resnet = models.wide_resnet101_2(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) else: raise ValueError("Not a valid network arch") self.layer0 = resnet.layer0 self.layer1, self.layer2, self.layer3, self.layer4 = \ resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 if self.variant == 'D': for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D4': for n, m in self.layer2.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (8, 8), (8, 8), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D16': for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") if self.variant == 'D': os = 8 elif self.variant == 'D4': os = 4 elif self.variant == 'D16': os = 16 else: os = 32 self.aspp = _AtrousSpatialPyramidPoolingModule(final_channel, 256, output_stride=os) self.bot_fine = nn.Sequential( nn.Conv2d(channel_3rd, 48, kernel_size=1, bias=False), Norm2d(48), nn.ReLU(inplace=True)) self.bot_aspp = nn.Sequential( nn.Conv2d(1280, 256, kernel_size=1, bias=False), Norm2d(256), nn.ReLU(inplace=True)) self.final1 = nn.Sequential( nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True)) self.final2 = nn.Sequential( nn.Conv2d(256, num_classes, kernel_size=1, bias=True)) if self.args.aux_loss is True: self.dsn = nn.Sequential( nn.Conv2d(prev_final_channel, 512, kernel_size=3, stride=1, padding=1), Norm2d(512), nn.ReLU(inplace=True), nn.Dropout2d(0.1), nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)) initialize_weights(self.dsn) if self.args.hanet[0] == 1: self.hanet0 = HANet_Conv(prev_final_channel, final_channel, self.args.hanet_set[0], self.args.hanet_set[1], self.args.hanet_set[2], self.args.hanet_pos[0], self.args.hanet_pos[1], pos_rfactor=self.args.pos_rfactor, pooling=self.args.pooling, dropout_prob=self.args.dropout, pos_noise=self.args.pos_noise) initialize_weights(self.hanet0) if self.args.hanet[1] == 1: self.hanet1 = HANet_Conv(final_channel, 1280, self.args.hanet_set[0], self.args.hanet_set[1], self.args.hanet_set[2], self.args.hanet_pos[0], self.args.hanet_pos[1], pos_rfactor=self.args.pos_rfactor, pooling=self.args.pooling, dropout_prob=self.args.dropout, pos_noise=self.args.pos_noise) initialize_weights(self.hanet1) if self.args.hanet[2] == 1: self.hanet2 = HANet_Conv(1280, 256, self.args.hanet_set[0], self.args.hanet_set[1], self.args.hanet_set[2], self.args.hanet_pos[0], self.args.hanet_pos[1], pos_rfactor=self.args.pos_rfactor, pooling=self.args.pooling, dropout_prob=self.args.dropout, pos_noise=self.args.pos_noise) initialize_weights(self.hanet2) if self.args.hanet[3] == 1: self.hanet3 = HANet_Conv(304, 256, self.args.hanet_set[0], self.args.hanet_set[1], self.args.hanet_set[2], self.args.hanet_pos[0], self.args.hanet_pos[1], pos_rfactor=self.args.pos_rfactor, pooling=self.args.pooling, dropout_prob=self.args.dropout, pos_noise=self.args.pos_noise) initialize_weights(self.hanet3) if self.args.hanet[4] == 1: self.hanet4 = HANet_Conv(256, num_classes, self.args.hanet_set[0], self.args.hanet_set[1], self.args.hanet_set[2], self.args.hanet_pos[0], self.args.hanet_pos[1], pos_rfactor=self.args.pos_rfactor, pooling='max', dropout_prob=self.args.dropout, pos_noise=self.args.pos_noise) initialize_weights(self.hanet4) initialize_weights(self.aspp) initialize_weights(self.bot_aspp) initialize_weights(self.bot_fine) initialize_weights(self.final1) initialize_weights(self.final2)
'0002_c1s1_000776_01.jpg', '0007_c3s3_077419_03.jpg', '0007_c2s3_070952_01.jpg', '0010_c6s4_002427_02.jpg', '0010_c6s4_002452_02.jpg'] paths = ['../market1501/Market1501/bounding_box_train/'+_ for _ in paths] origin_transforms = transforms.Compose([ transforms.Resize((384, 128), interpolation=3), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) imgs = [origin_transforms(default_loader(path)) for path in paths] model = Resnet() checkpoint = torch.load('models/market1501/weights/Resnet/checkpoint_600.pth.tar') model.load_state_dict(checkpoint['state_dict'], map_location='cuda') ########################################### # img1 = get_layers(imgs[0])[1][10] # img2 = get_layers(imgs[1])[1][10] # img3 = get_layers(imgs[4])[1][10] # output1_1 = spatial_attention(img1, img1) # output1_2 = spatial_attention(img1, img2) # output1_3 = spatial_attention(img1, img3) # draw(img1, 'simg_1') # draw(img2, 'simg_2')
def __init__(self, num_classes, trunk='resnet-101', criterion=None, criterion_aux=None, variant='D', skip='m1', skip_num=48, args=None): super(DeepV3Plus, self).__init__() self.criterion = criterion self.criterion_aux = criterion_aux self.variant = variant self.args = args self.trunk = trunk if trunk == 'shufflenetv2': channel_1st = 3 channel_2nd = 24 channel_3rd = 116 channel_4th = 232 prev_final_channel = 464 final_channel = 1024 resnet = Shufflenet.shufflenet_v2_x1_0(pretrained=True, iw=self.args.wt_layer) class Layer0(nn.Module): def __init__(self, iw): super(Layer0, self).__init__() self.layer = nn.Sequential(resnet.conv1, resnet.maxpool) self.instance_norm_layer = resnet.instance_norm_layer1 self.iw = iw def forward(self, x_tuple): if len(x_tuple) == 2: w_arr = x_tuple[1] x = x_tuple[0] else: print("error in shufflnet layer 0 forward path") return x = self.layer[0][0](x) if self.iw >= 1: if self.iw == 1 or self.iw == 2: x, w = self.instance_norm_layer(x) w_arr.append(w) else: x = self.instance_norm_layer(x) else: x = self.layer[0][1](x) x = self.layer[0][2](x) x = self.layer[1](x) return [x, w_arr] class Layer4(nn.Module): def __init__(self, iw): super(Layer4, self).__init__() self.layer = resnet.conv5 self.instance_norm_layer = resnet.instance_norm_layer2 self.iw = iw def forward(self, x_tuple): if len(x_tuple) == 2: w_arr = x_tuple[1] x = x_tuple[0] else: print("error in shufflnet layer 4 forward path") return x = self.layer[0](x) if self.iw >= 1: if self.iw == 1 or self.iw == 2: x, w = self.instance_norm_layer(x) w_arr.append(w) else: x = self.instance_norm_layer(x) else: x = self.layer[1](x) x = self.layer[2](x) return [x, w_arr] self.layer0 = Layer0(iw=self.args.wt_layer[2]) self.layer1 = resnet.stage2 self.layer2 = resnet.stage3 self.layer3 = resnet.stage4 self.layer4 = Layer4(iw=self.args.wt_layer[6]) if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") elif trunk == 'mnasnet_05' or trunk == 'mnasnet_10': if trunk == 'mnasnet_05': resnet = models.mnasnet0_5(pretrained=True) channel_1st = 3 channel_2nd = 16 channel_3rd = 24 channel_4th = 48 prev_final_channel = 160 final_channel = 1280 print("# of layers", len(resnet.layers)) self.layer0 = nn.Sequential(resnet.layers[0],resnet.layers[1],resnet.layers[2], resnet.layers[3],resnet.layers[4],resnet.layers[5],resnet.layers[6],resnet.layers[7]) # 16 self.layer1 = nn.Sequential(resnet.layers[8], resnet.layers[9]) # 24, 40 self.layer2 = nn.Sequential(resnet.layers[10], resnet.layers[11]) # 48, 96 self.layer3 = nn.Sequential(resnet.layers[12], resnet.layers[13]) # 160, 320 self.layer4 = nn.Sequential(resnet.layers[14], resnet.layers[15], resnet.layers[16]) # 1280 else: resnet = models.mnasnet1_0(pretrained=True) channel_1st = 3 channel_2nd = 16 channel_3rd = 40 channel_4th = 96 prev_final_channel = 320 final_channel = 1280 print("# of layers", len(resnet.layers)) self.layer0 = nn.Sequential(resnet.layers[0],resnet.layers[1],resnet.layers[2], resnet.layers[3],resnet.layers[4],resnet.layers[5],resnet.layers[6],resnet.layers[7]) # 16 self.layer1 = nn.Sequential(resnet.layers[8], resnet.layers[9]) # 24, 40 self.layer2 = nn.Sequential(resnet.layers[10], resnet.layers[11]) # 48, 96 self.layer3 = nn.Sequential(resnet.layers[12], resnet.layers[13]) # 160, 320 self.layer4 = nn.Sequential(resnet.layers[14], resnet.layers[15], resnet.layers[16]) # 1280 if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") elif trunk == 'mobilenetv2': channel_1st = 3 channel_2nd = 16 channel_3rd = 32 channel_4th = 64 # prev_final_channel = 160 prev_final_channel = 320 final_channel = 1280 resnet = Mobilenet.mobilenet_v2(pretrained=True, iw=self.args.wt_layer) self.layer0 = nn.Sequential(resnet.features[0], resnet.features[1]) self.layer1 = nn.Sequential(resnet.features[2], resnet.features[3], resnet.features[4], resnet.features[5], resnet.features[6]) self.layer2 = nn.Sequential(resnet.features[7], resnet.features[8], resnet.features[9], resnet.features[10]) # self.layer3 = nn.Sequential(resnet.features[11], resnet.features[12], resnet.features[13], resnet.features[14], resnet.features[15], resnet.features[16]) # self.layer4 = nn.Sequential(resnet.features[17], resnet.features[18]) self.layer3 = nn.Sequential(resnet.features[11], resnet.features[12], resnet.features[13], resnet.features[14], resnet.features[15], resnet.features[16], resnet.features[17]) self.layer4 = nn.Sequential(resnet.features[18]) if self.variant == 'D': for n, m in self.layer2.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif self.variant == 'D16': for n, m in self.layer3.named_modules(): if isinstance(m, nn.Conv2d) and m.stride==(2,2): m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") else: channel_1st = 3 channel_2nd = 64 channel_3rd = 256 channel_4th = 512 prev_final_channel = 1024 final_channel = 2048 if trunk == 'resnet-18': channel_1st = 3 channel_2nd = 64 channel_3rd = 64 channel_4th = 128 prev_final_channel = 256 final_channel = 512 resnet = Resnet.resnet18(wt_layer=self.args.wt_layer) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnet-50': resnet = Resnet.resnet50(wt_layer=self.args.wt_layer) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnet-101': # three 3 X 3 resnet = Resnet.resnet101(pretrained=True, wt_layer=self.args.wt_layer) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu1, resnet.conv2, resnet.bn2, resnet.relu2, resnet.conv3, resnet.bn3, resnet.relu3, resnet.maxpool) elif trunk == 'resnet-152': resnet = Resnet.resnet152() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnext-50': resnet = models.resnext50_32x4d(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnext-101': resnet = models.resnext101_32x8d(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'wide_resnet-50': resnet = models.wide_resnet50_2(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'wide_resnet-101': resnet = models.wide_resnet101_2(pretrained=True) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) else: raise ValueError("Not a valid network arch") self.layer0 = resnet.layer0 self.layer1, self.layer2, self.layer3, self.layer4 = \ resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 if self.variant == 'D': for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D4': for n, m in self.layer2.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (8, 8), (8, 8), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D16': for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) else: # raise 'unknown deepv3 variant: {}'.format(self.variant) print("Not using Dilation ") if self.variant == 'D': os = 8 elif self.variant == 'D4': os = 4 elif self.variant == 'D16': os = 16 else: os = 32 self.output_stride = os self.aspp = _AtrousSpatialPyramidPoolingModule(final_channel, 256, output_stride=os) self.bot_fine = nn.Sequential( nn.Conv2d(channel_3rd, 48, kernel_size=1, bias=False), Norm2d(48), nn.ReLU(inplace=True)) self.bot_aspp = nn.Sequential( nn.Conv2d(1280, 256, kernel_size=1, bias=False), Norm2d(256), nn.ReLU(inplace=True)) self.final1 = nn.Sequential( nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True)) self.final2 = nn.Sequential( nn.Conv2d(256, num_classes, kernel_size=1, bias=True)) self.dsn = nn.Sequential( nn.Conv2d(prev_final_channel, 512, kernel_size=3, stride=1, padding=1), Norm2d(512), nn.ReLU(inplace=True), nn.Dropout2d(0.1), nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True) ) initialize_weights(self.dsn) initialize_weights(self.aspp) initialize_weights(self.bot_aspp) initialize_weights(self.bot_fine) initialize_weights(self.final1) initialize_weights(self.final2) # Setting the flags self.eps = 1e-5 self.whitening = False if trunk == 'resnet-101': self.three_input_layer = True in_channel_list = [64, 64, 128, 256, 512, 1024, 2048] # 8128, 32640, 130816 out_channel_list = [32, 32, 64, 128, 256, 512, 1024] elif trunk == 'resnet-18': self.three_input_layer = False in_channel_list = [0, 0, 64, 64, 128, 256, 512] # 8128, 32640, 130816 out_channel_list = [0, 0, 32, 32, 64, 128, 256] elif trunk == 'shufflenetv2': self.three_input_layer = False in_channel_list = [0, 0, 24, 116, 232, 464, 1024] elif trunk == 'mobilenetv2': self.three_input_layer = False in_channel_list = [0, 0, 16, 32, 64, 320, 1280] else: # ResNet-50 self.three_input_layer = False in_channel_list = [0, 0, 64, 256, 512, 1024, 2048] # 8128, 32640, 130816 out_channel_list = [0, 0, 32, 128, 256, 512, 1024] self.cov_matrix_layer = [] self.cov_type = [] for i in range(len(self.args.wt_layer)): if self.args.wt_layer[i] > 0: self.whitening = True if self.args.wt_layer[i] == 1: self.cov_matrix_layer.append(CovMatrix_IRW(dim=in_channel_list[i], relax_denom=self.args.relax_denom)) self.cov_type.append(self.args.wt_layer[i]) elif self.args.wt_layer[i] == 2: self.cov_matrix_layer.append(CovMatrix_ISW(dim=in_channel_list[i], relax_denom=self.args.relax_denom, clusters=self.args.clusters)) self.cov_type.append(self.args.wt_layer[i])
def __init__(self, num_classes, trunk=None, criterion=None): super(GSCNN, self).__init__() self.criterion = criterion self.num_classes = num_classes wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) wide_resnet = torch.nn.DataParallel(wide_resnet) try: checkpoint = torch.load( './network/pretrained_models/wider_resnet38.pth.tar', map_location='cpu') wide_resnet.load_state_dict(checkpoint['state_dict']) del checkpoint except: print( "Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models/wider_resnet38.pth.tar." ) raise RuntimeError( "=====================Could not load ImageNet weights of WideResNet38 network.=======================" ) wide_resnet = wide_resnet.module self.mod1 = wide_resnet.mod1 self.mod2 = wide_resnet.mod2 self.mod3 = wide_resnet.mod3 self.mod4 = wide_resnet.mod4 self.mod5 = wide_resnet.mod5 self.mod6 = wide_resnet.mod6 self.mod7 = wide_resnet.mod7 self.pool2 = wide_resnet.pool2 self.pool3 = wide_resnet.pool3 self.interpolate = F.interpolate del wide_resnet self.dsn1 = nn.Conv2d(64, 1, 1) self.dsn3 = nn.Conv2d(256, 1, 1) self.dsn4 = nn.Conv2d(512, 1, 1) self.dsn7 = nn.Conv2d(4096, 1, 1) self.res1 = Resnet.BasicBlock(64, 64, stride=1, downsample=None) self.d1 = nn.Conv2d(64, 32, 1) self.res2 = Resnet.BasicBlock(32, 32, stride=1, downsample=None) self.d2 = nn.Conv2d(32, 16, 1) self.res3 = Resnet.BasicBlock(16, 16, stride=1, downsample=None) self.d3 = nn.Conv2d(16, 8, 1) self.fuse = nn.Conv2d(8, 1, kernel_size=1, padding=0, bias=False) self.cw = nn.Conv2d(2, 1, kernel_size=1, padding=0, bias=False) self.gate1 = gsc.GatedSpatialConv2d(32, 32) self.gate2 = gsc.GatedSpatialConv2d(16, 16) self.gate3 = gsc.GatedSpatialConv2d(8, 8) self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, output_stride=8) self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(1280 + 256, 256, kernel_size=1, bias=False) self.final_seg = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) self.sigmoid = nn.Sigmoid() initialize_weights(self.final_seg)