def main(args): # ============================================= # Load model # ============================================= with open(args.config, 'r') as f: config = json.load(f) model = CompletionNetwork() model.load_state_dict(torch.load(args.model, map_location='cpu')) # ============================================= # Predict # ============================================= # convert img to tensor image_list = os.listdir(args.input_dir) for image_name in image_list: img = Image.open(args.input_dir + image_name) img = img.resize((256, 256), Image.ANTIALIAS) x = transforms.ToTensor()(img) x = torch.unsqueeze(x, dim=0) # inpaint with torch.no_grad(): #distort_input = distort_images(x) # 扭曲 output = model(x) #distort_output = torch.cat((x, distort_input, distort_output), dim=-1) # 拼接 save_image(output, args.output_dir + image_name, nrow=3) print('distort_output img was saved as %s.' % args.output_dir + image_name)
def main(args): args.model = os.path.expanduser(args.model) args.config = os.path.expanduser(args.config) args.input_img = os.path.expanduser(args.input_img) args.output_img = os.path.expanduser(args.output_img) # ============================================= # Load model # ============================================= with open(args.config, 'r') as f: config = json.load(f) mpv = torch.tensor(config['mpv']).view(1, 3, 1, 1) model = CompletionNetwork() if config['data_parallel']: model = torch.nn.DataParallel(model) model.load_state_dict(torch.load(args.model, map_location='cpu')) # ============================================= # Predict # ============================================= # convert img to tensor img = Image.open(args.input_img) img = transforms.Resize(args.img_size)(img) img = transforms.RandomCrop((args.img_size, args.img_size))(img) x = transforms.ToTensor()(img) x = torch.unsqueeze(x, dim=0) # create mask mask = gen_input_mask( shape=(1, 1, x.shape[2], x.shape[3]), hole_size=( (args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h), ), max_holes=args.max_holes, ) # inpaint with torch.no_grad(): x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model(input) inpainted = poisson_blend(x, output, mask) imgs = torch.cat((x, x_mask, inpainted), dim=0) save_image(imgs, args.output_img, nrow=3) print('output img was saved as %s.' % args.output_img)
def main(args): args.model = os.path.expanduser(args.model) args.config = os.path.expanduser(args.config) args.input_img = os.path.expanduser(args.input_img) args.output_img = os.path.expanduser(args.output_img) # ============================================= # Load model # ============================================= with open(args.config, 'r') as f: config = json.load(f) mpv = config['mean_pv'] model = CompletionNetwork() model.load_state_dict(torch.load(args.model, map_location='cpu')) # ============================================= # Predict # ============================================= # convert img to tensor img = Image.open(args.input_img) img = transforms.Resize(args.img_size)(img) img = transforms.RandomCrop((args.img_size, args.img_size))(img) x = transforms.ToTensor()(img) x = torch.unsqueeze(x, dim=0) # create mask msk = gen_input_mask( shape=x.shape, hole_size=( (args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h), ), max_holes=args.max_holes, ) # inpaint with torch.no_grad(): input = x - x * msk + mpv * msk output = model(input) inpainted = poisson_blend(input, output, msk) imgs = torch.cat((x, input, inpainted), dim=-1) imgs = save_image(imgs, args.output_img, nrow=3) print('output img was saved as %s.' % args.output_img)
def predict(model_path, input_img): model = CompletionNetwork() model.load_state_dict(torch.load(model_path, map_location='cpu')) img = input_img.resize((224, 224)) x = transforms.ToTensor()(img) x = torch.unsqueeze(x, 0) # print(x.shape) model.eval() with torch.no_grad(): output = model(x) # save_image(output, args.output_img, nrow=3) # print('output img was saved as %s.' % args.output_img) return transforms.ToPILImage()(output[0]).convert("RGB")
def main(args): args.model = os.path.expanduser(args.model) args.config = os.path.expanduser(args.config) args.input_img = os.path.expanduser(args.input_img) args.output_img = os.path.expanduser(args.output_img) # ============================================= # Load model # ============================================= with open(args.config, 'r') as f: config = json.load(f) config['mpv'] = 0.13465263 mpv = torch.tensor(config['mpv']).view(1, 1, 1, 1) model = CompletionNetwork() model.load_state_dict(torch.load(args.model, map_location='cpu')) # ============================================= # Predict # ============================================= # convert img to tensor img = Image.open(args.input_img) x = transforms.ToTensor()(img) x = torch.unsqueeze(x, dim=0) # create mask mask = gen_input_mask(shape=(1, 1, x.shape[2], x.shape[3])) # inpaint model.eval() with torch.no_grad(): x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model(input) inpainted = rejoiner(x_mask, output, mask) imgs = torch.cat((x, x_mask, inpainted), dim=0) save_image(imgs, args.output_img, nrow=3) print('output img was saved as %s.' % args.output_img)
def main(args): # ================================================ # Preparation # ================================================ if not torch.cuda.is_available(): raise Exception('At least one gpu must be available.') gpu = torch.device('cuda:0') # create result directory (if necessary) if not os.path.exists(args.result_dir): os.makedirs(args.result_dir) for phase in ['phase_1', 'phase_2', 'phase_3']: if not os.path.exists(os.path.join(args.result_dir, phase)): os.makedirs(os.path.join(args.result_dir, phase)) # load dataset trnsfm = transforms.Compose([ transforms.Resize(args.cn_input_size), transforms.RandomCrop((args.cn_input_size, args.cn_input_size)), transforms.ToTensor(), ]) print('loading dataset... (it may take a few minutes)') train_dset = ImageDataset(os.path.join(args.data_dir, 'train'), trnsfm, recursive_search=args.recursive_search) test_dset = ImageDataset(os.path.join(args.data_dir, 'test'), trnsfm, recursive_search=args.recursive_search) train_loader = DataLoader(train_dset, batch_size=(args.bsize // args.bdivs), shuffle=True) # compute mpv (mean pixel value) of training dataset if args.mpv is None: mpv = np.zeros(shape=(1, )) pbar = tqdm(total=len(train_dset.imgpaths), desc='computing mean pixel value of training dataset...') for imgpath in train_dset.imgpaths: img = Image.open(imgpath) x = np.array(img) / 255. mpv += x.mean(axis=(0, 1)) pbar.update() mpv /= len(train_dset.imgpaths) pbar.close() else: mpv = np.array(args.mpv) # save training config mpv_json = [] for i in range(1): mpv_json.append(float(mpv[i])) args_dict = vars(args) # args_dict['mpv'] = mpv_json with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f: json.dump(args_dict, f) # make mpv & alpha tensors mpv = torch.tensor(mpv.reshape(1, 1, 1, 1), dtype=torch.float32).to(gpu) alpha = torch.tensor(args.alpha, dtype=torch.float32).to(gpu) # ================================================ # Training Phase 1 # ================================================ # load completion network model_cn = CompletionNetwork() if args.init_model_cn is not None: model_cn.load_state_dict( torch.load(args.init_model_cn, map_location='cpu')) if args.data_parallel: model_cn = DataParallel(model_cn) model_cn = model_cn.to(gpu) opt_cn = Adadelta(model_cn.parameters()) # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_1) while pbar.n < args.steps_1: for x in train_loader: # forward x = x.to(gpu) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) loss = completion_network_loss(x, output, mask) # backward loss.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cn.step() opt_cn.zero_grad() pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu()) pbar.update() # test if pbar.n % args.snaperiod_1 == 0: model_cn.eval() with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = rejoiner(x_mask, output, mask) imgs = torch.cat( (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_1', 'step%d.png' % pbar.n) model_cn_path = os.path.join( args.result_dir, 'phase_1', 'model_cn_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) if args.data_parallel: torch.save(model_cn.module.state_dict(), model_cn_path) else: torch.save(model_cn.state_dict(), model_cn_path) model_cn.train() if pbar.n >= args.steps_1: break pbar.close() # ================================================ # Training Phase 2 # ================================================ # load context discriminator model_cd = ContextDiscriminator( local_input_shape=(1, args.ld_input_size, args.ld_input_size), global_input_shape=(1, args.cn_input_size, args.cn_input_size), ) if args.init_model_cd is not None: model_cd.load_state_dict( torch.load(args.init_model_cd, map_location='cpu')) if args.data_parallel: model_cd = DataParallel(model_cd) model_cd = model_cd.to(gpu) opt_cd = Adadelta(model_cd.parameters(), lr=0.1) bceloss = BCELoss() # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_2) while pbar.n < args.steps_2: for x in train_loader: # fake forward x = x.to(gpu) hole_area_fake = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) fake = torch.zeros((len(x), 1)).to(gpu) x_mask = x - x * mask + mpv * mask input_cn = torch.cat((x_mask, mask), dim=1) output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd( (input_ld_fake.to(gpu), input_gd_fake.to(gpu))) loss_fake = bceloss(output_fake, fake) # real forward hole_area_real = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) real = torch.ones((len(x), 1)).to(gpu) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area_real) output_real = model_cd((input_ld_real, input_gd_real)) loss_real = bceloss(output_real, real) # reduce loss = (loss_fake + loss_real) / 2. # backward loss.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cd.step() opt_cd.zero_grad() pbar.set_description('phase 2 | train loss: %.5f' % loss.cpu()) pbar.update() # test if pbar.n % args.snaperiod_2 == 0: model_cn.eval() with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = rejoiner(x_mask, output, mask) imgs = torch.cat( (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_2', 'step%d.png' % pbar.n) model_cd_path = os.path.join( args.result_dir, 'phase_2', 'model_cd_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) if args.data_parallel: torch.save(model_cd.module.state_dict(), model_cd_path) else: torch.save(model_cd.state_dict(), model_cd_path) model_cn.train() if pbar.n >= args.steps_2: break pbar.close() # ================================================ # Training Phase 3 # ================================================ cnt_bdivs = 0 pbar = tqdm(total=args.steps_3) while pbar.n < args.steps_3: for x in train_loader: # forward model_cd x = x.to(gpu) hole_area_fake = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) # fake forward fake = torch.zeros((len(x), 1)).to(gpu) x_mask = x - x * mask + mpv * mask input_cn = torch.cat((x_mask, mask), dim=1) output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake, input_gd_fake)) loss_cd_fake = bceloss(output_fake, fake) # real forward hole_area_real = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) real = torch.ones((len(x), 1)).to(gpu) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area_real) output_real = model_cd((input_ld_real, input_gd_real)) loss_cd_real = bceloss(output_real, real) # reduce loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2. # backward model_cd loss_cd.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: # optimize opt_cd.step() opt_cd.zero_grad() # forward model_cn loss_cn_1 = completion_network_loss(x, output_cn, mask) input_gd_fake = output_cn input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake, (input_gd_fake))) loss_cn_2 = bceloss(output_fake, real) # reduce loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2. # backward model_cn loss_cn.backward() if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cn.step() opt_cn.zero_grad() pbar.set_description( 'phase 3 | train loss (cd): %.5f (cn): %.5f' % (loss_cd.cpu(), loss_cn.cpu())) pbar.update() # test if pbar.n % args.snaperiod_3 == 0: model_cn.eval() with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2], x.shape[3]), ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = rejoiner(x_mask, output, mask) imgs = torch.cat( (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_3', 'step%d.png' % pbar.n) model_cn_path = os.path.join( args.result_dir, 'phase_3', 'model_cn_step%d' % pbar.n) model_cd_path = os.path.join( args.result_dir, 'phase_3', 'model_cd_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) if args.data_parallel: torch.save(model_cn.module.state_dict(), model_cn_path) torch.save(model_cd.module.state_dict(), model_cd_path) else: torch.save(model_cn.state_dict(), model_cn_path) torch.save(model_cd.state_dict(), model_cd_path) model_cn.train() if pbar.n >= args.steps_3: break pbar.close()
]) print('loading dataset... (it may take a few minutes)') train_dset = ImageDataset(args.data_dir, trnsfm, trnsfm2) test_dset = ImageDataset(args.data_dir, trnsfm, trnsfm2) train_loader = DataLoader(train_dset, batch_size=(args.bsize // args.bdivs), shuffle=True, num_workers=5) alpha = torch.tensor(args.alpha).to(gpu) #################################################################################### # Create model G model_cn = CompletionNetwork() if args.data_parallel: model_cn = DataParallel(model_cn) if args.init_model_cn != None: model_cn.load_state_dict(torch.load(args.init_model_cn, map_location='cpu')) if args.optimizer == 'adadelta': opt_cn = Adadelta(model_cn.parameters(), lr=0.8) else: opt_cn = Adam(model_cn.parameters()) model_cn = model_cn.to(gpu) # Create model D model_cd = GlobalDiscriminator_P((3, 640, 480)) if args.data_parallel:
def main(args): # ================================================ # Preparation # ================================================ args.data_dir = os.path.expanduser(args.data_dir) args.result_dir = os.path.expanduser(args.result_dir) if torch.cuda.is_available() == False: raise Exception('At least one gpu must be available.') if args.num_gpus == 1: # train models in a single gpu gpu_cn = torch.device('cuda:0') gpu_cd = gpu_cn else: # train models in different two gpus gpu_cn = torch.device('cuda:0') gpu_cd = torch.device('cuda:1') # create result directory (if necessary) if os.path.exists(args.result_dir) == False: os.makedirs(args.result_dir) for s in ['phase_1', 'phase_2', 'phase_3']: if os.path.exists(os.path.join(args.result_dir, s)) == False: os.makedirs(os.path.join(args.result_dir, s)) # dataset trnsfm = transforms.Compose([ transforms.Resize(args.cn_input_size), transforms.RandomCrop((args.cn_input_size, args.cn_input_size)), transforms.ToTensor(), ]) print('loading dataset... (it may take a few minutes)') train_dset = ImageDataset(os.path.join(args.data_dir, 'train'), trnsfm) test_dset = ImageDataset(os.path.join(args.data_dir, 'test'), trnsfm) train_loader = DataLoader(train_dset, batch_size=args.bsize, shuffle=True) # compute the mean pixel value of train dataset mean_pv = 0. imgpaths = train_dset.imgpaths[:min(args.max_mpv_samples, len(train_dset))] if args.comp_mpv: pbar = tqdm(total=len(imgpaths), desc='computing the mean pixel value') for imgpath in imgpaths: img = Image.open(imgpath) x = np.array(img, dtype=np.float32) / 255. mean_pv += x.mean() pbar.update() mean_pv /= len(imgpaths) pbar.close() mpv = torch.tensor(mean_pv).to(gpu_cn) # save training config args_dict = vars(args) args_dict['mean_pv'] = mean_pv with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f: json.dump(args_dict, f) # ================================================ # Training Phase 1 # ================================================ # model & optimizer model_cn = CompletionNetwork() model_cn = model_cn.to(gpu_cn) if args.optimizer == 'adadelta': opt_cn = Adadelta(model_cn.parameters()) else: opt_cn = Adam(model_cn.parameters()) # training pbar = tqdm(total=args.steps_1) while pbar.n < args.steps_1: for x in train_loader: opt_cn.zero_grad() # generate hole area hole_area = gen_hole_area( size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2]), ) # create mask msk = gen_input_mask( shape=x.shape, hole_size=( (args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h), ), hole_area=hole_area, max_holes=args.max_holes, ) # merge x, mask, and mpv msg = 'phase 1 |' x = x.to(gpu_cn) msk = msk.to(gpu_cn) input = x - x * msk + mpv * msk output = model_cn(input) # optimize loss = completion_network_loss(x, output, msk) loss.backward() opt_cn.step() msg += ' train loss: %.5f' % loss.cpu() pbar.set_description(msg) pbar.update() # test if pbar.n % args.snaperiod_1 == 0: with torch.no_grad(): x = sample_random_batch(test_dset, batch_size=args.bsize) x = x.to(gpu_cn) input = x - x * msk + mpv * msk output = model_cn(input) completed = poisson_blend(input, output, msk) imgs = torch.cat((input.cpu(), completed.cpu()), dim=0) save_image(imgs, os.path.join(args.result_dir, 'phase_1', 'step%d.png' % pbar.n), nrow=len(x)) torch.save( model_cn.state_dict(), os.path.join(args.result_dir, 'phase_1', 'model_cn_step%d' % pbar.n)) if pbar.n >= args.steps_1: break pbar.close() # ================================================ # Training Phase 2 # ================================================ # model, optimizer & criterion model_cd = ContextDiscriminator( local_input_shape=(3, args.ld_input_size, args.ld_input_size), global_input_shape=(3, args.cn_input_size, args.cn_input_size), ) model_cd = model_cd.to(gpu_cd) if args.optimizer == 'adadelta': opt_cd = Adadelta(model_cd.parameters()) else: opt_cd = Adam(model_cd.parameters()) criterion_cd = BCELoss() # training pbar = tqdm(total=args.steps_2) while pbar.n < args.steps_2: for x in train_loader: x = x.to(gpu_cn) opt_cd.zero_grad() # ================================================ # fake # ================================================ hole_area = gen_hole_area( size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2]), ) # create mask msk = gen_input_mask( shape=x.shape, hole_size=( (args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h), ), hole_area=hole_area, max_holes=args.max_holes, ) fake = torch.zeros((len(x), 1)).to(gpu_cd) msk = msk.to(gpu_cn) input_cn = x - x * msk + mpv * msk output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area) input_fake = (input_ld_fake.to(gpu_cd), input_gd_fake.to(gpu_cd)) output_fake = model_cd(input_fake) loss_fake = criterion_cd(output_fake, fake) # ================================================ # real # ================================================ hole_area = gen_hole_area( size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2]), ) real = torch.ones((len(x), 1)).to(gpu_cd) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area) input_real = (input_ld_real.to(gpu_cd), input_gd_real.to(gpu_cd)) output_real = model_cd(input_real) loss_real = criterion_cd(output_real, real) # ================================================ # optimize # ================================================ loss = (loss_fake + loss_real) / 2. loss.backward() opt_cd.step() msg = 'phase 2 |' msg += ' train loss: %.5f' % loss.cpu() pbar.set_description(msg) pbar.update() # test if pbar.n % args.snaperiod_2 == 0: with torch.no_grad(): x = sample_random_batch(test_dset, batch_size=args.bsize) x = x.to(gpu_cn) input = x - x * msk + mpv * msk output = model_cn(input) completed = poisson_blend(input, output, msk) imgs = torch.cat((input.cpu(), completed.cpu()), dim=0) save_image(imgs, os.path.join(args.result_dir, 'phase_2', 'step%d.png' % pbar.n), nrow=len(x)) torch.save( model_cd.state_dict(), os.path.join(args.result_dir, 'phase_2', 'model_cd_step%d' % pbar.n)) if pbar.n >= args.steps_2: break pbar.close() # ================================================ # Training Phase 3 # ================================================ # training alpha = torch.tensor(args.alpha).to(gpu_cd) pbar = tqdm(total=args.steps_3) while pbar.n < args.steps_3: for x in train_loader: x = x.to(gpu_cn) # ================================================ # train model_cd # ================================================ opt_cd.zero_grad() # fake hole_area = gen_hole_area( size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2]), ) # create mask msk = gen_input_mask( shape=x.shape, hole_size=( (args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h), ), hole_area=hole_area, max_holes=args.max_holes, ) fake = torch.zeros((len(x), 1)).to(gpu_cd) msk = msk.to(gpu_cn) input_cn = x - x * msk + mpv * msk output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area) input_fake = (input_ld_fake.to(gpu_cd), input_gd_fake.to(gpu_cd)) output_fake = model_cd(input_fake) loss_cd_1 = criterion_cd(output_fake, fake) # real hole_area = gen_hole_area( size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2]), ) real = torch.ones((len(x), 1)).to(gpu_cd) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area) input_real = (input_ld_real.to(gpu_cd), input_gd_real.to(gpu_cd)) output_real = model_cd(input_real) loss_cd_2 = criterion_cd(output_real, real) # optimize loss_cd = (loss_cd_1 + loss_cd_2) * alpha / 2. loss_cd.backward() opt_cd.step() # ================================================ # train model_cn # ================================================ opt_cn.zero_grad() loss_cn_1 = completion_network_loss(x, output_cn, msk).to(gpu_cd) input_gd_fake = output_cn input_ld_fake = crop(input_gd_fake, hole_area) input_fake = (input_ld_fake.to(gpu_cd), input_gd_fake.to(gpu_cd)) output_fake = model_cd(input_fake) loss_cn_2 = criterion_cd(output_fake, real) # optimize loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2. loss_cn.backward() opt_cn.step() msg = 'phase 3 |' msg += ' train loss (cd): %.5f' % loss_cd.cpu() msg += ' train loss (cn): %.5f' % loss_cn.cpu() pbar.set_description(msg) pbar.update() # test if pbar.n % args.snaperiod_3 == 0: with torch.no_grad(): x = sample_random_batch(test_dset, batch_size=args.bsize) x = x.to(gpu_cn) input = x - x * msk + mpv * msk output = model_cn(input) completed = poisson_blend(input, output, msk) imgs = torch.cat((input.cpu(), completed.cpu()), dim=0) save_image(imgs, os.path.join(args.result_dir, 'phase_3', 'step%d.png' % pbar.n), nrow=len(x)) torch.save( model_cn.state_dict(), os.path.join(args.result_dir, 'phase_3', 'model_cn_step%d' % pbar.n)) torch.save( model_cd.state_dict(), os.path.join(args.result_dir, 'phase_3', 'model_cd_step%d' % pbar.n)) if pbar.n >= args.steps_3: break pbar.close()
def main(args): args.model = os.path.expanduser(args.model) args.config = os.path.expanduser(args.config) args.input_img = os.path.expanduser(args.input_img) args.output_img = os.path.expanduser(args.output_img) # ============================================= # Load model # ============================================= with open(args.config, 'r') as f: config = json.load(f) mpv = torch.tensor(config['mpv']).view(1, 1, 1, 1) model = CompletionNetwork() model.load_state_dict(torch.load(args.model, map_location='cpu')) # ============================================= # Predict # ============================================= # convert img to tensor import torchvision as tv img = Image.open(args.input_img) #img = tv.transforms.Grayscale(num_output_channels=1), img = transforms.Resize(args.img_size)(img) img = transforms.RandomCrop((args.img_size, args.img_size))(img) x = transforms.ToTensor()(img) x = torch.unsqueeze(x, dim=0) # create mask mask = gen_input_mask( shape=(1, 1, x.shape[2], x.shape[3]), hole_size=( (args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h), ), max_holes=args.max_holes, ) #print(mask.shape) #print(mask) temp_str = str(args.input_img).replace("test", "masks") temp_index = len(temp_str) - 4 out_img = torch.Tensor() for i in range(3): #print(mask_filename) mask_filename = temp_str[:temp_index] + '_mask' + str( i) + temp_str[temp_index:] mask_img = Image.open(mask_filename).convert('L') mask_img_inverted = PIL.ImageOps.invert(mask_img) #mask_transformed = mask_trans(mask_img_inverted) mask_trans = transforms.ToTensor() mask_transformed = mask_trans(mask_img) mask_shape = (1, 1, x.shape[2], x.shape[3]) new_mask = torch.zeros(mask_shape) new_mask[0, 0, :, :] = mask_transformed mask = new_mask with torch.no_grad(): x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model(input) inpainted = poisson_blend(x, output, mask) binary_out = inpainted.clone() binary_out = gray_to_binary(binary_out) out_img = torch.cat((out_img, x, x_mask, inpainted, binary_out), dim=0) save_image(out_img, args.output_img, nrow=4) print('output img was saved as %s.' % args.output_img)
def main(args): # ================================================ # Preparation # ================================================ args.data_dir = os.path.expanduser(args.data_dir) args.result_dir = os.path.expanduser(args.result_dir) if args.init_model_cn != None: args.init_model_cn = os.path.expanduser(args.init_model_cn) if args.init_model_cd != None: args.init_model_cd = os.path.expanduser(args.init_model_cd) if torch.cuda.is_available() == False: raise Exception('At least one gpu must be available.') else: gpu = torch.device('cuda:0') # create result directory (if necessary) if os.path.exists(args.result_dir) == False: os.makedirs(args.result_dir) for s in ['phase_1', 'phase_2', 'phase_3']: if os.path.exists(os.path.join(args.result_dir, s)) == False: os.makedirs(os.path.join(args.result_dir, s)) # dataset trnsfm = transforms.Compose([ transforms.Resize(args.cn_input_size), transforms.RandomCrop((args.cn_input_size, args.cn_input_size)), transforms.ToTensor(), ]) print('loading dataset... (it may take a few minutes)') train_dset = ImageDataset(os.path.join(args.data_dir, 'train'), trnsfm) test_dset = ImageDataset(os.path.join(args.data_dir, 'test'), trnsfm) train_loader = DataLoader(train_dset, batch_size=(args.bsize // args.bdivs), shuffle=True) # compute mean pixel value of training dataset mpv = 0. if args.mpv == None: pbar = tqdm(total=len(train_dset.imgpaths), desc='computing mean pixel value for training dataset...') for imgpath in train_dset.imgpaths: img = Image.open(imgpath) x = np.array(img, dtype=np.float32) / 255. mpv += x.mean() pbar.update() mpv /= len(train_dset.imgpaths) pbar.close() else: mpv = args.mpv mpv = torch.tensor(mpv).to(gpu) alpha = torch.tensor(args.alpha).to(gpu) # save training config args_dict = vars(args) args_dict['mpv'] = float(mpv) with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f: json.dump(args_dict, f) # ================================================ # Training Phase 1 # ================================================ model_cn = CompletionNetwork() if args.data_parallel: model_cn = DataParallel(model_cn) if args.init_model_cn != None: model_cn.load_state_dict(torch.load(args.init_model_cn, map_location='cpu')) if args.optimizer == 'adadelta': opt_cn = Adadelta(model_cn.parameters()) else: opt_cn = Adam(model_cn.parameters()) model_cn = model_cn.to(gpu) # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_1) while pbar.n < args.steps_1: for x in train_loader: # forward x = x.to(gpu) msk = gen_input_mask( shape=x.shape, hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) output = model_cn(x - x * msk + mpv * msk) loss = completion_network_loss(x, output, msk) # backward loss.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cn.step() # clear grads opt_cn.zero_grad() # update progbar pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu()) pbar.update() # test if pbar.n % args.snaperiod_1 == 0: with torch.no_grad(): x = sample_random_batch(test_dset, batch_size=args.num_test_completions).to(gpu) msk = gen_input_mask( shape=x.shape, hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) input = x - x * msk + mpv * msk output = model_cn(input) completed = poisson_blend(input, output, msk) imgs = torch.cat((x.cpu(), input.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_1', 'step%d.png' % pbar.n) model_cn_path = os.path.join(args.result_dir, 'phase_1', 'model_cn_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) torch.save(model_cn.state_dict(), model_cn_path) # terminate if pbar.n >= args.steps_1: break pbar.close() # ================================================ # Training Phase 2 # ================================================ model_cd = ContextDiscriminator( local_input_shape=(3, args.ld_input_size, args.ld_input_size), global_input_shape=(3, args.cn_input_size, args.cn_input_size), ) if args.data_parallel: model_cd = DataParallel(model_cd) if args.init_model_cd != None: model_cd.load_state_dict(torch.load(args.init_model_cd, map_location='cpu')) if args.optimizer == 'adadelta': opt_cd = Adadelta(model_cd.parameters()) else: opt_cd = Adam(model_cd.parameters()) model_cd = model_cd.to(gpu) bceloss = BCELoss() # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_2) while pbar.n < args.steps_2: for x in train_loader: # fake forward x = x.to(gpu) hole_area_fake = gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) msk = gen_input_mask( shape=x.shape, hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=hole_area_fake, max_holes=args.max_holes, ).to(gpu) fake = torch.zeros((len(x), 1)).to(gpu) output_cn = model_cn(x - x * msk + mpv * msk) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake.to(gpu), input_gd_fake.to(gpu))) loss_fake = bceloss(output_fake, fake) # real forward hole_area_real = gen_hole_area(size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2])) real = torch.ones((len(x), 1)).to(gpu) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area_real) output_real = model_cd((input_ld_real, input_gd_real)) loss_real = bceloss(output_real, real) # reduce loss = (loss_fake + loss_real) / 2. # backward loss.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cd.step() # clear grads opt_cd.zero_grad() # update progbar pbar.set_description('phase 2 | train loss: %.5f' % loss.cpu()) pbar.update() # test if pbar.n % args.snaperiod_2 == 0: with torch.no_grad(): x = sample_random_batch(test_dset, batch_size=args.num_test_completions).to(gpu) msk = gen_input_mask( shape=x.shape, hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) input = x - x * msk + mpv * msk output = model_cn(input) completed = poisson_blend(input, output, msk) imgs = torch.cat((x.cpu(), input.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_2', 'step%d.png' % pbar.n) model_cd_path = os.path.join(args.result_dir, 'phase_2', 'model_cd_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) torch.save(model_cd.state_dict(), model_cd_path) # terminate if pbar.n >= args.steps_2: break pbar.close() # ================================================ # Training Phase 3 # ================================================ # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_3) while pbar.n < args.steps_3: for x in train_loader: # forward model_cd x = x.to(gpu) hole_area_fake = gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) msk = gen_input_mask( shape=x.shape, hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=hole_area_fake, max_holes=args.max_holes, ).to(gpu) # fake forward fake = torch.zeros((len(x), 1)).to(gpu) output_cn = model_cn(x - x * msk + mpv * msk) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake, input_gd_fake)) loss_cd_fake = bceloss(output_fake, fake) # real forward hole_area_real = gen_hole_area(size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2])) real = torch.ones((len(x), 1)).to(gpu) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area_real) output_real = model_cd((input_ld_real, input_gd_real)) loss_cd_real = bceloss(output_real, real) # reduce loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2. # backward model_cd loss_cd.backward() # forward model_cn loss_cn_1 = completion_network_loss(x, output_cn, msk) input_gd_fake = output_cn input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake, (input_gd_fake))) loss_cn_2 = bceloss(output_fake, real) # reduce loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2. # backward model_cn loss_cn.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cn.step() opt_cn.step() # clear grads opt_cd.zero_grad() opt_cn.zero_grad() # update progbar pbar.set_description('phase 3 | train loss (cd): %.5f (cn): %.5f' % (loss_cd.cpu(), loss_cn.cpu())) pbar.update() # test if pbar.n % args.snaperiod_3 == 0: with torch.no_grad(): x = sample_random_batch(test_dset, batch_size=args.num_test_completions).to(gpu) msk = gen_input_mask( shape=x.shape, hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) input = x - x * msk + mpv * msk output = model_cn(input) completed = poisson_blend(input, output, msk) imgs = torch.cat((x.cpu(), input.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_3', 'step%d.png' % pbar.n) model_cn_path = os.path.join(args.result_dir, 'phase_3', 'model_cn_step%d' % pbar.n) model_cd_path = os.path.join(args.result_dir, 'phase_3', 'model_cd_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) torch.save(model_cn.state_dict(), model_cn_path) torch.save(model_cd.state_dict(), model_cd_path) # terminate if pbar.n >= args.steps_3: break pbar.close()
def GAN_patching_inputs(images, predicted): # images and its predicted tensors global N model = CompletionNetwork() model.load_state_dict(torch.load("cifar10_inpainting", map_location='cuda')) model.eval() model = model.to(device) cleanimgs = list(range(len(images))) # GAN inpainted # This is to apply Grad CAM to the load images # -------------------------------------------- for j in range(len(images)): N += 1 image = images[j] image = unnormalize(image) # unnormalize to [0 1] to feed into GAN image = torch.unsqueeze(image, 0) # unsqueeze meaning adding 1D to the tensor start_time = time.time() mask = gcam(image) # get the mask through GradCAM cond_mask = mask >= MASK_COND mask = cond_mask.astype(int) # --------------------------------------- mask = np.expand_dims(mask, axis=0) # add 1D to mask mask = np.expand_dims(mask, axis=0) mask = torch.tensor(mask) # convert mask to tensor 1,1,32,32 mask = mask.type(torch.FloatTensor) mask = mask.to(device) x = image # original test image mpv = [0.4914655575466156, 0.4821903321331739, 0.4465675537097454] mpv = torch.tensor(mpv).view(1, 3, 1, 1) mpv = mpv.to(device) # inpaint with torch.no_grad(): x_mask = x - x * mask + mpv * mask # generate the occluded input [0 1] inputx = torch.cat((x_mask, mask), dim=1) output = model( inputx) # generate the output for the occluded input [0 1] end_time = time.time() GAN_process_time = 1000.0 * (end_time - start_time ) # convert to ms GAN_process_time = round(GAN_process_time, 3) np.savetxt('runtime.csv', (N, GAN_process_time), delimiter=',') # image restoration inpainted = poisson_blend_old(x_mask, output, mask) # this is GAN output [0 1] inpainted = inpainted.to(device) # store GAN output clean_input = inpainted clean_input = normalize_tensor_batch( clean_input) # normalize to [-1 1] clean_input = torch.squeeze( clean_input) # remove the 1st dimension cleanimgs[j] = clean_input.cpu().numpy() # store to a list # this is tensor for GAN output cleanimgs_tensor = torch.from_numpy(np.asarray(cleanimgs)) cleanimgs_tensor = cleanimgs_tensor.type(torch.FloatTensor) cleanimgs_tensor = cleanimgs_tensor.to(device) return cleanimgs_tensor
def main(args): # ================================================ # Preparation # ================================================ args.data_dir = os.path.expanduser(args.data_dir) args.result_dir = os.path.expanduser(args.result_dir) if args.init_model_cn != None: args.init_model_cn = os.path.expanduser(args.init_model_cn) if args.init_model_cd != None: args.init_model_cd = os.path.expanduser(args.init_model_cd) if torch.cuda.is_available() == False: raise Exception('At least one gpu must be available.') else: gpu = torch.device('cuda:0') # create result directory (if necessary) if os.path.exists(args.result_dir) == False: os.makedirs(args.result_dir) for s in ['phase_1', 'phase_2', 'phase_3']: if os.path.exists(os.path.join(args.result_dir, s)) == False: os.makedirs(os.path.join(args.result_dir, s)) # dataset trnsfm = transforms.Compose([ transforms.Resize(args.cn_input_size), transforms.RandomCrop((args.cn_input_size, args.cn_input_size)), transforms.ToTensor(), ]) print('loading dataset... (it may take a few minutes)') train_dset = ImageDataset(os.path.join(args.data_dir, 'train'), trnsfm, recursive_search=args.recursive_search) test_dset = ImageDataset(os.path.join(args.data_dir, 'test'), trnsfm, recursive_search=args.recursive_search) train_loader = DataLoader(train_dset, batch_size=(args.bsize // args.bdivs), shuffle=True) # compute mean pixel value of training dataset mpv = np.zeros(shape=(3, )) if args.mpv == None: pbar = tqdm(total=len(train_dset.imgpaths), desc='computing mean pixel value for training dataset...') for imgpath in train_dset.imgpaths: img = Image.open(imgpath) x = np.array(img, dtype=np.float32) / 255. mpv += x.mean(axis=(0, 1)) pbar.update() mpv /= len(train_dset.imgpaths) pbar.close() else: mpv = np.array(args.mpv) # save training config mpv_json = [] for i in range(3): mpv_json.append(float(mpv[i])) # convert to json serializable type args_dict = vars(args) args_dict['mpv'] = mpv_json with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f: json.dump(args_dict, f) # make mpv & alpha tensor mpv = torch.tensor(mpv.astype(np.float32).reshape(1, 3, 1, 1)).to(gpu) alpha = torch.tensor(args.alpha).to(gpu) my_writer = SummaryWriter(log_dir='log') # ================================================ # Training Phase 1 # ================================================ model_cn = CompletionNetwork() if args.data_parallel: model_cn = DataParallel(model_cn) if args.init_model_cn != None: new = OrderedDict() for k, v in torch.load(args.init_model_cn, map_location='cpu').items(): new['module.' + k] = v # model_cn.load_state_dict(torch.load(args.init_model_cn, map_location='cpu')) model_cn.load_state_dict(new) print('第一阶段加载模型成功!') # model_cn.load_state_dict(torch.load(args.init_model_cn, map_location='cpu')) if args.optimizer == 'adadelta': opt_cn = Adadelta(model_cn.parameters()) else: opt_cn = Adam(model_cn.parameters()) model_cn = model_cn.to(gpu) # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_1) pbar.n = 90000 while pbar.n < args.steps_1: for i, x in enumerate(train_loader): # forward x = x.to(gpu) mask = gen_input_mask( shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) loss_mse = completion_network_loss(x, output, mask) loss_contextual = contextual_loss(x, output, mask) loss = loss_mse + 0.004 * loss_contextual # backward loss.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cn.step() # clear grads opt_cn.zero_grad() # update progbar pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu()) pbar.update() my_writer.add_scalar('mse_loss', loss_mse, pbar.n * len(train_loader) + i) # test if pbar.n % args.snaperiod_1 == 0: with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask( shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = poisson_blend(x, output, mask) imgs = torch.cat( (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_1', 'step%d.png' % pbar.n) model_cn_path = os.path.join( args.result_dir, 'phase_1', 'model_cn_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) if args.data_parallel: torch.save(model_cn.module.state_dict(), model_cn_path) else: torch.save(model_cn.state_dict(), model_cn_path) # terminate if pbar.n >= args.steps_1: break pbar.close() # ================================================ # Training Phase 2 # ================================================ model_cd = ContextDiscriminator( local_input_shape=(3, args.ld_input_size, args.ld_input_size), global_input_shape=(3, args.cn_input_size, args.cn_input_size), arc=args.arc, ) if args.data_parallel: model_cd = DataParallel(model_cd) if args.init_model_cd != None: # model_cd.load_state_dict(torch.load(args.init_model_cd, map_location='cpu')) new = OrderedDict() for k, v in torch.load(args.init_model_cd, map_location='cpu').items(): new['module.' + k] = v # model_cn.load_state_dict(torch.load(args.init_model_cn, map_location='cpu')) model_cd.load_state_dict(new) print('第二阶段加载模型成功!') if args.optimizer == 'adadelta': opt_cd = Adadelta(model_cd.parameters()) else: opt_cd = Adam(model_cd.parameters()) model_cd = model_cd.to(gpu) bceloss = BCELoss() # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_2) pbar.n = 120000 while pbar.n < args.steps_2: for x in train_loader: # fake forward x = x.to(gpu) hole_area_fake = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) mask = gen_input_mask( shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=hole_area_fake, max_holes=args.max_holes, ).to(gpu) fake = torch.zeros((len(x), 1)).to(gpu) x_mask = x - x * mask + mpv * mask input_cn = torch.cat((x_mask, mask), dim=1) output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() # 输入全局判别器的生成图片 input_ld_fake = crop(input_gd_fake, hole_area_fake) # 输入局部判别器的生成图片 output_fake = model_cd( (input_ld_fake.to(gpu), input_gd_fake.to(gpu))) loss_fake = bceloss(output_fake, fake) # real forward hole_area_real = gen_hole_area(size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2])) real = torch.ones((len(x), 1)).to(gpu) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area_real) # 输入全局判别器的真实图片 output_real = model_cd( (input_ld_real, input_gd_real)) # 输入局部判别器的生成图片 loss_real = bceloss(output_real, real) # reduce loss = (loss_fake + loss_real) / 2. # backward loss.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cd.step() # clear grads opt_cd.zero_grad() # update progbar pbar.set_description('phase 2 | train loss: %.5f' % loss.cpu()) pbar.update() # test if pbar.n % args.snaperiod_2 == 0: with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask( shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = poisson_blend(x, output, mask) # 泊松融合 imgs = torch.cat( (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_2', 'step%d.png' % pbar.n) model_cd_path = os.path.join( args.result_dir, 'phase_2', 'model_cd_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) if args.data_parallel: torch.save(model_cd.module.state_dict(), model_cd_path) else: torch.save(model_cd.state_dict(), model_cd_path) # terminate if pbar.n >= args.steps_2: break pbar.close() # ================================================ # Training Phase 3 # ================================================ # training cnt_bdivs = 0 pbar = tqdm(total=args.steps_3) pbar.n = 120000 while pbar.n < args.steps_3: for i, x in enumerate(train_loader): # forward model_cd x = x.to(gpu) hole_area_fake = gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) mask = gen_input_mask( shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=hole_area_fake, max_holes=args.max_holes, ).to(gpu) # fake forward fake = torch.zeros((len(x), 1)).to(gpu) x_mask = x - x * mask + mpv * mask input_cn = torch.cat((x_mask, mask), dim=1) output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake, input_gd_fake)) loss_cd_fake = bceloss(output_fake, fake) # real forward hole_area_real = gen_hole_area(size=(args.ld_input_size, args.ld_input_size), mask_size=(x.shape[3], x.shape[2])) real = torch.ones((len(x), 1)).to(gpu) input_gd_real = x input_ld_real = crop(input_gd_real, hole_area_real) output_real = model_cd((input_ld_real, input_gd_real)) loss_cd_real = bceloss(output_real, real) # reduce loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2. # backward model_cd loss_cd.backward() cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: # optimize opt_cd.step() # clear grads opt_cd.zero_grad() # forward model_cn loss_cn_1 = completion_network_loss(x, output_cn, mask) input_gd_fake = output_cn input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake, (input_gd_fake))) loss_cn_2 = bceloss(output_fake, real) loss_cn_3 = contextual_loss(x, output_cn, mask) # reduce loss_cn = (loss_cn_1 + alpha * loss_cn_2 + 4e-3 * loss_cn_3) / 2. # backward model_cn loss_cn.backward() my_writer.add_scalar('mse_loss', loss_cn_1, (90000 + pbar.n) * len(train_loader) + i) if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cn.step() # clear grads opt_cn.zero_grad() # update progbar pbar.set_description( 'phase 3 | train loss (cd): %.5f (cn): %.5f' % (loss_cd.cpu(), loss_cn.cpu())) pbar.update() # test if pbar.n % args.snaperiod_3 == 0: with torch.no_grad(): x = sample_random_batch( test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask( shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area( (args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) x_mask = x - x * mask + mpv * mask input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = poisson_blend(x, output, mask) imgs = torch.cat( (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_3', 'step%d.png' % pbar.n) model_cn_path = os.path.join( args.result_dir, 'phase_3', 'model_cn_step%d' % pbar.n) model_cd_path = os.path.join( args.result_dir, 'phase_3', 'model_cd_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) if args.data_parallel: torch.save(model_cn.module.state_dict(), model_cn_path) torch.save(model_cd.module.state_dict(), model_cd_path) else: torch.save(model_cn.state_dict(), model_cn_path) torch.save(model_cd.state_dict(), model_cd_path) # terminate if pbar.n >= args.steps_3: break pbar.close()
def main(args): if args.wandb != "tmp": wandb.init(project=args.wandb, config=args) if not torch.cuda.is_available(): raise Exception('At least one gpu must be available.') gpu = torch.device('cuda:0') # create result directory (if necessary) if not os.path.exists(args.result_dir): os.makedirs(args.result_dir) for phase in ['phase_1', 'phase_2', 'phase_3']: if not os.path.exists(os.path.join(args.result_dir, phase)): os.makedirs(os.path.join(args.result_dir, phase)) # load dataset print('loading dataset... (it may take a few minutes)') train_dset = masked_dataset(os.path.join(args.data_dir, 'train'), args.max_train) test_dset = masked_dataset(os.path.join(args.data_dir, 'test'), args.max_test) train_loader = DataLoader(train_dset, batch_size=(args.batch_size), shuffle=True) alpha = torch.tensor(args.alpha, dtype=torch.float32).to(gpu) model_cn = CompletionNetwork() if args.init_model_cn is not None: model_cn.load_state_dict( torch.load(args.init_model_cn, map_location='cpu')) model_cn = model_cn.to(gpu) model_cn.train() opt_cn = Adam(model_cn.parameters(), lr=args.learning_rate) # training # ================================================ # Training Phase 1 # ================================================ pbar = tqdm(total=args.steps_1) epochs = 0 while epochs < args.steps_1: for i, (normal, masked) in tqdm(enumerate(train_loader, 0)): # forward # normal = torch.autograd.Variable(normal,requires_grad=True).to(gpu) # masked = torch.autograd.Variable(normal,requires_grad=True).to(gpu) output = model_cn(masked.to(gpu)) loss = torch.nn.functional.mse_loss(output, normal.to(gpu)) # backward loss.backward() # optimize opt_cn.step() opt_cn.zero_grad() if args.wandb != "tmp": wandb.log({"phase_1_train_loss": loss.cpu()}) pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu()) pbar.update() # test model_cn.eval() with torch.no_grad(): normal, masked = sample_random_batch( test_dset, batch_size=args.num_test_completions) normal = normal.to(gpu) masked = masked.to(gpu) output = model_cn(masked) # completed = output imgs = torch.cat((masked.cpu(), normal.cpu(), output.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_1', 'step%d.png' % pbar.n) model_cn_path = os.path.join(args.result_dir, 'phase_1', 'model_cn_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(masked)) torch.save(model_cn.state_dict(), model_cn_path) model_cn.train() epochs += 1 pbar.close()
trnsfm2 = transforms.Compose([ transforms.ToTensor(), ]) print('loading dataset... (it may take a few minutes)') train_dset = ImageDataset(args.data_dir, trnsfm, trnsfm2, load2meme=False) train_loader = DataLoader(train_dset, batch_size=(args.bsize // args.bdivs), shuffle=False, num_workers=5) alpha = torch.tensor(args.alpha).to(gpu) # Create model G model_cn = CompletionNetwork() if args.data_parallel: model_cn = DataParallel(model_cn) if args.init_model_cn != None: model_cn.load_state_dict(torch.load(args.init_model_cn, map_location='cpu')) model_cn = model_cn.to(gpu) transPIL = transforms.ToPILImage() def inference_G(): cnt_bdivs = 0 #Pro.start()