def main(): print('===> Loading datasets') test_set = get_eval_set(opt.data_dir, opt.test_dir, opt.sr_upscale_factor, opt.num_classes) test_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=1, shuffle=False) print('Building SR model ', opt.sr_model_name) if opt.sr_model_name == 'DBPN': sr_model = DBPN(num_channels=3, base_filter=64, feat=256, num_stages=7, scale_factor=opt.sr_upscale_factor) sr_model = torch.nn.DataParallel(sr_model, device_ids=gpus_list) model_name = os.path.join(opt.models_dir, exp_name, opt.sr_model) print(model_name) sr_model.load_state_dict( torch.load(model_name, map_location=lambda storage, loc: storage)) print('Pre-trained SR model is loaded.') else: sys.exit('Invalid SR network') print('Building SemSeg model', opt.seg_model_name) if opt.seg_model_name == 'segnet': seg_model = segnet(num_classes=opt.num_classes, in_channels=3) seg_model = torch.nn.DataParallel(seg_model, device_ids=gpus_list) model_name = os.path.join(opt.models_dir, exp_name, opt.seg_model) print(model_name) seg_model.load_state_dict(torch.load(model_name)) print('Pre-trained SemSeg model is loaded.') else: sys.exit('Invalid Semantic segmentation network') if cuda: sr_model = sr_model.cuda(gpus_list[0]) seg_model = seg_model.cuda(gpus_list[0]) check_mkdir(os.path.join('Results')) check_mkdir(os.path.join('Results', exp_name)) check_mkdir(os.path.join('Results', exp_name, 'segmentation')) check_mkdir(os.path.join('Results', exp_name, 'super-resolution')) check_mkdir(os.path.join('heat_maps')) check_mkdir(os.path.join('heat_maps', exp_name)) test(test_loader, sr_model, seg_model)
def __init__(self, upscale_factor, is_training=False, center=None, nf=64, nframes=3): # def __init__(self,upscale_factor, is_training=False): super(VRCNN, self).__init__() self.upscale_factor = upscale_factor self.center = nframes // 2 if center is None else center self.is_training = is_training # self.OFRnet = OFRnet(upscale_factor=upscale_factor, is_training=is_training) self.conv1 = nn.Conv2d(1, nf, 3, 1, 1, bias=True) #functools.partial(a,b,...) 固定函数a中某些参数的值(从左到右顺序固定),然后返回一个新的函数 ResidualBlock_noBN_f = functools.partial(ResidualBlock_noBN, nf=nf) self.feature_extraction = make_layer(ResidualBlock_noBN_f, 3) self.pcd_align = PCD_Align() self.SRnet = SRnet(upscale_factor=upscale_factor, is_training=is_training) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) self.fea_L2_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) self.fea_L2_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.fea_L3_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) self.fea_L3_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #减少通道数,从33-3 self.reduce = nn.Conv2d(33, 3, 3, 1, 1, bias=False) #sisr-upsample self.dbpn = DBPN(num_channels=1, base_filter=64, feat=256, num_stages=7, scale_factor=upscale_factor) self.conv_final1 = nn.Conv2d(1, 1, 3, 1, 1, bias=True) self.conv_final2 = nn.Conv2d(1, 1, 3, 1, 1, bias=True) self.conv_final3 = nn.Conv2d(1, 1, 3, 1, 1, bias=True) #注意力 self.sa_ca = spatial_channel_attention()
if opt.model_type == 'DBPNLL': model = DBPNLL(num_channels=3, base_filter=64, feat=256, num_stages=10, scale_factor=opt.upscale_factor) ###D-DBPN elif opt.model_type == 'DBPN-RES-MR64-3': model = DBPNITER(num_channels=3, base_filter=64, feat=256, num_stages=3, scale_factor=opt.upscale_factor) ###D-DBPN else: model = DBPN(num_channels=3, base_filter=64, feat=256, num_stages=7, scale_factor=opt.upscale_factor) ###D-DBPN if cuda: model = torch.nn.DataParallel(model, device_ids=gpus_list) model.load_state_dict( torch.load(opt.model, map_location=lambda storage, loc: storage)) print('Pre-trained SR model is loaded.') if cuda: model = model.cuda(gpus_list[0]) def eval():
def main(): print('===> Loading datasets') train_set = get_training_set(opt.data_dir, opt.train_dir, opt.patch_size, opt.sr_patch_size, opt.sr_upscale_factor, opt.num_classes, opt.sr_data_augmentation) if opt.val_dir != None: val_set = get_eval_set(opt.data_dir, opt.val_dir, opt.sr_upscale_factor, opt.num_classes) train_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size) val_loader = DataLoader(dataset=val_set, num_workers=opt.threads, batch_size=1) else: # Creating data indices for training and validation splits: validation_split = .2 dataset_size = len(train_set) indices = list(range(dataset_size)) split = int(np.floor(validation_split * dataset_size)) np.random.seed(opt.seed) np.random.shuffle(indices) train_indices, val_indices = indices[split:], indices[:split] train_sampler = SubsetRandomSampler(train_indices) val_sampler = SubsetRandomSampler(val_indices) train_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size, sampler=train_sampler) val_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=1, sampler=val_sampler) print('Building SR model ', opt.sr_model_name) if opt.sr_model_name == 'DBPN': sr_model = DBPN(num_channels=3, base_filter=64, feat=256, num_stages=7, scale_factor=opt.sr_upscale_factor) sr_model = torch.nn.DataParallel(sr_model, device_ids=gpus_list) if opt.sr_pretrained: model_name = os.path.join(opt.save_folder + opt.sr_pretrained_model) print(model_name) sr_model.load_state_dict( torch.load(model_name, map_location=lambda storage, loc: storage)) print('Pre-trained SR model is loaded.') else: sys.exit('Invalid SR network') print('Building SemSeg model', opt.seg_model_name) if opt.seg_model_name == 'segnet': seg_model = segnet(num_classes=opt.num_classes, in_channels=3) if not opt.seg_pretrained: seg_model.init_vgg16_params() print('segnet params initialized') seg_model = torch.nn.DataParallel(seg_model, device_ids=gpus_list) if opt.seg_pretrained: model_name = os.path.join(opt.save_folder + opt.seg_pretrained_model) print(model_name) seg_model.load_state_dict(torch.load(model_name)) print('Pre-trained SemSeg model is loaded.') seg_model = torch.nn.DataParallel(seg_model, device_ids=gpus_list) sr_criterion = nn.L1Loss() psnr_criterion = nn.MSELoss() if cuda: sr_model = sr_model.cuda(gpus_list[0]) seg_model = seg_model.cuda(gpus_list[0]) sr_criterion = sr_criterion.cuda(gpus_list[0]) psnr_criterion = psnr_criterion.cuda(gpus_list[0]) if 'grss' in opt.data_dir: seg_criterion = CrossEntropyLoss2d(ignore_index=-1).cuda() else: seg_criterion = CrossEntropyLoss2d().cuda() sr_optimizer = optim.Adam(sr_model.parameters(), lr=opt.sr_lr, betas=(0.9, 0.999), eps=1e-8) seg_optimizer = optim.Adam(seg_model.parameters(), lr=opt.seg_lr, weight_decay=opt.seg_weight_decay, betas=(opt.seg_momentum, 0.99)) scheduler = ReduceLROnPlateau(seg_optimizer, 'min', factor=0.5, patience=opt.seg_lr_patience, min_lr=2.5e-5, verbose=True) check_mkdir(os.path.join('outputs', exp_name)) check_mkdir(os.path.join('outputs', exp_name, 'segmentation')) check_mkdir(os.path.join('outputs', exp_name, 'super-resolution')) check_mkdir(os.path.join(opt.save_folder, exp_name)) #best_iou = 0 best_iou = val_results = validate(0, val_loader, sr_model, seg_model, sr_criterion, psnr_criterion, seg_criterion, sr_optimizer, seg_optimizer) #sys.exit() #best_epoch = -1 best_epoch = 0 best_model = (sr_model, seg_model) since_last_best = 0 for epoch in range(opt.start_iter, opt.epoch_num + 1): train(epoch, train_loader, sr_model, seg_model, sr_criterion, psnr_criterion, seg_criterion, sr_optimizer, seg_optimizer) val_results = validate(epoch, val_loader, sr_model, seg_model, sr_criterion, psnr_criterion, seg_criterion, sr_optimizer, seg_optimizer) if val_results > best_iou: best_iou = val_results best_epoch = epoch print('New best iou ', best_iou) best_model = (copy.deepcopy(sr_model), copy.deepcopy(seg_model)) since_last_best = 0 checkpoint(epoch, sr_model, seg_model, 'tmp_best') else: print('Best iou epoch: ', best_epoch, ':', best_iou) scheduler.step(val_results) if (epoch) % (opt.epoch_num / 2) == 0: for param_group in sr_optimizer.param_groups: param_group['lr'] /= 10.0 print('SR Learning rate decay: lr={}'.format( sr_optimizer.param_groups[0]['lr'])) if (epoch) % (opt.snapshots) == 0: checkpoint(epoch, sr_model, seg_model) #since_last_best += 1 #if since_last_best == 20: # checkpoint(epoch, best_model[0], best_model[1], 'tmp_best') print('Saving final best model') checkpoint(epoch, best_model[0], best_model[1], 'best')