def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default=None, type=str)
    parser.add_argument('--info', default=None, type=str)
    args = parser.parse_args()

    config = load_decoder_config(args.config)
    if args.info:
        config["experiment_name"] += args.info

    pprint.pprint(config)

    #########################
    # Context Setting
    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info(f'Running in {config["context"]}.')
    ctx = get_extension_context(config["context"],
                                device_id=config["device_id"])
    nn.set_default_context(ctx)
    #########################

    # Data Loading
    logger.info('Initialing Datasource')
    train_iterator = data.celebv_data_iterator(
        dataset_mode="decoder",
        celeb_name=config["trg_celeb_name"],
        data_dir=config["train_dir"],
        ref_dir=config["ref_dir"],
        mode=config["mode"],
        batch_size=config["train"]["batch_size"],
        shuffle=config["train"]["shuffle"],
        with_memory_cache=config["train"]["with_memory_cache"],
        with_file_cache=config["train"]["with_file_cache"],
    )

    monitor = nm.Monitor(
        os.path.join(config["logdir"], "decoder", config["trg_celeb_name"],
                     config["experiment_name"]))
    # Optimizer
    solver_netG = S.Adam(alpha=config["train"]["lr"],
                         beta1=config["train"]["beta1"])
    solver_netD = S.Adam(alpha=config["train"]["lr"],
                         beta1=config["train"]["beta1"])

    # Network
    netG = models.netG_decoder
    netD = models.netD_decoder

    train(config, netG, netD, solver_netG, solver_netD, train_iterator,
          monitor)
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('config', default=None, type=str)
    parser.add_argument('--param-file', default=None, type=str)
    parser.add_argument('--num-test', '-n', default=None, type=int)
    args = parser.parse_args()
    param_file = args.param_file

    config = load_decoder_config(args.config)

    config["num_test"] = args.num_test

    #########################
    # Context Setting
    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info(f'Running in {config["context"]}.')
    ctx = get_extension_context(config["context"],
                                device_id=config["device_id"])
    nn.set_default_context(ctx)
    #########################

    # Data Loading
    logger.info('Initialing Datasource')
    train_iterator = data.celebv_data_iterator(
        dataset_mode="decoder",
        celeb_name=config["trg_celeb_name"],
        data_dir=config["train_dir"],
        ref_dir=config["ref_dir"],
        mode="test",
        batch_size=config["test"]["batch_size"],
        shuffle=False,
        with_memory_cache=config["test"]["with_memory_cache"],
        with_file_cache=config["test"]["with_file_cache"],
    )

    monitor = nm.Monitor(
        os.path.join(config["test"]["logdir"], "decoder",
                     config["trg_celeb_name"], config["experiment_name"]))

    # Network
    netG = models.netG_decoder
    if not param_file:
        param_file = sorted(glob.glob(
            os.path.join(config["logdir"], "decoder", config["trg_celeb_name"],
                         config["experiment_name"], "netG_*")),
                            key=os.path.getmtime)[-1]

    test(config, netG, train_iterator, monitor, param_file)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('config', default=None, type=str)
    parser.add_argument('--param-file', default=None, type=str)
    parser.add_argument('--num-test', '-n', default=None, type=int)
    args = parser.parse_args()
    param_file = args.param_file

    config = load_transformer_config(args.config)

    config["num_test"] = args.num_test

    #########################
    # Context Setting
    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info(f'Running in {config["context"]}.')
    ctx = get_extension_context(config["context"],
                                device_id=config["device_id"])
    nn.set_default_context(ctx)
    #########################

    # Data Loading
    logger.info('Initializing Datasource')
    train_iterator_src = data.celebv_data_iterator(
        dataset_mode="transformer",
        celeb_name=config["src_celeb_name"],
        data_dir=config["train_dir"],
        ref_dir=config["ref_dir"],
        mode="test",
        batch_size=config["test"]["batch_size"],
        shuffle=False,
        with_memory_cache=config["test"]["with_memory_cache"],
        with_file_cache=config["test"]["with_file_cache"],
        resize_size=config["preprocess"]["resize_size"],
        line_thickness=config["preprocess"]["line_thickness"],
        gaussian_kernel=config["preprocess"]["gaussian_kernel"],
        gaussian_sigma=config["preprocess"]["gaussian_sigma"])

    train_iterator_trg = data.celebv_data_iterator(
        dataset_mode="transformer",
        celeb_name=config["trg_celeb_name"],
        data_dir=config["train_dir"],
        ref_dir=config["ref_dir"],
        mode="test",
        batch_size=config["test"]["batch_size"],
        shuffle=False,
        with_memory_cache=config["test"]["with_memory_cache"],
        with_file_cache=config["test"]["with_file_cache"],
        resize_size=config["preprocess"]["resize_size"],
        line_thickness=config["preprocess"]["line_thickness"],
        gaussian_kernel=config["preprocess"]["gaussian_kernel"],
        gaussian_sigma=config["preprocess"]["gaussian_sigma"])
    train_iterators = (train_iterator_src, train_iterator_trg)

    # monitor
    monitor = nm.Monitor(
        os.path.join(config["test"]["logdir"], "transformer",
                     f'{config["src_celeb_name"]}2{config["trg_celeb_name"]}',
                     config["experiment_name"]))

    # Network
    netG = {
        'netG_A2B': models.netG_transformer,
        'netG_B2A': models.netG_transformer
    }
    if not param_file:
        param_file_A2B = sorted(glob.glob(
            os.path.join(
                config["logdir"], "transformer",
                f'{config["src_celeb_name"]}2{config["trg_celeb_name"]}',
                config["experiment_name"], "netG_transformer_A2B_*")),
                                key=os.path.getmtime)[-1]
    else:
        param_file_A2B = param_file

    test_transformer(config, netG, train_iterators, monitor, param_file_A2B)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--encoder-config', default=None, type=str)
    parser.add_argument('--transformer-config', default=None, type=str)
    parser.add_argument('--decoder-config', default=None, type=str)

    parser.add_argument('--src-celeb-name', default=None, type=str)
    parser.add_argument('--trg-celeb-name', default=None, type=str)

    parser.add_argument('--encoder-param-file', default=None, type=str)
    parser.add_argument('--transformer-param-file', default=None, type=str)
    parser.add_argument('--decoder-param-file', default=None, type=str)
    parser.add_argument('--info', default=None, type=str)
    args = parser.parse_args()

    encoder_param_file = args.encoder_param_file
    transformer_param_file = args.transformer_param_file
    decoder_param_file = args.decoder_param_file

    encoder_config = load_encoder_config(args.encoder_config)
    transformer_config = load_transformer_config(args.transformer_config)
    decoder_config = load_decoder_config(args.decoder_config)

    src_celeb_name = args.src_celeb_name
    trg_celeb_name = args.trg_celeb_name

    assert trg_celeb_name == transformer_config[
        "trg_celeb_name"], f"not trained on {trg_celeb_name}."

    if args.info:
        decoder_config["experiment_name"] += args.info

    #########################
    # Context Setting
    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info(f'Running in {decoder_config["context"]}.')
    ctx = get_extension_context(decoder_config["context"],
                                device_id=decoder_config["device_id"])
    nn.set_default_context(ctx)
    #########################

    # Data Loading
    logger.info('Initialing Datasource')
    test_iterator = data.celebv_data_iterator(
        dataset_mode="decoder",
        celeb_name=src_celeb_name,
        data_dir=decoder_config["train_dir"],
        ref_dir=decoder_config["ref_dir"],
        mode="test",
        batch_size=1,
        shuffle=False,
        with_memory_cache=decoder_config["test"]["with_memory_cache"],
        with_file_cache=decoder_config["test"]["with_file_cache"],
    )

    # Encoder
    encoder_netG = models.stacked_hourglass_net
    if not encoder_param_file:
        encoder_param_file = sorted(glob.glob(
            os.path.join(encoder_config["logdir"],
                         encoder_config["dataset_mode"],
                         encoder_config["experiment_name"], "model",
                         "model_epoch-*")),
                                    key=os.path.getmtime)[-1]

    # Transformer
    transformer_netG = models.netG_transformer
    if not transformer_param_file:
        transformer_param_file = sorted(glob.glob(
            os.path.join(
                transformer_config["logdir"],
                transformer_config["dataset_mode"],
                f'{transformer_config["src_celeb_name"]}2{transformer_config["trg_celeb_name"]}',
                transformer_config["experiment_name"],
                "netG_transformer_A2B_*")),
                                        key=os.path.getmtime)[-1]

    # Decoder
    decoder_netG = models.netG_decoder
    if not decoder_param_file:
        decoder_param_file = sorted(glob.glob(
            os.path.join(decoder_config["logdir"],
                         decoder_config["dataset_mode"],
                         decoder_config["trg_celeb_name"],
                         decoder_config["experiment_name"], "netG_*")),
                                    key=os.path.getmtime)[-1]

    monitor = nm.Monitor(
        os.path.join("reenactment_result",
                     f'{src_celeb_name}2{decoder_config["trg_celeb_name"]}',
                     decoder_config["experiment_name"]))

    test(encoder_config, transformer_config, decoder_config, encoder_netG,
         transformer_netG, decoder_netG, src_celeb_name, trg_celeb_name,
         test_iterator, monitor, encoder_param_file, transformer_param_file,
         decoder_param_file)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default=None, type=str)
    parser.add_argument('--info', default=None, type=str)
    args = parser.parse_args()

    config = load_transformer_config(args.config)
    if args.info:
        config["experiment_name"] += args.info

    pprint.pprint(config)

    #########################
    # Context Setting
    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info(f'Running in {config["context"]}.')
    ctx = get_extension_context(config["context"],
                                device_id=config["device_id"])
    nn.set_default_context(ctx)
    #########################

    # Data Loading
    logger.info('Initialing Datasource')
    train_iterator_src = data.celebv_data_iterator(
        dataset_mode="transformer",
        celeb_name=config["src_celeb_name"],
        data_dir=config["train_dir"],
        ref_dir=config["ref_dir"],
        mode=config["mode"],
        batch_size=config["train"]["batch_size"],
        shuffle=config["train"]["shuffle"],
        with_memory_cache=config["train"]["with_memory_cache"],
        with_file_cache=config["train"]["with_file_cache"],
        resize_size=config["preprocess"]["resize_size"],
        line_thickness=config["preprocess"]["line_thickness"],
        gaussian_kernel=config["preprocess"]["gaussian_kernel"],
        gaussian_sigma=config["preprocess"]["gaussian_sigma"])

    train_iterator_trg = data.celebv_data_iterator(
        dataset_mode="transformer",
        celeb_name=config["trg_celeb_name"],
        data_dir=config["train_dir"],
        ref_dir=config["ref_dir"],
        mode=config["mode"],
        batch_size=config["train"]["batch_size"],
        shuffle=config["train"]["shuffle"],
        with_memory_cache=config["train"]["with_memory_cache"],
        with_file_cache=config["train"]["with_file_cache"],
        resize_size=config["preprocess"]["resize_size"],
        line_thickness=config["preprocess"]["line_thickness"],
        gaussian_kernel=config["preprocess"]["gaussian_kernel"],
        gaussian_sigma=config["preprocess"]["gaussian_sigma"])
    train_iterators = (train_iterator_src, train_iterator_trg)
    # monitor
    monitor = nm.Monitor(
        os.path.join(config["logdir"], "transformer",
                     f'{config["src_celeb_name"]}2{config["trg_celeb_name"]}',
                     config["experiment_name"]))

    # Network
    netG = {
        'netG_A2B': models.netG_transformer,
        'netG_B2A': models.netG_transformer
    }
    netD = {
        'netD_A': models.netD_transformer,
        'netD_B': models.netD_transformer
    }

    # Optimizer
    solver_netG = {
        'netG_A2B':
        S.Adam(alpha=config["train"]["lr"],
               beta1=config["train"]["beta1"],
               beta2=config["train"]["beta2"]),
        'netG_B2A':
        S.Adam(alpha=config["train"]["lr"],
               beta1=config["train"]["beta1"],
               beta2=config["train"]["beta2"])
    }

    solver_netD = {
        'netD_A':
        S.Adam(alpha=0.5 * config["train"]["lr"],
               beta1=config["train"]["beta1"],
               beta2=config["train"]["beta2"]),
        'netD_B':
        S.Adam(alpha=0.5 * config["train"]["lr"],
               beta1=config["train"]["beta1"],
               beta2=config["train"]["beta2"])
    }

    train_transformer(config, netG, netD, solver_netG, solver_netD,
                      train_iterators, monitor)