コード例 #1
0
def test(args):
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    chainer.cuda.get_device_from_id(0).use()
    gen = load_gen(config)
    chainer.serializers.load_npz(args.gen_model, gen)
    gen.to_gpu()
    xp = gen.xp
    _, test = yaml_utils.load_dataset(config)
    test_iter = chainer.iterators.SerialIterator(test,
                                                 config.batchsize_test,
                                                 repeat=False,
                                                 shuffle=False)

    results_dir = args.results_dir
    images_dir = os.path.join(results_dir, 'images')
    if not os.path.exists(images_dir):
        os.makedirs(images_dir)

    n = 0
    psnr = []
    ssim = []
    while True:
        x, t, batchsize = get_batch(test_iter, xp)

        with chainer.using_config('train', False), chainer.using_config(
                'enable_backprop', False):
            y = gen(x)
            x = x.array.get() * 127.5 + 127.5
            y = np.clip(y.array.get() * 127.5 + 127.5, 0., 255.)
            t = t.array.get() * 127.5 + 127.5

            _psnr, _ssim = save_images(x, y, t, images_dir, current_n=n)
            psnr += _psnr
            ssim += _ssim
            n += len(x)

        if test_iter.is_new_epoch:
            test_iter.reset()
            break

    print('psnr: {}'.format(np.mean(psnr)))
    print('ssim: {}'.format(np.mean(ssim)))

    psnr = list(map(str, psnr))
    ssim = list(map(str, ssim))

    with open(os.path.join(results_dir, 'psnr.txt'), 'w') as f:
        f.write('\n'.join(psnr))
    with open(os.path.join(results_dir, 'ssim.txt'), 'w') as f:
        f.write('\n'.join(ssim))
コード例 #2
0
def test(args):
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    if args.dir_nir is not None:
        config.dataset['args']['args_test']['dir_nir'] = args.dir_nir
        config.dataset['args']['args_test']['imlist_nir'] = args.imlist_nir
    if args.dir_rgb is not None:
        config.dataset['args']['args_test']['dir_rgb'] = args.dir_rgb
        config.dataset['args']['args_test']['imlist_rgb'] = args.imlist_rgb
    chainer.cuda.get_device_from_id(0).use()
    gen = load_gen(config)
    chainer.serializers.load_npz(args.gen_model, gen)
    gen.to_gpu()
    xp = gen.xp
    _, test = yaml_utils.load_dataset(config)
    test_iter = chainer.iterators.SerialIterator(test,
                                                 config.batchsize_test,
                                                 repeat=False,
                                                 shuffle=False)

    results_dir = args.results_dir
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    n = 0
    while True:
        x, batchsize = get_batch(test_iter, xp)

        with chainer.using_config('train', False), chainer.using_config(
                'enable_backprop', False):
            out = gen(x)
            out = np.clip(out.array.get() * 127.5 + 127.5, 0., 255.)
            x = x.array.get() * 127.5 + 127.5

            save_images(x, out, results_dir, current_n=n)
            n += len(out)

        if test_iter.is_new_epoch:
            test_iter.reset()
            break

    return None
コード例 #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/base.yml',
                        help='path to config file')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='index of gpu to be used')
    parser.add_argument('--input_dir', type=str, default='./data/imagenet')
    parser.add_argument('--truth_dir', type=str, default='./data/imagenet')
    parser.add_argument('--results_dir',
                        type=str,
                        default='./results/gans',
                        help='directory to save the results to')
    parser.add_argument('--snapshot',
                        type=str,
                        default='',
                        help='path to the snapshot file to use')
    parser.add_argument('--enc_model',
                        type=str,
                        default='',
                        help='path to the generator .npz file')
    parser.add_argument('--gen_model',
                        type=str,
                        default='',
                        help='path to the generator .npz file')
    parser.add_argument('--dis_model',
                        type=str,
                        default='',
                        help='path to the discriminator .npz file')
    parser.add_argument('--loaderjob',
                        type=int,
                        help='number of parallel data loading processes')

    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    chainer.cuda.get_device_from_id(args.gpu).use()
    gen, dis, enc = load_models(config)

    chainer.serializers.load_npz(args.gen_model, gen, strict=False)
    chainer.serializers.load_npz(args.dis_model, dis)
    chainer.serializers.load_npz(args.enc_model, enc)

    gen.to_gpu(device=args.gpu)
    dis.to_gpu(device=args.gpu)
    enc.to_gpu(device=args.gpu)
    models = {"gen": gen, "dis": dis, "enc": enc}
    opt_gen = make_optimizer(gen,
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opt_gen.add_hook(chainer.optimizer.WeightDecay(config.weight_decay))
    opt_gen.add_hook(chainer.optimizer.GradientClipping(config.grad_clip))

    # disable update of pre-trained weights
    layers_to_train = ['lA1', 'lA2', 'lB1', 'lB2', 'preluW', 'preluMiddleW']
    for layer in gen.children():
        if not layer.name in layers_to_train:
            layer.disable_update()

    lmd_pixel = 0.05

    def fast_loss(out, gt):
        l1 = reconstruction_loss(dis, out, gt)
        l2 = lmd_pixel * pixel_loss(out, gt)
        loss = l1 + l2
        return loss

    gen.set_fast_loss(fast_loss)

    opts = {"opt_gen": opt_gen}

    # Dataset
    config['dataset']['args']['root_input'] = args.input_dir
    config['dataset']['args']['root_truth'] = args.truth_dir
    dataset = yaml_utils.load_dataset(config)
    # Iterator
    iterator = chainer.iterators.MultiprocessIterator(
        dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'),
                               out=out)
    report_keys = [
        "loss_noab", "loss1", "loss2", "loss3", "fast_alpha", "loss_ae",
        "fast_benefit", "min_slope", "max_slope", "min_slope_middle",
        "max_slope_middle"
    ]
    # Set up logging
    trainer.extend(extensions.snapshot(),
                   trigger=(config.snapshot_interval, 'iteration'))
    for m in models.values():
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(config.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(config.display_interval, 'iteration')))
    trainer.extend(extensions.ParameterStatistics(gen),
                   trigger=(config.display_interval, 'iteration'))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(config.display_interval, 'iteration'))

    trainer.extend(sample_reconstruction_auxab(enc,
                                               gen,
                                               out,
                                               n_classes=gen.n_classes),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(
        extensions.ProgressBar(update_interval=config.progressbar_interval))
    ext_opt_gen = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_gen)
    trainer.extend(ext_opt_gen)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()
コード例 #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/base.yml',
                        help='path to config file')
    parser.add_argument('--data_dir', type=str, default='./data/imagenet')
    parser.add_argument('--results_dir',
                        type=str,
                        default='./results/gans',
                        help='directory to save the results to')
    parser.add_argument('--inception_model_path',
                        type=str,
                        default='./datasets/inception_model',
                        help='path to the inception model')
    parser.add_argument('--snapshot',
                        type=str,
                        default='',
                        help='path to the snapshot')
    parser.add_argument('--loaderjob',
                        type=int,
                        help='number of parallel data loading processes')
    parser.add_argument('--communicator',
                        type=str,
                        default='hierarchical',
                        help='Type of communicator')

    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    comm = chainermn.create_communicator(args.communicator)
    device = comm.intra_rank
    chainer.cuda.get_device_from_id(device).use()
    print("init")
    if comm.rank == 0:
        print('==========================================')
        print('Using {} communicator'.format(args.communicator))
        print('==========================================')
    # Model
    gen, dis = load_models(config)
    gen.to_gpu()
    dis.to_gpu()
    models = {"gen": gen, "dis": dis}
    # Optimizer
    opt_gen = make_optimizer(gen,
                             comm,
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opt_dis = make_optimizer(dis,
                             comm,
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opts = {"opt_gen": opt_gen, "opt_dis": opt_dis}
    # Dataset
    config['dataset']['args']['root'] = args.data_dir
    if comm.rank == 0:
        dataset = yaml_utils.load_dataset(config)
    else:
        _ = yaml_utils.load_dataset(
            config)  # Dummy, for adding path to the dataset module
        dataset = None
    dataset = chainermn.scatter_dataset(dataset, comm)
    # Iterator
    multiprocessing.set_start_method('forkserver')
    iterator = chainer.iterators.MultiprocessIterator(
        dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
        'device': device,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    if comm.rank == 0:
        create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'),
                               out=out)
    report_keys = ["loss_dis", "loss_gen", "inception_mean", "inception_std"]
    if comm.rank == 0:
        # Set up logging
        trainer.extend(extensions.snapshot(),
                       trigger=(config.snapshot_interval, 'iteration'))
        for m in models.values():
            trainer.extend(extensions.snapshot_object(
                m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                           trigger=(config.snapshot_interval, 'iteration'))
        trainer.extend(
            extensions.LogReport(keys=report_keys,
                                 trigger=(config.display_interval,
                                          'iteration')))
        trainer.extend(extensions.PrintReport(report_keys),
                       trigger=(config.display_interval, 'iteration'))
        trainer.extend(sample_generate_conditional(gen,
                                                   out,
                                                   n_classes=gen.n_classes),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
        trainer.extend(sample_generate_light(gen, out, rows=10, cols=10),
                       trigger=(config.evaluation_interval // 10, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
        trainer.extend(calc_inception(gen,
                                      n_ims=5000,
                                      splits=1,
                                      path=args.inception_model_path),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
        trainer.extend(
            extensions.ProgressBar(
                update_interval=config.progressbar_interval))
    ext_opt_gen = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_gen)
    ext_opt_dis = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_dis)
    trainer.extend(ext_opt_gen)
    trainer.extend(ext_opt_dis)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()
コード例 #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, default='configs/base.yml', help='path to config file')
    parser.add_argument('--results_dir', type=str, default='./results',
                        help='directory to save the results to')
    parser.add_argument('--snapshot', type=str, default='',
                        help='path to the snapshot')
    parser.add_argument('--loaderjob', type=int,
                        help='number of parallel data loading processes')
    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    device = 0
    chainer.cuda.get_device_from_id(device).use()
    print("init")
    multiprocessing.set_start_method('forkserver')

    # Model
    gen, dis = load_models(config)
    gen.to_gpu()
    dis.to_gpu()
    models = {"gen": gen, "dis": dis}

    # Optimizer
    opt_gen = make_optimizer(gen, alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2'])
    opt_dis = make_optimizer(dis, alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2'])
    opts = {"opt_gen": opt_gen, "opt_dis": opt_dis}

    # Dataset
    train, test = yaml_utils.load_dataset(config)

    # Iterator
    train_iter = chainer.iterators.MultiprocessIterator(train, config.batchsize, n_processes=args.loaderjob)
    test_iter = chainer.iterators.SerialIterator(test, config.batchsize_test, repeat=False, shuffle=False)

    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': train_iter,
        'optimizer': opts,
        'device': device,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'), out=out)
    report_keys = ["loss_dis", "loss_gen", "loss_l1", "psnr", "ssim"]
    eval_func = yaml_utils.load_eval_func(config)
        # Set up logging
    trainer.extend(extensions.snapshot(), trigger=(config.snapshot_interval, 'iteration'))
    for m in models.values():
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(config.snapshot_interval, 'iteration'))
    trainer.extend(extensions.LogReport(keys=report_keys,
                                        trigger=(config.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys), trigger=(config.display_interval, 'iteration'))
    trainer.extend(eval_func(test_iter, gen, dst=out), trigger=(config.evaluation_interval, 'iteration'), priority=chainer.training.extension.PRIORITY_WRITER)
    for key in report_keys:
        trainer.extend(extensions.PlotReport(key, trigger=(config.evaluation_interval, 'iteration'), file_name='{}.png'.format(key)))
    trainer.extend(extensions.ProgressBar(update_interval=config.progressbar_interval))
    ext_opt_gen = extensions.LinearShift('alpha', (config.adam['alpha'], 0.),
                                         (config.iteration_decay_start, config.iteration), opt_gen)
    ext_opt_dis = extensions.LinearShift('alpha', (config.adam['alpha'], 0.),
                                         (config.iteration_decay_start, config.iteration), opt_dis)
    trainer.extend(ext_opt_gen)
    trainer.extend(ext_opt_dis)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()
コード例 #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/base.yml',
                        help='path to config file')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='index of gpu to be used')
    parser.add_argument('--data_dir', type=str, default='./data/imagenet')
    parser.add_argument('--results_dir',
                        type=str,
                        default='./results/gans',
                        help='directory to save the results to')
    parser.add_argument('--inception_model_path',
                        type=str,
                        default='./datasets/inception_model',
                        help='path to the inception model')
    parser.add_argument('--snapshot',
                        type=str,
                        default='',
                        help='path to the snapshot')
    parser.add_argument('--loaderjob',
                        type=int,
                        help='number of parallel data loading processes')

    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    chainer.cuda.get_device_from_id(args.gpu).use()
    gen, dis = load_models(config)
    gen.to_gpu(device=args.gpu)
    dis.to_gpu(device=args.gpu)

    seed_weights(gen)
    seed_weights(dis)
    print(np.sum(gen.block2.c2.b))
    _ = input()

    models = {"gen": gen, "dis": dis}
    # Optimizer
    opt_gen = make_optimizer(gen,
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opt_dis = make_optimizer(dis,
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opts = {"opt_gen": opt_gen, "opt_dis": opt_dis}
    # Dataset
    if config['dataset'][
            'dataset_name'] != 'CIFAR10Dataset':  # Cifar10 dataset handler does not take "root" as argument.
        config['dataset']['args']['root'] = args.data_dir
    dataset = yaml_utils.load_dataset(config)
    # Iterator
    iterator = chainer.iterators.MultiprocessIterator(
        dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'),
                               out=out)
    report_keys = ["loss_dis", "loss_gen", "inception_mean", "inception_std"]
    # Set up logging
    trainer.extend(extensions.snapshot(),
                   trigger=(config.snapshot_interval, 'iteration'))
    for m in models.values():
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(config.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(config.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(config.display_interval, 'iteration'))
    if gen.n_classes > 0:
        trainer.extend(sample_generate_conditional(gen,
                                                   out,
                                                   n_classes=gen.n_classes),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    else:
        trainer.extend(sample_generate(gen, out),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(gen, out, rows=10, cols=10),
                   trigger=(config.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_inception(gen,
                                  n_ims=5000,
                                  splits=1,
                                  path=args.inception_model_path),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(
        extensions.ProgressBar(update_interval=config.progressbar_interval))
    ext_opt_gen = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_gen)
    ext_opt_dis = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_dis)
    trainer.extend(ext_opt_gen)
    trainer.extend(ext_opt_dis)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()
コード例 #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/sr.yml',
                        help='path to config file')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='index of gpu to be used')
    parser.add_argument('--results_dir',
                        type=str,
                        default='./results',
                        help='directory to save the results to')
    parser.add_argument(
        '--inception_model_path',
        type=str,
        default='./datasets/inception_model/inception_score.model',
        help='path to the inception model')
    parser.add_argument(
        '--stat_file',
        type=str,
        default='./datasets/inception_model/fid_stats_cifar10_train.npz',
        help='path to the inception model')
    parser.add_argument('--snapshot',
                        type=str,
                        default='',
                        help='path to the snapshot')
    parser.add_argument('--loaderjob',
                        type=int,
                        help='number of parallel data loading processes')

    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    chainer.cuda.get_device_from_id(args.gpu).use()
    # set up the model
    devices = {'main': 0, 'second': 1, 'third': 2, 'fourth': 3}
    gen, dis = load_models(config)
    model = {"gen": gen, "dis": dis}
    names = list(six.iterkeys(devices))
    try:
        names.remove('main')
    except ValueError:
        raise KeyError("devices must contain a 'main' key.")
    models = {'main': model}

    for name in names:
        g = copy.deepcopy(model['gen'])
        d = copy.deepcopy(model['dis'])
        if devices[name] >= 0:
            g.to_gpu(device=devices[name])
            d.to_gpu(device=devices[name])
        models[name] = {"gen": g, "dis": d}

    if devices['main'] >= 0:
        models['main']['gen'].to_gpu(device=devices['main'])
        models['main']['dis'].to_gpu(device=devices['main'])

    links = [[name, link] for name, link in sorted(dis.namedlinks())]
    for name, link in links:
        print(name)
    links = [[name, link] for name, link in sorted(gen.namedlinks())]
    for name, link in links:
        print(name)

    # Optimizer
    opt_gen = make_optimizer(models['main']['gen'],
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opt_dis = make_optimizer(models['main']['dis'],
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opts = {"opt_gen": opt_gen, "opt_dis": opt_dis}
    # Dataset
    dataset = yaml_utils.load_dataset(config)
    # Iterator
    iterator = chainer.iterators.MultiprocessIterator(
        dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'devices': devices,
        'models': models,
        'iterator': iterator,
        'optimizer': opts
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'),
                               out=out)
    report_keys = ["loss_dis", "loss_gen", "inception_mean", "FID"]
    # Set up logging
    trainer.extend(extensions.snapshot(),
                   trigger=(config.snapshot_interval, 'iteration'))
    for m in models['main'].values():
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(config.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(config.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(config.display_interval, 'iteration'))
    if gen.n_classes > 0:
        trainer.extend(sample_generate_conditional(models['main']['gen'],
                                                   out,
                                                   n_classes=gen.n_classes),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    else:
        trainer.extend(sample_generate(models['main']['gen'], out),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(models['main']['gen'],
                                         out,
                                         rows=10,
                                         cols=10),
                   trigger=(config.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_inception(models['main']['gen'],
                                  n_ims=5000,
                                  splits=1,
                                  dst=args.results_dir,
                                  path=args.inception_model_path),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_FID(models['main']['gen'],
                            n_ims=5000,
                            dst=args.results_dir,
                            path=args.inception_model_path,
                            stat_file=args.stat_file),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(monitor_largest_singular_values(models['main']['dis'],
                                                   dst=args.results_dir),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)

    trainer.extend(
        extensions.ProgressBar(update_interval=config.progressbar_interval))
    ext_opt_gen = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_gen)
    ext_opt_dis = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_dis)
    trainer.extend(ext_opt_gen)
    trainer.extend(ext_opt_dis)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()
コード例 #8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/base.yml',
                        help='path to config file')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='index of gpu to be used')
    parser.add_argument('--input_dir', type=str, default='./data/imagenet')
    parser.add_argument('--truth_dir', type=str, default='./data/imagenet')
    parser.add_argument('--results_dir',
                        type=str,
                        default='./results/gans',
                        help='directory to save the results to')
    parser.add_argument('--snapshot',
                        type=str,
                        default='',
                        help='path to the snapshot')
    parser.add_argument('--gen_model',
                        type=str,
                        default='',
                        help='path to the generator .npz file')
    parser.add_argument('--dis_model',
                        type=str,
                        default='',
                        help='path to the discriminator .npz file')
    parser.add_argument('--loaderjob',
                        type=int,
                        help='number of parallel data loading processes')

    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    chainer.cuda.get_device_from_id(args.gpu).use()
    gen, dis, enc = load_models(config)

    chainer.serializers.load_npz(args.gen_model, gen)
    chainer.serializers.load_npz(args.dis_model, dis)

    gen.to_gpu(device=args.gpu)
    dis.to_gpu(device=args.gpu)
    enc.to_gpu(device=args.gpu)
    models = {"gen": gen, "dis": dis, "enc": enc}

    opt_enc = make_optimizer(enc,
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opts = {"opt_enc": opt_enc}
    # Dataset
    config['dataset']['args']['root_input'] = args.input_dir
    config['dataset']['args']['root_truth'] = args.truth_dir
    dataset = yaml_utils.load_dataset(config)
    # Iterator
    iterator = chainer.iterators.MultiprocessIterator(
        dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'),
                               out=out)
    report_keys = ["loss", "min_slope", "max_slope", "min_z", "max_z"]
    # Set up logging
    trainer.extend(extensions.snapshot(),
                   trigger=(config.snapshot_interval, 'iteration'))
    for m in models.values():
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(config.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(config.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(config.display_interval, 'iteration'))

    trainer.extend(sample_reconstruction(enc,
                                         gen,
                                         out,
                                         n_classes=gen.n_classes),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(
        extensions.ProgressBar(update_interval=config.progressbar_interval))
    ext_opt_enc = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_enc)
    trainer.extend(ext_opt_enc)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()
コード例 #9
0
ファイル: train_mn.py プロジェクト: grigorisg9gr/rocgan
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, default='configs/base.yml')
    parser.add_argument('--n_devices', type=int)
    parser.add_argument('--test', action='store_true', default=False)
    parser.add_argument('--communicator', type=str,
                        default='hierarchical', help='Type of communicator')
    parser.add_argument('--results_dir', type=str, default='results_rocgan')
    parser.add_argument('--inception_model_path', type=str,
                        default='/home/user/inception/inception.model')
    parser.add_argument('--resume', type=str, default='')
    parser.add_argument('--enc_snapshot', type=str, default=None, help='path to the encoder snapshot')
    parser.add_argument('--dec_snapshot', type=str, default=None, help='path to the decoder snapshot')
    parser.add_argument('--dis_snapshot', type=str, default=None, help='path to the discriminator snapshot')
    parser.add_argument('--loaderjob', type=int,
                        help='Number of parallel data loading processes')
    parser.add_argument('--multiprocessing', action='store_true', default=False)
    parser.add_argument('--validation', type=int, default=1)
    parser.add_argument('--valid_fn', type=str, default='files_valid_4k.txt', 
                        help='filename of the validation file')
    parser.add_argument('--label', type=str, default='synth')
    parser.add_argument('--stats_fid', type=str, default='', help='path for FID stats')
    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    # # ensure that the paths of the config are correct.
    config = ensure_config_paths(config)
    comm = chainermn.create_communicator(args.communicator)
    device = comm.intra_rank
    chainer.cuda.get_device_from_id(device).use()
    # # get the pc name, e.g. for chainerui.
    pcname = gethostname()
    imperialpc = 'doc.ic.ac.uk' in pcname or pcname in ['ladybug', 'odysseus']
    print('Init on pc: {}.'.format(pcname))
    if comm.rank == 0:
        print('==========================================')
        print('Using {} communicator'.format(args.communicator))
        print('==========================================')
    enc, dec, dis = load_models_cgan(config)
    if chainer.cuda.available:
        enc.to_gpu()
        dec.to_gpu()
        dis.to_gpu()
    else:
        print('No GPU found!!!\n')
    mma1 = ModelMovingAverage(0.999, enc)
    mma2 = ModelMovingAverage(0.999, dec)
    models = {'enc': enc, 'dec': dec, 'dis': dis}
    if args.enc_snapshot is not None:
        print('Loading encoder: {}.'.format(args.enc_snapshot))
        chainer.serializers.load_npz(args.enc_snapshot, enc)
    if args.dec_snapshot is not None:
        print('Loading decoder: {}.'.format(args.dec_snapshot))
        chainer.serializers.load_npz(args.dec_snapshot, dec)
    if args.dis_snapshot is not None:
        print('Loading discriminator: {}.'.format(args.dis_snapshot))
        chainer.serializers.load_npz(args.dis_snapshot, dis)
    # # convenience function for optimizer:
    func_opt = lambda net: make_optimizer(net, comm, chmn=args.multiprocessing,
                                          alpha=config.adam['alpha'], beta1=config.adam['beta1'], 
                                          beta2=config.adam['beta2'])
    # Optimizer
    opt_enc = func_opt(enc)
    opt_dec = func_opt(dec)
    opt_dis = func_opt(dis)
    opts = {'opt_enc': opt_enc, 'opt_dec': opt_dec, 'opt_dis': opt_dis}
    # Dataset
    if comm.rank == 0:
        dataset = yaml_utils.load_dataset(config)
        printtime('Length of dataset: {}.'.format(len(dataset)))
        if args.validation:
            # # add the validation db if we do perform validation.
            db_valid = yaml_utils.load_dataset(config, validation=True, valid_path=args.valid_fn)
    else:
        _ = yaml_utils.load_dataset(config)  # Dummy, for adding path to the dataset module
        dataset = None
        if args.validation:
            _ = yaml_utils.load_dataset(config, validation=True, valid_path=args.valid_fn)
            db_valid = None
    dataset = chainermn.scatter_dataset(dataset, comm)
    if args.validation:
        db_valid = chainermn.scatter_dataset(db_valid, comm)
    # Iterator
    multiprocessing.set_start_method('forkserver')
    if args.multiprocessing:
        # # In minoas this might fail with the forkserver.py error.
        iterator = chainer.iterators.MultiprocessIterator(dataset, config.batchsize,
                                                          n_processes=args.loaderjob)
        if args.validation:
            iter_val = chainer.iterators.MultiprocessIterator(db_valid, config.batchsize,
                                                              n_processes=args.loaderjob,
                                                              shuffle=False, repeat=False)
    else:
        iterator = chainer.iterators.SerialIterator(dataset, config.batchsize)
        if args.validation:
            iter_val = chainer.iterators.SerialIterator(db_valid, config.batchsize,
                                                        shuffle=False, repeat=False)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
        'device': device,
        'mma1': mma1,
        'mma2': mma2,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    if not args.test:
        mainf = '{}_{}'.format(strftime('%Y_%m_%d__%H_%M_%S'), args.label)
        out = os.path.join(args.results_dir, mainf, '')
    else:
        out = 'results/test'
    if comm.rank == 0:
        create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'), out=out)
    # # abbreviations below: inc -> incpetion, gadv -> grad_adv, lgen -> loss gener, 
    # # {m, sd, b}[var] -> {mean, std, best position} [var], 
    report_keys = ['loss_dis', 'lgen_adv', 'dis_real', 'dis_fake', 'loss_l1',
                   'mssim', 'sdssim', 'mmae', 'loss_projl', 'FID']

    if comm.rank == 0:
        # Set up logging
        for m in models.values():
            trainer.extend(extensions.snapshot_object(
                m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(config.snapshot_interval, 'iteration'))
#         trainer.extend(extensions.snapshot_object(
#             mma.avg_model, mma.avg_model.__class__.__name__ + '_avgmodel_{.updater.iteration}.npz'),
#             trigger=(config.snapshot_interval, 'iteration'))
        trainer.extend(extensions.LogReport(trigger=(config.display_interval, 'iteration')))
        trainer.extend(extensions.PrintReport(report_keys), trigger=(config.display_interval, 'iteration'))
        if args.validation:
            # # add the appropriate extension for validating the model.
            models_mma = {'enc': mma1.avg_model, 'dec': mma2.avg_model, 'dis': dis}
            trainer.extend(validation_trainer(models_mma, iter_val, n=len(db_valid), export_best=True, 
                                              pout=out, p_inc=args.inception_model_path, eval_fid=True,
                                              sfile=args.stats_fid),
                           trigger=(config.evaluation_interval, 'iteration'),
                           priority=extension.PRIORITY_WRITER)

        trainer.extend(extensions.ProgressBar(update_interval=config.display_interval))
        if imperialpc:
            # [ChainerUI] Observe learning rate
            trainer.extend(extensions.observe_lr(optimizer_name='opt_dis'))
            # [ChainerUI] enable to send commands from ChainerUI
            trainer.extend(CommandsExtension())
            # [ChainerUI] save 'args' to show experimental conditions
            save_args(args, out)

    # # convenience function for linearshift in optimizer:
    func_opt_shift = lambda optim1: extensions.LinearShift('alpha', (config.adam['alpha'], 0.),
                                                           (config.iteration_decay_start, 
                                                            config.iteration), optim1)
    # # define the actual extensions (for optimizer shift).
    trainer.extend(func_opt_shift(opt_enc))
    trainer.extend(func_opt_shift(opt_dec))
    trainer.extend(func_opt_shift(opt_dis))

    if args.resume:
        print('Resume Trainer')
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    printtime('start training')
    trainer.run()
    plot_losses_log(out, savefig=True)
コード例 #10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config_path',
        type=str,
        default='jobs/pinet/fashionmnist_cnn_prodpoly_linear.yml')
    parser.add_argument('--n_devices', type=int)
    parser.add_argument('--test', action='store_true', default=False)
    parser.add_argument('--communicator',
                        type=str,
                        default='hierarchical',
                        help='Type of communicator')
    parser.add_argument('--results_dir',
                        type=str,
                        default='results_polynomial')
    parser.add_argument('--inception_model_path',
                        type=str,
                        default='/home/user/inception/inception.model')
    parser.add_argument('--resume', type=str, default='')
    parser.add_argument('--gen_snapshot',
                        type=str,
                        default=None,
                        help='path to the generator snapshot')
    parser.add_argument('--dis_snapshot',
                        type=str,
                        default=None,
                        help='path to the discriminator snapshot')
    parser.add_argument('--loaderjob',
                        type=int,
                        help='Number of parallel data loading processes')
    parser.add_argument('--multiprocessing',
                        action='store_true',
                        default=False)
    parser.add_argument('--label', type=str, default='synth')
    parser.add_argument('--batch_val', type=int, default=1000)
    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    # # ensure that the paths of the config are correct.
    config = ensure_config_paths(config)
    try:
        comm = chainermn.create_communicator(args.communicator)
    except:
        comm = chainermn.create_communicator()
    device = comm.intra_rank
    chainer.cuda.get_device_from_id(device).use()
    # # get the pc name, e.g. for chainerui.
    pcname = gethostname()
    print('Init on pc: {}.'.format(pcname))
    if comm.rank == 0:
        print('==========================================')
        print('Using {} communicator'.format(args.communicator))
        print('==========================================')
    gen, dis = load_models(config)
    gen.to_gpu()
    dis.to_gpu()
    mma = ModelMovingAverage(0.999, gen)
    models = {"gen": gen, "dis": dis}
    if args.gen_snapshot is not None:
        print('Loading generator: {}.'.format(args.gen_snapshot))
        chainer.serializers.load_npz(args.gen_snapshot, gen)
    if args.dis_snapshot is not None:
        print('Loading discriminator: {}.'.format(args.dis_snapshot))
        chainer.serializers.load_npz(args.dis_snapshot, dis)
    # Optimizer
    # # convenience function for optimizer:
    func_opt = lambda net, alpha, wdr0=0: make_optimizer(
        net,
        comm,
        chmn=args.multiprocessing,
        alpha=alpha,
        beta1=config.adam['beta1'],
        beta2=config.adam['beta2'],
        weight_decay_rate=wdr0)
    # Optimizer
    wdr = 0 if 'weight_decay_rate' not in config.updater[
        'args'] else config.updater['args']['weight_decay_rate_gener']
    opt_gen = func_opt(gen, config.adam['alpha'], wdr0=wdr)
    keydopt = 'alphad' if 'alphad' in config.adam.keys() else 'alpha'
    opt_dis = func_opt(dis, config.adam[keydopt])
    opts = {"opt_gen": opt_gen, "opt_dis": opt_dis}
    if hasattr(dis, 'fix_last') and hasattr(dis, 'lin') and dis.fix_last:
        # # This should be used with care. It fixes the linear layer that
        # # makes the classification in the discriminator.
        print('Fixing the linear layer of the discriminator!')
        dis.disable_update()
    # Dataset
    if comm.rank == 0:
        dataset = yaml_utils.load_dataset(config)
        # # even though not new samples, use as proxy for iid validation ones.
        if hasattr(dataset, 'n_concats') and dataset.n_concats == 1:
            valid_samples = (dataset.base[:args.batch_val] + 1) * 127.5
        else:
            valid_samples = [(dataset.get_example(i)[0] + 1) * 127.5
                             for i in range(args.batch_val)]
            # # convert the validation to an array as required by the kl script.
            valid_samples = np.array(valid_samples, dtype=np.float32)
    else:
        _ = yaml_utils.load_dataset(
            config)  # Dummy, for adding path to the dataset module
        dataset = None
    dataset = chainermn.scatter_dataset(dataset, comm)
    # Iterator
    multiprocessing.set_start_method('forkserver')
    if args.multiprocessing:
        # # In minoas this might fail with the forkserver.py error.
        iterator = chainer.iterators.MultiprocessIterator(
            dataset, config.batchsize, n_processes=args.loaderjob)
    else:
        iterator = chainer.iterators.SerialIterator(dataset, config.batchsize)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
        'device': device,
        'mma': mma,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    if not args.test:
        mainf = '{}_{}'.format(strftime('%Y_%m_%d__%H_%M_%S'), args.label)
        out = os.path.join(args.results_dir, mainf, '')
    else:
        out = 'results/test'
    if comm.rank == 0:
        create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'),
                               out=out)
    report_keys = [
        'loss_dis', 'loss_gen', 'kl', 'ndb', 'JS', 'dis_real', 'dis_fake'
    ]

    if comm.rank == 0:
        # Set up logging
        #         trainer.extend(extensions.snapshot(), trigger=(config.snapshot_interval, 'iteration'))
        for m in models.values():
            trainer.extend(extensions.snapshot_object(
                m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                           trigger=(config.snapshot_interval, 'iteration'))

        trainer.extend(
            extensions.LogReport(trigger=(config.display_interval,
                                          'iteration')))
        trainer.extend(extensions.PrintReport(report_keys),
                       trigger=(config.display_interval, 'iteration'))
        trainer.extend(
            extensions.ProgressBar(update_interval=config.display_interval))

        if gen.n_classes == 0:
            trainer.extend(sample_generate(mma.avg_model, out),
                           trigger=(config.evaluation_interval, 'iteration'),
                           priority=extension.PRIORITY_WRITER)
            print('unconditional image generation extension added.')
        else:
            trainer.extend(sample_generate_conditional(
                mma.avg_model, out, n_classes=gen.n_classes),
                           trigger=(config.evaluation_interval, 'iteration'),
                           priority=extension.PRIORITY_WRITER)

        trainer.extend(divergence_trainer(gen,
                                          valid_samples,
                                          metric=['kl', 'ndb'],
                                          batch=args.batch_val),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)

    # # convenience function for linearshift in optimizer:
    func_opt_shift = lambda optim1: extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), optim1)
    # # define the actual extensions (for optimizer shift).
    trainer.extend(func_opt_shift(opt_gen))
    trainer.extend(func_opt_shift(opt_dis))

    if args.resume:
        print("Resume Trainer")
        chainer.serializers.load_npz(args.resume, trainer)

    m1 = 'Generator params: {}. Discriminator params: {}.'
    print(m1.format(gen.count_params(), dis.count_params()))
    # Run the training
    print("start training")
    trainer.run()
    print('The output dir was {}.'.format(out))
    plot_losses_log(out, savefig=True)
コード例 #11
0
ファイル: finetune.py プロジェクト: zeta1999/FreezeD
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, default='configs/base.yml', help='path to config file')
    parser.add_argument('--gpu', type=int, default=0, help='index of gpu to be used')
    parser.add_argument('--data_dir', type=str, default='./dataset/imagenet')
    parser.add_argument('--results_dir', type=str, default='./results/temp', help='directory to save the results to')
    parser.add_argument('--loaderjob', type=int, help='number of parallel data loading processes')
    parser.add_argument('--layer', type=int, default=0, help='freeze discriminator layer')

    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path), Loader=yaml.FullLoader))
    chainer.cuda.get_device_from_id(args.gpu).use()

    # Fix randomness
    random.seed(config['seed'])
    np.random.seed(config['seed'])
    cupy.random.seed(config['seed'])

    # Model
    G_src, D_src = load_models(config, mode='source')
    G_tgt, D_tgt = load_models(config, mode='target')

    G_src.to_gpu(device=args.gpu)
    D_src.to_gpu(device=args.gpu)
    G_tgt.to_gpu(device=args.gpu)
    D_tgt.to_gpu(device=args.gpu)

    models = {"G_src": G_src, "D_src": D_src, "G_tgt": G_tgt, "D_tgt": D_tgt}

    # Optimizer
    opt_G_tgt = make_optimizer(G_tgt, alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2'])
    opt_D_tgt = make_optimizer(D_tgt, alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2'])

    for i in range(1, args.layer + 1):  # freeze discriminator
        getattr(D_tgt, f'block{i}').disable_update()

    opts = {"opt_G_tgt": opt_G_tgt, "opt_D_tgt": opt_D_tgt}

    # Dataset
    config['dataset']['args']['root'] = args.data_dir
    dataset = yaml_utils.load_dataset(config)

    # Iterator
    iterator = chainer.iterators.MultiprocessIterator(dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'), out=out)
    report_keys = ["loss_dis", "loss_gen", "loss_FM", "FID"]

    # Set up logging
    trainer.extend(extensions.snapshot(filename='snapshot_best'),
                   trigger=training.triggers.MinValueTrigger("FID", trigger=(config.evaluation_interval, 'iteration')))
    for m in models.values():
        trainer.extend(extensions.snapshot_object(m, m.__class__.__name__ + '_best.npz'),
                       trigger=training.triggers.MinValueTrigger("FID", trigger=(config.evaluation_interval, 'iteration')))

    trainer.extend(extensions.LogReport(keys=report_keys, trigger=(config.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys), trigger=(config.display_interval, 'iteration'))

    trainer.extend(sample_generate_conditional(G_tgt, out, n_classes=G_tgt.n_classes),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(G_tgt, out, rows=10, cols=10),
                   trigger=(config.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_FID(G_tgt, n_ims=5000, stat_file=config['eval']['stat_file']),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=config.progressbar_interval))

    ext_opt_gen = extensions.LinearShift('alpha', (config.adam['alpha'], 0.), (config.iteration_decay_start, config.iteration), opt_G_tgt)
    ext_opt_dis = extensions.LinearShift('alpha', (config.adam['alpha'], 0.), (config.iteration_decay_start, config.iteration), opt_D_tgt)
    trainer.extend(ext_opt_gen)
    trainer.extend(ext_opt_dis)

    # Load source networks
    chainer.serializers.load_npz(config['pretrained']['gen'], trainer.updater.models['G_src'])
    chainer.serializers.load_npz(config['pretrained']['dis'], trainer.updater.models['D_src'])
    load_parameters(trainer.updater.models['G_src'], trainer.updater.models['G_tgt'])
    load_parameters(trainer.updater.models['D_src'], trainer.updater.models['D_tgt'])

    # Run the training
    print("start training")
    trainer.run()
コード例 #12
0
def main():

    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    chainer.cuda.get_device_from_id(args.gpu).use()
    from model_loader import gen, dis

    models = {"gen": gen, "dis": dis}
    # Optimizer
    opt_gen = make_optimizer(
        gen, alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2'])
    opt_dis = make_optimizer(
        dis, alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2'])
    opts = {"opt_gen": opt_gen, "opt_dis": opt_dis}
    # Dataset
    # Cifar10 and STL10 dataset handler does not take "root" as argument.
    if config['dataset']['dataset_name'] != 'CIFAR10Dataset' and \
            config['dataset']['dataset_name'] != 'STL10Dataset':
        config['dataset']['args']['root'] = args.data_dir
    dataset = yaml_utils.load_dataset(config)
    # Iterator
    iterator = chainer.iterators.MultiprocessIterator(
        dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'), out=out)
    report_keys = ["loss_dis", "loss_gen", "inception_mean", "inception_std", "FID_mean", "FID_std"]
    # Set up logging
    trainer.extend(extensions.snapshot(), trigger=(config.snapshot_interval, 'iteration'))
    for m in models.values():
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(config.snapshot_interval, 'iteration'))
    trainer.extend(extensions.LogReport(keys=report_keys,
                                        trigger=(config.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys), trigger=(config.display_interval, 'iteration'))
    if gen.n_classes > 0:
        trainer.extend(sample_generate_conditional(gen, out, n_classes=gen.n_classes),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    else:
        trainer.extend(sample_generate(gen, out),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(gen, out, rows=10, cols=10),
                   trigger=(config.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    # trainer.extend(calc_inception(gen, n_ims=50000, dst=out, splits=10, path=args.inception_model_path),
    #                trigger=(config.evaluation_interval, 'iteration'),
    #                priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_inception_and_FID(gen, n_ims=50000, dst=out, path=args.inception_model_path, splits=10, stat_file=args.FID_stat_file),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=config.progressbar_interval))
    ext_opt_gen = extensions.LinearShift('alpha', (config.adam['alpha'], 0.),
                                         (config.iteration_decay_start, config.iteration), opt_gen)
    ext_opt_dis = extensions.LinearShift('alpha', (config.adam['alpha'], 0.),
                                         (config.iteration_decay_start, config.iteration), opt_dis)
    trainer.extend(ext_opt_gen)
    trainer.extend(ext_opt_dis)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()