def main(args): # ================================================ # Preparation # ================================================ if not torch.cuda.is_available(): raise Exception('At least one gpu must be available.') gpu = torch.device('cuda:0') # create result directory (if necessary) if not os.path.exists(args.result_dir): os.makedirs(args.result_dir) for phase in ['phase_1', 'phase_2', 'phase_3']: if not os.path.exists(os.path.join(args.result_dir, phase)): os.makedirs(os.path.join(args.result_dir, phase)) # load dataset trnsfm = transforms.Compose([ transforms.Resize(args.cn_input_size), transforms.RandomCrop((args.cn_input_size, args.cn_input_size)), transforms.ToTensor(), ]) print('loading dataset... (it may take a few minutes)') train_dset = ImageDataset(os.path.join(args.data_dir, 'train'), trnsfm, recursive_search=args.recursive_search) test_dset = ImageDataset(os.path.join(args.data_dir, 'test'), trnsfm, recursive_search=args.recursive_search) train_loader = DataLoader(train_dset, batch_size=(args.bsize // args.bdivs), shuffle=True) # compute mpv (mean pixel value) of training dataset if args.mpv is None: mpv = np.zeros(shape=(1, )) pbar = tqdm(total=len(train_dset.imgpaths), desc='computing mean pixel value of training dataset...') for imgpath in train_dset.imgpaths: img = Image.open(imgpath) x = np.array(img) / 255. mpv += x.mean(axis=(0, 1)) pbar.update() mpv /= len(train_dset.imgpaths) pbar.close() else: mpv = np.array(args.mpv) # save training config mpv_json = [] for i in range(1): mpv_json.append(float(mpv[i])) args_dict = vars(args) # args_dict['mpv'] = mpv_json with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f: json.dump(args_dict, f) # make mpv & alpha tensors mpv = torch.tensor(mpv.reshape(1, 1, 1, 1), dtype=torch.float32).to(gpu) alpha = torch.tensor(args.alpha, dtype=torch.float32).to(gpu) # ================================================ # Training Phase 1 # ================================================ # load completion network model_cn = CompletionNetwork() if args.init_model_cn is not None: model_cn.load_state_dict( torch.load(args.init_model_cn, map_location='cpu')) if args.data_parallel: model_cn = DataParallel(model_cn) model_cn = model_cn.to(gpu) opt_cn = Adadelta(model_cn.parameters()) # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_1) while pbar.n < args.steps_1: for x in train_loader: # forward x = x.to(gpu) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) loss = completion_network_loss(x, output, mask) # backward loss.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cn.step() opt_cn.zero_grad() pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu()) pbar.update() # test if pbar.n % args.snaperiod_1 == 0: model_cn.eval() with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = rejoiner(x_mask, output, mask) imgs = torch.cat( (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_1', 'step%d.png' % pbar.n) model_cn_path = os.path.join( args.result_dir, 'phase_1', 'model_cn_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) if args.data_parallel: torch.save(model_cn.module.state_dict(), model_cn_path) else: torch.save(model_cn.state_dict(), model_cn_path) model_cn.train() if pbar.n >= args.steps_1: break pbar.close() # ================================================ # Training Phase 2 # ================================================ # load context discriminator model_cd = ContextDiscriminator( local_input_shape=(1, args.ld_input_size, args.ld_input_size), global_input_shape=(1, args.cn_input_size, args.cn_input_size), ) if args.init_model_cd is not None: model_cd.load_state_dict( torch.load(args.init_model_cd, map_location='cpu')) if args.data_parallel: model_cd = DataParallel(model_cd) model_cd = model_cd.to(gpu) opt_cd = Adadelta(model_cd.parameters(), lr=0.1) bceloss = BCELoss() # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_2) while pbar.n < args.steps_2: for x in train_loader: # fake forward x = x.to(gpu) hole_area_fake = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) fake = torch.zeros((len(x), 1)).to(gpu) x_mask = x - x * mask + mpv * mask input_cn = torch.cat((x_mask, mask), dim=1) output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd( (input_ld_fake.to(gpu), input_gd_fake.to(gpu))) loss_fake = bceloss(output_fake, fake) # real forward hole_area_real = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) real = torch.ones((len(x), 1)).to(gpu) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area_real) output_real = model_cd((input_ld_real, input_gd_real)) loss_real = bceloss(output_real, real) # reduce loss = (loss_fake + loss_real) / 2. # backward loss.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cd.step() opt_cd.zero_grad() pbar.set_description('phase 2 | train loss: %.5f' % loss.cpu()) pbar.update() # test if pbar.n % args.snaperiod_2 == 0: model_cn.eval() with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = rejoiner(x_mask, output, mask) imgs = torch.cat( (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_2', 'step%d.png' % pbar.n) model_cd_path = os.path.join( args.result_dir, 'phase_2', 'model_cd_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) if args.data_parallel: torch.save(model_cd.module.state_dict(), model_cd_path) else: torch.save(model_cd.state_dict(), model_cd_path) model_cn.train() if pbar.n >= args.steps_2: break pbar.close() # ================================================ # Training Phase 3 # ================================================ cnt_bdivs = 0 pbar = tqdm(total=args.steps_3) while pbar.n < args.steps_3: for x in train_loader: # forward model_cd x = x.to(gpu) hole_area_fake = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) # fake forward fake = torch.zeros((len(x), 1)).to(gpu) x_mask = x - x * mask + mpv * mask input_cn = torch.cat((x_mask, mask), dim=1) output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake, input_gd_fake)) loss_cd_fake = bceloss(output_fake, fake) # real forward hole_area_real = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) real = torch.ones((len(x), 1)).to(gpu) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area_real) output_real = model_cd((input_ld_real, input_gd_real)) loss_cd_real = bceloss(output_real, real) # reduce loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2. # backward model_cd loss_cd.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: # optimize opt_cd.step() opt_cd.zero_grad() # forward model_cn loss_cn_1 = completion_network_loss(x, output_cn, mask) input_gd_fake = output_cn input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake, (input_gd_fake))) loss_cn_2 = bceloss(output_fake, real) # reduce loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2. # backward model_cn loss_cn.backward() if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cn.step() opt_cn.zero_grad() pbar.set_description( 'phase 3 | train loss (cd): %.5f (cn): %.5f' % (loss_cd.cpu(), loss_cn.cpu())) pbar.update() # test if pbar.n % args.snaperiod_3 == 0: model_cn.eval() with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = rejoiner(x_mask, output, mask) imgs = torch.cat( (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_3', 'step%d.png' % pbar.n) model_cn_path = os.path.join( args.result_dir, 'phase_3', 'model_cn_step%d' % pbar.n) model_cd_path = os.path.join( args.result_dir, 'phase_3', 'model_cd_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) if args.data_parallel: torch.save(model_cn.module.state_dict(), model_cn_path) torch.save(model_cd.module.state_dict(), model_cd_path) else: torch.save(model_cn.state_dict(), model_cn_path) torch.save(model_cd.state_dict(), model_cd_path) model_cn.train() if pbar.n >= args.steps_3: break pbar.close()
freeze(DG) unfreeze(G) mask = get_center_mask(img_size, bs) x_mask = x - x * mask + mpv * mask inp = torch.cat((x_mask, mask), dim=1) output = G(inp) hole_area = gen_hole_area((ld_input_size, ld_input_size), (x.shape[3], x.shape[2])) fake_crop = crop(output, hole_area) __, logit_dl = DL(fake_crop) __, logit_dg = DG(output) # calculate g_loss gan_loss = (-logit_dl.mean() - logit_dg.mean()) / 2 re_loss = completion_network_loss(x, output, mask) loss = gan_loss * 20 + re_loss gan.update(gan_loss.detach().cpu().numpy(), bs) re.update(re_loss.detach().cpu().numpy(), bs) # print("gan_loss:{:.3f}\tre_loss:{:.3f}".format(gan_loss, re_loss)) g_optimizer.zero_grad() loss.backward() g_optimizer.step() interval = time.time() - st st = time.time() if epoch % 20 == 0: test_model(test_set, G, epoch) print( "Epoch:{}\tTime:{:.2f}\tdl:{:.2f}\tdg:{:.2f}\tgan:{:.2f}\tre:{:.2f}"
def main(args): # ================================================ # Preparation # ================================================ args.data_dir = os.path.expanduser(args.data_dir) args.result_dir = os.path.expanduser(args.result_dir) if args.init_model_cn != None: args.init_model_cn = os.path.expanduser(args.init_model_cn) if args.init_model_cd != None: args.init_model_cd = os.path.expanduser(args.init_model_cd) if torch.cuda.is_available() == False: raise Exception('At least one gpu must be available.') else: gpu = torch.device('cuda:0') # create result directory (if necessary) if os.path.exists(args.result_dir) == False: os.makedirs(args.result_dir) for s in ['phase_1', 'phase_2', 'phase_3']: if os.path.exists(os.path.join(args.result_dir, s)) == False: os.makedirs(os.path.join(args.result_dir, s)) # dataset trnsfm = transforms.Compose([ transforms.Resize(args.cn_input_size), transforms.RandomCrop((args.cn_input_size, args.cn_input_size)), transforms.ToTensor(), ]) print('loading dataset... (it may take a few minutes)') train_dset = ImageDataset(os.path.join(args.data_dir, 'train'), trnsfm) test_dset = ImageDataset(os.path.join(args.data_dir, 'test'), trnsfm) train_loader = DataLoader(train_dset, batch_size=(args.bsize // args.bdivs), shuffle=True) # compute mean pixel value of training dataset mpv = 0. if args.mpv == None: pbar = tqdm(total=len(train_dset.imgpaths), desc='computing mean pixel value for training dataset...') for imgpath in train_dset.imgpaths: img = Image.open(imgpath) x = np.array(img, dtype=np.float32) / 255. mpv += x.mean() pbar.update() mpv /= len(train_dset.imgpaths) pbar.close() else: mpv = args.mpv mpv = torch.tensor(mpv).to(gpu) alpha = torch.tensor(args.alpha).to(gpu) # save training config args_dict = vars(args) args_dict['mpv'] = float(mpv) with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f: json.dump(args_dict, f) # ================================================ # Training Phase 1 # ================================================ data = load_lua('./glcic/completionnet_places2.t7') model_cn = data.model if args.optimizer == 'adadelta': opt_cn = Adadelta(model_cn.parameters()) else: opt_cn = Adam(model_cn.parameters()) """ model_cn = CompletionNetwork() if args.init_model_cn != None: model_cn.load_state_dict(torch.load(args.init_model_cn, map_location='cpu')) if args.data_parallel: model_cn = DataParallel(model_cn) model_cn = model_cn.to(gpu) # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_1) while pbar.n < args.steps_1: for x in train_loader: # forward x = x.to(gpu) msk = gen_input_mask( shape=x.shape, hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) output = model_cn(x - x * msk + mpv * msk) loss = completion_network_loss(x, output, msk) # backward loss.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cn.step() # clear grads opt_cn.zero_grad() # update progbar pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu()) pbar.update() # test if pbar.n % args.snaperiod_1 == 0: with torch.no_grad(): x = sample_random_batch(test_dset, batch_size=args.num_test_completions).to(gpu) msk = gen_input_mask( shape=x.shape, hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) input = x - x * msk + mpv * msk output = model_cn(input) completed = poisson_blend(input, output, msk) imgs = torch.cat((x.cpu(), input.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_1', 'step%d.png' % pbar.n) model_cn_path = os.path.join(args.result_dir, 'phase_1', 'model_cn_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) torch.save(model_cn.state_dict(), model_cn_path) # terminate if pbar.n >= args.steps_1: break pbar.close() """ # ================================================ # Training Phase 2 # ================================================ model_cd = ContextDiscriminator( local_input_shape=(3, args.ld_input_size, args.ld_input_size), global_input_shape=(3, args.cn_input_size, args.cn_input_size), ) if args.data_parallel: model_cd = DataParallel(model_cd) if args.init_model_cd != None: model_cd.load_state_dict( torch.load(args.init_model_cd, map_location='cpu')) if args.optimizer == 'adadelta': opt_cd = Adadelta(model_cd.parameters()) else: opt_cd = Adam(model_cd.parameters()) model_cd = model_cd.to(gpu) bceloss = BCELoss() # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_2) while pbar.n < args.steps_2: for x in train_loader: # fake forward x = x.to(gpu) hole_area_fake = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) msk = gen_input_mask( shape=x.shape, hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=hole_area_fake, max_holes=args.max_holes, ).to(gpu) fake = torch.zeros((len(x), 1)).to(gpu) output_cn = model_cn(x - x * msk + mpv * msk) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd( (input_ld_fake.to(gpu), input_gd_fake.to(gpu))) loss_fake = bceloss(output_fake, fake) # real forward hole_area_real = gen_hole_area(size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2])) real = torch.ones((len(x), 1)).to(gpu) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area_real) output_real = model_cd((input_ld_real, input_gd_real)) loss_real = bceloss(output_real, real) # reduce loss = (loss_fake + loss_real) / 2. # backward loss.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cd.step() # clear grads opt_cd.zero_grad() # update progbar pbar.set_description('phase 2 | train loss: %.5f' % loss.cpu()) pbar.update() # test if pbar.n % args.snaperiod_2 == 0: with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) msk = gen_input_mask( shape=x.shape, hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) input = x - x * msk + mpv * msk output = model_cn(input) completed = poisson_blend(input, output, msk) imgs = torch.cat( (x.cpu(), input.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_2', 'step%d.png' % pbar.n) model_cd_path = os.path.join( args.result_dir, 'phase_2', 'model_cd_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) torch.save(model_cd.state_dict(), model_cd_path) # terminate if pbar.n >= args.steps_2: break pbar.close() # ================================================ # Training Phase 3 # ================================================ # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_3) while pbar.n < args.steps_3: for x in train_loader: # forward model_cd x = x.to(gpu) hole_area_fake = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) msk = gen_input_mask( shape=x.shape, hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=hole_area_fake, max_holes=args.max_holes, ).to(gpu) # fake forward fake = torch.zeros((len(x), 1)).to(gpu) output_cn = model_cn(x - x * msk + mpv * msk) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake, input_gd_fake)) loss_cd_fake = bceloss(output_fake, fake) # real forward hole_area_real = gen_hole_area(size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2])) real = torch.ones((len(x), 1)).to(gpu) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area_real) output_real = model_cd((input_ld_real, input_gd_real)) loss_cd_real = bceloss(output_real, real) # reduce loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2. # backward model_cd loss_cd.backward() # forward model_cn loss_cn_1 = completion_network_loss(x, output_cn, msk) input_gd_fake = output_cn input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake, (input_gd_fake))) loss_cn_2 = bceloss(output_fake, real) # reduce loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2. # backward model_cn loss_cn.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cn.step() opt_cn.step() # clear grads opt_cd.zero_grad() opt_cn.zero_grad() # update progbar pbar.set_description( 'phase 3 | train loss (cd): %.5f (cn): %.5f' % (loss_cd.cpu(), loss_cn.cpu())) pbar.update() # test if pbar.n % args.snaperiod_3 == 0: with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) msk = gen_input_mask( shape=x.shape, hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) input = x - x * msk + mpv * msk output = model_cn(input) completed = poisson_blend(input, output, msk) imgs = torch.cat( (x.cpu(), input.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_3', 'step%d.png' % pbar.n) model_cn_path = os.path.join( args.result_dir, 'phase_3', 'model_cn_step%d' % pbar.n) model_cd_path = os.path.join( args.result_dir, 'phase_3', 'model_cd_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) torch.save(model_cn.state_dict(), model_cn_path) torch.save(model_cd.state_dict(), model_cd_path) # terminate if pbar.n >= args.steps_3: break pbar.close()
def main(args): # ================================================ # Preparation # ================================================ args.data_dir = os.path.expanduser(args.data_dir) args.result_dir = os.path.expanduser(args.result_dir) if torch.cuda.is_available() == False: raise Exception('At least one gpu must be available.') if args.num_gpus == 1: # train models in a single gpu gpu_cn = torch.device('cuda:0') gpu_cd = gpu_cn else: # train models in different two gpus gpu_cn = torch.device('cuda:0') gpu_cd = torch.device('cuda:1') # create result directory (if necessary) if os.path.exists(args.result_dir) == False: os.makedirs(args.result_dir) for s in ['phase_1', 'phase_2', 'phase_3']: if os.path.exists(os.path.join(args.result_dir, s)) == False: os.makedirs(os.path.join(args.result_dir, s)) # dataset trnsfm = transforms.Compose([ transforms.Resize(args.cn_input_size), transforms.RandomCrop((args.cn_input_size, args.cn_input_size)), transforms.ToTensor(), ]) print('loading dataset... (it may take a few minutes)') train_dset = ImageDataset(os.path.join(args.data_dir, 'train'), trnsfm) test_dset = ImageDataset(os.path.join(args.data_dir, 'test'), trnsfm) train_loader = DataLoader(train_dset, batch_size=args.bsize, shuffle=True) # compute the mean pixel value of train dataset mean_pv = 0. imgpaths = train_dset.imgpaths[:min(args.max_mpv_samples, len(train_dset))] if args.comp_mpv: pbar = tqdm(total=len(imgpaths), desc='computing the mean pixel value') for imgpath in imgpaths: img = Image.open(imgpath) x = np.array(img, dtype=np.float32) / 255. mean_pv += x.mean() pbar.update() mean_pv /= len(imgpaths) pbar.close() mpv = torch.tensor(mean_pv).to(gpu_cn) # save training config args_dict = vars(args) args_dict['mean_pv'] = mean_pv with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f: json.dump(args_dict, f) # ================================================ # Training Phase 1 # ================================================ # model & optimizer model_cn = CompletionNetwork() model_cn = model_cn.to(gpu_cn) if args.optimizer == 'adadelta': opt_cn = Adadelta(model_cn.parameters()) else: opt_cn = Adam(model_cn.parameters()) # training pbar = tqdm(total=args.steps_1) while pbar.n < args.steps_1: for x in train_loader: opt_cn.zero_grad() # generate hole area hole_area = gen_hole_area( size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2]), ) # create mask msk = gen_input_mask( shape=x.shape, hole_size=( (args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h), ), hole_area=hole_area, max_holes=args.max_holes, ) # merge x, mask, and mpv msg = 'phase 1 |' x = x.to(gpu_cn) msk = msk.to(gpu_cn) input = x - x * msk + mpv * msk output = model_cn(input) # optimize loss = completion_network_loss(x, output, msk) loss.backward() opt_cn.step() msg += ' train loss: %.5f' % loss.cpu() pbar.set_description(msg) pbar.update() # test if pbar.n % args.snaperiod_1 == 0: with torch.no_grad(): x = sample_random_batch(test_dset, batch_size=args.bsize) x = x.to(gpu_cn) input = x - x * msk + mpv * msk output = model_cn(input) completed = poisson_blend(input, output, msk) imgs = torch.cat((input.cpu(), completed.cpu()), dim=0) save_image(imgs, os.path.join(args.result_dir, 'phase_1', 'step%d.png' % pbar.n), nrow=len(x)) torch.save( model_cn.state_dict(), os.path.join(args.result_dir, 'phase_1', 'model_cn_step%d' % pbar.n)) if pbar.n >= args.steps_1: break pbar.close() # ================================================ # Training Phase 2 # ================================================ # model, optimizer & criterion model_cd = ContextDiscriminator( local_input_shape=(3, args.ld_input_size, args.ld_input_size), global_input_shape=(3, args.cn_input_size, args.cn_input_size), ) model_cd = model_cd.to(gpu_cd) if args.optimizer == 'adadelta': opt_cd = Adadelta(model_cd.parameters()) else: opt_cd = Adam(model_cd.parameters()) criterion_cd = BCELoss() # training pbar = tqdm(total=args.steps_2) while pbar.n < args.steps_2: for x in train_loader: x = x.to(gpu_cn) opt_cd.zero_grad() # ================================================ # fake # ================================================ hole_area = gen_hole_area( size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2]), ) # create mask msk = gen_input_mask( shape=x.shape, hole_size=( (args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h), ), hole_area=hole_area, max_holes=args.max_holes, ) fake = torch.zeros((len(x), 1)).to(gpu_cd) msk = msk.to(gpu_cn) input_cn = x - x * msk + mpv * msk output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area) input_fake = (input_ld_fake.to(gpu_cd), input_gd_fake.to(gpu_cd)) output_fake = model_cd(input_fake) loss_fake = criterion_cd(output_fake, fake) # ================================================ # real # ================================================ hole_area = gen_hole_area( size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2]), ) real = torch.ones((len(x), 1)).to(gpu_cd) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area) input_real = (input_ld_real.to(gpu_cd), input_gd_real.to(gpu_cd)) output_real = model_cd(input_real) loss_real = criterion_cd(output_real, real) # ================================================ # optimize # ================================================ loss = (loss_fake + loss_real) / 2. loss.backward() opt_cd.step() msg = 'phase 2 |' msg += ' train loss: %.5f' % loss.cpu() pbar.set_description(msg) pbar.update() # test if pbar.n % args.snaperiod_2 == 0: with torch.no_grad(): x = sample_random_batch(test_dset, batch_size=args.bsize) x = x.to(gpu_cn) input = x - x * msk + mpv * msk output = model_cn(input) completed = poisson_blend(input, output, msk) imgs = torch.cat((input.cpu(), completed.cpu()), dim=0) save_image(imgs, os.path.join(args.result_dir, 'phase_2', 'step%d.png' % pbar.n), nrow=len(x)) torch.save( model_cd.state_dict(), os.path.join(args.result_dir, 'phase_2', 'model_cd_step%d' % pbar.n)) if pbar.n >= args.steps_2: break pbar.close() # ================================================ # Training Phase 3 # ================================================ # training alpha = torch.tensor(args.alpha).to(gpu_cd) pbar = tqdm(total=args.steps_3) while pbar.n < args.steps_3: for x in train_loader: x = x.to(gpu_cn) # ================================================ # train model_cd # ================================================ opt_cd.zero_grad() # fake hole_area = gen_hole_area( size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2]), ) # create mask msk = gen_input_mask( shape=x.shape, hole_size=( (args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h), ), hole_area=hole_area, max_holes=args.max_holes, ) fake = torch.zeros((len(x), 1)).to(gpu_cd) msk = msk.to(gpu_cn) input_cn = x - x * msk + mpv * msk output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area) input_fake = (input_ld_fake.to(gpu_cd), input_gd_fake.to(gpu_cd)) output_fake = model_cd(input_fake) loss_cd_1 = criterion_cd(output_fake, fake) # real hole_area = gen_hole_area( size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2]), ) real = torch.ones((len(x), 1)).to(gpu_cd) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area) input_real = (input_ld_real.to(gpu_cd), input_gd_real.to(gpu_cd)) output_real = model_cd(input_real) loss_cd_2 = criterion_cd(output_real, real) # optimize loss_cd = (loss_cd_1 + loss_cd_2) * alpha / 2. loss_cd.backward() opt_cd.step() # ================================================ # train model_cn # ================================================ opt_cn.zero_grad() loss_cn_1 = completion_network_loss(x, output_cn, msk).to(gpu_cd) input_gd_fake = output_cn input_ld_fake = crop(input_gd_fake, hole_area) input_fake = (input_ld_fake.to(gpu_cd), input_gd_fake.to(gpu_cd)) output_fake = model_cd(input_fake) loss_cn_2 = criterion_cd(output_fake, real) # optimize loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2. loss_cn.backward() opt_cn.step() msg = 'phase 3 |' msg += ' train loss (cd): %.5f' % loss_cd.cpu() msg += ' train loss (cn): %.5f' % loss_cn.cpu() pbar.set_description(msg) pbar.update() # test if pbar.n % args.snaperiod_3 == 0: with torch.no_grad(): x = sample_random_batch(test_dset, batch_size=args.bsize) x = x.to(gpu_cn) input = x - x * msk + mpv * msk output = model_cn(input) completed = poisson_blend(input, output, msk) imgs = torch.cat((input.cpu(), completed.cpu()), dim=0) save_image(imgs, os.path.join(args.result_dir, 'phase_3', 'step%d.png' % pbar.n), nrow=len(x)) torch.save( model_cn.state_dict(), os.path.join(args.result_dir, 'phase_3', 'model_cn_step%d' % pbar.n)) torch.save( model_cd.state_dict(), os.path.join(args.result_dir, 'phase_3', 'model_cd_step%d' % pbar.n)) if pbar.n >= args.steps_3: break pbar.close()
def main(args): # ================================================ # Preparation # ================================================ args.data_dir = os.path.expanduser(args.data_dir) args.result_dir = os.path.expanduser(args.result_dir) if args.init_model_cn != None: args.init_model_cn = os.path.expanduser(args.init_model_cn) if args.init_model_cd != None: args.init_model_cd = os.path.expanduser(args.init_model_cd) if torch.cuda.is_available() == False: raise Exception('At least one gpu must be available.') else: gpu = torch.device('cuda:0') # create result directory (if necessary) if os.path.exists(args.result_dir) == False: os.makedirs(args.result_dir) for s in ['phase_1', 'phase_2', 'phase_3']: if os.path.exists(os.path.join(args.result_dir, s)) == False: os.makedirs(os.path.join(args.result_dir, s)) # dataset trnsfm = transforms.Compose([ transforms.Resize(args.cn_input_size), transforms.RandomCrop((args.cn_input_size, args.cn_input_size)), transforms.ToTensor(), ]) print('loading dataset... (it may take a few minutes)') train_dset = ImageDataset(os.path.join(args.data_dir, 'train'), trnsfm, recursive_search=args.recursive_search) test_dset = ImageDataset(os.path.join(args.data_dir, 'test'), trnsfm, recursive_search=args.recursive_search) train_loader = DataLoader(train_dset, batch_size=(args.bsize // args.bdivs), shuffle=True) # compute mean pixel value of training dataset mpv = np.zeros(shape=(3, )) if args.mpv == None: pbar = tqdm(total=len(train_dset.imgpaths), desc='computing mean pixel value for training dataset...') for imgpath in train_dset.imgpaths: img = Image.open(imgpath) x = np.array(img, dtype=np.float32) / 255. mpv += x.mean(axis=(0, 1)) pbar.update() mpv /= len(train_dset.imgpaths) pbar.close() else: mpv = np.array(args.mpv) # save training config mpv_json = [] for i in range(3): mpv_json.append(float(mpv[i])) # convert to json serializable type args_dict = vars(args) args_dict['mpv'] = mpv_json with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f: json.dump(args_dict, f) # make mpv & alpha tensor mpv = torch.tensor(mpv.astype(np.float32).reshape(1, 3, 1, 1)).to(gpu) alpha = torch.tensor(args.alpha).to(gpu) my_writer = SummaryWriter(log_dir='log') # ================================================ # Training Phase 1 # ================================================ model_cn = CompletionNetwork() if args.data_parallel: model_cn = DataParallel(model_cn) if args.init_model_cn != None: new = OrderedDict() for k, v in torch.load(args.init_model_cn, map_location='cpu').items(): new['module.' + k] = v # model_cn.load_state_dict(torch.load(args.init_model_cn, map_location='cpu')) model_cn.load_state_dict(new) print('第一阶段加载模型成功!') # model_cn.load_state_dict(torch.load(args.init_model_cn, map_location='cpu')) if args.optimizer == 'adadelta': opt_cn = Adadelta(model_cn.parameters()) else: opt_cn = Adam(model_cn.parameters()) model_cn = model_cn.to(gpu) # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_1) pbar.n = 90000 while pbar.n < args.steps_1: for i, x in enumerate(train_loader): # forward x = x.to(gpu) mask = gen_input_mask( shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) loss_mse = completion_network_loss(x, output, mask) loss_contextual = contextual_loss(x, output, mask) loss = loss_mse + 0.004 * loss_contextual # backward loss.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cn.step() # clear grads opt_cn.zero_grad() # update progbar pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu()) pbar.update() my_writer.add_scalar('mse_loss', loss_mse, pbar.n * len(train_loader) + i) # test if pbar.n % args.snaperiod_1 == 0: with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask( shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = poisson_blend(x, output, mask) imgs = torch.cat( (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_1', 'step%d.png' % pbar.n) model_cn_path = os.path.join( args.result_dir, 'phase_1', 'model_cn_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) if args.data_parallel: torch.save(model_cn.module.state_dict(), model_cn_path) else: torch.save(model_cn.state_dict(), model_cn_path) # terminate if pbar.n >= args.steps_1: break pbar.close() # ================================================ # Training Phase 2 # ================================================ model_cd = ContextDiscriminator( local_input_shape=(3, args.ld_input_size, args.ld_input_size), global_input_shape=(3, args.cn_input_size, args.cn_input_size), arc=args.arc, ) if args.data_parallel: model_cd = DataParallel(model_cd) if args.init_model_cd != None: # model_cd.load_state_dict(torch.load(args.init_model_cd, map_location='cpu')) new = OrderedDict() for k, v in torch.load(args.init_model_cd, map_location='cpu').items(): new['module.' + k] = v # model_cn.load_state_dict(torch.load(args.init_model_cn, map_location='cpu')) model_cd.load_state_dict(new) print('第二阶段加载模型成功!') if args.optimizer == 'adadelta': opt_cd = Adadelta(model_cd.parameters()) else: opt_cd = Adam(model_cd.parameters()) model_cd = model_cd.to(gpu) bceloss = BCELoss() # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_2) pbar.n = 120000 while pbar.n < args.steps_2: for x in train_loader: # fake forward x = x.to(gpu) hole_area_fake = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) mask = gen_input_mask( shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=hole_area_fake, max_holes=args.max_holes, ).to(gpu) fake = torch.zeros((len(x), 1)).to(gpu) x_mask = x - x * mask + mpv * mask input_cn = torch.cat((x_mask, mask), dim=1) output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() # 输入全局判别器的生成图片 input_ld_fake = crop(input_gd_fake, hole_area_fake) # 输入局部判别器的生成图片 output_fake = model_cd( (input_ld_fake.to(gpu), input_gd_fake.to(gpu))) loss_fake = bceloss(output_fake, fake) # real forward hole_area_real = gen_hole_area(size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2])) real = torch.ones((len(x), 1)).to(gpu) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area_real) # 输入全局判别器的真实图片 output_real = model_cd( (input_ld_real, input_gd_real)) # 输入局部判别器的生成图片 loss_real = bceloss(output_real, real) # reduce loss = (loss_fake + loss_real) / 2. # backward loss.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cd.step() # clear grads opt_cd.zero_grad() # update progbar pbar.set_description('phase 2 | train loss: %.5f' % loss.cpu()) pbar.update() # test if pbar.n % args.snaperiod_2 == 0: with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask( shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = poisson_blend(x, output, mask) # 泊松融合 imgs = torch.cat( (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_2', 'step%d.png' % pbar.n) model_cd_path = os.path.join( args.result_dir, 'phase_2', 'model_cd_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) if args.data_parallel: torch.save(model_cd.module.state_dict(), model_cd_path) else: torch.save(model_cd.state_dict(), model_cd_path) # terminate if pbar.n >= args.steps_2: break pbar.close() # ================================================ # Training Phase 3 # ================================================ # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_3) pbar.n = 120000 while pbar.n < args.steps_3: for i, x in enumerate(train_loader): # forward model_cd x = x.to(gpu) hole_area_fake = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) mask = gen_input_mask( shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=hole_area_fake, max_holes=args.max_holes, ).to(gpu) # fake forward fake = torch.zeros((len(x), 1)).to(gpu) x_mask = x - x * mask + mpv * mask input_cn = torch.cat((x_mask, mask), dim=1) output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake, input_gd_fake)) loss_cd_fake = bceloss(output_fake, fake) # real forward hole_area_real = gen_hole_area(size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2])) real = torch.ones((len(x), 1)).to(gpu) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area_real) output_real = model_cd((input_ld_real, input_gd_real)) loss_cd_real = bceloss(output_real, real) # reduce loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2. # backward model_cd loss_cd.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: # optimize opt_cd.step() # clear grads opt_cd.zero_grad() # forward model_cn loss_cn_1 = completion_network_loss(x, output_cn, mask) input_gd_fake = output_cn input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake, (input_gd_fake))) loss_cn_2 = bceloss(output_fake, real) loss_cn_3 = contextual_loss(x, output_cn, mask) # reduce loss_cn = (loss_cn_1 + alpha * loss_cn_2 + 4e-3 * loss_cn_3) / 2. # backward model_cn loss_cn.backward() my_writer.add_scalar('mse_loss', loss_cn_1, (90000 + pbar.n) * len(train_loader) + i) if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cn.step() # clear grads opt_cn.zero_grad() # update progbar pbar.set_description( 'phase 3 | train loss (cd): %.5f (cn): %.5f' % (loss_cd.cpu(), loss_cn.cpu())) pbar.update() # test if pbar.n % args.snaperiod_3 == 0: with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask( shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = poisson_blend(x, output, mask) imgs = torch.cat( (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_3', 'step%d.png' % pbar.n) model_cn_path = os.path.join( args.result_dir, 'phase_3', 'model_cn_step%d' % pbar.n) model_cd_path = os.path.join( args.result_dir, 'phase_3', 'model_cd_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) if args.data_parallel: torch.save(model_cn.module.state_dict(), model_cn_path) torch.save(model_cd.module.state_dict(), model_cd_path) else: torch.save(model_cn.state_dict(), model_cn_path) torch.save(model_cd.state_dict(), model_cd_path) # terminate if pbar.n >= args.steps_3: break pbar.close()