Пример #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/base.yml',
                        help='path to config file')
    parser.add_argument('--data_dir', type=str, default='./data/imagenet')
    parser.add_argument('--results_dir',
                        type=str,
                        default='./results/gans',
                        help='directory to save the results to')
    parser.add_argument('--inception_model_path',
                        type=str,
                        default='./datasets/inception_model',
                        help='path to the inception model')
    parser.add_argument('--snapshot',
                        type=str,
                        default='',
                        help='path to the snapshot')
    parser.add_argument('--loaderjob',
                        type=int,
                        help='number of parallel data loading processes')
    parser.add_argument('--communicator',
                        type=str,
                        default='hierarchical',
                        help='Type of communicator')

    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    comm = chainermn.create_communicator(args.communicator)
    device = comm.intra_rank
    chainer.cuda.get_device_from_id(device).use()
    print("init")
    if comm.rank == 0:
        print('==========================================')
        print('Using {} communicator'.format(args.communicator))
        print('==========================================')
    # Model
    gen, dis = load_models(config)
    gen.to_gpu()
    dis.to_gpu()
    models = {"gen": gen, "dis": dis}
    # Optimizer
    opt_gen = make_optimizer(gen,
                             comm,
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opt_dis = make_optimizer(dis,
                             comm,
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opts = {"opt_gen": opt_gen, "opt_dis": opt_dis}
    # Dataset
    config['dataset']['args']['root'] = args.data_dir
    if comm.rank == 0:
        dataset = yaml_utils.load_dataset(config)
    else:
        _ = yaml_utils.load_dataset(
            config)  # Dummy, for adding path to the dataset module
        dataset = None
    dataset = chainermn.scatter_dataset(dataset, comm)
    # Iterator
    multiprocessing.set_start_method('forkserver')
    iterator = chainer.iterators.MultiprocessIterator(
        dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
        'device': device,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    if comm.rank == 0:
        create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'),
                               out=out)
    report_keys = ["loss_dis", "loss_gen", "inception_mean", "inception_std"]
    if comm.rank == 0:
        # Set up logging
        trainer.extend(extensions.snapshot(),
                       trigger=(config.snapshot_interval, 'iteration'))
        for m in models.values():
            trainer.extend(extensions.snapshot_object(
                m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                           trigger=(config.snapshot_interval, 'iteration'))
        trainer.extend(
            extensions.LogReport(keys=report_keys,
                                 trigger=(config.display_interval,
                                          'iteration')))
        trainer.extend(extensions.PrintReport(report_keys),
                       trigger=(config.display_interval, 'iteration'))
        trainer.extend(sample_generate_conditional(gen,
                                                   out,
                                                   n_classes=gen.n_classes),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
        trainer.extend(sample_generate_light(gen, out, rows=10, cols=10),
                       trigger=(config.evaluation_interval // 10, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
        trainer.extend(calc_inception(gen,
                                      n_ims=5000,
                                      splits=1,
                                      path=args.inception_model_path),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
        trainer.extend(
            extensions.ProgressBar(
                update_interval=config.progressbar_interval))
    ext_opt_gen = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_gen)
    ext_opt_dis = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_dis)
    trainer.extend(ext_opt_gen)
    trainer.extend(ext_opt_dis)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()
Пример #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/base.yml',
                        help='path to config file')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='index of gpu to be used')
    parser.add_argument('--data_dir', type=str, default='./data/imagenet')
    parser.add_argument('--results_dir',
                        type=str,
                        default='./results/gans',
                        help='directory to save the results to')
    parser.add_argument('--inception_model_path',
                        type=str,
                        default='./datasets/inception_model',
                        help='path to the inception model')
    parser.add_argument('--snapshot',
                        type=str,
                        default='',
                        help='path to the snapshot')
    parser.add_argument('--loaderjob',
                        type=int,
                        help='number of parallel data loading processes')

    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    chainer.cuda.get_device_from_id(args.gpu).use()
    gen, dis = load_models(config)
    gen.to_gpu(device=args.gpu)
    dis.to_gpu(device=args.gpu)

    seed_weights(gen)
    seed_weights(dis)
    print(np.sum(gen.block2.c2.b))
    _ = input()

    models = {"gen": gen, "dis": dis}
    # Optimizer
    opt_gen = make_optimizer(gen,
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opt_dis = make_optimizer(dis,
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opts = {"opt_gen": opt_gen, "opt_dis": opt_dis}
    # Dataset
    if config['dataset'][
            'dataset_name'] != 'CIFAR10Dataset':  # Cifar10 dataset handler does not take "root" as argument.
        config['dataset']['args']['root'] = args.data_dir
    dataset = yaml_utils.load_dataset(config)
    # Iterator
    iterator = chainer.iterators.MultiprocessIterator(
        dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'),
                               out=out)
    report_keys = ["loss_dis", "loss_gen", "inception_mean", "inception_std"]
    # Set up logging
    trainer.extend(extensions.snapshot(),
                   trigger=(config.snapshot_interval, 'iteration'))
    for m in models.values():
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(config.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(config.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(config.display_interval, 'iteration'))
    if gen.n_classes > 0:
        trainer.extend(sample_generate_conditional(gen,
                                                   out,
                                                   n_classes=gen.n_classes),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    else:
        trainer.extend(sample_generate(gen, out),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(gen, out, rows=10, cols=10),
                   trigger=(config.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_inception(gen,
                                  n_ims=5000,
                                  splits=1,
                                  path=args.inception_model_path),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(
        extensions.ProgressBar(update_interval=config.progressbar_interval))
    ext_opt_gen = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_gen)
    ext_opt_dis = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_dis)
    trainer.extend(ext_opt_gen)
    trainer.extend(ext_opt_dis)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()
Пример #3
0
def main():
    parser = argparse.ArgumentParser(description='Train script')
    parser.add_argument('--data_path',
                        type=str,
                        default="data/datasets/cifar10/train",
                        help='dataset directory path')
    parser.add_argument('--class_name',
                        '-class',
                        type=str,
                        default='all_class',
                        help='class name (default: all_class(str))')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of parallel data loading processes')
    parser.add_argument('--batchsize', '-b', type=int, default=16)
    parser.add_argument('--max_iter', '-m', type=int, default=40000)
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=2500,
                        help='Interval of snapshot')
    parser.add_argument('--evaluation_interval',
                        type=int,
                        default=5000,
                        help='Interval of evaluation')
    parser.add_argument('--out_image_interval',
                        type=int,
                        default=1250,
                        help='Interval of evaluation')
    parser.add_argument('--stage_interval',
                        type=int,
                        default=40000,
                        help='Interval of stage progress')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console')
    parser.add_argument(
        '--n_dis',
        type=int,
        default=1,
        help='number of discriminator update per generator update')
    parser.add_argument('--lam',
                        type=float,
                        default=10,
                        help='gradient penalty')
    parser.add_argument('--gamma',
                        type=float,
                        default=750,
                        help='gradient penalty')
    parser.add_argument('--pooling_comp',
                        type=float,
                        default=1.0,
                        help='compensation')
    parser.add_argument('--pretrained_generator', type=str, default="")
    parser.add_argument('--pretrained_discriminator', type=str, default="")
    parser.add_argument('--initial_stage', type=float, default=0.0)
    parser.add_argument('--generator_smoothing', type=float, default=0.999)

    args = parser.parse_args()
    record_setting(args.out)

    report_keys = [
        "stage", "loss_dis", "loss_gp", "loss_gen", "g", "inception_mean",
        "inception_std", "FID"
    ]
    max_iter = args.max_iter

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()

    generator = Generator()
    generator_smooth = Generator()
    discriminator = Discriminator(pooling_comp=args.pooling_comp)

    # select GPU
    if args.gpu >= 0:
        generator.to_gpu()
        generator_smooth.to_gpu()
        discriminator.to_gpu()
        print("use gpu {}".format(args.gpu))

    if args.pretrained_generator != "":
        chainer.serializers.load_npz(args.pretrained_generator, generator)
    if args.pretrained_discriminator != "":
        chainer.serializers.load_npz(args.pretrained_discriminator,
                                     discriminator)
    copy_param(generator_smooth, generator)

    # Setup an optimizer
    def make_optimizer(model, alpha=0.001, beta1=0.0, beta2=0.99):
        optimizer = chainer.optimizers.Adam(alpha=alpha,
                                            beta1=beta1,
                                            beta2=beta2)
        optimizer.setup(model)
        # optimizer.add_hook(chainer.optimizer.WeightDecay(0.00001), 'hook_dec')
        return optimizer

    opt_gen = make_optimizer(generator)
    opt_dis = make_optimizer(discriminator)

    if args.class_name == 'all_class':
        data_path = args.data_path
        one_class_flag = False
    else:
        data_path = os.path.join(args.data_path, args.class_name)
        one_class_flag = True

    train_dataset = ImageDataset(data_path, one_class_flag=one_class_flag)
    train_iter = chainer.iterators.MultiprocessIterator(
        train_dataset, args.batchsize, n_processes=args.num_workers)

    # Set up a trainer
    updater = Updater(models=(generator, discriminator, generator_smooth),
                      iterator={'main': train_iter},
                      optimizer={
                          'opt_gen': opt_gen,
                          'opt_dis': opt_dis
                      },
                      device=args.gpu,
                      n_dis=args.n_dis,
                      lam=args.lam,
                      gamma=args.gamma,
                      smoothing=args.generator_smoothing,
                      initial_stage=args.initial_stage,
                      stage_interval=args.stage_interval)
    trainer = training.Trainer(updater, (max_iter, 'iteration'), out=args.out)
    trainer.extend(extensions.snapshot_object(
        generator, 'generator_{.updater.iteration}.npz'),
                   trigger=(args.snapshot_interval, 'iteration'))
    trainer.extend(extensions.snapshot_object(
        generator_smooth, 'generator_smooth_{.updater.iteration}.npz'),
                   trigger=(args.snapshot_interval, 'iteration'))
    trainer.extend(extensions.snapshot_object(
        discriminator, 'discriminator_{.updater.iteration}.npz'),
                   trigger=(args.snapshot_interval, 'iteration'))

    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(args.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.display_interval, 'iteration'))
    trainer.extend(sample_generate(generator_smooth, args.out),
                   trigger=(args.out_image_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(generator_smooth, args.out),
                   trigger=(args.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_inception(generator_smooth),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_FID(generator_smooth),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    # Run the training
    trainer.run()
Пример #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/sr.yml',
                        help='path to config file')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='index of gpu to be used')
    parser.add_argument('--results_dir',
                        type=str,
                        default='./results',
                        help='directory to save the results to')
    parser.add_argument(
        '--inception_model_path',
        type=str,
        default='./datasets/inception_model/inception_score.model',
        help='path to the inception model')
    parser.add_argument(
        '--stat_file',
        type=str,
        default='./datasets/inception_model/fid_stats_cifar10_train.npz',
        help='path to the inception model')
    parser.add_argument('--snapshot',
                        type=str,
                        default='',
                        help='path to the snapshot')
    parser.add_argument('--loaderjob',
                        type=int,
                        help='number of parallel data loading processes')

    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    chainer.cuda.get_device_from_id(args.gpu).use()
    # set up the model
    devices = {'main': 0, 'second': 1, 'third': 2, 'fourth': 3}
    gen, dis = load_models(config)
    model = {"gen": gen, "dis": dis}
    names = list(six.iterkeys(devices))
    try:
        names.remove('main')
    except ValueError:
        raise KeyError("devices must contain a 'main' key.")
    models = {'main': model}

    for name in names:
        g = copy.deepcopy(model['gen'])
        d = copy.deepcopy(model['dis'])
        if devices[name] >= 0:
            g.to_gpu(device=devices[name])
            d.to_gpu(device=devices[name])
        models[name] = {"gen": g, "dis": d}

    if devices['main'] >= 0:
        models['main']['gen'].to_gpu(device=devices['main'])
        models['main']['dis'].to_gpu(device=devices['main'])

    links = [[name, link] for name, link in sorted(dis.namedlinks())]
    for name, link in links:
        print(name)
    links = [[name, link] for name, link in sorted(gen.namedlinks())]
    for name, link in links:
        print(name)

    # Optimizer
    opt_gen = make_optimizer(models['main']['gen'],
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opt_dis = make_optimizer(models['main']['dis'],
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opts = {"opt_gen": opt_gen, "opt_dis": opt_dis}
    # Dataset
    dataset = yaml_utils.load_dataset(config)
    # Iterator
    iterator = chainer.iterators.MultiprocessIterator(
        dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'devices': devices,
        'models': models,
        'iterator': iterator,
        'optimizer': opts
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'),
                               out=out)
    report_keys = ["loss_dis", "loss_gen", "inception_mean", "FID"]
    # Set up logging
    trainer.extend(extensions.snapshot(),
                   trigger=(config.snapshot_interval, 'iteration'))
    for m in models['main'].values():
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(config.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(config.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(config.display_interval, 'iteration'))
    if gen.n_classes > 0:
        trainer.extend(sample_generate_conditional(models['main']['gen'],
                                                   out,
                                                   n_classes=gen.n_classes),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    else:
        trainer.extend(sample_generate(models['main']['gen'], out),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(models['main']['gen'],
                                         out,
                                         rows=10,
                                         cols=10),
                   trigger=(config.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_inception(models['main']['gen'],
                                  n_ims=5000,
                                  splits=1,
                                  dst=args.results_dir,
                                  path=args.inception_model_path),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_FID(models['main']['gen'],
                            n_ims=5000,
                            dst=args.results_dir,
                            path=args.inception_model_path,
                            stat_file=args.stat_file),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(monitor_largest_singular_values(models['main']['dis'],
                                                   dst=args.results_dir),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)

    trainer.extend(
        extensions.ProgressBar(update_interval=config.progressbar_interval))
    ext_opt_gen = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_gen)
    ext_opt_dis = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_dis)
    trainer.extend(ext_opt_gen)
    trainer.extend(ext_opt_dis)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()
Пример #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, default='configs/base.yml', help='path to config file')
    parser.add_argument('--gpu', type=int, default=0, help='index of gpu to be used')
    parser.add_argument('--data_dir', type=str, default='./dataset/imagenet')
    parser.add_argument('--results_dir', type=str, default='./results/temp', help='directory to save the results to')
    parser.add_argument('--loaderjob', type=int, help='number of parallel data loading processes')
    parser.add_argument('--layer', type=int, default=0, help='freeze discriminator layer')

    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path), Loader=yaml.FullLoader))
    chainer.cuda.get_device_from_id(args.gpu).use()

    # Fix randomness
    random.seed(config['seed'])
    np.random.seed(config['seed'])
    cupy.random.seed(config['seed'])

    # Model
    G_src, D_src = load_models(config, mode='source')
    G_tgt, D_tgt = load_models(config, mode='target')

    G_src.to_gpu(device=args.gpu)
    D_src.to_gpu(device=args.gpu)
    G_tgt.to_gpu(device=args.gpu)
    D_tgt.to_gpu(device=args.gpu)

    models = {"G_src": G_src, "D_src": D_src, "G_tgt": G_tgt, "D_tgt": D_tgt}

    # Optimizer
    opt_G_tgt = make_optimizer(G_tgt, alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2'])
    opt_D_tgt = make_optimizer(D_tgt, alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2'])

    for i in range(1, args.layer + 1):  # freeze discriminator
        getattr(D_tgt, f'block{i}').disable_update()

    opts = {"opt_G_tgt": opt_G_tgt, "opt_D_tgt": opt_D_tgt}

    # Dataset
    config['dataset']['args']['root'] = args.data_dir
    dataset = yaml_utils.load_dataset(config)

    # Iterator
    iterator = chainer.iterators.MultiprocessIterator(dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'), out=out)
    report_keys = ["loss_dis", "loss_gen", "loss_FM", "FID"]

    # Set up logging
    trainer.extend(extensions.snapshot(filename='snapshot_best'),
                   trigger=training.triggers.MinValueTrigger("FID", trigger=(config.evaluation_interval, 'iteration')))
    for m in models.values():
        trainer.extend(extensions.snapshot_object(m, m.__class__.__name__ + '_best.npz'),
                       trigger=training.triggers.MinValueTrigger("FID", trigger=(config.evaluation_interval, 'iteration')))

    trainer.extend(extensions.LogReport(keys=report_keys, trigger=(config.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys), trigger=(config.display_interval, 'iteration'))

    trainer.extend(sample_generate_conditional(G_tgt, out, n_classes=G_tgt.n_classes),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(G_tgt, out, rows=10, cols=10),
                   trigger=(config.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_FID(G_tgt, n_ims=5000, stat_file=config['eval']['stat_file']),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=config.progressbar_interval))

    ext_opt_gen = extensions.LinearShift('alpha', (config.adam['alpha'], 0.), (config.iteration_decay_start, config.iteration), opt_G_tgt)
    ext_opt_dis = extensions.LinearShift('alpha', (config.adam['alpha'], 0.), (config.iteration_decay_start, config.iteration), opt_D_tgt)
    trainer.extend(ext_opt_gen)
    trainer.extend(ext_opt_dis)

    # Load source networks
    chainer.serializers.load_npz(config['pretrained']['gen'], trainer.updater.models['G_src'])
    chainer.serializers.load_npz(config['pretrained']['dis'], trainer.updater.models['D_src'])
    load_parameters(trainer.updater.models['G_src'], trainer.updater.models['G_tgt'])
    load_parameters(trainer.updater.models['D_src'], trainer.updater.models['D_tgt'])

    # Run the training
    print("start training")
    trainer.run()
def main():

    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    chainer.cuda.get_device_from_id(args.gpu).use()
    from model_loader import gen, dis

    models = {"gen": gen, "dis": dis}
    # Optimizer
    opt_gen = make_optimizer(
        gen, alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2'])
    opt_dis = make_optimizer(
        dis, alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2'])
    opts = {"opt_gen": opt_gen, "opt_dis": opt_dis}
    # Dataset
    # Cifar10 and STL10 dataset handler does not take "root" as argument.
    if config['dataset']['dataset_name'] != 'CIFAR10Dataset' and \
            config['dataset']['dataset_name'] != 'STL10Dataset':
        config['dataset']['args']['root'] = args.data_dir
    dataset = yaml_utils.load_dataset(config)
    # Iterator
    iterator = chainer.iterators.MultiprocessIterator(
        dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'), out=out)
    report_keys = ["loss_dis", "loss_gen", "inception_mean", "inception_std", "FID_mean", "FID_std"]
    # Set up logging
    trainer.extend(extensions.snapshot(), trigger=(config.snapshot_interval, 'iteration'))
    for m in models.values():
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(config.snapshot_interval, 'iteration'))
    trainer.extend(extensions.LogReport(keys=report_keys,
                                        trigger=(config.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys), trigger=(config.display_interval, 'iteration'))
    if gen.n_classes > 0:
        trainer.extend(sample_generate_conditional(gen, out, n_classes=gen.n_classes),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    else:
        trainer.extend(sample_generate(gen, out),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(gen, out, rows=10, cols=10),
                   trigger=(config.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    # trainer.extend(calc_inception(gen, n_ims=50000, dst=out, splits=10, path=args.inception_model_path),
    #                trigger=(config.evaluation_interval, 'iteration'),
    #                priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_inception_and_FID(gen, n_ims=50000, dst=out, path=args.inception_model_path, splits=10, stat_file=args.FID_stat_file),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=config.progressbar_interval))
    ext_opt_gen = extensions.LinearShift('alpha', (config.adam['alpha'], 0.),
                                         (config.iteration_decay_start, config.iteration), opt_gen)
    ext_opt_dis = extensions.LinearShift('alpha', (config.adam['alpha'], 0.),
                                         (config.iteration_decay_start, config.iteration), opt_dis)
    trainer.extend(ext_opt_gen)
    trainer.extend(ext_opt_dis)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()