def __init__(self, color=True, burst_length=8, blind_est=False, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode='bilinear', core_bias=False): super(Att_KPN_noise, self).__init__() self.Att_KPN = Att_KPN(color=color, burst_length=burst_length, blind_est=blind_est, kernel_size=kernel_size, sep_conv=sep_conv, channel_att=channel_att, spatial_att=spatial_att, upMode=upMode, core_bias=core_bias) self.noise_estimate = NoiseEstimate(color=color)
def __init__(self, color=True, burst_length=8, blind_est=False, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode='bilinear', core_bias=False, in_channel=3): super(Att_KPN_DGF, self).__init__() self.Att_KPN = Att_KPN( color=color, burst_length=burst_length, blind_est=blind_est, kernel_size=kernel_size, sep_conv=sep_conv, channel_att=channel_att, spatial_att=spatial_att, upMode=upMode, core_bias=core_bias, in_channel=in_channel * burst_length, )
def test_multi(args): color = True burst_length = 8 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.model_type == "attKPN": model = Att_KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "attWKPN": model = Att_Weight_KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "KPN": model = KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) else: print(" Model type not valid") return # model2 = KPN( # color=color, # burst_length=burst_length, # blind_est=True, # kernel_size=[5], # sep_conv=False, # channel_att=False, # spatial_att=False, # upMode="bilinear", # core_bias=False # ) checkpoint_dir = "checkpoints/" + args.checkpoint if not os.path.exists(checkpoint_dir) or len( os.listdir(checkpoint_dir)) == 0: print('There is no any checkpoint file in path:{}'.format( checkpoint_dir)) # load trained model ckpt = load_checkpoint(checkpoint_dir, cuda=device == 'cuda', best_or_latest=args.load_type) state_dict = ckpt['state_dict'] # if not args.cuda: new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) # else: # model.load_state_dict(ckpt['state_dict']) ############################################# # checkpoint_dir = "checkpoints/" + "kpn" # if not os.path.exists(checkpoint_dir) or len(os.listdir(checkpoint_dir)) == 0: # print('There is no any checkpoint file in path:{}'.format(checkpoint_dir)) # # load trained model # ckpt = load_checkpoint(checkpoint_dir,cuda=device=='cuda') # state_dict = ckpt['state_dict'] # new_state_dict = OrderedDict() # if not args.cuda: # for k, v in state_dict.items(): # name = k[7:] # remove `module.` # new_state_dict[name] = v # # model.load_state_dict(ckpt['state_dict']) # model2.load_state_dict(new_state_dict) ########################################### print('The model has been loaded from epoch {}, n_iter {}.'.format( ckpt['epoch'], ckpt['global_iter'])) # switch the eval mode model.to(device) model.eval() # model2.eval() # model= save_dict['state_dict'] trans = transforms.ToPILImage() torch.manual_seed(0) noisy_path = sorted(glob.glob(args.noise_dir + "/*.png")) clean_path = [i.replace("noisy", "clean") for i in noisy_path] for i in range(len(noisy_path)): image_noise = load_data(noisy_path[i], burst_length) begin = time.time() image_noise_batch = image_noise.to(device) # print(image_noise.size()) # print(image_noise_batch.size()) burst_noise = image_noise_batch.to(device) if color: b, N, c, h, w = burst_noise.size() feedData = burst_noise.view(b, -1, h, w) else: feedData = burst_noise # print(feedData.size()) pred_i, pred = model(feedData, burst_noise[:, 0:burst_length, ...]) del pred_i # pred_i2, pred2 = model2(feedData, burst_noise[:, 0:burst_length, ...]) # print("Time : ", time.time()-begin) pred = pred.detach().cpu() gt = transforms.ToTensor()(Image.open(clean_path[i]).convert('RGB')) # print(pred_i.size()) # print(pred.size()) # print(gt.size()) gt = gt.unsqueeze(0) _, _, h_hr, w_hr = gt.size() _, _, h_lr, w_lr = pred.size() gt_down = F.interpolate(gt, (h_lr, w_lr), mode='bilinear', align_corners=True) pred_up = F.interpolate(pred, (h_hr, w_hr), mode='bilinear', align_corners=True) # print("After interpolate") # print(pred_up.size()) # print(gt_down.size()) psnr_t_up = calculate_psnr(pred_up, gt) ssim_t_up = calculate_ssim(pred_up, gt) psnr_t_down = calculate_psnr(pred, gt_down) ssim_t_down = calculate_ssim(pred, gt_down) print(i, " UP : PSNR : ", str(psnr_t_up), " : SSIM : ", str(ssim_t_up), " : DOWN : PSNR : ", str(psnr_t_down), " : SSIM : ", str(ssim_t_down)) if args.save_img != '': if not os.path.exists(args.save_img): os.makedirs(args.save_img) plt.figure(figsize=(15, 15)) plt.imshow(np.array(trans(pred_up[0]))) plt.title("denoise KPN split " + args.model_type, fontsize=25) image_name = noisy_path[i].split("/")[-1].split(".")[0] plt.axis("off") plt.suptitle(image_name + " UP : PSNR : " + str(psnr_t_up) + " : SSIM : " + str(ssim_t_up), fontsize=25) plt.savefig(os.path.join( args.save_img, image_name + "_" + args.checkpoint + '.png'), pad_inches=0) # print(np.array(trans(mf8[0]))) """
def eval(args): color = args.color print('Eval Process......') burst_length = 8 # print(args.checkpoint) checkpoint_dir = "checkpoints/" + args.checkpoint if not os.path.exists(checkpoint_dir) or len( os.listdir(checkpoint_dir)) == 0: print('There is no any checkpoint file in path:{}'.format( checkpoint_dir)) # the path for saving eval images eval_dir = "eval_img" if not os.path.exists(eval_dir): os.mkdir(eval_dir) # dataset and dataloader data_set = MultiLoader(noise_dir=args.noise_dir, gt_dir=args.gt_dir, image_size=args.image_size) data_loader = DataLoader(data_set, batch_size=1, shuffle=False, num_workers=args.num_workers) # model here if args.model_type == "attKPN": model = Att_KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) elif args.model_type == "attWKPN": model = Att_Weight_KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) elif args.model_type == "KPN": model = KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) else: print(" Model type not valid") return if args.cuda: model = model.cuda() if args.mGPU: model = nn.DataParallel(model) # load trained model ckpt = load_checkpoint(checkpoint_dir, cuda=args.cuda) state_dict = ckpt['state_dict'] if not args.mGPU: new_state_dict = OrderedDict() if not args.cuda: for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) else: model.load_state_dict(ckpt['state_dict']) print('The model has been loaded from epoch {}, n_iter {}.'.format( ckpt['epoch'], ckpt['global_iter'])) # torch.save(model.state_dict(), "model_state.pth") # exit(0) # switch the eval mode model.eval() # data_loader = iter(data_loader) trans = transforms.ToPILImage() with torch.no_grad(): psnr = 0.0 ssim = 0.0 torch.manual_seed(0) for i, (burst_noise, gt) in enumerate(data_loader): if i < 100: # data = next(data_loader) if args.cuda: burst_noise = burst_noise.cuda() gt = gt.cuda() if color: b, N, c, h, w = burst_noise.size() feedData = burst_noise.view(b, -1, h, w) else: feedData = burst_noise pred_i, pred = model(feedData, burst_noise[:, 0:burst_length, ...]) if not color: psnr_t = calculate_psnr(pred.unsqueeze(1), gt.unsqueeze(1)) ssim_t = calculate_ssim(pred.unsqueeze(1), gt.unsqueeze(1)) psnr_noisy = calculate_psnr( burst_noise[:, 0, ...].unsqueeze(1), gt.unsqueeze(1)) else: psnr_t = calculate_psnr(pred, gt) ssim_t = calculate_ssim(pred, gt) psnr_noisy = calculate_psnr(burst_noise[:, 0, ...], gt) psnr += psnr_t ssim += ssim_t pred = torch.clamp(pred, 0.0, 1.0) if args.cuda: pred = pred.cpu() gt = gt.cpu() burst_noise = burst_noise.cpu() if args.save_img: trans(burst_noise[0, 0, ...].squeeze()).save(os.path.join( eval_dir, '{}_noisy_{:.2f}dB.png'.format(i, psnr_noisy)), quality=100) trans(pred.squeeze()).save(os.path.join( eval_dir, '{}_pred_{:.2f}dB.png'.format(i, psnr_t)), quality=100) trans(gt.squeeze()).save(os.path.join( eval_dir, '{}_gt.png'.format(i)), quality=100) print('{}-th image is OK, with PSNR: {:.2f} , SSIM: {:.4f}'. format(i, psnr_t, ssim_t)) else: break
def test_multi(dir, image_size, args): num_workers = 1 batch_size = 1 color = True burst_length = 8 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.model_type == "attKPN": model = Att_KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) elif args.model_type == "attWKPN": model = Att_Weight_KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) elif args.model_type == "KPN": model = KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) else: print(" Model type not valid") return model2 = KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) checkpoint_dir = "checkpoints/" + args.checkpoint if not os.path.exists(checkpoint_dir) or len( os.listdir(checkpoint_dir)) == 0: print('There is no any checkpoint file in path:{}'.format( checkpoint_dir)) # load trained model ckpt = load_checkpoint(checkpoint_dir, cuda=device == 'cuda') state_dict = ckpt['state_dict'] new_state_dict = OrderedDict() # if not args.mGPU: for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # model.load_state_dict(ckpt['state_dict']) model.load_state_dict(new_state_dict) checkpoint_dir = "checkpoints/" + "kpn" if not os.path.exists(checkpoint_dir) or len( os.listdir(checkpoint_dir)) == 0: print('There is no any checkpoint file in path:{}'.format( checkpoint_dir)) # load trained model ckpt = load_checkpoint(checkpoint_dir, cuda=device == 'cuda', best_or_latest=args.load_type) state_dict = ckpt['state_dict'] if not args.cuda: new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model2.load_state_dict(new_state_dict) else: model.load_state_dict(ckpt['state_dict']) print('The model has been loaded from epoch {}, n_iter {}.'.format( ckpt['epoch'], ckpt['global_iter'])) # switch the eval mode model.to(device) model2.to(device) model.eval() model2.eval() # model= save_dict['state_dict'] trans = transforms.ToPILImage() torch.manual_seed(0) for i in range(10): image_noise = load_data(dir, image_size, burst_length) begin = time.time() image_noise_batch = image_noise.to(device) print(image_noise_batch.size()) burst_size = image_noise_batch.size()[1] burst_noise = image_noise_batch.to(device) if color: b, N, c, h, w = burst_noise.size() feedData = burst_noise.view(b, -1, h, w) else: feedData = burst_noise # print(feedData.size()) pred_i, pred = model(feedData, burst_noise[:, 0:burst_length, ...]) pred_i2, pred2 = model2(feedData, burst_noise[:, 0:burst_length, ...]) pred = pred.detach().cpu() pred2 = pred2.detach().cpu() print("Time : ", time.time() - begin) print(pred_i.size()) print(pred.size()) if args.save_img != '': # print(np.array(trans(mf8[0]))) plt.figure(figsize=(10, 3)) plt.subplot(1, 3, 1) plt.imshow(np.array(trans(pred[0]))) plt.title("denoise attKPN") plt.subplot(1, 3, 2) plt.imshow(np.array(trans(pred2[0]))) plt.title("denoise KPN") # plt.show() plt.subplot(1, 3, 3) plt.imshow(np.array(trans(image_noise[0][0]))) plt.title("noise ") image_name = str(i) plt.savefig(os.path.join( args.save_img, image_name + "_" + args.checkpoint + '.png'), pad_inches=0)
def eval(args): color = True burst_length = args.burst_length device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.model_type == "attKPN": model = Att_KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "attKPN_Wave": model = Att_KPN_Wavelet(color=color, burst_length=1, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "attWKPN": model = Att_Weight_KPN(color=color, burst_length=1, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "KPN": model = KPN(color=color, burst_length=1, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) else: print(" Model type not valid") return checkpoint_dir = "checkpoints/" + args.checkpoint if not os.path.exists(checkpoint_dir) or len( os.listdir(checkpoint_dir)) == 0: print('There is no any checkpoint file in path:{}'.format( checkpoint_dir)) # load trained model ckpt = load_checkpoint(checkpoint_dir, cuda=device == 'cuda', best_or_latest=args.load_type) state_dict = ckpt['state_dict'] # if not args.cuda: new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) # else: # model.load_state_dict(ckpt['state_dict']) model.to(device) print('The model has been loaded from epoch {}, n_iter {}.'.format( ckpt['epoch'], ckpt['global_iter'])) # switch the eval mode model.eval() # model= save_dict['state_dict'] trans = transforms.ToPILImage() torch.manual_seed(0) all_noisy_imgs = scipy.io.loadmat( args.noise_dir)['siddplus_valid_noisy_srgb'] all_clean_imgs = scipy.io.loadmat(args.gt_dir)['siddplus_valid_gt_srgb'] i_imgs, _, _, _ = all_noisy_imgs.shape psnrs = [] ssims = [] for i_img in range(i_imgs): image_noise = transforms.ToTensor()(Image.fromarray( all_noisy_imgs[i_img])) image_noise_lr, image_noise_hr = load_data(image_noise, burst_length) burst_noise = image_noise_lr[:, 0:1, :, :, :].to(device) if color: b, N, c, h, w = burst_noise.size() feedData = burst_noise.view(b, -1, h, w) else: feedData = burst_noise # print(feedData.size()) _, pred = model(feedData, burst_noise) pred = pred.detach().cpu() # print("Time : ", time.time()-begin) gt = transforms.ToTensor()(Image.fromarray(all_clean_imgs[i_img])) image_gt_lr, image_gt_hr = load_data(gt, burst_length) gt = image_gt_lr[:, 0, :, :, :].to(device) # print(pred_i.size()) # print(pred[0].size()) psnr_t = calculate_psnr(pred, gt) ssim_t = calculate_ssim(pred, gt) psnrs.append(psnr_t) ssims.append(ssim_t) print(i_img, " UP : PSNR : ", str(psnr_t), " : SSIM : ", str(ssim_t)) if args.save_img != '': if not os.path.exists(args.save_img): os.makedirs(args.save_img) plt.figure(figsize=(15, 15)) plt.imshow(np.array(trans(pred[0]))) plt.title("denoise KPN DGF " + args.model_type, fontsize=25) image_name = str(i_img) plt.axis("off") plt.suptitle(image_name + " UP : PSNR : " + str(psnr_t) + " : SSIM : " + str(ssim_t), fontsize=25) plt.savefig(os.path.join( args.save_img, image_name + "_" + args.checkpoint + '.png'), pad_inches=0) """ if args.save_img: # print(np.array(trans(mf8[0]))) plt.figure(figsize=(30, 9)) plt.subplot(1,3,1) plt.imshow(np.array(trans(pred[0]))) plt.title("denoise DGF "+args.model_type, fontsize=26) plt.subplot(1,3,2) plt.imshow(np.array(trans(gt[0]))) plt.title("gt ", fontsize=26) plt.subplot(1,3,3) plt.imshow(np.array(trans(image_noise_hr[0]))) plt.title("noise ", fontsize=26) plt.axis("off") plt.suptitle(str(i)+" UP : PSNR : "+ str(psnr_t)+" : SSIM : "+ str(ssim_t), fontsize=26) plt.savefig("checkpoints/22_DGF_" + args.checkpoint+str(i)+'.png',pad_inches=0) """ print(" AVG : PSNR : " + str(np.mean(psnrs)) + " : SSIM : " + str(np.mean(ssims)))
def train(num_workers, cuda, restart_train, mGPU): # torch.set_num_threads(num_threads) color = True batch_size = args.batch_size lr = 2e-4 lr_decay = 0.89125093813 n_epoch = args.epoch # num_workers = 8 save_freq = args.save_every loss_freq = args.loss_every lr_step_size = 100 burst_length = args.burst_length # checkpoint path checkpoint_dir = "checkpoints/" + args.checkpoint if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # logs path logs_dir = "checkpoints/logs/" + args.checkpoint if not os.path.exists(logs_dir): os.makedirs(logs_dir) shutil.rmtree(logs_dir) log_writer = SummaryWriter(logs_dir) # dataset and dataloader if args.data_type == 'real': data_set = SingleLoader_DGF(noise_dir=args.noise_dir, gt_dir=args.gt_dir, image_size=args.image_size, burst_length=burst_length) elif args.data_type == "synth": data_set = SingleLoader_DGF_synth(gt_dir=args.gt_dir, image_size=args.image_size, burst_length=burst_length) else: print("Wrong type data") return data_loader = DataLoader(data_set, batch_size=batch_size, shuffle=True, num_workers=num_workers) # model here if args.model_type == "attKPN": model = Att_KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "attKPN_Wave": model = Att_KPN_Wavelet(color=color, burst_length=1, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "attWKPN": model = Att_Weight_KPN(color=color, burst_length=1, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "KPN": model = KPN(color=color, burst_length=1, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) else: print(" Model type not valid") return if cuda: model = model.cuda() if mGPU: model = nn.DataParallel(model) model.train() # loss function here # loss_func = LossFunc( # coeff_basic=1.0, # coeff_anneal=1.0, # gradient_L1=True, # alpha=0.9998, # beta=100.0 # ) loss_func = LossBasic() loss_func_i = LossAnneal_i() if args.wavelet_loss: print("Use wavelet loss") loss_func2 = WaveletLoss() # Optimizer here optimizer = optim.Adam(model.parameters(), lr=lr) optimizer.zero_grad() # learning rate scheduler here scheduler = lr_scheduler.StepLR(optimizer, step_size=lr_step_size, gamma=lr_decay) average_loss = MovingAverage(save_freq) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not restart_train: try: checkpoint = load_checkpoint(checkpoint_dir, cuda=device == 'cuda', best_or_latest=args.load_type) start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] best_loss = checkpoint['best_loss'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['lr_scheduler']) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) except: start_epoch = 0 global_step = 0 best_loss = np.inf print('=> no checkpoint file to be loaded.') else: start_epoch = 0 global_step = 0 best_loss = np.inf if os.path.exists(checkpoint_dir): pass # files = os.listdir(checkpoint_dir) # for f in files: # os.remove(os.path.join(checkpoint_dir, f)) else: os.mkdir(checkpoint_dir) print('=> training') for epoch in range(start_epoch, n_epoch): epoch_start_time = time.time() # decay the learning rate # print('='*20, 'lr={}'.format([param['lr'] for param in optimizer.param_groups]), '='*20) t1 = time.time() for step, (image_noise_hr, image_noise_lr, image_gt_hr, image_gt_lr) in enumerate(data_loader): # print(burst_noise.size()) # print(gt.size()) if cuda: burst_noise = image_noise_lr[:, 0:1, :, :, :].cuda() # gt = image_gt_hr.cuda() gt = image_gt_lr[:, 0, :, :, :].cuda() # image_noise_hr = image_noise_hr.cuda() else: burst_noise = image_noise_lr[:, 0:1, :, :, :] gt = image_gt_lr[:, 0, :, :, :] if color: b, N, c, h, w = burst_noise.size() # print(image_noise_lr.size()) feedData = burst_noise.view(b, -1, h, w) else: feedData = image_noise_lr # print('white_level', white_level, white_level.size()) # print("feedData : ",feedData.size()) # print("burst_noise : ",burst_noise.size()) # pred_i, pred = model(feedData, burst_noise) # # loss_basic, loss_anneal = loss_func(pred_i, pred, gt, global_step) # print(pred.size()) # print(gt.size()) loss_basic = loss_func(pred, gt) # loss_i =loss_func_i(global_step, pred_i, image_gt_lr) loss = loss_basic if args.wavelet_loss: loss_wave = loss_func2(pred, gt) # print(loss_wave) loss = loss_basic + loss_wave # backward optimizer.zero_grad() loss.backward() optimizer.step() # update the average loss average_loss.update(loss) # global_step if not color: pred = pred.unsqueeze(1) gt = gt.unsqueeze(1) if global_step % loss_freq == 0: # calculate PSNR # print("burst_noise : ",burst_noise.size()) # print("gt : ",gt.size()) # print("feedData : ", feedData.size()) psnr = calculate_psnr(pred, gt) ssim = calculate_ssim(pred, gt) # add scalars to tensorboardX log_writer.add_scalar('loss_basic', loss_basic, global_step) # log_writer.add_scalar('loss_anneal', loss_anneal, global_step) log_writer.add_scalar('loss_total', loss, global_step) log_writer.add_scalar('psnr', psnr, global_step) log_writer.add_scalar('ssim', ssim, global_step) # print print( '{:-4d}\t| epoch {:2d}\t| step {:4d}\t| loss_basic: {:.4f}\t|' ' loss: {:.4f}\t| PSNR: {:.2f}dB\t| SSIM: {:.4f}\t| time:{:.2f} seconds.' .format(global_step, epoch, step, loss_basic, loss, psnr, ssim, time.time() - t1)) t1 = time.time() if global_step % save_freq == 0: if average_loss.get_value() < best_loss: is_best = True best_loss = average_loss.get_value() else: is_best = False save_dict = { 'epoch': epoch, 'global_iter': global_step, 'state_dict': model.state_dict(), 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), 'lr_scheduler': scheduler.state_dict() } save_checkpoint(save_dict, is_best, checkpoint_dir, global_step, max_keep=10) print( 'Save : {:-4d}\t| epoch {:2d}\t| step {:4d}\t| loss_basic: {:.4f}\t|' ' loss: {:.4f}'.format(global_step, epoch, step, loss_basic, loss)) global_step += 1 print('Epoch {} is finished, time elapsed {:.2f} seconds.'.format( epoch, time.time() - epoch_start_time)) lr_cur = [param['lr'] for param in optimizer.param_groups] if lr_cur[0] > 5e-6: scheduler.step() else: for param in optimizer.param_groups: param['lr'] = 5e-6