def main(args): imgs = load_image(args.input, args.ref) vgg = VGG(model_type='vgg19').to(device) swapper = Swapper().to(device) map_in = vgg(imgs['bic'].to(device), TARGET_LAYERS) map_ref = vgg(imgs['ref'].to(device), TARGET_LAYERS) map_ref_blur = vgg(imgs['ref_blur'].to(device), TARGET_LAYERS) with torch.no_grad(), timer('Feature swapping'): maps, weights, correspondences = swapper(map_in, map_ref, map_ref_blur) model = SRNTT(use_weights=args.use_weights).to(device) model.load_state_dict(torch.load(args.weight)) img_hr = imgs['hr'].to(device) img_lr = imgs['lr'].to(device) maps = { k: torch.tensor(v).unsqueeze(0).to(device) for k, v in maps.items() } weights = torch.tensor(weights).reshape(1, 1, *weights.shape).to(device) with torch.no_grad(), timer('Inference'): _, img_sr = model(img_lr, maps, weights) psnr = PSNR()(img_sr.clamp(0, 1), img_hr.clamp(0, 1)).item() ssim = SSIM()(img_sr.clamp(0, 1), img_hr.clamp(0, 1)).item() print(f'[Result] PSNR:{psnr:.2f}, SSIM:{ssim:.4f}') save_image(img_sr.clamp(0, 1), './out.png')
def __init__(self, args): self.args = args args.logger.info('Initializing trainer') self.model = get_model(args) params_cnt = count_parameters(self.model) args.logger.info("params "+str(params_cnt)) torch.cuda.set_device(args.rank) self.model.cuda(args.rank) self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[args.rank]) train_dataset, val_dataset = get_dataset(args) if args.split == 'train': # train loss self.RGBLoss = RGBLoss(args, sharp=False) self.SegLoss = nn.CrossEntropyLoss() self.RGBLoss.cuda(args.rank) self.SegLoss.cuda(args.rank) if args.optimizer == "adamax": self.optimizer = torch.optim.Adamax(list(self.model.parameters()), lr=args.learning_rate) elif args.optimizer == "adam": self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args.learning_rate) elif args.optimizer == "sgd": self.optimizer = torch.optim.SGD(self.model.parameters(), lr=args.learning_rate, momentum=0.9) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) self.train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size//args.gpus, shuffle=False, num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) else: # val criteria self.L1Loss = nn.L1Loss().cuda(args.rank) self.PSNRLoss = PSNR().cuda(args.rank) self.SSIMLoss = SSIM().cuda(args.rank) self.IoULoss = IoU().cuda(args.rank) self.VGGCosLoss = VGGCosineLoss().cuda(args.rank) val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) self.val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size//args.gpus, shuffle=False, num_workers=args.num_workers, pin_memory=True, sampler=val_sampler) torch.backends.cudnn.benchmark = True self.global_step = 0 self.epoch=1 if args.resume or ( args.split != 'train' and not args.checkepoch_range): self.load_checkpoint() if args.rank == 0: writer_name = args.path+'/{}_int_{}_len_{}_{}_logs'.format(self.args.split, int(self.args.interval), self.args.vid_length, self.args.dataset) self.writer = SummaryWriter(writer_name) self.stand_heat_map = self.create_stand_heatmap()
def __init__(self, args): self.args = args args.logger.info('Initializing trainer') self.model = get_model(args) if args.load_G: for p in self.model.netG.parameters(): p.requires_grad = False torch.cuda.set_device(args.rank) self.model.cuda(args.rank) self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=[args.rank]) train_dataset, val_dataset = get_dataset(args) if not args.val: # train loss self.RGBLoss = RGBLoss(args).cuda(args.rank) self.SegLoss = nn.CrossEntropyLoss().cuda(args.rank) self.GANLoss = GANMapLoss().cuda(args.rank) # self.GANLoss = GANLoss(tensor=torch.FloatTensor).cuda(args.rank) self.GANFeatLoss = nn.L1Loss().cuda(args.rank) if args.optG == 'adamax': self.optG = torch.optim.Adamax( self.model.module.netG.parameters(), lr=args.lr_G) if args.optD == 'sgd': self.optD = torch.optim.SGD( self.model.module.netD.parameters(), lr=args.lr_D, momentum=0.9) elif args.optD == 'adamax': self.optD = torch.optim.Adamax( self.model.module.netD.parameters(), lr=args.lr_D) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) self.train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size // args.gpus, shuffle=False, num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) else: # val criteria self.L1Loss = nn.L1Loss().cuda(args.rank) self.PSNRLoss = PSNR().cuda(args.rank) self.SSIMLoss = SSIM().cuda(args.rank) self.IoULoss = IoU().cuda(args.rank) self.VGGCosLoss = VGGCosineLoss().cuda(args.rank) val_sampler = torch.utils.data.distributed.DistributedSampler( val_dataset) self.val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size // args.gpus, shuffle=False, num_workers=args.num_workers, pin_memory=True, sampler=val_sampler) torch.backends.cudnn.benchmark = True self.global_step = 0 self.epoch = 1 if args.load_G: self.load_temp() elif args.resume or (args.val and not args.checkepoch_range): self.load_checkpoint() if args.rank == 0: if not args.load_G: if args.val: self.writer = SummaryWriter(args.path+'/val_logs') if args.interval == 2 else\ SummaryWriter(args.path+'/val_int_1_logs') else: self.writer = SummaryWriter(args.path + '/logs') else: if args.val: self.writer = SummaryWriter(args.path+'/dis_val_logs') if args.interval == 2 else\ SummaryWriter(args.path+'/dis_val_int_1_logs') else: self.writer = SummaryWriter( args.path + '/dis_{}_logs'.format(args.session)) self.heatmap = self.create_stand_heatmap()
def main(args): init_seeds(seed=args.seed) # split data files = list([f.stem for f in Path(args.dataroot).glob('map/*.npz')]) train_files, val_files = train_test_split(files, test_size=0.1) # define dataloaders train_set = ReferenceDataset(train_files, args.dataroot) val_set = ReferenceDatasetEval(val_files, args.dataroot) train_loader = DataLoader(train_set, args.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_set, args.batch_size, drop_last=True) # define networks netG = SRNTT(args.ngf, args.n_blocks, args.use_weights).to(device) netG.content_extractor.load_state_dict(torch.load(args.init_weight)) if args.netD == 'image': netD = ImageDiscriminator(args.ndf).to(device) elif args.netD == 'patch': netD = Discriminator(args.ndf).to(device) # define criteria criterion_rec = nn.L1Loss().to(device) criterion_per = PerceptualLoss().to(device) criterion_adv = AdversarialLoss().to(device) criterion_tex = TextureLoss(args.use_weights).to(device) # metrics criterion_psnr = PSNR(max_val=1., mode='Y') criterion_ssim = SSIM(window_size=11) # define optimizers optimizer_G = optim.Adam(netG.parameters(), args.lr) optimizer_D = optim.Adam(netD.parameters(), args.lr) scheduler_G = StepLR(optimizer_G, int(args.n_epochs * len(train_loader) / 2), 0.1) scheduler_D = StepLR(optimizer_D, int(args.n_epochs * len(train_loader) / 2), 0.1) # for tensorboard writer = SummaryWriter(log_dir=f'runs/{args.pid}' if args.pid else None) if args.netG_pre is None: """ pretrain """ step = 0 for epoch in range(1, args.n_epochs_init + 1): for i, batch in enumerate(train_loader, 1): img_hr = batch['img_hr'].to(device) img_lr = batch['img_lr'].to(device) maps = {k: v.to(device) for k, v in batch['maps'].items()} weights = batch['weights'].to(device) _, img_sr = netG(img_lr, maps, weights) """ train G """ optimizer_G.zero_grad() g_loss = criterion_rec(img_sr, img_hr) g_loss.backward() optimizer_G.step() """ logging """ writer.add_scalar('pre/g_loss', g_loss.item(), step) if step % args.display_freq == 0: writer.add_images('pre/img_lr', img_lr.clamp(0, 1), step) writer.add_images('pre/img_hr', img_hr.clamp(0, 1), step) writer.add_images('pre/img_sr', img_sr.clamp(0, 1), step) log_txt = [ f'[Pre][Epoch{epoch}][{i}/{len(train_loader)}]', f'G Loss: {g_loss.item()}' ] print(' '.join(log_txt)) step += 1 if args.debug: break out_path = Path(writer.log_dir) / f'netG_pre{epoch:03}.pth' torch.save(netG.state_dict(), out_path) else: # ommit pre-training netG.load_state_dict(torch.load(args.netG_pre)) if args.netD_pre: netD.load_state_dict(torch.load(args.netD_pre)) """ train with all losses """ step = 0 for epoch in range(1, args.n_epochs + 1): """ training loop """ netG.train() netD.train() for i, batch in enumerate(train_loader, 1): img_hr = batch['img_hr'].to(device) img_lr = batch['img_lr'].to(device) maps = {k: v.to(device) for k, v in batch['maps'].items()} weights = batch['weights'].to(device) _, img_sr = netG(img_lr, maps, weights) """ train D """ optimizer_D.zero_grad() for p in netD.parameters(): p.requires_grad = True for p in netG.parameters(): p.requires_grad = False # compute WGAN loss d_out_real = netD(img_hr) d_loss_real = criterion_adv(d_out_real, True) d_out_fake = netD(img_sr.detach()) d_loss_fake = criterion_adv(d_out_fake, False) d_loss = d_loss_real + d_loss_fake # gradient penalty gradient_penalty = compute_gp(netD, img_hr.data, img_sr.data) d_loss += 10 * gradient_penalty d_loss.backward() optimizer_D.step() """ train G """ optimizer_G.zero_grad() for p in netD.parameters(): p.requires_grad = False for p in netG.parameters(): p.requires_grad = True # compute all losses loss_rec = criterion_rec(img_sr, img_hr) loss_per = criterion_per(img_sr, img_hr) loss_adv = criterion_adv(netD(img_sr), True) loss_tex = criterion_tex(img_sr, maps, weights) # optimize with combined d_loss g_loss = (loss_rec * args.lambda_rec + loss_per * args.lambda_per + loss_adv * args.lambda_adv + loss_tex * args.lambda_tex) g_loss.backward() optimizer_G.step() """ logging """ writer.add_scalar('train/g_loss', g_loss.item(), step) writer.add_scalar('train/loss_rec', loss_rec.item(), step) writer.add_scalar('train/loss_per', loss_per.item(), step) writer.add_scalar('train/loss_tex', loss_tex.item(), step) writer.add_scalar('train/loss_adv', loss_adv.item(), step) writer.add_scalar('train/d_loss', d_loss.item(), step) writer.add_scalar('train/d_real', d_loss_real.item(), step) writer.add_scalar('train/d_fake', d_loss_fake.item(), step) if step % args.display_freq == 0: writer.add_images('train/img_lr', img_lr, step) writer.add_images('train/img_hr', img_hr, step) writer.add_images('train/img_sr', img_sr.clamp(0, 1), step) log_txt = [ f'[Train][Epoch{epoch}][{i}/{len(train_loader)}]', f'G Loss: {g_loss.item()}, D Loss: {d_loss.item()}' ] print(' '.join(log_txt)) scheduler_G.step() scheduler_D.step() step += 1 if args.debug: break """ validation loop """ netG.eval() netD.eval() val_psnr, val_ssim = 0, 0 tbar = tqdm(total=len(val_loader)) for i, batch in enumerate(val_loader, 1): img_hr = batch['img_hr'].to(device) img_lr = batch['img_lr'].to(device) maps = {k: v.to(device) for k, v in batch['maps'].items()} weights = batch['weights'].to(device) with torch.no_grad(): _, img_sr = netG(img_lr, maps, weights) val_psnr += criterion_psnr(img_hr, img_sr.clamp(0, 1)).item() val_ssim += criterion_ssim(img_hr, img_sr.clamp(0, 1)).item() tbar.update(1) if args.debug: break else: tbar.close() val_psnr /= len(val_loader) val_ssim /= len(val_loader) writer.add_scalar('val/psnr', val_psnr, epoch) writer.add_scalar('val/ssim', val_ssim, epoch) print(f'[Val][Epoch{epoch}] PSNR:{val_psnr:.4f}, SSIM:{val_ssim:.4f}') netG_path = Path(writer.log_dir) / f'netG_{epoch:03}.pth' netD_path = Path(writer.log_dir) / f'netD_{epoch:03}.pth' torch.save(netG.state_dict(), netG_path) torch.save(netD.state_dict(), netD_path)
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--weight', '-w', type=str, required=True) parser.add_argument('--use_weights', action='store_true') args = parser.parse_args() dataset = CUFED5Dataset('/home/ubuntu/srntt-pytorch/data/CUFED5') dataloader = DataLoader(dataset) vgg = VGG(model_type='vgg19').to(device) swapper = Swapper().to(device) model = SRNTT(use_weights=args.use_weights).to(device) model.load_state_dict(torch.load(args.weight)) criterion_psnr = PSNR() table = [] tbar = tqdm(total=len(dataloader)) for batch_idx, batch in enumerate(dataloader): with torch.no_grad(): img_hr = batch['img_hr'].to(device) img_lr = batch['img_lr'].to(device) img_in_up = batch['img_in_up'].to(device) map_in = vgg(img_in_up, TARGET_LAYERS) row = [batch['filename'][0].split('_')[0]] for ref_idx in range(7): ref = batch['ref'][ref_idx] map_ref = vgg(ref['ref'].to(device), TARGET_LAYERS)
def __init__(self, args): self.args = args args.logger.info('Initializing trainer') # if not os.path.isdir('../predict'): only used in validation # os.makedirs('../predict') self.model = get_model(args) if self.args.lock_coarse: for p in self.model.coarse_model.parameters(): p.requires_grad = False torch.cuda.set_device(args.rank) self.model.cuda(args.rank) self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[args.rank]) train_dataset, val_dataset = get_dataset(args) if not args.val: # train loss self.coarse_RGBLoss = RGBLoss(args, sharp=False) self.refine_RGBLoss = RGBLoss(args, sharp=False, refine=True) self.SegLoss = nn.CrossEntropyLoss() self.GANLoss = GANLoss(tensor=torch.FloatTensor) self.coarse_RGBLoss.cuda(args.rank) self.refine_RGBLoss.cuda(args.rank) self.SegLoss.cuda(args.rank) self.GANLoss.cuda(args.rank) if args.optimizer == "adamax": self.optG = torch.optim.Adamax(list(self.model.module.coarse_model.parameters()) + list(self.model.module.refine_model.parameters()), lr=args.learning_rate) elif args.optimizer == "adam": self.optG = torch.optim.Adam(self.model.parameters(), lr=args.learning_rate) elif args.optimizer == "sgd": self.optG = torch.optim.SGD(self.model.parameters(), lr=args.learning_rate, momentum=0.9) # self.optD = torch.optim.Adam(self.model.module.discriminator.parameters(), lr=args.learning_rate) self.optD = torch.optim.SGD(self.model.module.discriminator.parameters(), lr=args.learning_rate, momentum=0.9) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) self.train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size//args.gpus, shuffle=False, num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) else: # val criteria self.L1Loss = nn.L1Loss().cuda(args.rank) self.PSNRLoss = PSNR().cuda(args.rank) self.SSIMLoss = SSIM().cuda(args.rank) self.IoULoss = IoU().cuda(args.rank) self.VGGCosLoss = VGGCosineLoss().cuda(args.rank) val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) self.val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size//args.gpus, shuffle=False, num_workers=args.num_workers, pin_memory=True, sampler=val_sampler) torch.backends.cudnn.benchmark = True self.global_step = 0 self.epoch=1 if args.resume or (args.val and not args.checkepoch_range): self.load_checkpoint() if args.rank == 0: if args.val: self.writer = SummaryWriter(args.path+'/val_logs') if args.interval == 2 else\ SummaryWriter(args.path+'/val_int_1_logs') else: self.writer = SummaryWriter(args.path+'/logs') self.heatmap = self.create_stand_heatmap()