def main(): args = get_arguments() print("=====> Configure dataset and model") configure_dataset_model(args) print(args) print("=====> Set GPU for training") if args.cuda: print("====> Use gpu id: '{}'".format(args.gpus)) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus if not torch.cuda.is_available(): raise Exception( "No GPU found or Wrong gpu id, please run without --cuda") model = CoattentionNet(num_classes=args.num_classes) saved_state_dict = torch.load(args.restore_from, map_location=lambda storage, loc: storage) #print(saved_state_dict.keys()) #model.load_state_dict({k.replace('pspmodule.',''):v for k,v in torch.load(args.restore_from)['state_dict'].items()}) model.load_state_dict(convert_state_dict(saved_state_dict["model"]) ) #convert_state_dict(saved_state_dict["model"]) model.eval() model.cuda() if args.dataset == 'voc12': testloader = data.DataLoader(VOCDataTestSet(args.data_dir, args.data_list, crop_size=(505, 505), mean=args.img_mean), batch_size=1, shuffle=False, pin_memory=True) interp = nn.Upsample(size=(505, 505), mode='bilinear') voc_colorize = VOCColorize() elif args.dataset == 'cityscapes': testloader = data.DataLoader( CityscapesTestDataSet(args.data_dir, args.data_list, f_scale=args.f_scale, mean=args.img_mean), batch_size=1, shuffle=False, pin_memory=True ) # f_sale, meaning resize image at f_scale as input interp = nn.Upsample(size=(1024, 2048), mode='bilinear') #size = (h,w) voc_colorize = VOCColorize() elif args.dataset == 'davis': #for davis 2016 db_test = db.PairwiseImg( train=False, inputRes=(473, 473), db_root_dir=args.data_dir, transform=None, seq_name=None, sample_range=args.sample_range ) #db_root_dir() --> '/path/to/DAVIS-2016' train path testloader = data.DataLoader(db_test, batch_size=1, shuffle=False, num_workers=0) voc_colorize = VOCColorize() else: print("dataset error") data_list = [] if args.save_segimage: if not os.path.exists(args.seg_save_dir) and not os.path.exists( args.vis_save_dir): os.makedirs(args.seg_save_dir) os.makedirs(args.vis_save_dir) print("======> test set size:", len(testloader)) my_index = 0 old_temp = '' for index, batch in enumerate(testloader): print('%d processd' % (index)) target = batch['target'] #search = batch['search'] temp = batch['seq_name'] args.seq_name = temp[0] print(args.seq_name) if old_temp == args.seq_name: my_index = my_index + 1 else: my_index = 0 output_sum = 0 for i in range(0, args.sample_range): search = batch['search' + '_' + str(i)] search_im = search #print(search_im.size()) output = model( Variable(target, volatile=True).cuda(), Variable(search_im, volatile=True).cuda()) #print(output[0]) # output有两个 output_sum = output_sum + output[0].data[ 0, 0].cpu().numpy() #分割那个分支的结果 #np.save('infer'+str(i)+'.npy',output1) #output2 = output[1].data[0, 0].cpu().numpy() #interp' output1 = output_sum / args.sample_range first_image = np.array( Image.open(args.data_dir + '/JPEGImages/480p/blackswan/00000.jpg')) original_shape = first_image.shape output1 = cv2.resize(output1, (original_shape[1], original_shape[0])) if 0: original_image = target[0] #print('image type:',type(original_image.numpy())) original_image = original_image.numpy() original_image = original_image.transpose((2, 1, 0)) original_image = cv2.resize(original_image, (original_shape[1], original_shape[0])) unary = np.zeros((2, original_shape[0] * original_shape[1]), dtype='float32') #unary[0, :, :] = res_saliency/255 #unary[1, :, :] = 1-res_saliency/255 EPSILON = 1e-8 tau = 1.05 crf = dcrf.DenseCRF(original_shape[1] * original_shape[0], 2) anno_norm = (output1 - np.min(output1)) / ( np.max(output1) - np.min(output1)) #res_saliency/ 255. n_energy = 1.0 - anno_norm + EPSILON #-np.log((1.0 - anno_norm + EPSILON)) #/ (tau * sigmoid(1 - anno_norm)) p_energy = anno_norm + EPSILON #-np.log(anno_norm + EPSILON) #/ (tau * sigmoid(anno_norm)) #unary = unary.reshape((2, -1)) #print(unary.shape) unary[1, :] = p_energy.flatten() unary[0, :] = n_energy.flatten() crf.setUnaryEnergy(unary_from_softmax(unary)) feats = create_pairwise_gaussian(sdims=(3, 3), shape=original_shape[:2]) crf.addPairwiseEnergy(feats, compat=3, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC) feats = create_pairwise_bilateral( sdims=(10, 10), schan=(1, 1, 1), # orgin is 60, 60 5, 5, 5 img=original_image, chdim=2) crf.addPairwiseEnergy(feats, compat=5, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC) Q = crf.inference(5) MAP = np.argmax(Q, axis=0) output1 = MAP.reshape((original_shape[0], original_shape[1])) mask = (output1 * 255).astype(np.uint8) #print(mask.shape[0]) mask = Image.fromarray(mask) if args.dataset == 'voc12': print(output.shape) print(size) output = output[:, :size[0], :size[1]] output = output.transpose(1, 2, 0) output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8) if args.save_segimage: seg_filename = os.path.join(args.seg_save_dir, '{}.png'.format(name[0])) color_file = Image.fromarray( voc_colorize(output).transpose(1, 2, 0), 'RGB') color_file.save(seg_filename) elif args.dataset == 'davis': save_dir_res = os.path.join(args.seg_save_dir, 'Results', args.seq_name) old_temp = args.seq_name if not os.path.exists(save_dir_res): os.makedirs(save_dir_res) if args.save_segimage: my_index1 = str(my_index).zfill(5) seg_filename = os.path.join(save_dir_res, '{}.png'.format(my_index1)) #color_file = Image.fromarray(voc_colorize(output).transpose(1, 2, 0), 'RGB') mask.save(seg_filename) #np.concatenate((torch.zeros(1, 473, 473), mask, torch.zeros(1, 512, 512)),axis = 0) #save_image(output1 * 0.8 + target.data, args.vis_save_dir, normalize=True) elif args.dataset == 'cityscapes': output = output.transpose(1, 2, 0) output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8) if args.save_segimage: output_color = cityscapes_colorize_mask(output) output = Image.fromarray(output) output.save('%s/%s.png' % (args.seg_save_dir, name[0])) output_color.save('%s/%s_color.png' % (args.seg_save_dir, name[0])) else: print("dataset error")
def main(): print("=====> Configure dataset and pretrained model") configure_dataset_init_model(args) print(args) print(" current dataset: ", args.dataset) print(" init model: ", args.restore_from) print("=====> Set GPU for training") if args.cuda: print("====> Use gpu id: '{}'".format(args.gpus)) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus if not torch.cuda.is_available(): raise Exception("No GPU found or Wrong gpu id, please run without --cuda") # Select which GPU, -1 if CPU #gpu_id = args.gpus #device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu") print("=====> Random Seed: ", args.random_seed) torch.manual_seed(args.random_seed) if args.cuda: torch.cuda.manual_seed(args.random_seed) h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True print("=====> Building network") saved_state_dict = torch.load(args.restore_from) # saved_state_dict = torch.load(args.restore_from, map_location='cpu') #### model = CoattentionNet(num_classes=args.num_classes) #print(model) new_params = model.state_dict().copy() for i in saved_state_dict["model"]: #Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # 针对多GPU的情况 #i_parts.pop(1) #print('i_parts: ', '.'.join(i_parts[1:-1])) #if not i_parts[1]=='main_classifier': #and not '.'.join(i_parts[1:-1]) == 'layer5.bottleneck' and not '.'.join(i_parts[1:-1]) == 'layer5.bn': #init model pretrained on COCO, class name=21, layer5 is ASPP new_params['encoder'+'.'+'.'.join(i_parts[1:])] = saved_state_dict["model"][i] #print('copy {}'.format('.'.join(i_parts[1:]))) print("=====> Loading init weights, pretrained COCO for VOC2012, and pretrained Coarse cityscapes for cityscapes") model.load_state_dict(new_params) #只用到resnet的第5个卷积层的参数 #print(model.keys()) if args.cuda: #model.to(device) if torch.cuda.device_count()>1: print("torch.cuda.device_count()=",torch.cuda.device_count()) model = torch.nn.DataParallel(model).cuda() #multi-card data parallel else: print("single GPU for training") model = model.cuda() #1-card data parallel start_epoch=0 print("=====> Whether resuming from a checkpoint, for continuing training") if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint["epoch"] model.load_state_dict(checkpoint["model"]) else: print("=> no checkpoint found at '{}'".format(args.resume)) model.train() cudnn.benchmark = True if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) print('=====> Computing network parameters') total_paramters = netParams(model) print('Total network parameters: ' + str(total_paramters)) print("=====> Preparing training data") if args.dataset == 'voc12': trainloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, max_iters=None, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=args.img_mean), batch_size= args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True) elif args.dataset == 'cityscapes': trainloader = data.DataLoader(CityscapesDataSet(args.data_dir, args.data_list, max_iters=None, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=args.img_mean), batch_size = args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True) elif args.dataset == 'davis': #for davis 2016 db_train = db.PairwiseImg(train=True, inputRes=input_size, db_root_dir=args.data_dir, img_root_dir=args.img_dir, transform=None) #db_root_dir() --> '/path/to/DAVIS-2016' train path # trainloader = data.DataLoader(db_train, batch_size= args.batch_size, shuffle=True, num_workers=0) trainloader = data.DataLoader(db_train, batch_size= args.batch_size, shuffle=True, num_workers=0, drop_last=True) #### else: print("dataset error") optimizer = optim.SGD([{'params': get_1x_lr_params(model), 'lr': 1*args.learning_rate }, #针对特定层进行学习,有些层不学习 {'params': get_10x_lr_params(model), 'lr': 10*args.learning_rate}], lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() logFileLoc = args.snapshot_dir + args.logFile if os.path.isfile(logFileLoc): logger = open(logFileLoc, 'a') else: logger = open(logFileLoc, 'w') logger.write("Parameters: %s" % (str(total_paramters))) logger.write("\n%s\t\t%s" % ('iter', 'Loss(train)\n')) logger.flush() print("=====> Begin to train") train_len=len(trainloader) print(" iteration numbers of per epoch: ", train_len) print(" epoch num: ", args.maxEpoches) print(" max iteration: ", args.maxEpoches*train_len) for epoch in range(start_epoch, int(args.maxEpoches)): np.random.seed(args.random_seed + epoch) for i_iter, batch in enumerate(trainloader,0): #i_iter from 0 to len-1 #print("i_iter=", i_iter, "epoch=", epoch) target, target_gt, search, search_gt = batch['target'], batch['target_gt'], batch['search'], batch['search_gt'] images, labels = batch['img'], batch['img_gt'] #print(labels.size()) images.requires_grad_() images = Variable(images).cuda() labels = Variable(labels.float().unsqueeze(1)).cuda() target.requires_grad_() target = Variable(target).cuda() target_gt = Variable(target_gt.float().unsqueeze(1)).cuda() search.requires_grad_() search = Variable(search).cuda() search_gt = Variable(search_gt.float().unsqueeze(1)).cuda() optimizer.zero_grad() lr = adjust_learning_rate(optimizer, i_iter+epoch*train_len, epoch, max_iter = args.maxEpoches * train_len) #print(images.size()) if i_iter%3 ==0: #对于静态图片的训练 pred1, pred2, pred3 = model(images, images) loss = 0.1*(loss_calc1(pred3, labels) + 0.8* loss_calc2(pred3, labels) ) loss.backward() else: pred1, pred2, pred3 = model(target, search) loss = loss_calc1(pred1, target_gt) + 0.8* loss_calc2(pred1, target_gt) + loss_calc1(pred2, search_gt) + 0.8* loss_calc2(pred2, search_gt)#class_balanced_cross_entropy_loss(pred, labels, size_average=False) loss.backward() optimizer.step() print("===> Epoch[{}]({}/{}): Loss: {:.10f} lr: {:.5f}".format(epoch, i_iter, train_len, loss.data, lr)) logger.write("Epoch[{}]({}/{}): Loss: {:.10f} lr: {:.5f}\n".format(epoch, i_iter, train_len, loss.data, lr)) logger.flush() print("=====> saving model") state={"epoch": epoch+1, "model": model.state_dict()} torch.save(state, osp.join(args.snapshot_dir, 'co_attention_'+str(args.dataset)+"_"+str(epoch)+'.pth')) end = timeit.default_timer() print( float(end-start)/3600, 'h') logger.write("total training time: {:.2f} h\n".format(float(end-start)/3600)) logger.close()
def main(): args = get_arguments() print("=====> Configure dataset and model") configure_dataset_model(args) print(args) model = CoattentionNet(num_classes=args.num_classes) saved_state_dict = torch.load(args.restore_from, map_location=lambda storage, loc: storage) #print(saved_state_dict.keys()) #model.load_state_dict({k.replace('pspmodule.',''):v for k,v in torch.load(args.restore_from)['state_dict'].items()}) model.load_state_dict(convert_state_dict(saved_state_dict["model"]) ) #convert_state_dict(saved_state_dict["model"]) model.eval() model.cuda() if args.dataset == 'voc12': testloader = data.DataLoader(VOCDataTestSet(args.data_dir, args.data_list, crop_size=(505, 505), mean=args.img_mean), batch_size=1, shuffle=False, pin_memory=True) interp = nn.Upsample(size=(505, 505), mode='bilinear') voc_colorize = VOCColorize() elif args.dataset == 'davis': #for davis 2016 db_test = db.PairwiseImg( train=False, inputRes=(473, 473), db_root_dir=args.data_dir, transform=None, seq_name=None, sample_range=args.sample_range ) #db_root_dir() --> '/path/to/DAVIS-2016' train path testloader = data.DataLoader(db_test, batch_size=1, shuffle=False, num_workers=0) #voc_colorize = VOCColorize() else: print("dataset error") data_list = [] if args.save_segimage: if not os.path.exists(args.seg_save_dir) and not os.path.exists( args.vis_save_dir): os.makedirs(args.seg_save_dir) os.makedirs(args.vis_save_dir) print("======> test set size:", len(testloader)) my_index = 0 old_temp = '' for index, batch in enumerate(testloader): print('%d processd' % (index)) target = batch['target'] #search = batch['search'] temp = batch['seq_name'] args.seq_name = temp[0] print(args.seq_name) if old_temp == args.seq_name: my_index = my_index + 1 else: my_index = 0 output_sum = 0 for i in range(0, args.sample_range): search = batch['search' + '_' + str(i)] search_im = search #print(search_im.size()) output = model( Variable(target, volatile=True).cuda(), Variable(search_im, volatile=True).cuda()) #print(output[0]) # output有两个 output_sum = output_sum + output[0].data[ 0, 0].cpu().numpy() #分割那个分支的结果 #np.save('infer'+str(i)+'.npy',output1) #output2 = output[1].data[0, 0].cpu().numpy() #interp' output1 = output_sum / args.sample_range first_image = np.array( Image.open(args.data_dir + '/JPEGImages/480p/blackswan/00000.jpg')) original_shape = first_image.shape output1 = cv2.resize(output1, (original_shape[1], original_shape[0])) mask = (output1 * 255).astype(np.uint8) #print(mask.shape[0]) mask = Image.fromarray(mask) if args.dataset == 'voc12': print(output.shape) print(size) output = output[:, :size[0], :size[1]] output = output.transpose(1, 2, 0) output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8) if args.save_segimage: seg_filename = os.path.join(args.seg_save_dir, '{}.png'.format(name[0])) color_file = Image.fromarray( voc_colorize(output).transpose(1, 2, 0), 'RGB') color_file.save(seg_filename) elif args.dataset == 'davis': save_dir_res = os.path.join(args.seg_save_dir, 'Results', args.seq_name) old_temp = args.seq_name if not os.path.exists(save_dir_res): os.makedirs(save_dir_res) if args.save_segimage: my_index1 = str(my_index).zfill(5) seg_filename = os.path.join(save_dir_res, '{}.png'.format(my_index1)) #color_file = Image.fromarray(voc_colorize(output).transpose(1, 2, 0), 'RGB') mask.save(seg_filename) #np.concatenate((torch.zeros(1, 473, 473), mask, torch.zeros(1, 512, 512)),axis = 0) #save_image(output1 * 0.8 + target.data, args.vis_save_dir, normalize=True) else: print("dataset error")