예제 #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, default='gated_convolution.yml',
                        help='path to config file')
    parser.add_argument('--model', type=str, default='v1',
                        help='model type: v1 or v2')
    parser.add_argument('--snapshot', type=str, default='model.npz',
                        help='path to snapshot')
    parser.add_argument('--name', type=str, default='result.png',
                        help='file name to save results')

    args = parser.parse_args()

    config = Config(args.config_path)

    if args.model == "v1":
        config.FREE_FORM = False
        inpaint_model = InpaintCAModel(config)
    elif args.model == "v2":
        inpaint_model = InpaintGCModel(config)
    else:
        assert False, "Model name '{args.model}' is invalid."

    if config.GPU_ID >= 0:
        chainer.cuda.get_device(config.GPU_ID).use()
        inpaint_model.to_gpu()

    if os.path.exists(args.snapshot):
        serializers.load_npz(args.snapshot, inpaint_model)
    else:
        assert False, "Flie '{args.snapshot}' does not exist."

    xp = inpaint_model.xp

    # training data
    test_dataset = Dataset(config, test=True, return_mask=True)
    test_iter = chainer.iterators.SerialIterator(test_dataset, 8)

    batch_and_mask = test_iter.next()
    batch_data, mask_data = zip(*batch_and_mask)
    batch_data = xp.array(batch_data)
    mask = xp.array(mask_data)

    batch_pos = batch_data / 127.5 - 1.
    # edges = None
    batch_incomplete = batch_pos * (1. - mask[:, :1])
    # inpaint
    with chainer.using_config("train", False), chainer.using_config("enable_backprop", False):
        x1, x2, offset_flow = inpaint_model.inpaintnet(batch_incomplete, mask, config)

    batch_complete = x2 * mask[:, :1] + batch_incomplete * (1. - mask[:, :1])
    # visualization
    viz_img = [batch_pos, batch_incomplete - mask[:, 1:] + mask[:, :1], batch_complete.data]

    batch_w = len(viz_img)
    batch_h = viz_img[0].shape[0]
    viz_img = xp.concatenate(viz_img, axis=0)
    viz_img = batch_postprocess_images(viz_img, batch_w, batch_h)
    viz_img = cuda.to_cpu(viz_img)
    Image.fromarray(viz_img).save(args.name)
예제 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--snapshot',
                        type=str,
                        default='',
                        help='path to the snapshot')
    args = parser.parse_args()

    config = Config('contextual_attention.yml')

    # training data
    train_dataset = Dataset(config)
    test_dataset = Dataset(config, test=True)
    train_iter = chainer.iterators.MultiprocessIterator(
        train_dataset, config.BATCH_SIZE)
    test_iter = chainer.iterators.SerialIterator(test_dataset, 8)

    inpaint_model = InpaintCAModel(config)
    if config.GPU_ID >= 0:
        chainer.cuda.get_device(config.GPU_ID).use()
        inpaint_model.to_gpu()

    if not os.path.exists(config.EVAL_FOLDER):
        os.makedirs(config.EVAL_FOLDER)

    # optimizer
    optimizer = {
        "g_opt": optimizers.Adam(config.ALPHA, config.BETA1, config.BETA2),
        "d_opt": optimizers.Adam(config.ALPHA, config.BETA1, config.BETA2)
    }
    optimizer["g_opt"].setup(inpaint_model.inpaintnet)
    optimizer["d_opt"].setup(inpaint_model.discriminator)

    # Set up a trainer
    updater = Updater(
        model=inpaint_model,
        iterator={
            'main': train_iter,
            'test': test_iter
        },
        optimizer=optimizer,
        device=config.GPU_ID,
        config=config,
    )

    trainer = training.Trainer(updater, (config.MAX_ITERS, 'iteration'),
                               out=config.MODEL_RESTORE)
    trainer.extend(extensions.snapshot_object(
        inpaint_model, 'inpaint_model_{.updater.iteration}.npz'),
                   trigger=(config.SNAPSHOT_INTERVAL, 'iteration'))

    log_keys = ['epoch', 'iteration', 'l1_loss', 'ae_loss', 'g_loss', 'd_loss']
    trainer.extend(
        extensions.LogReport(keys=log_keys, trigger=(20, 'iteration')))
    trainer.extend(extensions.PrintReport(log_keys), trigger=(20, 'iteration'))
    trainer.extend(extensions.ProgressBar(update_interval=50))

    trainer.extend(inpaint_model.evaluation(config.EVAL_FOLDER),
                   trigger=(config.VAL_PSTEPS, 'iteration'))

    if args.snapshot:
        if os.path.exists(args.snapshot):
            print("Resume with snapshot:{}".format(args.snapshot))
            chainer.serializers.load_npz(args.snapshot, inpaint_model)
        else:
            print("{}: invalid snapshot path".format(args.snapshot))

    # Run the training
    trainer.run()