def main(args): transform = getTransforms() data_path = args.input_data if not os.path.exists(data_path): print('ERROR: No dataset named {}'.format(data_path)) exit(1) dataset = EvalDataset(data_path, transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) with open(args.class_list, 'r') as class_file: class_names = [] for class_name in class_file.readlines(): if len(class_name.strip()) > 0: class_names.append(class_name.strip()) model = ResNet(num_layers=18, num_classes=len(class_names)).to(DEVICE) model = model.eval() output_dir = os.path.join(data_path, 'out') os.makedirs(output_dir, exist_ok=True) model_file = args.model_file if os.path.exists(model_file): checkpoint = torch.load(model_file) if 'state_dict' in checkpoint.keys(): model.load_state_dict(checkpoint['state_dict'], strict=False) else: model.load_state_dict(checkpoint, strict=False) print('=> loaded {}'.format(model_file)) else: print('model_file "{}" does not exists.'.format(model_file)) exit(1) font = cv2.FONT_HERSHEY_SIMPLEX with torch.no_grad(): for data, path in dataloader: outputs = model(data.to(DEVICE)) _, predicted = torch.max(outputs.data, 1) predicted = predicted.to('cpu')[0].item() class_text = class_names[predicted] print(class_text, path) image = cv2.imread(path[0], cv2.IMREAD_COLOR) image = cv2.rectangle(image, (0, 0), (150, 25), (255, 255, 255), -1) image = cv2.rectangle(image, (0, 0), (150, 25), (255, 0, 0), 2) cv2.putText(image, class_text, (5, 15), font, 0.5, ( 255, 0, ), 1, cv2.LINE_AA) cv2.imwrite(os.path.join(output_dir, os.path.basename(path[0])), image)
criterion = nn.MSELoss() optimizer = optim.Adam([{ 'params': model.first_part.parameters() }, { 'params': model.last_part.parameters(), 'lr': args.lr * 0.1 }], lr=args.lr) train_dataset = TrainDataset(args.train_file) train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) eval_dataset = EvalDataset(args.eval_file) eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1) best_weights = copy.deepcopy(model.state_dict()) best_epoch = 0 best_psnr = 0.0 for epoch in range(args.num_epochs): for param_group in optimizer.param_groups: param_group['lr'] = args.lr * (0.1**(epoch // int(args.num_epochs * 0.8))) model.train() epoch_losses = AverageMeter() with tqdm(total=(len(train_dataset) -
def main(): # Load dataset print('Loading dataset ...\n') #dataset_train = TrainDataset(1000, opt.batchSize, train=True) #dataset_val = ValDataset(train=False) dataset_train = TrainDataset('train_BSD500.h5', opt.patch_size, int(opt.upscale_factor[0])) print(len(dataset_train)) dataset_val = EvalDataset('test_BSD500.h5') loader_train = DataLoader(dataset=dataset_train, num_workers=1, batch_size=opt.batchSize, shuffle=True) loader_val = DataLoader(dataset=dataset_val, batch_size=1) print("# of training samples: %d\n" % int(len(dataset_train))) # Build model netG = make_model(opt) print('# generator parameters:', sum(param.numel() for param in netG.parameters())) #netG.apply(weights_init_kaiming) # content_criterion = nn.MSELoss() #feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained=True)) content_criterion = nn.L1Loss() # Move to GPU if torch.cuda.is_available(): netG.cuda() content_criterion.cuda() optim_rdn = optim.Adam(netG.parameters(), lr=opt.generatorLR) # Optimizer # training writer = SummaryWriter(opt.outf) step = 0 # noiseL_B=[0,55] # ingnored when opt.mode=='S' # Generator Pretraining(Using MSE Loss) for epoch in range(opt.epochs): mean_generator_content_loss = 0.0 mean_generator_PSNRs = 0.0 mean_generator_SSIMs = 0.0 for param_group in optim_rdn.param_groups: param_group['lr'] = opt.generatorLR * (0.1 ** (epoch // int(opt.epochs * 0.8))) for i, (lrimg, hrimg) in enumerate(loader_train): # adding noise for j in range(opt.batchSize): #noise = torch.FloatTensor(lrimg[j].size()).normal_(mean=0.0, std=opt.noiseL/255.) #lrimg[j] = lrimg[j] + noise lrimg[j] = lrimg[j] # Generate real and fake inputs if opt.cuda: high_res_real = Variable(hrimg.cuda()) high_res_fake = netG(Variable(lrimg).cuda()) else: high_res_real = Variable(hrimg) high_res_fake = netG(Variable(lrimg)) ######### Train generator ######### netG.zero_grad() generator_content_loss = content_criterion(high_res_fake, high_res_real) mean_generator_content_loss += generator_content_loss.data generator_content_loss.backward() optim_rdn.step() ######### Status and display ######### sys.stdout.write('\r[%d/%d][%d/%d] Generator_MSE_Loss: %.4f' % (epoch, opt.epochs, i, len(loader_train), generator_content_loss.data)) # visualizer.show(low_res, high_res_real.cpu().data, high_res_fake.cpu().data) out_train = torch.clamp(high_res_fake, 0., 1.) psnr_train, ssim_train = batch_PSNR(out_train, high_res_real, scale=3, data_range=255.0) mean_generator_PSNRs += psnr_train mean_generator_SSIMs += ssim_train if step % 10 == 0: # Log the scalar values writer.add_scalar('generator_content_loss', generator_content_loss.item(), step) #writer.add_scalar('PSNR on training data', psnr_train, step) #writer.add_scalar('SSIM on training data', ssim_train, step) step += 1 # sys.stdout.write('\r[%d/%d][%d/%d] PSNR: %.4f, SSIM:%.4f' % (epoch, 2, i, len(loader_train), psnr_train, ssim_train)) psnr_avg_train = mean_generator_PSNRs/len(loader_train) ssim_avg_train = mean_generator_SSIMs/len(loader_train) sys.stdout.write('\r[%d/%d][%d/%d] Generator_MSE_Loss: %.4f\n' % (epoch, opt.epochs, i, len(loader_train), mean_generator_content_loss/len(loader_train))) print("\n[epoch %d] PSNR_train: %.4f" % (epoch+1, psnr_avg_train)) print("\n[epoch %d] SSIM_train: %.4f" % (epoch+1, ssim_avg_train)) writer.add_scalar('PSNR on training data', psnr_avg_train, epoch) writer.add_scalar('SSIM on training data', ssim_avg_train, epoch) #log_value('generator_mse_loss', mean_generator_content_loss/len(dataloader), epoch) torch.save(netG.state_dict(), '%s/model/rdn_final_%d.pth'%(opt.outf,epoch)) ## the end of each epoch # netG.eval() # validate psnr_val = 0 ssim_val = 0.0 val_images = [] num = 0 numofex=opt.noiseL for index, (lrimg_val, hrimg_val) in enumerate(loader_val): #lrimg_val, hrimg_val = dataset_val[k] #noise = torch.FloatTensor(lrimg_val.size()).normal_(mean=0, std=opt.val_noiseL/255.) #lrimgn_val = lrimg_val + noise lrimgn_val = lrimg_val #lrimgn_val = torch.Tensor(np.expand_dims(lrimgn_val, axis=0)) #hrimg_val = torch.Tensor(np.expand_dims(hrimg_val, axis=0)) hrimg_val, lrimg_val = Variable(hrimg_val.cuda(), volatile=True), Variable(lrimgn_val.cuda(), volatile=True) out_val = netG(lrimg_val) psnr_val_e, ssim_val_e = batch_PSNR(out_val, hrimg_val, scale=3, data_range=255.0) psnr_val += psnr_val_e ssim_val += ssim_val_e hrimg_val = np.transpose(hrimg_val[0].detach().cpu().numpy(), (1,2,0)) out_val = np.transpose(out_val[0].detach().cpu().numpy(),(1,2,0)) if num<5: num+=1 # hrimg_val = hrimg_val[int(hrimg_val.shape[0] / 2) - 160:int(hrimg_val.shape[0] / 2) + 160, # int(hrimg_val.shape[1] / 2) - 160:int(hrimg_val.shape[1] / 2) + 160] # out_val = out_val[int(out_val.shape[0] / 2) - 160:int(out_val.shape[0] / 2) + 160, # int(out_val.shape[1] / 2) - 160:int(out_val.shape[1] / 2) + 160] val_images.extend([hrimg_val,out_val]) output_image=make_grid(val_images,nrow=2,nline=1) if not os.path.exists('%s/training_results/%d/' % (opt.outf, numofex)): os.makedirs('%s/training_results/%d/' % (opt.outf, numofex)) save_result(output_image,path='%s/training_results/%d/epoch%d.png' % (opt.outf,numofex,epoch)) psnr_val /= len(dataset_val) ssim_val /= len(dataset_val) print("\n[epoch %d] PSNR_val: %.4f" % (epoch+1, psnr_val)) print("\n[epoch %d] SSIM_val: %.4f" % (epoch+1, ssim_val)) writer.add_scalar('PSNR on validation data', psnr_val, epoch) writer.add_scalar('SSIM on validation data', ssim_val, epoch)
def train(self, args): with open(args.train_list, 'r') as train_list_file: self.train_list = [line.strip() for line in train_list_file.readlines()] self.eval_file = args.eval_file self.num_train_sentences = args.num_train_sentences self.batch_size = args.batch_size self.lr = args.lr self.max_epoch = args.max_epoch self.model_path = args.model_path self.log_path = args.log_path self.fig_path = args.fig_path self.eval_plot_num = args.eval_plot_num self.eval_steps = args.eval_steps self.resume_model = args.resume_model self.wav_path = args.wav_path self.train_wav_path = args.train_wav_path self.tool_path = args.tool_path # create a training dataset and an evaluation dataset trainSet = TrainingDataset(self.train_list, frame_size=self.frame_size, frame_shift=self.frame_shift) evalSet = EvalDataset(self.eval_file, self.num_test_sentences) # trainSet = evalSet # create data loaders for training and evaluation train_loader = DataLoader(trainSet, batch_size=self.batch_size, shuffle=True, num_workers=16, collate_fn=TrainCollate()) eval_loader = DataLoader(evalSet, batch_size=1, shuffle=False, num_workers=4, collate_fn=EvalCollate()) # create a network print('model', self.model_name) net = Net(device=self.device, L=self.frame_size, width=self.width) # net = torch.nn.DataParallel(net) net.to(self.device) print('Number of learnable parameters: %d' % numParams(net)) print(net) criterion = mse_loss() criterion1 = stftm_loss(device=self.device) optimizer = torch.optim.Adam(net.parameters(), lr=self.lr) self.lr_list = [0.0002] * 3 + [0.0001] * 6 + [0.00005] * 3 + [0.00001] * 3 if self.resume_model: print('Resume model from "%s"' % self.resume_model) checkpoint = Checkpoint() checkpoint.load(self.resume_model) start_epoch = checkpoint.start_epoch start_iter = checkpoint.start_iter best_loss = checkpoint.best_loss net.load_state_dict(checkpoint.state_dict) optimizer.load_state_dict(checkpoint.optimizer) else: print('Training from scratch.') start_epoch = 0 start_iter = 0 best_loss = np.inf num_train_batches = self.num_train_sentences // self.batch_size total_train_batch = self.max_epoch * num_train_batches print('num_train_sentences', self.num_train_sentences) print('batches_per_epoch', num_train_batches) print('total_train_batch', total_train_batch) print('batch_size', self.batch_size) print('model_name', self.model_name) batch_timings = 0. counter = int(start_epoch * num_train_batches + start_iter) counter1 = 0 print('counter', counter) ttime = 0. cnt = 0. iteration = 0 print('best_loss', best_loss) for epoch in range(start_epoch, self.max_epoch): accu_train_loss = 0.0 net.train() for param_group in optimizer.param_groups: param_group['lr'] = self.lr_list[epoch] start = timeit.default_timer() for i, (features, labels, nframes, feat_size, label_size, get_filename) in enumerate( train_loader): # features:torch.Size([4, 1, 250, 512]) iteration += 1 labels_cpu = labels i += start_iter features, labels = features.to(self.device), labels.to(self.device) # torch.Size([4, 1, 250, 512]) loss_mask = compLossMask(labels, nframes=nframes) # forward + backward + optimize optimizer.zero_grad() outputs = net(features) # torch.Size([4, 1, 64256]) feature_maker = Fbank(sample_rate=16000, n_fft=400, n_mels=40) loss_fbank = 0 for t in range(len(get_filename)): reader = h5py.File(get_filename[t], 'r') feature_asr = reader['noisy_raw'][:] label_asr = reader['clean_raw'][:] feat_asr_size = int(feat_size[t][0].item()) label_asr_size = int(label_size[t][0].item()) output_asr = self.train_asr_forward(feature_asr, net) est_output_asr = output_asr[:feat_asr_size] ideal_labels_asr = label_asr # 保存train的wav est_path = os.path.join(self.train_wav_path, '{}_est.wav'.format(t + 1)) ideal_path = os.path.join(self.train_wav_path, '{}_ideal.wav'.format(t + 1)) sf.write(est_path, normalize_wav(est_output_asr)[0], self.srate) sf.write(ideal_path, normalize_wav(ideal_labels_asr)[0], self.srate) # read wav est_sig = sb.dataio.dataio.read_audio(est_path).unsqueeze(axis=0).to(self.device) ideal_sig = sb.dataio.dataio.read_audio(ideal_path).unsqueeze(axis=0).to(self.device) est_sig_feats = feature_maker(est_sig) ideal_sig_feats = feature_maker(ideal_sig) # fbank_loss loss_fbank += F.mse_loss(est_sig_feats, ideal_sig_feats, True) loss_fbank /= 100 * len(get_filename) # print(loss_fbank) # loss_fbank = 1 / (1 + math.exp(loss_fbank)) outputs = outputs[:, :, :labels.shape[-1]] loss1 = criterion(outputs, labels, loss_mask, nframes) loss2 = criterion1(outputs, labels, loss_mask, nframes) # print(loss1) # print(loss2) # loss = 0.8 * loss1 + 0.2 * loss2 loss = 0.4 * loss1 + 0.1 * loss2 + 0.5 * loss_fbank loss.backward() optimizer.step() # calculate losses running_loss = loss.data.item() accu_train_loss += running_loss # train-loss show summary.add_scalar('Train Loss', accu_train_loss, iteration) cnt += 1. counter += 1 counter1 += 1 del loss, loss_fbank, loss1, loss2, outputs, loss_mask, features, labels end = timeit.default_timer() curr_time = end - start ttime += curr_time mtime = ttime / counter1 print( 'iter = {}/{}, epoch = {}/{}, loss = {:.5f}, time/batch = {:.5f}, mtime/batch = {:.5f}'.format( i + 1, num_train_batches, epoch + 1, self.max_epoch, running_loss, curr_time, mtime)) start = timeit.default_timer() if (i + 1) % self.eval_steps == 0: start = timeit.default_timer() avg_train_loss = accu_train_loss / cnt avg_eval_loss = self.validate(net, eval_loader, iteration) net.train() print('Epoch [%d/%d], Iter [%d/%d] ( TrainLoss: %.4f | EvalLoss: %.4f )' % ( epoch + 1, self.max_epoch, i + 1, self.num_train_sentences // self.batch_size, avg_train_loss, avg_eval_loss)) is_best = True if avg_eval_loss < best_loss else False best_loss = avg_eval_loss if is_best else best_loss checkpoint = Checkpoint(epoch, i, avg_train_loss, avg_eval_loss, best_loss, net.state_dict(), optimizer.state_dict()) model_name = self.model_name + '_latest.model' best_model = self.model_name + '_best.model' checkpoint.save(is_best, os.path.join(self.model_path, model_name), os.path.join(self.model_path, best_model)) logging(self.log_path, self.model_name + '_loss_log.txt', checkpoint, self.eval_steps) # metric_logging(self.log_path, self.model_name +'_metric_log.txt', epoch+1, [avg_st, avg_sn, avg_pe]) accu_train_loss = 0.0 cnt = 0. net.train() if (i + 1) % num_train_batches == 0: break avg_st, avg_sn, avg_pe = self.validate_with_metrics(net, eval_loader) net.train() print('#' * 50) print('') print('After {} epoch the performance on validation score is a s follows:'.format(epoch + 1)) print('') print('STOI: {:.4f}'.format(avg_st)) print('SNR: {:.4f}'.format(avg_sn)) print('PESQ: {:.4f}'.format(avg_pe)) for param_group in optimizer.param_groups: print('learning_rate', param_group['lr']) print('') print('#' * 50) checkpoint = Checkpoint(epoch, 0, None, None, best_loss, net.state_dict(), optimizer.state_dict()) checkpoint.save(False, os.path.join(self.model_path, self.model_name + '-{}.model'.format(epoch + 1)), os.path.join(self.model_path, best_model)) metric_logging(self.log_path, self.model_name + '_metric_log.txt', epoch, [avg_st, avg_sn, avg_pe]) start_iter = 0.
def test(self, args): with open(args.test_list, 'r') as test_list_file: self.test_list = [line.strip() for line in test_list_file.readlines()] self.model_name = args.model_name self.model_file = args.model_file self.test_mixture_path = args.test_mixture_path self.prediction_path = args.prediction_path # create a network print('model', self.model_name) net = Net(device=self.device, L=self.frame_size, width=self.width) # net = torch.nn.DataParallel(net) net.to(self.device) print('Number of learnable parameters: %d' % numParams(net)) print(net) # loss and optimizer criterion = mse_loss() net.eval() print('Load model from "%s"' % self.model_file) checkpoint = Checkpoint() checkpoint.load(self.model_file) net.load_state_dict(checkpoint.state_dict) with torch.no_grad(): for i in range(len(self.test_list)): # read the mixture for resynthesis filename_input = self.test_list[i].split('/')[-1] start1 = timeit.default_timer() print('{}/{}, Started working on {}.'.format(i + 1, len(self.test_list), self.test_list[i])) print('') filename_mix = filename_input.replace('.samp', '_mix.dat') filename_s_ideal = filename_input.replace('.samp', '_s_ideal.dat') filename_s_est = filename_input.replace('.samp', '_s_est.dat') # print(filename_mix) # sys.exit() f_mix = h5py.File(os.path.join(self.test_mixture_path, filename_mix), 'r') f_s_ideal = h5py.File(os.path.join(self.prediction_path, filename_s_ideal), 'w') f_s_est = h5py.File(os.path.join(self.prediction_path, filename_s_est), 'w') # create a test dataset testSet = EvalDataset(os.path.join(self.test_mixture_path, self.test_list[i]), self.num_test_sentences) # create a data loader for test test_loader = DataLoader(testSet, batch_size=1, shuffle=False, num_workers=2, collate_fn=EvalCollate()) # print '\n[%d/%d] Predict on %s' % (i+1, len(self.test_list), self.test_list[i]) accu_test_loss = 0.0 accu_test_nframes = 0 ttime = 0. mtime = 0. cnt = 0. for k, (mix_raw, cln_raw) in enumerate(test_loader): start = timeit.default_timer() est_s = self.eval_forward(mix_raw, net) est_s = est_s[:mix_raw.size] mix = f_mix[str(k)][:] ideal_s = cln_raw f_s_ideal.create_dataset(str(k), data=ideal_s.astype(np.float32), chunks=True) f_s_est.create_dataset(str(k), data=est_s.astype(np.float32), chunks=True) # compute eval_loss test_loss = np.mean((est_s - ideal_s) ** 2) accu_test_loss += test_loss cnt += 1 end = timeit.default_timer() curr_time = end - start ttime += curr_time mtime = ttime / cnt mtime = (mtime * (k) + (end - start)) / (k + 1) print('{}/{}, test_loss = {:.4f}, time/utterance = {:.4f}, ' 'mtime/utternace = {:.4f}'.format(k + 1, self.num_test_sentences, test_loss, curr_time, mtime)) avg_test_loss = accu_test_loss / cnt # bar.update(k,test_loss=avg_test_loss) # bar.finish() end1 = timeit.default_timer() print('********** Finisehe working on {}. time taken = {:.4f} **********'.format(filename_input, end1 - start1)) print('') f_mix.close() f_s_est.close() f_s_ideal.close()
torch.manual_seed(args.seed) model = CSRCNN(scale_factor=args.scale).to(device) criterion =CharbonnierLoss(delta=0.0001)#CharbonnierLoss(delta=0.0001)#HuberLoss(delta=0.9)#nn.L1Loss()# nn.MSELoss() optimizer = optim.Adam([ {'params': model.first_part.parameters(), 'lr': args.lr * 0.1}, # {'params': model.mid_part.parameters(), 'lr': args.lr * 0.1}, {'params': model.last_part.parameters(), 'lr': args.lr * 0.1} ], lr=args.lr) train_dataset = TrainDataset(args.train_file) train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=False)#drop_last=False eval_dataset = EvalDataset(args.eval_file) eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1) eval_dataset1 = EvalDataset(args.eval_file1) eval_dataloader1 = DataLoader(dataset=eval_dataset1, batch_size=1) eval_dataset2 = EvalDataset(args.eval_file2) eval_dataloader2 = DataLoader(dataset=eval_dataset2, batch_size=1) best_weights = copy.deepcopy(model.state_dict()) best_epoch = 0 best_psnr = 0.0 epoch_num=range(1,args.num_epochs+1) psrn=[] loss_num=[] psrn_Set14=[] psrn_BSD200=[]
def main(): # Load dataset print('Loading dataset ...\n') #dataset_train = Dataset(train=True) #dataset_val = Dataset(train=False) dataset_train = TrainDataset('train_DIV_new.h5', opt.patch_size, int(opt.upscale_factor[0])) dataset_val = EvalDataset('test_DIV.h5') loader_train = DataLoader(dataset=dataset_train, num_workers=1, batch_size=opt.batchSize, shuffle=True) loader_val = DataLoader(dataset=dataset_val, batch_size=1) print("# of training samples: %d\n" % int(len(dataset_train))) # Build model netG = Generator_RDN(opt) print('# generator parameters:', sum(param.numel() for param in netG.parameters())) netD = Discriminator() print('# discriminator parameters:', sum(param.numel() for param in netD.parameters())) # net.apply(weights_init_kaiming) # content_criterion = nn.MSELoss() # feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained=True)) content_criterion = nn.L1Loss() adversarial_criterion = nn.BCELoss() ones_const = Variable(torch.ones(opt.batchSize, 1)) # Move to GPU if torch.cuda.is_available(): netG.cuda() netD.cuda() content_criterion.cuda() adversarial_criterion.cuda() ones_const = ones_const.cuda() # feature_extractor.cuda() optim_rdn = optim.Adam(netG.parameters(), lr=1e-4) # Optimizer # training writer = SummaryWriter(opt.outf) step = 0 # noiseL_B=[0,55] # ingnored when opt.mode=='S' # Generator Pretraining(Using MSE Loss) for epoch in range(10): mean_generator_content_loss = 0.0 for i, (lrimg, hrimg) in enumerate(loader_train): # adding noise #print(lrimg[-1].shape) #print(hrimg[-1].shape) #cv2.imshow('win1',lrimg[-1].detach().numpy().transpose((1,2,0))) #cv2.imshow('win2',hrimg[-1].detach().numpy().transpose((1,2,0))) #cv2.waitKey(0) for j in range(opt.batchSize): noise = torch.FloatTensor(lrimg[j].size()).normal_( mean=0.0, std=opt.noiseL / 255.) #lrimg[j] = lrimg[j] + noise lrimg[j] = lrimg[j] # Generate real and fake inputs if opt.cuda: high_res_real = Variable(hrimg.cuda()) high_res_fake = netG(Variable(lrimg).cuda()) else: high_res_real = Variable(hrimg) high_res_fake = netG(Variable(lrimg)) ######### Train generator ######### netG.zero_grad() generator_content_loss = content_criterion(high_res_fake, high_res_real) mean_generator_content_loss += generator_content_loss.data generator_content_loss.backward() optim_rdn.step() ######### Status and display ######### # sys.stdout.write('\r[%d/%d][%d/%d] Generator_MSE_Loss: %.4f' % (epoch, 2, i, len(loader_train), generator_content_loss.data)) # visualizer.show(low_res, high_res_real.cpu().data, high_res_fake.cpu().data) out_train = torch.clamp(high_res_fake, 0., 1.) psnr_train, ssim_train = batch_PSNR(out_train, high_res_real, scale=3.0, data_range=1.) if step % 10 == 0: # Log the scalar values writer.add_scalar('generator_content_loss', generator_content_loss.item(), step) writer.add_scalar('PSNR on training data', psnr_train, step) step += 1 # sys.stdout.write('\r[%d/%d][%d/%d] PSNR: %.4f, SSIM:%.4f' % (epoch, 2, i, len(loader_train), psnr_train, ssim_train)) sys.stdout.write('\r[%d/%d][%d/%d] Generator_MSE_Loss: %.4f\n' % (epoch, 2, i, len(loader_train), mean_generator_content_loss / len(loader_train))) #log_value('generator_mse_loss', mean_generator_content_loss/len(dataloader), epoch) # Do checkpointing torch.save(netG.state_dict(), '%s/model/generator_pretrain.pth' % opt.outf) #SRGAN-RDN training optim_generator = optim.Adam(netG.parameters(), lr=opt.generatorLR) optim_discriminator = optim.Adam(netD.parameters(), lr=opt.discriminatorLR) scheduler_dis = torch.optim.lr_scheduler.StepLR(optim_discriminator, 50, 0.1) scheduler_gen = torch.optim.lr_scheduler.StepLR(optim_generator, 50, 0.1) print('SRGAN training') step_new = 0 for epoch in range(opt.epochs): mean_generator_content_loss = 0.0 mean_generator_adversarial_loss = 0.0 mean_generator_total_loss = 0.0 mean_discriminator_loss = 0.0 netG.train() scheduler_gen.step() scheduler_dis.step() for i, (lrimg, hrimg) in enumerate(loader_train): #print(lrimg[-1].shape) #print(hrimg[-1].shape) #cv2.imshow('win1',lrimg[-1].detach().numpy().transpose((1,2,0))) #cv2.imshow('win2',hrimg[-1].detach().numpy().transpose((1,2,0))) #cv2.waitKey(0) for j in range(opt.batchSize): noise = torch.FloatTensor(lrimg[j].size()).normal_( mean=0, std=opt.noiseL / 255.) #lrimg[j] = lrimg[j] + noise lrimg[j] = lrimg[j] # Generate real and fake inputs if opt.cuda: high_res_real = Variable(hrimg.cuda()) high_res_fake = netG(Variable(lrimg).cuda()) target_real = Variable( torch.rand(opt.batchSize, 1) * 0.5 + 0.7).cuda() target_fake = Variable(torch.rand(opt.batchSize, 1) * 0.3).cuda() else: high_res_real = Variable(hrimg) high_res_fake = netG(Variable(lrimg)) target_real = Variable( torch.rand(opt.batchSize, 1) * 0.5 + 0.7) target_fake = Variable(torch.rand(opt.batchSize, 1) * 0.3) ######### Train discriminator ######### netD.zero_grad() discriminator_loss = adversarial_criterion(netD(high_res_real), target_real) + \ adversarial_criterion(netD(Variable(high_res_fake.data)), target_fake) mean_discriminator_loss += discriminator_loss.data discriminator_loss.backward() optim_discriminator.step() ######### Train generator ######### netG.zero_grad() #real_features = Variable(feature_extractor(high_res_real).data) #fake_features = feature_extractor(high_res_fake) generator_content_loss = content_criterion( high_res_fake, high_res_real ) #+ 0.006*content_criterion(fake_features, real_features) mean_generator_content_loss += generator_content_loss.data generator_adversarial_loss = adversarial_criterion( netD(high_res_fake), ones_const) mean_generator_adversarial_loss += generator_adversarial_loss.data generator_total_loss = generator_content_loss + 1e-3 * generator_adversarial_loss mean_generator_total_loss += generator_total_loss.data generator_total_loss.backward() optim_generator.step() ######### Status and display ######### sys.stdout.write( '\r[%d/%d][%d/%d] Discriminator_Loss: %.4f Generator_Loss (Content/Advers/Total): %.4f/%.4f/%.4f' % (epoch, opt.epochs, i, len(loader_train), discriminator_loss.data, generator_content_loss.data, generator_adversarial_loss.data, generator_total_loss.data)) # visualizer.show(low_res, high_res_real.cpu().data, high_res_fake.cpu().data) out_train = torch.clamp(high_res_fake, 0., 1.) psnr_train, ssim_train = batch_PSNR(out_train, high_res_real, scale=3.0, data_range=1.) if step_new % 10 == 0: # Log the scalar values writer.add_scalar('generator_content_loss', generator_content_loss.item(), step) writer.add_scalar('PSNR on training data', psnr_train, step) writer.add_scalar('discriminator_loss', discriminator_loss.data, step_new) writer.add_scalar('generator_adversarial_loss', generator_adversarial_loss.item(), step_new) writer.add_scalar('generator_total_loss', generator_total_loss, step_new) step += 1 step_new += 1 sys.stdout.write( '\r[%d/%d][%d/%d] Discriminator_Loss: %.4f Generator_Loss (Content/Advers/Total): %.4f/%.4f/%.4f\n' % (epoch, opt.epochs, i, len(loader_train), mean_discriminator_loss / len(loader_train), mean_generator_content_loss / len(loader_train), mean_generator_adversarial_loss / len(loader_train), mean_generator_total_loss / len(loader_train))) # Do checkpointing torch.save(netG.state_dict(), '%s/model/generator_final_%d.pth' % (opt.outf, epoch)) torch.save(netD.state_dict(), '%s/model/discriminator_final%d.pth' % (opt.outf, epoch)) ## the end of each epoch netG.eval() # validate psnr_val = 0 ssim_val = 0.0 val_images = [] num = 0 numofex = opt.noiseL for index, (lrimg_val, hrimg_val) in enumerate(loader_val): #lrimg_val, hrimg_val = dataset_val[k] noise = torch.FloatTensor(lrimg_val.size()).normal_( mean=0, std=opt.val_noiseL / 255.) #lrimgn_val = lrimg_val + noise lrimgn_val = lrimg_val #lrimgn_val = torch.Tensor(np.expand_dims(lrimgn_val, axis=0)) #hrimg_val = torch.Tensor(np.expand_dims(hrimg_val, axis=0)) #lrimg_val = lrimg_val + noise hrimg_val, lrimgn_val = Variable( hrimg_val.cuda(), volatile=True), Variable(lrimgn_val.cuda(), volatile=True) #print(lrimgn_val[-1].shape) #print(hrimg_val[-1].shape) #cv2.imshow('win1', lrimgn_val[-1].detach().cpu().numpy().transpose((1,2,0))) #cv2.imshow('win2', hrimg_val[-1].detach().cpu().numpy().transpose((1,2,0))) #cv2.waitKey(0) out_val = netG(lrimgn_val) psnr_val_e, ssim_val_e = batch_PSNR(out_val, hrimg_val, scale=3.0, data_range=1.) psnr_val += psnr_val_e ssim_val += ssim_val_e hrimg_val = np.transpose(hrimg_val[0].detach().cpu().numpy(), (1, 2, 0)) out_val = np.transpose(out_val[0].detach().cpu().numpy(), (1, 2, 0)) if num < 5: num += 1 # hrimg_val = hrimg_val[int(hrimg_val.shape[0] / 2) - 160:int(hrimg_val.shape[0] / 2) + 160, # int(hrimg_val.shape[1] / 2) - 160:int(hrimg_val.shape[1] / 2) + 160] # out_val = out_val[int(out_val.shape[0] / 2) - 160:int(out_val.shape[0] / 2) + 160, # int(out_val.shape[1] / 2) - 160:int(out_val.shape[1] / 2) + 160] val_images.extend([hrimg_val, out_val]) output_image = make_grid(val_images, nrow=2, nline=1) if not os.path.exists('%s/training_results/%d/' % (opt.outf, numofex)): os.makedirs('%s/training_results/%d/' % (opt.outf, numofex)) save_result(output_image, path='%s/training_results/%d/epoch%d.png' % (opt.outf, numofex, epoch)) psnr_val /= len(dataset_val) ssim_val /= len(dataset_val) print("\n[epoch %d] PSNR_val: %.4f" % (epoch + 1, psnr_val)) print("\n[epoch %d] SSIM_val: %.4f" % (epoch + 1, ssim_val)) writer.add_scalar('PSNR on validation data', psnr_val, epoch) writer.add_scalar('SSIM on validation data', ssim_val, epoch)
# Load model checkpoint that is to be evaluated checkpoint = torch.load(checkpoint) model = checkpoint['model'] model = model.to(device) # Switch to eval mode model.eval() # Load test data test_dataset = PascalVOCDataset(data_folder, split='test', keep_difficult=keep_difficult) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=test_dataset.collate_fn, num_workers=workers, pin_memory=True) eval_dataset = EvalDataset() eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size = 1, shuffle=False, collate_fn = eval_dataset.collate_fn, num_workers = 1, pin_memory=True) def evaluate_new(loader, model): model.eval() with torch.no_grad(): for i, images in enumerate(tqdm(eval_loader, desc='Evaluating')): images = images.to(device) # (N, 3, 300, 300) predicted_locs, predicted_scores = model(images) det_boxes_batch, det_labels_batch, det_scores_batch = model.detect_objects(predicted_locs, predicted_scores, min_score=0.01, max_overlap=0.45, top_k=200) print(i, images) print(det_boxes_batch, det_labels_batch, det_scores_batch)
pth_path = './checkpoint/' + str(network) + '/x' + str( scale) + '/best.pth' print('Loading weights:', pth_path) checkpoint = torch.load(pth_path) model.load_state_dict(checkpoint) # model_dict = model.state_dict() # pretrained_dict = checkpoint # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # model_dict.update(pretrained_dict) # model.load_state_dict(model_dict) eval_file_ = "./h5file_" + datasetfortest + "_x" + str( scale) + "_test" eval_dataset = EvalDataset(eval_file_) eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1) model.eval() epoch_psnr = AverageMeter() epoch_ssim = AverageMeter() for data in eval_dataloader: inputs, labels = data if network == "SRCNN": import torch.nn.functional as F inputs = F.interpolate(inputs, scale_factor=opt.scale, mode='bilinear')
def train(self, args): with open(args.train_list, 'r') as train_list_file: self.train_list = [ line.strip() for line in train_list_file.readlines() ] self.eval_file = args.eval_file self.num_train_sentences = args.num_train_sentences self.batch_size = args.batch_size self.lr = args.lr self.max_epoch = args.max_epoch self.model_path = args.model_path self.log_path = args.log_path self.fig_path = args.fig_path self.eval_plot_num = args.eval_plot_num self.eval_steps = args.eval_steps self.resume_model = args.resume_model self.wav_path = args.wav_path self.tool_path = args.tool_path # create a training dataset and an evaluation dataset trainSet = TrainingDataset(self.train_list, frame_size=self.frame_size, frame_shift=self.frame_shift) evalSet = EvalDataset(self.eval_file, self.num_test_sentences) #trainSet = evalSet # create data loaders for training and evaluation train_loader = DataLoader(trainSet, batch_size=self.batch_size, shuffle=True, num_workers=16, collate_fn=TrainCollate()) eval_loader = DataLoader(evalSet, batch_size=1, shuffle=False, num_workers=4, collate_fn=EvalCollate()) # create a network print('model', self.model_name) net = Net(device=self.device, L=self.frame_size, width=self.width) #net = torch.nn.DataParallel(net) net.to(self.device) print('Number of learnable parameters: %d' % numParams(net)) print(net) criterion = mse_loss() criterion1 = stftm_loss(device=self.device) optimizer = torch.optim.Adam(net.parameters(), lr=self.lr) self.lr_list = [0.0002] * 3 + [0.0001] * 6 + [0.00005] * 3 + [0.00001 ] * 3 if self.resume_model: print('Resume model from "%s"' % self.resume_model) checkpoint = Checkpoint() checkpoint.load(self.resume_model) start_epoch = checkpoint.start_epoch start_iter = checkpoint.start_iter best_loss = checkpoint.best_loss net.load_state_dict(checkpoint.state_dict) optimizer.load_state_dict(checkpoint.optimizer) else: print('Training from scratch.') start_epoch = 0 start_iter = 0 best_loss = np.inf num_train_batches = self.num_train_sentences // self.batch_size total_train_batch = self.max_epoch * num_train_batches print('num_train_sentences', self.num_train_sentences) print('batches_per_epoch', num_train_batches) print('total_train_batch', total_train_batch) print('batch_size', self.batch_size) print('model_name', self.model_name) batch_timings = 0. counter = int(start_epoch * num_train_batches + start_iter) counter1 = 0 print('counter', counter) ttime = 0. cnt = 0. print('best_loss', best_loss) for epoch in range(start_epoch, self.max_epoch): accu_train_loss = 0.0 net.train() for param_group in optimizer.param_groups: param_group['lr'] = self.lr_list[epoch] start = timeit.default_timer() for i, (features, labels, nframes) in enumerate(train_loader): i += start_iter features, labels = features.to(self.device), labels.to( self.device) loss_mask = compLossMask(labels, nframes=nframes) # forward + backward + optimize optimizer.zero_grad() outputs = net(features) outputs = outputs[:, :, :labels.shape[-1]] loss1 = criterion(outputs, labels, loss_mask, nframes) loss2 = criterion1(outputs, labels, loss_mask, nframes) loss = 0.8 * loss1 + 0.2 * loss2 loss.backward() optimizer.step() # calculate losses running_loss = loss.data.item() accu_train_loss += running_loss cnt += 1. counter += 1 counter1 += 1 del loss, loss1, loss2, outputs, loss_mask, features, labels end = timeit.default_timer() curr_time = end - start ttime += curr_time mtime = ttime / counter1 print( 'iter = {}/{}, epoch = {}/{}, loss = {:.5f}, time/batch = {:.5f}, mtime/batch = {:.5f}' .format(i + 1, num_train_batches, epoch + 1, self.max_epoch, running_loss, curr_time, mtime)) start = timeit.default_timer() if (i + 1) % self.eval_steps == 0: start = timeit.default_timer() avg_train_loss = accu_train_loss / cnt avg_eval_loss = self.validate(net, eval_loader) net.train() print( 'Epoch [%d/%d], Iter [%d/%d] ( TrainLoss: %.4f | EvalLoss: %.4f )' % (epoch + 1, self.max_epoch, i + 1, self.num_train_sentences // self.batch_size, avg_train_loss, avg_eval_loss)) is_best = True if avg_eval_loss < best_loss else False best_loss = avg_eval_loss if is_best else best_loss checkpoint = Checkpoint(epoch, i, avg_train_loss, avg_eval_loss, best_loss, net.state_dict(), optimizer.state_dict()) model_name = self.model_name + '_latest.model' best_model = self.model_name + '_best.model' checkpoint.save(is_best, os.path.join(self.model_path, model_name), os.path.join(self.model_path, best_model)) logging(self.log_path, self.model_name + '_loss_log.txt', checkpoint, self.eval_steps) #metric_logging(self.log_path, self.model_name +'_metric_log.txt', epoch+1, [avg_st, avg_sn, avg_pe]) accu_train_loss = 0.0 cnt = 0. net.train() if (i + 1) % num_train_batches == 0: break avg_st, avg_sn, avg_pe = self.validate_with_metrics( net, eval_loader) net.train() print('#' * 50) print('') print( 'After {} epoch the performance on validation score is a s follows:' .format(epoch + 1)) print('') print('STOI: {:.4f}'.format(avg_st)) print('SNR: {:.4f}'.format(avg_sn)) print('PESQ: {:.4f}'.format(avg_pe)) for param_group in optimizer.param_groups: print('learning_rate', param_group['lr']) print('') print('#' * 50) checkpoint = Checkpoint(epoch, 0, None, None, best_loss, net.state_dict(), optimizer.state_dict()) checkpoint.save( False, os.path.join(self.model_path, self.model_name + '-{}.model'.format(epoch + 1)), os.path.join(self.model_path, best_model)) metric_logging(self.log_path, self.model_name + '_metric_log.txt', epoch, [avg_st, avg_sn, avg_pe]) start_iter = 0.
def main(args): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True torch.set_num_threads(args.workers) print('Training Base Detector : prepare_seed : {:}'.format(args.rand_seed)) prepare_seed(args.rand_seed) logger = prepare_logger(args) checkpoint = load_checkpoint(args.init_model) xargs = checkpoint['args'] logger.log('Previous args : {:}'.format(xargs)) # General Data Augmentation if xargs.use_gray == False: mean_fill = tuple([int(x * 255) for x in [0.485, 0.456, 0.406]]) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) else: mean_fill = (0.5, ) normalize = transforms.Normalize(mean=[mean_fill[0]], std=[0.5]) eval_transform = transforms.Compose2V([transforms.ToTensor(), normalize, \ transforms.PreCrop(xargs.pre_crop_expand), \ transforms.CenterCrop(xargs.crop_max)]) # Model Configure Load model_config = load_configure(xargs.model_config, logger) shape = (xargs.height, xargs.width) logger.log('--> {:}\n--> Sigma : {:}, Shape : {:}'.format( model_config, xargs.sigma, shape)) # Evaluation Dataloader eval_loaders = [] if args.eval_ilists is not None: for eval_ilist in args.eval_ilists: eval_idata = EvalDataset(eval_transform, xargs.sigma, model_config.downsample, xargs.heatmap_type, shape, xargs.use_gray, xargs.data_indicator) eval_idata.load_list(eval_ilist, args.num_pts, xargs.boxindicator, xargs.normalizeL, True) eval_iloader = torch.utils.data.DataLoader( eval_idata, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) eval_loaders.append((eval_iloader, False)) if args.eval_vlists is not None: for eval_vlist in args.eval_vlists: eval_vdata = EvalDataset(eval_transform, xargs.sigma, model_config.downsample, xargs.heatmap_type, shape, xargs.use_gray, xargs.data_indicator) eval_vdata.load_list(eval_vlist, args.num_pts, xargs.boxindicator, xargs.normalizeL, True) eval_vloader = torch.utils.data.DataLoader( eval_vdata, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) eval_loaders.append((eval_vloader, True)) # define the detector detector = obtain_pro_model(model_config, xargs.num_pts, xargs.sigma, xargs.use_gray) assert model_config.downsample == detector.downsample, 'downsample is not correct : {:} vs {:}'.format( model_config.downsample, detector.downsample) logger.log("=> detector :\n {:}".format(detector)) logger.log("=> Net-Parameters : {:} MB".format( count_parameters_in_MB(detector))) logger.log('=> Eval-Transform : {:}'.format(eval_transform)) detector = detector.cuda() net = torch.nn.DataParallel(detector) net.eval() net.load_state_dict(checkpoint['detector']) cpu = torch.device('cpu') assert len(args.use_stable) == 2 for iLOADER, (loader, is_video) in enumerate(eval_loaders): logger.log( '{:} The [{:2d}/{:2d}]-th test set [{:}] = {:} with {:} batches.'. format(time_string(), iLOADER, len(eval_loaders), 'video' if is_video else 'image', loader.dataset, len(loader))) with torch.no_grad(): all_points, all_results, all_image_ps = [], [], [] for i, (inputs, targets, masks, normpoints, transthetas, image_index, nopoints, shapes) in enumerate(loader): image_index = image_index.squeeze(1).tolist() (batch_size, C, H, W), num_pts = inputs.size(), xargs.num_pts # batch_heatmaps is a list for stage-predictions, each element should be [Batch, C, H, W] if xargs.procedure == 'heatmap': batch_features, batch_heatmaps, batch_locs, batch_scos = net( inputs) batch_locs = batch_locs[:, :-1, :] else: batch_locs = net(inputs) batch_locs = batch_locs.detach().to(cpu) # evaluate the training data for ibatch, (imgidx, nopoint) in enumerate(zip(image_index, nopoints)): if xargs.procedure == 'heatmap': norm_locs = normalize_points( (H, W), batch_locs[ibatch].transpose(1, 0)) norm_locs = torch.cat( (norm_locs, torch.ones(1, num_pts)), dim=0) else: norm_locs = torch.cat((batch_locs[ibatch].permute( 1, 0), torch.ones(1, num_pts)), dim=0) transtheta = transthetas[ibatch][:2, :] norm_locs = torch.mm(transtheta, norm_locs) real_locs = denormalize_points(shapes[ibatch].tolist(), norm_locs) #real_locs = torch.cat((real_locs, batch_scos[ibatch].permute(1,0)), dim=0) real_locs = torch.cat((real_locs, torch.ones(1, num_pts)), dim=0) xpoints = loader.dataset.labels[imgidx].get_points().numpy( ) image_path = loader.dataset.datas[imgidx] # put into the list all_points.append(torch.from_numpy(xpoints)) all_results.append(real_locs) all_image_ps.append(image_path) total = len(all_points) logger.log( '{:} The [{:2d}/{:2d}]-th test set finishes evaluation : {:} frames/images' .format(time_string(), iLOADER, len(eval_loaders), total)) """ if args.use_stable[0] > 0: save_dir = Path( osp.join(args.save_path, '{:}-X-{:03d}'.format(args.model_name, iLOADER)) ) save_dir.mkdir(parents=True, exist_ok=True) wrap_parallel = WrapParallel(save_dir, all_image_ps, all_results, all_points, 180, (255, 0, 0)) wrap_loader = torch.utils.data.DataLoader(wrap_parallel, batch_size=args.workers, shuffle=False, num_workers=args.workers, pin_memory=True) for iL, INDEXES in enumerate(wrap_loader): _ = INDEXES cmd = 'ffmpeg -y -i {:}/%06d.png -framerate 30 {:}.avi'.format(save_dir, save_dir) logger.log('{:} possible >>>>> : {:}'.format(time_string(), cmd)) os.system( cmd ) if args.use_stable[1] > 0: save_dir = Path( osp.join(args.save_path, '{:}-Y-{:03d}'.format(args.model_name, iLOADER)) ) save_dir.mkdir(parents=True, exist_ok=True) Xpredictions, Xgts = torch.stack(all_results), torch.stack(all_points) new_preds = fc_solve(Xgts, Xpredictions, is_cuda=True) wrap_parallel = WrapParallel(save_dir, all_image_ps, new_preds, all_points, 180, (0, 0, 255)) wrap_loader = torch.utils.data.DataLoader(wrap_parallel, batch_size=args.workers, shuffle=False, num_workers=args.workers, pin_memory=True) for iL, INDEXES in enumerate(wrap_loader): _ = INDEXES cmd = 'ffmpeg -y -i {:}/%06d.png -framerate 30 {:}.avi'.format(save_dir, save_dir) logger.log('{:} possible >>>>> : {:}'.format(time_string(), cmd)) os.system( cmd ) """ Xpredictions, Xgts = torch.stack(all_results), torch.stack(all_points) save_path = Path( osp.join(args.save_path, '{:}-result-{:03d}.pth'.format(args.model_name, iLOADER))) torch.save( { 'paths': all_image_ps, 'ground-truths': Xgts, 'predictions': all_results }, save_path) logger.log('{:} save into {:}'.format(time_string(), save_path)) if False: new_preds = fc_solve_v2(Xgts, Xpredictions, is_cuda=True) # create the dir save_dir = Path( osp.join(args.save_path, '{:}-T-{:03d}'.format(args.model_name, iLOADER))) save_dir.mkdir(parents=True, exist_ok=True) wrap_parallel = WrapParallelV2(save_dir, all_image_ps, Xgts, all_results, new_preds, all_points, 180, [args.model_name, 'SRT']) wrap_parallel[0] wrap_loader = torch.utils.data.DataLoader(wrap_parallel, batch_size=args.workers, shuffle=False, num_workers=args.workers, pin_memory=True) for iL, INDEXES in enumerate(wrap_loader): _ = INDEXES cmd = 'ffmpeg -y -i {:}/%06d.png -vb 5000k {:}.avi'.format( save_dir, save_dir) logger.log('{:} possible >>>>> : {:}'.format(time_string(), cmd)) os.system(cmd) logger.close() return
# my_lr_scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=4500/opt.batch_size*opt.num_epochs,eta_min=0.000001) my_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=int(4000 / opt.batch_size * 200), gamma=0.5) train_dataset = TrainDataset(opt.train_file, patch_size=opt.patch_size, scale=opt.scale) train_dataloader = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) eval_dataset = EvalDataset(opt.eval_file) eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1) best_weights = copy.deepcopy(model.state_dict()) best_epoch = 0 best_psnr = 0.0 for epoch in range(opt.num_epochs): for param_group in optimizer.param_groups: param_group['lr'] = opt.lr * (0.1**(epoch // int(opt.num_epochs * 0.8))) model.train() epoch_losses = AverageMeter() with tqdm(total=(len(train_dataset) -