def main(args): # ================================================ # Preparation # ================================================ if not torch.cuda.is_available(): raise Exception('At least one gpu must be available.') gpu = torch.device('cuda:0') # create result directory (if necessary) if not os.path.exists(args.result_dir): os.makedirs(args.result_dir) for phase in ['phase_1', 'phase_2', 'phase_3']: if not os.path.exists(os.path.join(args.result_dir, phase)): os.makedirs(os.path.join(args.result_dir, phase)) # load dataset trnsfm = transforms.Compose([ transforms.Resize(args.cn_input_size), transforms.RandomCrop((args.cn_input_size, args.cn_input_size)), transforms.ToTensor(), ]) print('loading dataset... (it may take a few minutes)') train_dset = ImageDataset(os.path.join(args.data_dir, 'train'), trnsfm, recursive_search=args.recursive_search) test_dset = ImageDataset(os.path.join(args.data_dir, 'test'), trnsfm, recursive_search=args.recursive_search) train_loader = DataLoader(train_dset, batch_size=(args.bsize // args.bdivs), shuffle=True) # compute mpv (mean pixel value) of training dataset if args.mpv is None: mpv = np.zeros(shape=(1, )) pbar = tqdm(total=len(train_dset.imgpaths), desc='computing mean pixel value of training dataset...') for imgpath in train_dset.imgpaths: img = Image.open(imgpath) x = np.array(img) / 255. mpv += x.mean(axis=(0, 1)) pbar.update() mpv /= len(train_dset.imgpaths) pbar.close() else: mpv = np.array(args.mpv) # save training config mpv_json = [] for i in range(1): mpv_json.append(float(mpv[i])) args_dict = vars(args) # args_dict['mpv'] = mpv_json with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f: json.dump(args_dict, f) # make mpv & alpha tensors mpv = torch.tensor(mpv.reshape(1, 1, 1, 1), dtype=torch.float32).to(gpu) alpha = torch.tensor(args.alpha, dtype=torch.float32).to(gpu) # ================================================ # Training Phase 1 # ================================================ # load completion network model_cn = CompletionNetwork() if args.init_model_cn is not None: model_cn.load_state_dict( torch.load(args.init_model_cn, map_location='cpu')) if args.data_parallel: model_cn = DataParallel(model_cn) model_cn = model_cn.to(gpu) opt_cn = Adadelta(model_cn.parameters()) # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_1) while pbar.n < args.steps_1: for x in train_loader: # forward x = x.to(gpu) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) loss = completion_network_loss(x, output, mask) # backward loss.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cn.step() opt_cn.zero_grad() pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu()) pbar.update() # test if pbar.n % args.snaperiod_1 == 0: model_cn.eval() with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = rejoiner(x_mask, output, mask) imgs = torch.cat( (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_1', 'step%d.png' % pbar.n) model_cn_path = os.path.join( args.result_dir, 'phase_1', 'model_cn_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) if args.data_parallel: torch.save(model_cn.module.state_dict(), model_cn_path) else: torch.save(model_cn.state_dict(), model_cn_path) model_cn.train() if pbar.n >= args.steps_1: break pbar.close() # ================================================ # Training Phase 2 # ================================================ # load context discriminator model_cd = ContextDiscriminator( local_input_shape=(1, args.ld_input_size, args.ld_input_size), global_input_shape=(1, args.cn_input_size, args.cn_input_size), ) if args.init_model_cd is not None: model_cd.load_state_dict( torch.load(args.init_model_cd, map_location='cpu')) if args.data_parallel: model_cd = DataParallel(model_cd) model_cd = model_cd.to(gpu) opt_cd = Adadelta(model_cd.parameters(), lr=0.1) bceloss = BCELoss() # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_2) while pbar.n < args.steps_2: for x in train_loader: # fake forward x = x.to(gpu) hole_area_fake = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) fake = torch.zeros((len(x), 1)).to(gpu) x_mask = x - x * mask + mpv * mask input_cn = torch.cat((x_mask, mask), dim=1) output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd( (input_ld_fake.to(gpu), input_gd_fake.to(gpu))) loss_fake = bceloss(output_fake, fake) # real forward hole_area_real = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) real = torch.ones((len(x), 1)).to(gpu) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area_real) output_real = model_cd((input_ld_real, input_gd_real)) loss_real = bceloss(output_real, real) # reduce loss = (loss_fake + loss_real) / 2. # backward loss.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cd.step() opt_cd.zero_grad() pbar.set_description('phase 2 | train loss: %.5f' % loss.cpu()) pbar.update() # test if pbar.n % args.snaperiod_2 == 0: model_cn.eval() with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = rejoiner(x_mask, output, mask) imgs = torch.cat( (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_2', 'step%d.png' % pbar.n) model_cd_path = os.path.join( args.result_dir, 'phase_2', 'model_cd_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) if args.data_parallel: torch.save(model_cd.module.state_dict(), model_cd_path) else: torch.save(model_cd.state_dict(), model_cd_path) model_cn.train() if pbar.n >= args.steps_2: break pbar.close() # ================================================ # Training Phase 3 # ================================================ cnt_bdivs = 0 pbar = tqdm(total=args.steps_3) while pbar.n < args.steps_3: for x in train_loader: # forward model_cd x = x.to(gpu) hole_area_fake = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) # fake forward fake = torch.zeros((len(x), 1)).to(gpu) x_mask = x - x * mask + mpv * mask input_cn = torch.cat((x_mask, mask), dim=1) output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake, input_gd_fake)) loss_cd_fake = bceloss(output_fake, fake) # real forward hole_area_real = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) real = torch.ones((len(x), 1)).to(gpu) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area_real) output_real = model_cd((input_ld_real, input_gd_real)) loss_cd_real = bceloss(output_real, real) # reduce loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2. # backward model_cd loss_cd.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: # optimize opt_cd.step() opt_cd.zero_grad() # forward model_cn loss_cn_1 = completion_network_loss(x, output_cn, mask) input_gd_fake = output_cn input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake, (input_gd_fake))) loss_cn_2 = bceloss(output_fake, real) # reduce loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2. # backward model_cn loss_cn.backward() if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cn.step() opt_cn.zero_grad() pbar.set_description( 'phase 3 | train loss (cd): %.5f (cn): %.5f' % (loss_cd.cpu(), loss_cn.cpu())) pbar.update() # test if pbar.n % args.snaperiod_3 == 0: model_cn.eval() with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = rejoiner(x_mask, output, mask) imgs = torch.cat( (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_3', 'step%d.png' % pbar.n) model_cn_path = os.path.join( args.result_dir, 'phase_3', 'model_cn_step%d' % pbar.n) model_cd_path = os.path.join( args.result_dir, 'phase_3', 'model_cd_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) if args.data_parallel: torch.save(model_cn.module.state_dict(), model_cn_path) torch.save(model_cd.module.state_dict(), model_cd_path) else: torch.save(model_cn.state_dict(), model_cn_path) torch.save(model_cd.state_dict(), model_cd_path) model_cn.train() if pbar.n >= args.steps_3: break pbar.close()
def main(args): if args.wandb != "tmp": wandb.init(project=args.wandb, config=args) if not torch.cuda.is_available(): raise Exception('At least one gpu must be available.') gpu = torch.device('cuda:0') # create result directory (if necessary) if not os.path.exists(args.result_dir): os.makedirs(args.result_dir) for phase in ['phase_1', 'phase_2', 'phase_3']: if not os.path.exists(os.path.join(args.result_dir, phase)): os.makedirs(os.path.join(args.result_dir, phase)) # load dataset print('loading dataset... (it may take a few minutes)') train_dset = masked_dataset(os.path.join(args.data_dir, 'train'), args.max_train) test_dset = masked_dataset(os.path.join(args.data_dir, 'test'), args.max_test) train_loader = DataLoader(train_dset, batch_size=(args.batch_size), shuffle=True) alpha = torch.tensor(args.alpha, dtype=torch.float32).to(gpu) model_cn = CompletionNetwork() if args.init_model_cn is not None: model_cn.load_state_dict( torch.load(args.init_model_cn, map_location='cpu')) model_cn = model_cn.to(gpu) model_cn.train() opt_cn = Adam(model_cn.parameters(), lr=args.learning_rate) # training # ================================================ # Training Phase 1 # ================================================ pbar = tqdm(total=args.steps_1) epochs = 0 while epochs < args.steps_1: for i, (normal, masked) in tqdm(enumerate(train_loader, 0)): # forward # normal = torch.autograd.Variable(normal,requires_grad=True).to(gpu) # masked = torch.autograd.Variable(normal,requires_grad=True).to(gpu) output = model_cn(masked.to(gpu)) loss = torch.nn.functional.mse_loss(output, normal.to(gpu)) # backward loss.backward() # optimize opt_cn.step() opt_cn.zero_grad() if args.wandb != "tmp": wandb.log({"phase_1_train_loss": loss.cpu()}) pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu()) pbar.update() # test model_cn.eval() with torch.no_grad(): normal, masked = sample_random_batch( test_dset, batch_size=args.num_test_completions) normal = normal.to(gpu) masked = masked.to(gpu) output = model_cn(masked) # completed = output imgs = torch.cat((masked.cpu(), normal.cpu(), output.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_1', 'step%d.png' % pbar.n) model_cn_path = os.path.join(args.result_dir, 'phase_1', 'model_cn_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(masked)) torch.save(model_cn.state_dict(), model_cn_path) model_cn.train() epochs += 1 pbar.close()