def main(): print('Loading dataset ...\n') dataset_train = Dataset(data_path=opt.data_path) loader_train = DataLoader(dataset=dataset_train, num_workers=4, batch_size=opt.batch_size, shuffle=True) print("# of training samples: %d\n" % int(len(loader_train))) # Build model model = Network(nin=64, use_GPU=opt.use_GPU) print_network(model) # loss function criterion = SSIM() criterion1 = nn.L1Loss() criterion2 = nn.MSELoss() # Move to GPU if opt.use_GPU: model = model.cuda() criterion.cuda() criterion1.cuda() criterion2.cuda() # Optimizer optimizer = optim.Adam(model.parameters(), lr=opt.lr) scheduler = MultiStepLR(optimizer, milestones=opt.milestone, gamma=0.2) # learning rates, # record training writer = SummaryWriter(opt.save_path) # load the lastest model initial_epoch = findLastCheckpoint(save_dir=opt.save_path) if initial_epoch > 0: print('resuming by loading epoch %d' % initial_epoch) model.load_state_dict( torch.load( os.path.join(opt.save_path, 'net_epoch%d.pth' % initial_epoch))) # start training step = 0 for epoch in range(initial_epoch, opt.epochs): scheduler.step(epoch) for param_group in optimizer.param_groups: print('learning rate %f' % param_group["lr"]) ## epoch training start for i, (input_train, target_train) in enumerate(loader_train, 0): model.train() model.zero_grad() optimizer.zero_grad() input_train, target_train = Variable(input_train), Variable( target_train) if opt.use_GPU: input_train, target_train = input_train.cuda( ), target_train.cuda() out_train, r1, r2 = model(input_train) pixel_metric = criterion(target_train, out_train) loss1 = criterion(target_train, r1) loss2 = criterion(target_train, r2) loss3 = criterion1(target_train, out_train) #loss4 = criterion1(target_train, r1) #loss5=criterion1(target_train,r2) loss = -pixel_metric - loss1 - loss2 + loss3 #+loss4+loss5 loss.backward() optimizer.step() # training curve model.eval() out_train, _, _ = model(input_train) out_train = torch.clamp(out_train, 0., 1.) psnr_train = batch_PSNR(out_train, target_train, 1.) print( "[epoch %d][%d/%d] loss: %.4f, pixel_metric: %.4f,loss1: %.4f,loss2: %.4f,loss3: %.4f,PSNR: %.4f" % (epoch + 1, i + 1, len(loader_train), loss.item(), pixel_metric.item(), loss1.item(), loss2.item(), loss3.item(), psnr_train)) if step % 10 == 0: # Log the scalar values writer.add_scalar('loss', loss.item(), step) writer.add_scalar('PSNR on training data', psnr_train, step) step += 1 ## epoch training end # log the images model.eval() out_train, _, _ = model(input_train) out_train = torch.clamp(out_train, 0., 1.) im_target = utils.make_grid(target_train.data, nrow=8, normalize=True, scale_each=True) im_input = utils.make_grid(input_train.data, nrow=8, normalize=True, scale_each=True) im_derain = utils.make_grid(out_train.data, nrow=8, normalize=True, scale_each=True) writer.add_image('clean image', im_target, epoch + 1) writer.add_image('rainy image', im_input, epoch + 1) writer.add_image('deraining image', im_derain, epoch + 1) # save model torch.save(model.state_dict(), os.path.join(opt.save_path, 'net_latest.pth')) if epoch % opt.save_freq == 0: torch.save( model.state_dict(), os.path.join(opt.save_path, 'net_epoch%d.pth' % (epoch + 1)))
def main(): # Load dataset print('Loading dataset ...\n') dataset_train = Dataset(train=True, data_path=opt.data_path) loader_train = DataLoader(dataset=dataset_train, num_workers=4, batch_size=opt.batchSize, shuffle=True) print("# of training samples: %d\n" % int(len(dataset_train))) # Build model model = DRN(channel=3, inter_iter=opt.inter_iter, intra_iter=opt.intra_iter, use_GPU=opt.use_GPU) print_network(model) criterion = SSIM() # Move to GPU if opt.use_GPU: model = model.cuda() criterion.cuda() # Optimizer optimizer = optim.Adam(model.parameters(), lr=opt.lr) scheduler = MultiStepLR(optimizer, milestones=opt.milestone, gamma=0.5) # learning rates # training writer = SummaryWriter(opt.save_folder) step = 0 initial_epoch = findLastCheckpoint( save_dir=opt.save_folder) # load the last model in matconvnet style if initial_epoch > 0: print('resuming by loading epoch %03d' % initial_epoch) model.load_state_dict( torch.load( os.path.join(opt.save_folder, 'net_epoch%d.pth' % initial_epoch))) for epoch in range(initial_epoch, opt.epochs): scheduler.step(epoch) # set learning rate for param_group in optimizer.param_groups: print('learning rate %f' % param_group["lr"]) # train for i, (input, target) in enumerate(loader_train, 0): # training step loss_list = [] model.train() model.zero_grad() optimizer.zero_grad() input_train, target_train = Variable(input.cuda()), Variable( target.cuda()) out_train, outs = model(input_train) pixel_loss = criterion(target_train, out_train) for lossi in range(opt.inter_iter): loss1 = criterion(target_train, outs[lossi]) loss_list.append(loss1) loss = -pixel_loss index = 0.1 for lossi in range(opt.inter_iter): loss += -index * loss_list[lossi] index = index + 0.1 loss.backward() optimizer.step() # results model.eval() out_train, _ = model(input_train) out_train = torch.clamp(out_train, 0., 1.) psnr_train = batch_PSNR(out_train, target_train, 1.) print( "[epoch %d][%d/%d] loss: %.4f, loss1: %.4f, loss2: %.4f, loss3: %.4f, loss4: %.4f, PSNR_train: %.4f" % (epoch + 1, i + 1, len(loader_train), loss.item(), loss_list[0].item(), loss_list[1].item(), loss_list[2].item(), loss_list[3].item(), psnr_train)) # print("[epoch %d][%d/%d] loss: %.4f, PSNR_train: %.4f" % # (epoch + 1, i + 1, len(loader_train), loss.item(), psnr_train)) # if you are using older version of PyTorch, you may need to change loss.item() to loss.data[0] if step % 10 == 0: # Log the scalar values writer.add_scalar('loss', loss.item(), step) writer.add_scalar('PSNR on training data', psnr_train, step) step += 1 ## the end of each epoch model.eval() # log the images out_train, _ = model(input_train) out_train = torch.clamp(out_train, 0., 1.) Img = utils.make_grid(target_train.data, nrow=8, normalize=True, scale_each=True) Imgn = utils.make_grid(input_train.data, nrow=8, normalize=True, scale_each=True) Irecon = utils.make_grid(out_train.data, nrow=8, normalize=True, scale_each=True) writer.add_image('clean image', Img, epoch) writer.add_image('noisy image', Imgn, epoch) writer.add_image('reconstructed image', Irecon, epoch) # save model torch.save(model.state_dict(), os.path.join(opt.save_folder, 'net_latest.pth')) if epoch % opt.save_freq == 0: torch.save( model.state_dict(), os.path.join(opt.save_folder, 'net_epoch%d.pth' % (epoch + 1)))