def main(): parser = argparse.ArgumentParser() parser.add_argument('--config_path', type=str, default='gated_convolution.yml', help='path to config file') parser.add_argument('--model', type=str, default='v1', help='model type: v1 or v2') parser.add_argument('--snapshot', type=str, default='model.npz', help='path to snapshot') parser.add_argument('--name', type=str, default='result.png', help='file name to save results') args = parser.parse_args() config = Config(args.config_path) if args.model == "v1": config.FREE_FORM = False inpaint_model = InpaintCAModel(config) elif args.model == "v2": inpaint_model = InpaintGCModel(config) else: assert False, "Model name '{args.model}' is invalid." if config.GPU_ID >= 0: chainer.cuda.get_device(config.GPU_ID).use() inpaint_model.to_gpu() if os.path.exists(args.snapshot): serializers.load_npz(args.snapshot, inpaint_model) else: assert False, "Flie '{args.snapshot}' does not exist." xp = inpaint_model.xp # training data test_dataset = Dataset(config, test=True, return_mask=True) test_iter = chainer.iterators.SerialIterator(test_dataset, 8) batch_and_mask = test_iter.next() batch_data, mask_data = zip(*batch_and_mask) batch_data = xp.array(batch_data) mask = xp.array(mask_data) batch_pos = batch_data / 127.5 - 1. # edges = None batch_incomplete = batch_pos * (1. - mask[:, :1]) # inpaint with chainer.using_config("train", False), chainer.using_config("enable_backprop", False): x1, x2, offset_flow = inpaint_model.inpaintnet(batch_incomplete, mask, config) batch_complete = x2 * mask[:, :1] + batch_incomplete * (1. - mask[:, :1]) # visualization viz_img = [batch_pos, batch_incomplete - mask[:, 1:] + mask[:, :1], batch_complete.data] batch_w = len(viz_img) batch_h = viz_img[0].shape[0] viz_img = xp.concatenate(viz_img, axis=0) viz_img = batch_postprocess_images(viz_img, batch_w, batch_h) viz_img = cuda.to_cpu(viz_img) Image.fromarray(viz_img).save(args.name)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--snapshot', type=str, default='', help='path to the snapshot') args = parser.parse_args() config = Config('contextual_attention.yml') # training data train_dataset = Dataset(config) test_dataset = Dataset(config, test=True) train_iter = chainer.iterators.MultiprocessIterator( train_dataset, config.BATCH_SIZE) test_iter = chainer.iterators.SerialIterator(test_dataset, 8) inpaint_model = InpaintCAModel(config) if config.GPU_ID >= 0: chainer.cuda.get_device(config.GPU_ID).use() inpaint_model.to_gpu() if not os.path.exists(config.EVAL_FOLDER): os.makedirs(config.EVAL_FOLDER) # optimizer optimizer = { "g_opt": optimizers.Adam(config.ALPHA, config.BETA1, config.BETA2), "d_opt": optimizers.Adam(config.ALPHA, config.BETA1, config.BETA2) } optimizer["g_opt"].setup(inpaint_model.inpaintnet) optimizer["d_opt"].setup(inpaint_model.discriminator) # Set up a trainer updater = Updater( model=inpaint_model, iterator={ 'main': train_iter, 'test': test_iter }, optimizer=optimizer, device=config.GPU_ID, config=config, ) trainer = training.Trainer(updater, (config.MAX_ITERS, 'iteration'), out=config.MODEL_RESTORE) trainer.extend(extensions.snapshot_object( inpaint_model, 'inpaint_model_{.updater.iteration}.npz'), trigger=(config.SNAPSHOT_INTERVAL, 'iteration')) log_keys = ['epoch', 'iteration', 'l1_loss', 'ae_loss', 'g_loss', 'd_loss'] trainer.extend( extensions.LogReport(keys=log_keys, trigger=(20, 'iteration'))) trainer.extend(extensions.PrintReport(log_keys), trigger=(20, 'iteration')) trainer.extend(extensions.ProgressBar(update_interval=50)) trainer.extend(inpaint_model.evaluation(config.EVAL_FOLDER), trigger=(config.VAL_PSTEPS, 'iteration')) if args.snapshot: if os.path.exists(args.snapshot): print("Resume with snapshot:{}".format(args.snapshot)) chainer.serializers.load_npz(args.snapshot, inpaint_model) else: print("{}: invalid snapshot path".format(args.snapshot)) # Run the training trainer.run()