def test_serialize_before_first_interval(self):
        self.trainer.updater.optimizer.x = 0
        extension = extensions.LinearShift('x', self.value_range,
                                           self.time_range)
        self._run_trainer(extension, self.expect[:self.interval - 1])
        target = dict()
        extension.serialize(DummySerializer(target))

        self.trainer.updater.optimizer.x = 0
        extension = extensions.LinearShift('x', self.value_range,
                                           self.time_range)
        extension.serialize(DummyDeserializer(target))
        self._run_trainer(extension, self.expect[self.interval - 1:])
    def test_serialize(self):
        self.trainer.updater.optimizer.x = 0
        extension = extensions.LinearShift('x', self.value_range,
                                           self.time_range)
        self._run_trainer(extension, self.expect[:len(self.expect) // 2])
        target = dict()
        extension.serialize(DummySerializer(target))

        self.trainer.updater.optimizer.x = 0
        extension = extensions.LinearShift('x', self.value_range,
                                           self.time_range)
        extension.serialize(DummyDeserializer(target))
        self._run_trainer(extension, self.expect[len(self.expect) // 2:])
Esempio n. 3
0
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.extension = extensions.LinearShift('x', self.value_range,
                                                self.time_range,
                                                self.optimizer)

        self.interval = 2
        self.trigger = training.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
Esempio n. 4
0
    def test_resume(self):
        new_optimizer = mock.Mock()
        new_extension = extensions.LinearShift('x', self.value_range,
                                               self.time_range, new_optimizer)

        self.trainer.extend(self.extension)
        self.trainer.run()

        new_trainer = testing.get_trainer_with_mock_updater((5, 'iteration'))
        new_trainer.extend(new_extension)
        testing.save_and_load_npz(self.trainer, new_trainer)

        new_extension.initialize(new_trainer)
        self.assertEqual(new_optimizer.x, self.optimizer.x)
Esempio n. 5
0
def run_training(args):

    model = LossyModel(c_list=args.c_list, q_num=args.q_num)
    if args.gpu >= 0:
        cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    log_interval = (100, 'iteration')
    eval_interval = (5000, 'iteration')

    train = D.ImageDataset(args.root, args.train_paths)
    test = D.ImageDataset(args.root, args.test_paths)

    train_iter = chainer.iterators.MultiprocessIterator(train, args.batchsize)
    test_iter = chainer.iterators.MultiprocessIterator(test,
                                                       args.batchsize,
                                                       repeat=False,
                                                       shuffle=False)

    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.iteration, 'iteration'),
                               out=args.out)

    evaluator = extensions.Evaluator(test_iter, model, device=args.gpu)
    trainer.extend(evaluator, trigger=eval_interval, name='val')

    trainer.extend(extensions.ProgressBar())
    trainer.extend(
        extensions.LogReport(trigger=log_interval, log_name='log_lossy'))
    trainer.extend(extensions.snapshot_object(model, 'model_lossy'),
                   trigger=(args.iteration, 'iteration'))

    trainer.extend(
        extensions.PrintReport([
            'iteration',
            'main/MSSSIM',
            'val/main/MSSSIM',
            'elapsed_time',
        ]))

    trainer.extend(
        extensions.LinearShift('alpha', (0.001, 0),
                               (int(args.iteration * 0.75), args.iteration)))

    print('==========================================')
    trainer.run()
Esempio n. 6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/sr.yml',
                        help='path to config file')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='index of gpu to be used')
    parser.add_argument('--results_dir',
                        type=str,
                        default='./results',
                        help='directory to save the results to')
    parser.add_argument(
        '--inception_model_path',
        type=str,
        default='./datasets/inception_model/inception_score.model',
        help='path to the inception model')
    parser.add_argument(
        '--stat_file',
        type=str,
        default='./datasets/inception_model/fid_stats_cifar10_train.npz',
        help='path to the inception model')
    parser.add_argument('--snapshot',
                        type=str,
                        default='',
                        help='path to the snapshot')
    parser.add_argument('--loaderjob',
                        type=int,
                        help='number of parallel data loading processes')

    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    chainer.cuda.get_device_from_id(args.gpu).use()
    # set up the model
    devices = {'main': 0, 'second': 1, 'third': 2, 'fourth': 3}
    gen, dis = load_models(config)
    model = {"gen": gen, "dis": dis}
    names = list(six.iterkeys(devices))
    try:
        names.remove('main')
    except ValueError:
        raise KeyError("devices must contain a 'main' key.")
    models = {'main': model}

    for name in names:
        g = copy.deepcopy(model['gen'])
        d = copy.deepcopy(model['dis'])
        if devices[name] >= 0:
            g.to_gpu(device=devices[name])
            d.to_gpu(device=devices[name])
        models[name] = {"gen": g, "dis": d}

    if devices['main'] >= 0:
        models['main']['gen'].to_gpu(device=devices['main'])
        models['main']['dis'].to_gpu(device=devices['main'])

    links = [[name, link] for name, link in sorted(dis.namedlinks())]
    for name, link in links:
        print(name)
    links = [[name, link] for name, link in sorted(gen.namedlinks())]
    for name, link in links:
        print(name)

    # Optimizer
    opt_gen = make_optimizer(models['main']['gen'],
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opt_dis = make_optimizer(models['main']['dis'],
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opts = {"opt_gen": opt_gen, "opt_dis": opt_dis}
    # Dataset
    dataset = yaml_utils.load_dataset(config)
    # Iterator
    iterator = chainer.iterators.MultiprocessIterator(
        dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'devices': devices,
        'models': models,
        'iterator': iterator,
        'optimizer': opts
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'),
                               out=out)
    report_keys = ["loss_dis", "loss_gen", "inception_mean", "FID"]
    # Set up logging
    trainer.extend(extensions.snapshot(),
                   trigger=(config.snapshot_interval, 'iteration'))
    for m in models['main'].values():
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(config.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(config.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(config.display_interval, 'iteration'))
    if gen.n_classes > 0:
        trainer.extend(sample_generate_conditional(models['main']['gen'],
                                                   out,
                                                   n_classes=gen.n_classes),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    else:
        trainer.extend(sample_generate(models['main']['gen'], out),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(models['main']['gen'],
                                         out,
                                         rows=10,
                                         cols=10),
                   trigger=(config.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_inception(models['main']['gen'],
                                  n_ims=5000,
                                  splits=1,
                                  dst=args.results_dir,
                                  path=args.inception_model_path),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_FID(models['main']['gen'],
                            n_ims=5000,
                            dst=args.results_dir,
                            path=args.inception_model_path,
                            stat_file=args.stat_file),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(monitor_largest_singular_values(models['main']['dis'],
                                                   dst=args.results_dir),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)

    trainer.extend(
        extensions.ProgressBar(update_interval=config.progressbar_interval))
    ext_opt_gen = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_gen)
    ext_opt_dis = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_dis)
    trainer.extend(ext_opt_gen)
    trainer.extend(ext_opt_dis)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()
def main():
    archs = {
        'resnet_multilabel': resnet_finetune_custom_multilabel.Encoder,
        'pspnet_resnet_multilabel': pspnet_resnet_multilabel.Encoder
    }

    parser = argparse.ArgumentParser(
        description='Learning convnet from ILSVRC2012 dataset')
    parser.add_argument('--arch',
                        '-a',
                        choices=archs.keys(),
                        default='resnet_finetune',
                        help='Convnet architecture')
    parser.add_argument('--class_num',
                        '-c',
                        type=int,
                        default=None,
                        help='class')
    parser.add_argument('--batchsize',
                        '-B',
                        type=int,
                        default=32,
                        help='Learning minibatch size')
    parser.add_argument('--epoch',
                        '-E',
                        type=int,
                        default=100,
                        help='Number of epochs to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU')
    parser.add_argument('--initmodel',
                        help='Initialize the model from given file')
    parser.add_argument('--loaderjob',
                        '-j',
                        type=int,
                        help='Number of parallel data loading processes')
    parser.add_argument('--mean',
                        '-m',
                        default='mean.npy',
                        help='Mean file (computed by compute_mean.py)')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Initialize the trainer from given file')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Output directory')
    parser.add_argument('--optimizer',
                        '-opt',
                        default='msgd',
                        help='Output directory')
    parser.add_argument('--root',
                        '-R',
                        default='.',
                        help='Root directory path of image files')
    parser.add_argument('--val_batchsize',
                        '-b',
                        type=int,
                        default=100,
                        help='Validation minibatch size')
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--width',
                        '-wid',
                        type=int,
                        default=256,
                        help='Learning minibatch size')
    parser.add_argument('--height',
                        '-hei',
                        type=int,
                        default=384,
                        help='Learning minibatch size')
    parser.add_argument('--val_iter', '-val_iter', type=int, default=1000)
    parser.set_defaults(test=False)
    args = parser.parse_args()

    # Initialize the model to train
    model = archs[args.arch]()
    print("model:" + str(archs[args.arch]))
    if args.initmodel:
        print('Load model from', args.initmodel)
        chainer.serializers.load_npz(args.initmodel, model)
    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()  # Make the GPU current
        model.to_gpu()

    # Load the datasets and mean file
    mean = np.load(args.mean)
    train = PreprocessedDataset('train_288416',
                                mean,
                                crop_size_x=args.width,
                                crop_size_y=args.height)
    val = PreprocessedDataset('test_288416',
                              mean,
                              crop_size_x=args.width,
                              crop_size_y=args.height,
                              random=False)

    train_iter = chainer.iterators.MultiprocessIterator(
        train, args.batchsize, n_processes=args.loaderjob)
    val_iter = chainer.iterators.MultiprocessIterator(
        val, args.val_batchsize, repeat=False, n_processes=args.loaderjob)
    # Set up an optimizer
    if args.optimizer == 'adagrad':
        optimizer = chainer.optimizers.AdaGrad()
        print(optimizer)
        print("lr:" + str(optimizer.lr))
    elif args.optimizer == 'sgd':
        optimizer = chainer.optimizers.SGD()
        print(optimizer)
        print("lr:" + str(optimizer.lr))
    elif args.optimizer == 'adam':
        optimizer = chainer.optimizers.Adam()
        print(optimizer)
        print("alpha:" + str(optimizer.alpha))
    elif args.optimizer == 'rmsprop':
        optimizer = chainer.optimizers.RMSprop()
        print(optimizer)
        print("lr:" + str(optimizer.lr))
    elif args.optimizer == 'adadelta':
        optimizer = chainer.optimizers.AdaDelta()
        print(optimizer)
    else:
        optimizer = chainer.optimizers.MomentumSGD(lr=0.01, momentum=0.9)
        print(optimizer)
        print("lr:" + str(optimizer.lr))

    optimizer.setup(model)

    # Set up a trainer
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), args.out)

    log_interval = (10 if args.test else 100), 'iteration'
    val_interval = (100 if args.test else args.val_iter), 'iteration'

    trainer.extend(extensions.dump_graph('main/loss'))
    if not args.optimizer == "adam":
        trainer.extend(
            extensions.LinearShift("lr", (0.01, 0.001), (10000, 20000)))
    trainer.extend(TestModeEvaluator(val_iter, model, device=args.gpu),
                   trigger=val_interval)
    trainer.extend(extensions.snapshot_object(
        model, 'model_iteration_{.updater.iteration}'),
                   trigger=val_interval)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    #trainer.extend(extensions.observe_lr(), trigger=log_interval)
    trainer.extend(extensions.PrintReport([
        'epoch', 'iteration', 'main/loss', 'validation/main/loss',
        'main/accuracy', 'validation/main/accuracy', 'validation/main/recall',
        '\
        validation/main/precision', 'validation/main/f_value', 'lr'
    ]),
                   trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))
    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    trainer.run()
Esempio n. 8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/base.yml',
                        help='path to config file')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='index of gpu to be used')
    parser.add_argument('--input_dir', type=str, default='./data/imagenet')
    parser.add_argument('--truth_dir', type=str, default='./data/imagenet')
    parser.add_argument('--results_dir',
                        type=str,
                        default='./results/gans',
                        help='directory to save the results to')
    parser.add_argument('--snapshot',
                        type=str,
                        default='',
                        help='path to the snapshot file to use')
    parser.add_argument('--enc_model',
                        type=str,
                        default='',
                        help='path to the generator .npz file')
    parser.add_argument('--gen_model',
                        type=str,
                        default='',
                        help='path to the generator .npz file')
    parser.add_argument('--dis_model',
                        type=str,
                        default='',
                        help='path to the discriminator .npz file')
    parser.add_argument('--loaderjob',
                        type=int,
                        help='number of parallel data loading processes')

    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    chainer.cuda.get_device_from_id(args.gpu).use()
    gen, dis, enc = load_models(config)

    chainer.serializers.load_npz(args.gen_model, gen, strict=False)
    chainer.serializers.load_npz(args.dis_model, dis)
    chainer.serializers.load_npz(args.enc_model, enc)

    gen.to_gpu(device=args.gpu)
    dis.to_gpu(device=args.gpu)
    enc.to_gpu(device=args.gpu)
    models = {"gen": gen, "dis": dis, "enc": enc}
    opt_gen = make_optimizer(gen,
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opt_gen.add_hook(chainer.optimizer.WeightDecay(config.weight_decay))
    opt_gen.add_hook(chainer.optimizer.GradientClipping(config.grad_clip))

    # disable update of pre-trained weights
    layers_to_train = ['lA1', 'lA2', 'lB1', 'lB2', 'preluW', 'preluMiddleW']
    for layer in gen.children():
        if not layer.name in layers_to_train:
            layer.disable_update()

    lmd_pixel = 0.05

    def fast_loss(out, gt):
        l1 = reconstruction_loss(dis, out, gt)
        l2 = lmd_pixel * pixel_loss(out, gt)
        loss = l1 + l2
        return loss

    gen.set_fast_loss(fast_loss)

    opts = {"opt_gen": opt_gen}

    # Dataset
    config['dataset']['args']['root_input'] = args.input_dir
    config['dataset']['args']['root_truth'] = args.truth_dir
    dataset = yaml_utils.load_dataset(config)
    # Iterator
    iterator = chainer.iterators.MultiprocessIterator(
        dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'),
                               out=out)
    report_keys = [
        "loss_noab", "loss1", "loss2", "loss3", "fast_alpha", "loss_ae",
        "fast_benefit", "min_slope", "max_slope", "min_slope_middle",
        "max_slope_middle"
    ]
    # Set up logging
    trainer.extend(extensions.snapshot(),
                   trigger=(config.snapshot_interval, 'iteration'))
    for m in models.values():
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(config.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(config.display_interval, 'iteration')))
    trainer.extend(extensions.ParameterStatistics(gen),
                   trigger=(config.display_interval, 'iteration'))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(config.display_interval, 'iteration'))

    trainer.extend(sample_reconstruction_auxab(enc,
                                               gen,
                                               out,
                                               n_classes=gen.n_classes),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(
        extensions.ProgressBar(update_interval=config.progressbar_interval))
    ext_opt_gen = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_gen)
    trainer.extend(ext_opt_gen)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()
Esempio n. 9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/base.yml',
                        help='path to config file')
    parser.add_argument('--data_dir', type=str, default='./data/imagenet')
    parser.add_argument('--results_dir',
                        type=str,
                        default='./results/gans',
                        help='directory to save the results to')
    parser.add_argument('--inception_model_path',
                        type=str,
                        default='./datasets/inception_model',
                        help='path to the inception model')
    parser.add_argument('--snapshot',
                        type=str,
                        default='',
                        help='path to the snapshot')
    parser.add_argument('--loaderjob',
                        type=int,
                        help='number of parallel data loading processes')
    parser.add_argument('--communicator',
                        type=str,
                        default='hierarchical',
                        help='Type of communicator')

    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    comm = chainermn.create_communicator(args.communicator)
    device = comm.intra_rank
    chainer.cuda.get_device_from_id(device).use()
    print("init")
    if comm.rank == 0:
        print('==========================================')
        print('Using {} communicator'.format(args.communicator))
        print('==========================================')
    # Model
    gen, dis = load_models(config)
    gen.to_gpu()
    dis.to_gpu()
    models = {"gen": gen, "dis": dis}
    # Optimizer
    opt_gen = make_optimizer(gen,
                             comm,
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opt_dis = make_optimizer(dis,
                             comm,
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opts = {"opt_gen": opt_gen, "opt_dis": opt_dis}
    # Dataset
    config['dataset']['args']['root'] = args.data_dir
    if comm.rank == 0:
        dataset = yaml_utils.load_dataset(config)
    else:
        _ = yaml_utils.load_dataset(
            config)  # Dummy, for adding path to the dataset module
        dataset = None
    dataset = chainermn.scatter_dataset(dataset, comm)
    # Iterator
    multiprocessing.set_start_method('forkserver')
    iterator = chainer.iterators.MultiprocessIterator(
        dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
        'device': device,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    if comm.rank == 0:
        create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'),
                               out=out)
    report_keys = ["loss_dis", "loss_gen", "inception_mean", "inception_std"]
    if comm.rank == 0:
        # Set up logging
        trainer.extend(extensions.snapshot(),
                       trigger=(config.snapshot_interval, 'iteration'))
        for m in models.values():
            trainer.extend(extensions.snapshot_object(
                m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                           trigger=(config.snapshot_interval, 'iteration'))
        trainer.extend(
            extensions.LogReport(keys=report_keys,
                                 trigger=(config.display_interval,
                                          'iteration')))
        trainer.extend(extensions.PrintReport(report_keys),
                       trigger=(config.display_interval, 'iteration'))
        trainer.extend(sample_generate_conditional(gen,
                                                   out,
                                                   n_classes=gen.n_classes),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
        trainer.extend(sample_generate_light(gen, out, rows=10, cols=10),
                       trigger=(config.evaluation_interval // 10, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
        trainer.extend(calc_inception(gen,
                                      n_ims=5000,
                                      splits=1,
                                      path=args.inception_model_path),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
        trainer.extend(
            extensions.ProgressBar(
                update_interval=config.progressbar_interval))
    ext_opt_gen = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_gen)
    ext_opt_dis = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_dis)
    trainer.extend(ext_opt_gen)
    trainer.extend(ext_opt_dis)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()
Esempio n. 10
0
def main():
    parser = argparse.ArgumentParser(description='Train 3D-Unet')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--root',
                        '-R',
                        default=os.path.dirname(os.path.abspath(__file__)),
                        help='Root directory path of input image')
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/base.yml',
                        help='path to config file')
    parser.add_argument('--out',
                        '-o',
                        default='Results_trM1_ValiM2',
                        help='Directory to output the result')

    parser.add_argument('--model', '-m', default='', help='Load model data')
    parser.add_argument('--resume',
                        '-res',
                        default='',
                        help='Resume the training from snapshot')

    parser.add_argument('--training_list',
                        default='configs/M1.txt',
                        help='Path to training image list file')
    parser.add_argument('--training_coordinate_list',
                        type=str,
                        default='configs/M1.csv')

    parser.add_argument('--validation_list',
                        default='configs/M2.txt',
                        help='Path to validation image list file')
    parser.add_argument('--validation_coordinate_list',
                        type=str,
                        default='configs/M2.csv')

    args = parser.parse_args()
    '''
    'https://stackoverflow.com/questions/21005822/what-does-os-path-abspathos-path-joinos-path-dirname-file-os-path-pardir'
    '''
    config = yaml_utils.Config(
        yaml.load(
            open(os.path.join(os.path.dirname(__file__), args.config_path))))
    print('GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(config.batchsize))
    print('# iteration: {}'.format(config.iteration))
    print('')

    # Load the datasets
    train = UnetDataset(args.root, args.training_list,
                        args.training_coordinate_list,
                        config.patch['patchside'])
    train_iter = chainer.iterators.SerialIterator(train,
                                                  batch_size=config.batchsize)

    validation = UnetDataset(args.root, args.validation_list,
                             args.validation_coordinate_list,
                             config.patch['patchside'])
    validation_iter = chainer.iterators.SerialIterator(
        validation, batch_size=config.batchsize, repeat=False, shuffle=False)

    # Set up a neural network to train
    print('Set up a neural network to train')
    unet = UNet3D(2)
    if args.model:
        chainer.serializers.load_npz(args.model, gen)

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        unet.to_gpu()

    #Set up an optimizer
    def make_optimizer(model, alpha=0.0001, beta1=0.9, beta2=0.999):
        optimizer = chainer.optimizers.Adam(alpha=alpha,
                                            beta1=beta1,
                                            beta2=beta2)
        optimizer.setup(model)
        return optimizer

    opt_unet = make_optimizer(model=unet,
                              alpha=config.adam['alpha'],
                              beta1=config.adam['beta1'],
                              beta2=config.adam['beta2'])
    #Set up a trainer
    updater = Unet3DUpdater(models=(unet),
                            iterator=train_iter,
                            optimizer={'unet': opt_unet},
                            device=args.gpu)

    def create_result_dir(base, result_dir, config_path, config):
        """https://github.com/pfnet-research/sngan_projection/blob/master/train.py"""
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)

        def copy_to_result_dir(fn, result_dir):
            bfn = os.path.basename(fn)
            shutil.copy(fn, '{}/{}'.format(result_dir, bfn))

        copy_to_result_dir(config_path, result_dir)
        copy_to_result_dir(os.path.join(base, config.unet['fn']), result_dir)

        copy_to_result_dir(os.path.join(base, config.updater['fn']),
                           result_dir)

    out = os.path.join(args.root, args.out)
    config_path = os.path.join(os.path.dirname(__file__), args.config_path)
    create_result_dir(args.root, out, config_path, config)

    trainer = training.Trainer(updater, (config.iteration, 'iteration'),
                               out=out)
    #serializers.load_npz('C:\\Users\\yourb\\Documents\\GitHub\\3D-Unet\\Results_trM1_ValiM2\\snapshot_iter_10500.npz', trainer)

    # Set up logging
    snapshot_interval = (config.snapshot_interval, 'iteration')
    display_interval = (config.display_interval, 'iteration')
    evaluation_interval = (config.evaluation_interval, 'iteration')
    trainer.extend(UNet3DEvaluator(validation_iter, unet, device=args.gpu),
                   trigger=evaluation_interval)
    trainer.extend(
        extensions.snapshot(filename='snapshot_iter_{.updater.iteration}.npz'),
        trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        unet, filename=unet.__class__.__name__ + '_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=display_interval))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar(update_interval=10))

    # Print selected entries of the log to stdout
    #report_keys = ['epoch', 'iteration', 'unet/loss','unet/dice','vali/unet/loss','vali/unet/dice']
    report_keys = ['iteration', 'unet/dice', 'vali/unet/dice']

    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=display_interval)

    # Use linear shift
    ext_opt_unet = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_unet)
    trainer.extend(ext_opt_unet)

    # Save two plot images to the result dir
    if extensions.PlotReport.available():
        #trainer.extend(extensions.PlotReport(['unet/loss','vali/unet/loss'], 'iteration', file_name='unet_loss.png',trigger=display_interval))
        trainer.extend(
            extensions.PlotReport(['unet/dice', 'vali/unet/dice'],
                                  'iteration',
                                  file_name='unet_dice.png',
                                  trigger=display_interval))

    if args.resume:
        # Resume from a snapshot
        print("Resume training with snapshot:{}".format(args.resume))
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    chainer.config.autotune = True
    print('Start training')
    trainer.run()
def main():
    args = parse_args()
    if args.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)
    config = configparser.ConfigParser()

    logger.info("read {}".format(args.config_path))
    config.read(args.config_path, "UTF-8")
    logger.info("setup devices")
    if chainer.backends.cuda.available:
        devices = setup_devices(config["training_param"]["gpus"])
    else:
        # cpu run
        devices = {"main": -1}
    seed = config.getint("training_param", "seed")
    logger.info("set random seed {}".format(seed))
    set_random_seed(devices, seed)

    result = os.path.expanduser(config["result"]["dir"])
    destination = os.path.join(result, "pose")
    logger.info("> copy code to {}".format(os.path.join(result, "src")))
    save_files(result)
    logger.info("> copy config file to {}".format(destination))
    if not os.path.exists(destination):
        os.makedirs(destination)
    shutil.copy(args.config_path, os.path.join(destination, "config.ini"))

    logger.info("{} chainer debug".format("enable" if args.debug else "disable"))
    chainer.set_debug(args.debug)
    chainer.global_config.autotune = True
    chainer.cuda.set_max_workspace_size(11388608)
    chainer.config.cudnn_fast_batch_normalization = True

    logger.info("> get dataset")
    train_set, val_set, hand_param = select_dataset(config, return_data=["train_set", "val_set", "hand_param"])
    model = select_model(config, hand_param)

    logger.info("> transform dataset")
    train_set = TransformDataset(train_set, model.encode)
    val_set = TransformDataset(val_set, model.encode)
    logger.info("> size of train_set is {}".format(len(train_set)))
    logger.info("> size of val_set is {}".format(len(val_set)))
    logger.info("> create iterators")
    batch_size = config.getint("training_param", "batch_size")
    n_processes = config.getint("training_param", "n_processes")

    train_iter = chainer.iterators.MultiprocessIterator(
        train_set, batch_size,
        n_processes=n_processes
    )
    test_iter = chainer.iterators.MultiprocessIterator(
        val_set, batch_size,
        repeat=False, shuffle=False,
        n_processes=n_processes,
    )

    logger.info("> setup optimizer")
    optimizer = chainer.optimizers.MomentumSGD()
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(0.0005))

    logger.info("> setup parallel updater devices={}".format(devices))
    updater = training.updaters.ParallelUpdater(train_iter, optimizer, devices=devices)

    logger.info("> setup trainer")
    trainer = training.Trainer(
        updater,
        (config.getint("training_param", "train_iter"), "iteration"),
        destination,
    )

    logger.info("> setup extensions")
    trainer.extend(
        extensions.LinearShift("lr",
                               value_range=(config.getfloat("training_param", "learning_rate"), 0),
                               time_range=(0, config.getint("training_param", "train_iter"))
                               ),
        trigger=(1, "iteration")
    )

    trainer.extend(extensions.Evaluator(test_iter, model, device=devices["main"]))
    if extensions.PlotReport.available():
        trainer.extend(extensions.PlotReport([
            "main/loss", "validation/main/loss",
        ], "epoch", file_name="loss.png"))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.observe_lr())
    trainer.extend(extensions.PrintReport([
        "epoch", "elapsed_time", "lr",
        "main/loss", "validation/main/loss",
        "main/loss_resp", "validation/main/loss_resp",
        "main/loss_iou", "validation/main/loss_iou",
        "main/loss_coor", "validation/main/loss_coor",
        "main/loss_size", "validation/main/loss_size",
        "main/loss_limb", "validation/main/loss_limb",
        "main/loss_vect_cos", "validation/main/loss_vect_cos",
        "main/loss_vect_norm", "validation/main/loss_vect_cos",
        "main/loss_vect_square", "validation/main/loss_vect_square",
    ]))
    trainer.extend(extensions.ProgressBar())

    trainer.extend(extensions.snapshot(filename="best_snapshot"),
                   trigger=training.triggers.MinValueTrigger("validation/main/loss"))
    trainer.extend(extensions.snapshot_object(model, filename="bestmodel.npz"),
                   trigger=training.triggers.MinValueTrigger("validation/main/loss"))

    logger.info("> start training")
    trainer.run()
Esempio n. 12
0
def create_trainer(
    config: Config,
    output: Path,
):
    assert_config(config)
    if output.exists():
        raise Exception(f"output directory {output} already exists.")

    # model
    predictor = create_predictor(config.model)
    if config.train.trained_model is not None:
        chainer.serializers.load_npz(
            config.train.trained_model["predictor_path"], predictor)
    model = Model(
        loss_config=config.loss,
        predictor=predictor,
        local_padding_size=config.dataset.local_padding_size,
    )

    model.to_gpu(config.train.gpu[0])
    cuda.get_device_from_id(config.train.gpu[0]).use()

    # dataset
    dataset = create_dataset(config.dataset)
    batchsize_devided = config.train.batchsize // len(config.train.gpu)
    train_iter = MultiprocessIterator(dataset["train"], config.train.batchsize)
    test_iter = MultiprocessIterator(dataset["test"],
                                     batchsize_devided,
                                     repeat=False,
                                     shuffle=True)
    train_test_iter = MultiprocessIterator(dataset["train_test"],
                                           batchsize_devided,
                                           repeat=False,
                                           shuffle=True)

    if dataset["test_eval"] is not None:
        test_eval_iter = MultiprocessIterator(dataset["test_eval"],
                                              batchsize_devided,
                                              repeat=False,
                                              shuffle=True)
    else:
        test_eval_iter = None

    # optimizer
    def create_optimizer(model):
        cp: Dict[str, Any] = copy(config.train.optimizer)
        n = cp.pop("name").lower()

        if n == "adam":
            optimizer = optimizers.Adam(**cp)
        elif n == "sgd":
            optimizer = optimizers.SGD(**cp)
        else:
            raise ValueError(n)

        optimizer.setup(model)

        if config.train.optimizer_gradient_clipping is not None:
            optimizer.add_hook(
                optimizer_hooks.GradientClipping(
                    config.train.optimizer_gradient_clipping))

        return optimizer

    optimizer = create_optimizer(model)
    if config.train.trained_model is not None:
        chainer.serializers.load_npz(
            config.train.trained_model["optimizer_path"], optimizer)

    # updater
    if len(config.train.gpu) <= 1:
        updater = StandardUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            converter=concat_optional,
            device=config.train.gpu[0],
        )
    else:
        updater = ParallelUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            converter=concat_optional,
            devices={
                "main" if i == 0 else f"gpu{gpu}": gpu
                for i, gpu in enumerate(config.train.gpu)
            },
        )
    if config.train.trained_model is not None:
        updater.iteration = optimizer.t

    # trainer
    output.mkdir()
    config.save_as_json((output / "config.json").absolute())

    trigger_log = (config.train.log_iteration, "iteration")
    trigger_snapshot = (config.train.snapshot_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = training.Trainer(updater, stop_trigger=trigger_stop, out=output)
    tb_writer = SummaryWriter(Path(output))

    shift_ext = None
    if config.train.linear_shift is not None:
        shift_ext = extensions.LinearShift(**config.train.linear_shift)
    if config.train.step_shift is not None:
        shift_ext = extensions.StepShift(**config.train.step_shift)
    if shift_ext is not None:
        if config.train.trained_model is not None:
            shift_ext._t = optimizer.t
        trainer.extend(shift_ext)

    if config.train.ema_decay is not None:
        train_predictor = predictor
        predictor = deepcopy(predictor)
        ext = ExponentialMovingAverage(target=train_predictor,
                                       ema_target=predictor,
                                       decay=config.train.ema_decay)
        trainer.extend(ext, trigger=(1, "iteration"))

    ext = extensions.Evaluator(test_iter,
                               model,
                               concat_optional,
                               device=config.train.gpu[0])
    trainer.extend(ext, name="test", trigger=trigger_log)
    ext = extensions.Evaluator(train_test_iter,
                               model,
                               concat_optional,
                               device=config.train.gpu[0])
    trainer.extend(ext, name="train", trigger=trigger_log)

    if test_eval_iter is not None:
        generator = Generator(config=config,
                              model=predictor,
                              max_batch_size=config.train.batchsize)
        generate_evaluator = GenerateEvaluator(
            generator=generator,
            time_length=config.dataset.time_length_evaluate,
            local_padding_time_length=config.dataset.
            local_padding_time_length_evaluate,
        )
        ext = extensions.Evaluator(
            test_eval_iter,
            generate_evaluator,
            concat_optional,
            device=config.train.gpu[0],
        )
        trainer.extend(ext, name="eval", trigger=trigger_snapshot)

    ext = extensions.snapshot_object(predictor,
                                     filename="main_{.updater.iteration}.npz")
    trainer.extend(ext, trigger=trigger_snapshot)
    # ext = extensions.snapshot_object(
    #     optimizer, filename="optimizer_{.updater.iteration}.npz"
    # )
    # trainer.extend(ext, trigger=trigger_snapshot)

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.observe_lr(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )
    trainer.extend(TensorBoardReport(writer=tb_writer), trigger=trigger_log)

    trainer.extend(extensions.dump_graph(root_name="main/loss"))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    return trainer
def train_phase(generator, train, valid, args):

    print('# samples:')
    print('-- train:', len(train))
    print('-- valid:', len(valid))

    # setup dataset iterators
    train_batchsize = min(args.batchsize * len(args.gpu), len(train))
    valid_batchsize = args.batchsize
    train_iter = chainer.iterators.MultiprocessIterator(train, train_batchsize)
    valid_iter = chainer.iterators.SerialIterator(valid,
                                                  valid_batchsize,
                                                  repeat=False,
                                                  shuffle=True)

    # setup a model
    model = Regressor(generator,
                      activation=F.tanh,
                      lossfun=F.mean_absolute_error,
                      accfun=F.mean_absolute_error)

    discriminator = build_discriminator()
    discriminator.save_args(os.path.join(args.out, 'discriminator.json'))

    if args.gpu[0] >= 0:
        chainer.backends.cuda.get_device_from_id(args.gpu[0]).use()
        if len(args.gpu) == 1:
            model.to_gpu()
            discriminator.to_gpu()

    # setup an optimizer
    optimizer_G = chainer.optimizers.Adam(alpha=args.lr,
                                          beta1=args.beta,
                                          beta2=0.999,
                                          eps=1e-08,
                                          amsgrad=False)
    optimizer_G.setup(model)
    optimizer_D = chainer.optimizers.Adam(alpha=args.lr,
                                          beta1=args.beta,
                                          beta2=0.999,
                                          eps=1e-08,
                                          amsgrad=False)
    optimizer_D.setup(discriminator)

    if args.decay > 0:
        optimizer_G.add_hook(chainer.optimizer_hooks.WeightDecay(args.decay))
        optimizer_D.add_hook(chainer.optimizer_hooks.WeightDecay(args.decay))

    # setup a trainer
    if len(args.gpu) == 1:
        updater = DCGANUpdater(
            iterator=train_iter,
            optimizer={
                'gen': optimizer_G,
                'dis': optimizer_D,
            },
            alpha=args.alpha,
            device=args.gpu[0],
        )

    else:
        devices = {'main': args.gpu[0]}
        for idx, g in enumerate(args.gpu[1:]):
            devices['slave_%d' % idx] = g

        raise NotImplementedError('The parallel updater is not supported..')

    frequency = max(args.iteration //
                    80, 1) if args.frequency == -1 else max(1, args.frequency)

    stop_trigger = triggers.EarlyStoppingTrigger(
        monitor='validation/main/loss',
        max_trigger=(args.iteration, 'iteration'),
        check_trigger=(frequency, 'iteration'),
        patients=np.inf if args.pinfall == -1 else max(1, args.pinfall))

    trainer = training.Trainer(updater, stop_trigger, out=args.out)

    # shift lr
    trainer.extend(
        extensions.LinearShift('alpha', (args.lr, 0.0),
                               (args.iteration // 2, args.iteration),
                               optimizer=optimizer_G))
    trainer.extend(
        extensions.LinearShift('alpha', (args.lr, 0.0),
                               (args.iteration // 2, args.iteration),
                               optimizer=optimizer_D))

    # setup a visualizer

    transforms = {'x': lambda x: x, 'y': lambda x: x, 't': lambda x: x}
    clims = {'x': (-1., 1.), 'y': (-1., 1.), 't': (-1., 1.)}

    visualizer = ImageVisualizer(transforms=transforms,
                                 cmaps=None,
                                 clims=clims)

    # setup a validator
    valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png')
    trainer.extend(Validator(valid_iter,
                             model,
                             valid_file,
                             visualizer=visualizer,
                             n_vis=20,
                             device=args.gpu[0]),
                   trigger=(frequency, 'iteration'))

    trainer.extend(
        extensions.dump_graph('loss_gen', filename='generative_loss.dot'))
    trainer.extend(
        extensions.dump_graph('loss_cond', filename='conditional_loss.dot'))
    trainer.extend(
        extensions.dump_graph('loss_dis', filename='discriminative_loss.dot'))

    trainer.extend(extensions.snapshot(
        filename='snapshot_iter_{.updater.iteration:08}.npz'),
                   trigger=(frequency, 'iteration'))
    trainer.extend(extensions.snapshot_object(
        generator, 'generator_iter_{.updater.iteration:08}.npz'),
                   trigger=(frequency, 'iteration'))
    trainer.extend(extensions.snapshot_object(
        discriminator, 'discriminator_iter_{.updater.iteration:08}.npz'),
                   trigger=(frequency, 'iteration'))

    log_keys = [
        'loss_gen', 'loss_cond', 'loss_dis', 'validation/main/accuracy'
    ]

    trainer.extend(LogReport(keys=log_keys, trigger=(100, 'iteration')))

    # setup log ploter
    if extensions.PlotReport.available():
        for plot_key in ['loss', 'accuracy']:
            plot_keys = [
                key for key in log_keys
                if key.split('/')[-1].startswith(plot_key)
            ]
            trainer.extend(
                extensions.PlotReport(plot_keys,
                                      'iteration',
                                      file_name=plot_key + '.png',
                                      trigger=(frequency, 'iteration')))

    trainer.extend(
        PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=1))

    trainer.extend(extensions.ProgressBar())

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # train
    trainer.run()
Esempio n. 14
0
 def test_with_optimizer(self):
     optimizer = mock.Mock()
     optimizer.x = 0
     extension = extensions.LinearShift('x', self.value_range,
                                        self.time_range, optimizer)
     self._run_trainer(extension, self.expect, optimizer)
Esempio n. 15
0
 def test_basic(self):
     self.trainer.updater.optimizer.x = 0
     extension = extensions.LinearShift('x', self.value_range,
                                        self.time_range)
     self._run_trainer(extension, self.expect)
Esempio n. 16
0
def main():
    parser = argparse.ArgumentParser(
        description='Train Cycle-GAN.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--dataset-path',
                        help='Directory of dataset which should have '
                        '"trainA", "trainB", "testA" and "testB" '
                        'directory',
                        type=str,
                        required=True)
    parser.add_argument('--device',
                        help='GPU ID (negative ID indicates CPU)',
                        type=int,
                        default=0)
    parser.add_argument('--epochs',
                        help='Number of epochs',
                        type=int,
                        default=200)
    parser.add_argument('--lambda',
                        help='Coefficient of cycle consistency loss',
                        type=float,
                        default=10.0,
                        dest='lambda_v')
    parser.add_argument('--n-blocks',
                        help='Number of Resnet Blocks (Generator)',
                        type=int,
                        default=9)
    parser.add_argument('--out',
                        help='Directory to output the results',
                        type=str,
                        default='./result')
    parser.add_argument('--shift-lr-after-n-epochs',
                        help='Linearly decay the learning rate to 0 after '
                        'n epochs',
                        type=int,
                        default=None)
    parser.add_argument('--trained-models',
                        help='Load trained models from directory which has '
                        '"x_dis.hdf5", "y_dis.hdf5", "g_gen.hdf5", '
                        '"f_gen.hdf5".',
                        type=str,
                        default=None)
    args = parser.parse_args()

    x_dis = Discriminator()
    y_dis = Discriminator()
    g_gen = Generator(n_blocks=args.n_blocks)
    f_gen = Generator(n_blocks=args.n_blocks)
    if args.trained_models:
        load_hdf5(os.path.join(args.trained_models, 'x_dis.hdf5'), x_dis)
        load_hdf5(os.path.join(args.trained_models, 'y_dis.hdf5'), y_dis)
        load_hdf5(os.path.join(args.trained_models, 'g_gen.hdf5'), g_gen)
        load_hdf5(os.path.join(args.trained_models, 'f_gen.hdf5'), f_gen)
    if args.device >= 0:
        cuda.get_device_from_id(args.device).use()
        x_dis.to_gpu()
        y_dis.to_gpu()
        g_gen.to_gpu()
        f_gen.to_gpu()

    opt_x_dis = optimizers.Adam(0.0002)
    opt_x_dis.setup(x_dis)
    opt_y_dis = optimizers.Adam(0.0002)
    opt_y_dis.setup(y_dis)
    opt_g_gen = optimizers.Adam(0.0002)
    opt_g_gen.setup(g_gen)
    opt_f_gen = optimizers.Adam(0.0002)
    opt_f_gen.setup(f_gen)

    train_iter, test_a_iter, test_b_iter =\
        make_dataset_iterator(args.dataset_path)

    updater = CycleGANUpdater(train_iter=train_iter,
                              optimizer={
                                  'x_dis': opt_x_dis,
                                  'y_dis': opt_y_dis,
                                  'g_gen': opt_g_gen,
                                  'f_gen': opt_f_gen
                              },
                              device=args.device,
                              lambda_v=args.lambda_v)

    trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.out)
    trainer.extend(extensions.ProgressBar())
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.snapshot(filename='trainer_snapshot.npz'),
                   trigger=(10, 'epoch'))
    trainer.extend(output_fake_images(g_gen, f_gen, test_a_iter, test_b_iter,
                                      args.out),
                   trigger=(2, 'epoch'))
    if args.shift_lr_after_n_epochs:
        trainer.extend(extensions.LinearShift(
            'alpha', (0.0002, 0), (args.shift_lr_after_n_epochs, args.epochs),
            opt_x_dis),
                       trigger=(1, 'epoch'))
        trainer.extend(extensions.LinearShift(
            'alpha', (0.0002, 0), (args.shift_lr_after_n_epochs, args.epochs),
            opt_y_dis),
                       trigger=(1, 'epoch'))
        trainer.extend(extensions.LinearShift(
            'alpha', (0.0002, 0), (args.shift_lr_after_n_epochs, args.epochs),
            opt_g_gen),
                       trigger=(1, 'epoch'))
        trainer.extend(extensions.LinearShift(
            'alpha', (0.0002, 0), (args.shift_lr_after_n_epochs, args.epochs),
            opt_f_gen),
                       trigger=(1, 'epoch'))
    trainer.run()

    save_hdf5(os.path.join(args.out, 'x_dis.hdf5'), x_dis)
    save_hdf5(os.path.join(args.out, 'y_dis.hdf5'), y_dis)
    save_hdf5(os.path.join(args.out, 'g_gen.hdf5'), g_gen)
    save_hdf5(os.path.join(args.out, 'f_gen.hdf5'), f_gen)
Esempio n. 17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, default='configs/base.yml')
    parser.add_argument('--n_devices', type=int)
    parser.add_argument('--test', action='store_true', default=False)
    parser.add_argument('--communicator', type=str,
                        default='hierarchical', help='Type of communicator')
    parser.add_argument('--results_dir', type=str, default='results_rocgan')
    parser.add_argument('--inception_model_path', type=str,
                        default='/home/user/inception/inception.model')
    parser.add_argument('--resume', type=str, default='')
    parser.add_argument('--enc_snapshot', type=str, default=None, help='path to the encoder snapshot')
    parser.add_argument('--dec_snapshot', type=str, default=None, help='path to the decoder snapshot')
    parser.add_argument('--dis_snapshot', type=str, default=None, help='path to the discriminator snapshot')
    parser.add_argument('--loaderjob', type=int,
                        help='Number of parallel data loading processes')
    parser.add_argument('--multiprocessing', action='store_true', default=False)
    parser.add_argument('--validation', type=int, default=1)
    parser.add_argument('--valid_fn', type=str, default='files_valid_4k.txt', 
                        help='filename of the validation file')
    parser.add_argument('--label', type=str, default='synth')
    parser.add_argument('--stats_fid', type=str, default='', help='path for FID stats')
    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    # # ensure that the paths of the config are correct.
    config = ensure_config_paths(config)
    comm = chainermn.create_communicator(args.communicator)
    device = comm.intra_rank
    chainer.cuda.get_device_from_id(device).use()
    # # get the pc name, e.g. for chainerui.
    pcname = gethostname()
    imperialpc = 'doc.ic.ac.uk' in pcname or pcname in ['ladybug', 'odysseus']
    print('Init on pc: {}.'.format(pcname))
    if comm.rank == 0:
        print('==========================================')
        print('Using {} communicator'.format(args.communicator))
        print('==========================================')
    enc, dec, dis = load_models_cgan(config)
    if chainer.cuda.available:
        enc.to_gpu()
        dec.to_gpu()
        dis.to_gpu()
    else:
        print('No GPU found!!!\n')
    mma1 = ModelMovingAverage(0.999, enc)
    mma2 = ModelMovingAverage(0.999, dec)
    models = {'enc': enc, 'dec': dec, 'dis': dis}
    if args.enc_snapshot is not None:
        print('Loading encoder: {}.'.format(args.enc_snapshot))
        chainer.serializers.load_npz(args.enc_snapshot, enc)
    if args.dec_snapshot is not None:
        print('Loading decoder: {}.'.format(args.dec_snapshot))
        chainer.serializers.load_npz(args.dec_snapshot, dec)
    if args.dis_snapshot is not None:
        print('Loading discriminator: {}.'.format(args.dis_snapshot))
        chainer.serializers.load_npz(args.dis_snapshot, dis)
    # # convenience function for optimizer:
    func_opt = lambda net: make_optimizer(net, comm, chmn=args.multiprocessing,
                                          alpha=config.adam['alpha'], beta1=config.adam['beta1'], 
                                          beta2=config.adam['beta2'])
    # Optimizer
    opt_enc = func_opt(enc)
    opt_dec = func_opt(dec)
    opt_dis = func_opt(dis)
    opts = {'opt_enc': opt_enc, 'opt_dec': opt_dec, 'opt_dis': opt_dis}
    # Dataset
    if comm.rank == 0:
        dataset = yaml_utils.load_dataset(config)
        printtime('Length of dataset: {}.'.format(len(dataset)))
        if args.validation:
            # # add the validation db if we do perform validation.
            db_valid = yaml_utils.load_dataset(config, validation=True, valid_path=args.valid_fn)
    else:
        _ = yaml_utils.load_dataset(config)  # Dummy, for adding path to the dataset module
        dataset = None
        if args.validation:
            _ = yaml_utils.load_dataset(config, validation=True, valid_path=args.valid_fn)
            db_valid = None
    dataset = chainermn.scatter_dataset(dataset, comm)
    if args.validation:
        db_valid = chainermn.scatter_dataset(db_valid, comm)
    # Iterator
    multiprocessing.set_start_method('forkserver')
    if args.multiprocessing:
        # # In minoas this might fail with the forkserver.py error.
        iterator = chainer.iterators.MultiprocessIterator(dataset, config.batchsize,
                                                          n_processes=args.loaderjob)
        if args.validation:
            iter_val = chainer.iterators.MultiprocessIterator(db_valid, config.batchsize,
                                                              n_processes=args.loaderjob,
                                                              shuffle=False, repeat=False)
    else:
        iterator = chainer.iterators.SerialIterator(dataset, config.batchsize)
        if args.validation:
            iter_val = chainer.iterators.SerialIterator(db_valid, config.batchsize,
                                                        shuffle=False, repeat=False)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
        'device': device,
        'mma1': mma1,
        'mma2': mma2,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    if not args.test:
        mainf = '{}_{}'.format(strftime('%Y_%m_%d__%H_%M_%S'), args.label)
        out = os.path.join(args.results_dir, mainf, '')
    else:
        out = 'results/test'
    if comm.rank == 0:
        create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'), out=out)
    # # abbreviations below: inc -> incpetion, gadv -> grad_adv, lgen -> loss gener, 
    # # {m, sd, b}[var] -> {mean, std, best position} [var], 
    report_keys = ['loss_dis', 'lgen_adv', 'dis_real', 'dis_fake', 'loss_l1',
                   'mssim', 'sdssim', 'mmae', 'loss_projl', 'FID']

    if comm.rank == 0:
        # Set up logging
        for m in models.values():
            trainer.extend(extensions.snapshot_object(
                m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(config.snapshot_interval, 'iteration'))
#         trainer.extend(extensions.snapshot_object(
#             mma.avg_model, mma.avg_model.__class__.__name__ + '_avgmodel_{.updater.iteration}.npz'),
#             trigger=(config.snapshot_interval, 'iteration'))
        trainer.extend(extensions.LogReport(trigger=(config.display_interval, 'iteration')))
        trainer.extend(extensions.PrintReport(report_keys), trigger=(config.display_interval, 'iteration'))
        if args.validation:
            # # add the appropriate extension for validating the model.
            models_mma = {'enc': mma1.avg_model, 'dec': mma2.avg_model, 'dis': dis}
            trainer.extend(validation_trainer(models_mma, iter_val, n=len(db_valid), export_best=True, 
                                              pout=out, p_inc=args.inception_model_path, eval_fid=True,
                                              sfile=args.stats_fid),
                           trigger=(config.evaluation_interval, 'iteration'),
                           priority=extension.PRIORITY_WRITER)

        trainer.extend(extensions.ProgressBar(update_interval=config.display_interval))
        if imperialpc:
            # [ChainerUI] Observe learning rate
            trainer.extend(extensions.observe_lr(optimizer_name='opt_dis'))
            # [ChainerUI] enable to send commands from ChainerUI
            trainer.extend(CommandsExtension())
            # [ChainerUI] save 'args' to show experimental conditions
            save_args(args, out)

    # # convenience function for linearshift in optimizer:
    func_opt_shift = lambda optim1: extensions.LinearShift('alpha', (config.adam['alpha'], 0.),
                                                           (config.iteration_decay_start, 
                                                            config.iteration), optim1)
    # # define the actual extensions (for optimizer shift).
    trainer.extend(func_opt_shift(opt_enc))
    trainer.extend(func_opt_shift(opt_dec))
    trainer.extend(func_opt_shift(opt_dis))

    if args.resume:
        print('Resume Trainer')
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    printtime('start training')
    trainer.run()
    plot_losses_log(out, savefig=True)
Esempio n. 18
0
    # --- Set Extensions --- #
    # Setting for Extensions
    print_list = [
        'epoch', 'train/loss/LP', 'train/loss/DC', 'train/loss/total',
        'validation/main/loss', 'GRL/lmd', 'lr', 'train/accuracy',
        'validation/main/accuracy', 'elapsed_time'
    ]
    loss_list = [
        'train/loss/LP', 'train/loss/DC', 'train/loss/total',
        'validation/main/loss'
    ]
    accuracy_list = ['train/accuracy', 'validation/main/accuracy']
    eval_model = chainer.links.Classifier(model)

    # Set Extensions
    trainer.extend(ex.Evaluator(valid, eval_model, device=0))
    trainer.extend(
        ex.dump_graph(root_name='train/loss/total',
                      out_name='cg.dot'))  # calc graph
    trainer.extend(ex.observe_lr())
    trainer.extend(ex.LogReport())
    trainer.extend(ex.LinearShift('lr', (lr, lr / 100.0), (30000, 100000)))
    trainer.extend(ex.PrintReport(print_list))
    trainer.extend(ex.ProgressBar(update_interval=1))
    trainer.extend(ex.PlotReport(loss_list, 'epoch', file_name='loss.png'))
    trainer.extend(
        ex.PlotReport(accuracy_list, 'epoch', file_name='accuracy.png'))

    # --- Start Training! --- #
    trainer.run()
Esempio n. 19
0
def train(args):
    with open(args.config, "r") as cfg:
        data_config = json.load(cfg)

    size = data_config["train"]["kwargs"]["size"]

    # set up GPUs if necessary
    device = -1
    comm = None

    if args.device[0] < 0:
        print("training on CPU (not recommended!)")
    else:
        if len(args.device) > 1:
            import chainermn
            comm = chainermn.create_communicator()
            device = args.device[comm.intra_rank]

            print("using multi-gpu training with GPU {}".format(device))
        elif args.device[0] >= 0:
            device = args.device[0]
            print("using single gpu training with GPU {}".format(device))

    disc, gen = setup_models(args, size, device)

    def setup_optimizer(opt, model, comm=None):
        if comm is not None:
            opt = chainermn.create_multi_node_optimizer(opt, comm)
        opt.setup(model)
        return opt

    opt_disc = setup_optimizer(
        chainer.optimizers.Adam(alpha=args.learning_rate), disc, comm)
    opt_gen = setup_optimizer(
        chainer.optimizers.Adam(alpha=args.learning_rate), gen, comm)

    # pretraining of the global generator needs half-size images
    if comm is None or comm.rank == 0:
        train_d = getattr(pix2pixHD, data_config["class_name"])(
            *data_config["train"]["args"],
            **data_config["train"]["kwargs"],
            one_hot=args.no_one_hot)
        test_d = getattr(pix2pixHD, data_config["class_name"])(
            *data_config["test"]["args"],
            **data_config["test"]["kwargs"],
            one_hot=args.no_one_hot,
            random_flip=False)
    else:
        train_d, test_d = None, None

    if comm is not None:
        train_d = chainermn.scatter_dataset(train_d, comm)
        test_d = chainermn.scatter_dataset(test_d, comm)
        multiprocessing.set_start_method('forkserver')

    train_iter = chainer.iterators.MultiprocessIterator(train_d,
                                                        args.batchsize,
                                                        n_processes=2)
    test_iter = chainer.iterators.SerialIterator(test_d,
                                                 args.batchsize,
                                                 shuffle=False)

    iterators = {'main': train_iter, 'test': test_iter}
    optimizers = {'discriminator': opt_disc, 'generator': opt_gen}

    updater = Pix2pixHDUpdater(iterators, optimizers, device=device)

    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.output)

    if comm is None or comm.rank == 0:
        trigger = (100, "iteration")
        if comm is None:
            # this is a hack... sorry
            trainer.extend(train_d.visualizer(n=args.num_vis,
                                              one_hot=args.no_one_hot),
                           trigger=trigger)  #(1, "epoch"))
        else:
            trainer.extend(train_d._dataset.visualizer(
                n=args.num_vis, one_hot=args.no_one_hot),
                           trigger=trigger)
        trainer.extend(extensions.LogReport(trigger=(10, "iteration")))
        trainer.extend(extensions.PrintReport([
            'epoch', 'iteration', 'Dloss_real', 'Dloss_fake', 'Gloss',
            "feat_loss", "lr"
        ]),
                       trigger=(10, "iteration"))
        trainer.extend(extensions.ProgressBar(update_interval=10))

        trainer.extend(
            extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}'),
            trigger=(args.epochs // 10, "epoch"))
        trainer.extend(extensions.snapshot_object(
            gen, 'generator_model_epoch_{.updater.epoch}'),
                       trigger=(args.epochs // 10, "epoch"))

    # decay the learning rate from halfway through training
    trainer.extend(extensions.LinearShift("alpha",
                                          value_range=(args.learning_rate,
                                                       0.0),
                                          time_range=(args.epochs // 2,
                                                      args.epochs),
                                          optimizer=opt_disc),
                   trigger=(1, "epoch"))
    trainer.extend(extensions.LinearShift("alpha",
                                          value_range=(args.learning_rate,
                                                       0.0),
                                          time_range=(args.epochs // 2,
                                                      args.epochs),
                                          optimizer=opt_gen),
                   trigger=(1, "epoch"))

    trainer.extend(extensions.observe_value(
        "lr",
        lambda trainer: trainer.updater.get_optimizer("discriminator").lr),
                   trigger=(10, "iteration"))

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    trainer.run()
Esempio n. 20
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, default='config.ini')
    parser.add_argument('--resume')
    parser.add_argument('--plot_samples', type=int, default=0)
    args = parser.parse_args()

    config = configparser.ConfigParser()
    config.read(args.config_path, 'UTF-8')

    chainer.global_config.autotune = True
    chainer.cuda.set_max_workspace_size(11388608)

    # create result dir and copy file
    logger.info('> store file to result dir %s', config.get('result', 'dir'))
    save_files(config.get('result', 'dir'))

    logger.info('> set up devices')
    devices = setup_devices(config.get('training_param', 'gpus'))
    set_random_seed(devices, config.getint('training_param', 'seed'))

    logger.info('> get dataset')
    dataset_type = config.get('dataset', 'type')
    if dataset_type == 'coco':
        # force to set `use_cache = False`
        train_set = get_coco_dataset(
            insize=parse_size(config.get('model_param', 'insize')),
            image_root=config.get(dataset_type, 'train_images'),
            annotations=config.get(dataset_type, 'train_annotations'),
            min_num_keypoints=config.getint(dataset_type, 'min_num_keypoints'),
            use_cache=False,
            do_augmentation=True,
        )
        test_set = get_coco_dataset(
            insize=parse_size(config.get('model_param', 'insize')),
            image_root=config.get(dataset_type, 'val_images'),
            annotations=config.get(dataset_type, 'val_annotations'),
            min_num_keypoints=config.getint(dataset_type, 'min_num_keypoints'),
            use_cache=False,
        )
    elif dataset_type == 'mpii':
        train_set, test_set = get_mpii_dataset(
            insize=parse_size(config.get('model_param', 'insize')),
            image_root=config.get(dataset_type, 'images'),
            annotations=config.get(dataset_type, 'annotations'),
            train_size=config.getfloat(dataset_type, 'train_size'),
            min_num_keypoints=config.getint(dataset_type, 'min_num_keypoints'),
            use_cache=config.getboolean(dataset_type, 'use_cache'),
            seed=config.getint('training_param', 'seed'),
        )
    else:
        raise Exception('Unknown dataset {}'.format(dataset_type))
    logger.info('dataset type: %s', dataset_type)
    logger.info('training images: %d', len(train_set))
    logger.info('validation images: %d', len(test_set))

    if args.plot_samples > 0:
        for i in range(args.plot_samples):
            data = train_set[i]
            visualize.plot('train-{}.png'.format(i), data['image'],
                           data['keypoints'], data['bbox'], data['is_labeled'],
                           data['edges'])
            data = test_set[i]
            visualize.plot('val-{}.png'.format(i), data['image'],
                           data['keypoints'], data['bbox'], data['is_labeled'],
                           data['edges'])

    logger.info('> load model')
    model = create_model(config, train_set)

    logger.info('> transform dataset')
    train_set = TransformDataset(train_set, model.encode)
    test_set = TransformDataset(test_set, model.encode)

    logger.info('> create iterators')
    train_iter = chainer.iterators.MultiprocessIterator(
        train_set,
        config.getint('training_param', 'batchsize'),
        n_processes=config.getint('training_param', 'num_process'))
    test_iter = chainer.iterators.SerialIterator(test_set,
                                                 config.getint(
                                                     'training_param',
                                                     'batchsize'),
                                                 repeat=False,
                                                 shuffle=False)

    logger.info('> setup optimizer')
    optimizer = chainer.optimizers.MomentumSGD()
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(0.0005))

    logger.info('> setup trainer')
    updater = training.updaters.ParallelUpdater(train_iter,
                                                optimizer,
                                                devices=devices)
    trainer = training.Trainer(
        updater, (config.getint('training_param', 'train_iter'), 'iteration'),
        config.get('result', 'dir'))

    logger.info('> setup extensions')
    trainer.extend(extensions.LinearShift(
        'lr',
        value_range=(config.getfloat('training_param', 'learning_rate'), 0),
        time_range=(0, config.getint('training_param', 'train_iter'))),
                   trigger=(1, 'iteration'))

    trainer.extend(
        extensions.Evaluator(test_iter, model, device=devices['main']))
    if extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport([
                'main/loss',
                'validation/main/loss',
            ],
                                  'epoch',
                                  file_name='loss.png'))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.observe_lr())
    trainer.extend(
        extensions.PrintReport([
            'epoch',
            'elapsed_time',
            'lr',
            'main/loss',
            'validation/main/loss',
            'main/loss_resp',
            'validation/main/loss_resp',
            'main/loss_iou',
            'validation/main/loss_iou',
            'main/loss_coor',
            'validation/main/loss_coor',
            'main/loss_size',
            'validation/main/loss_size',
            'main/loss_limb',
            'validation/main/loss_limb',
        ]))
    trainer.extend(extensions.ProgressBar())

    trainer.extend(
        extensions.snapshot(filename='best_snapshot'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    trainer.extend(
        extensions.snapshot_object(model, filename='bestmodel.npz'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))

    if args.resume:
        serializers.load_npz(args.resume, trainer)

    logger.info('> start training')
    trainer.run()
Esempio n. 21
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/base.yml',
                        help='path to config file')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='index of gpu to be used')
    parser.add_argument('--data_dir', type=str, default='./data/imagenet')
    parser.add_argument('--results_dir',
                        type=str,
                        default='./results/gans',
                        help='directory to save the results to')
    parser.add_argument('--inception_model_path',
                        type=str,
                        default='./datasets/inception_model',
                        help='path to the inception model')
    parser.add_argument('--snapshot',
                        type=str,
                        default='',
                        help='path to the snapshot')
    parser.add_argument('--loaderjob',
                        type=int,
                        help='number of parallel data loading processes')

    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    chainer.cuda.get_device_from_id(args.gpu).use()
    gen, dis = load_models(config)
    gen.to_gpu(device=args.gpu)
    dis.to_gpu(device=args.gpu)

    seed_weights(gen)
    seed_weights(dis)
    print(np.sum(gen.block2.c2.b))
    _ = input()

    models = {"gen": gen, "dis": dis}
    # Optimizer
    opt_gen = make_optimizer(gen,
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opt_dis = make_optimizer(dis,
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opts = {"opt_gen": opt_gen, "opt_dis": opt_dis}
    # Dataset
    if config['dataset'][
            'dataset_name'] != 'CIFAR10Dataset':  # Cifar10 dataset handler does not take "root" as argument.
        config['dataset']['args']['root'] = args.data_dir
    dataset = yaml_utils.load_dataset(config)
    # Iterator
    iterator = chainer.iterators.MultiprocessIterator(
        dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'),
                               out=out)
    report_keys = ["loss_dis", "loss_gen", "inception_mean", "inception_std"]
    # Set up logging
    trainer.extend(extensions.snapshot(),
                   trigger=(config.snapshot_interval, 'iteration'))
    for m in models.values():
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(config.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(config.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(config.display_interval, 'iteration'))
    if gen.n_classes > 0:
        trainer.extend(sample_generate_conditional(gen,
                                                   out,
                                                   n_classes=gen.n_classes),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    else:
        trainer.extend(sample_generate(gen, out),
                       trigger=(config.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(gen, out, rows=10, cols=10),
                   trigger=(config.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_inception(gen,
                                  n_ims=5000,
                                  splits=1,
                                  path=args.inception_model_path),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(
        extensions.ProgressBar(update_interval=config.progressbar_interval))
    ext_opt_gen = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_gen)
    ext_opt_dis = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_dis)
    trainer.extend(ext_opt_gen)
    trainer.extend(ext_opt_dis)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()
Esempio n. 22
0
def main():
    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "livedoor": LivedoorProcessor,
    }

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_print_test:
        raise ValueError("At least one of `do_train` or `do_eval` "
                         "or `do_print_test` must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    if not os.path.isdir(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(model_file=FLAGS.model_file,
                                           vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None

    # TODO: use special Adam from "optimization.py"
    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / FLAGS.train_batch_size *
            FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    bert = modeling.BertModel(config=bert_config)
    pretrained = modeling.BertPretrainer(bert)
    chainer.serializers.load_npz(FLAGS.init_checkpoint, pretrained)

    model = modeling.BertClassifier(pretrained.bert,
                                    num_labels=len(label_list))

    if FLAGS.gpu >= 0:
        chainer.backends.cuda.get_device_from_id(FLAGS.gpu).use()
        model.to_gpu()

    if FLAGS.do_train:
        # Adam with weight decay only for 2D matrices
        optimizer = optimization.WeightDecayForMatrixAdam(
            alpha=1.,  # ignore alpha. instead, use eta as actual lr
            eps=1e-6,
            weight_decay_rate=0.01)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.GradientClipping(1.))

        train_iter = chainer.iterators.SerialIterator(train_examples,
                                                      FLAGS.train_batch_size)
        converter = Converter(label_list, FLAGS.max_seq_length, tokenizer)
        updater = training.updaters.StandardUpdater(train_iter,
                                                    optimizer,
                                                    converter=converter,
                                                    device=FLAGS.gpu)
        trainer = training.Trainer(updater, (num_train_steps, 'iteration'),
                                   out=FLAGS.output_dir)

        # learning rate (eta) scheduling in Adam
        lr_decay_init = FLAGS.learning_rate * \
            (num_train_steps - num_warmup_steps) / num_train_steps
        trainer.extend(
            extensions.LinearShift(  # decay
                'eta', (lr_decay_init, 0.),
                (num_warmup_steps, num_train_steps)))
        trainer.extend(
            extensions.WarmupShift(  # warmup
                'eta', 0., num_warmup_steps, FLAGS.learning_rate))
        trainer.extend(extensions.observe_value(
            'eta', lambda trainer: trainer.updater.get_optimizer('main').eta),
                       trigger=(50, 'iteration'))  # logging

        trainer.extend(extensions.snapshot_object(
            model, 'model_snapshot_iter_{.updater.iteration}.npz'),
                       trigger=(num_train_steps, 'iteration'))
        trainer.extend(extensions.LogReport(trigger=(50, 'iteration')))
        trainer.extend(
            extensions.PrintReport(
                ['iteration', 'main/loss', 'main/accuracy', 'elapsed_time']))
        trainer.extend(extensions.ProgressBar(update_interval=10))

        trainer.run()

    if FLAGS.do_eval:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        test_iter = chainer.iterators.SerialIterator(eval_examples,
                                                     FLAGS.train_batch_size *
                                                     2,
                                                     repeat=False,
                                                     shuffle=False)
        converter = Converter(label_list, FLAGS.max_seq_length, tokenizer)
        evaluator = extensions.Evaluator(test_iter,
                                         model,
                                         converter=converter,
                                         device=FLAGS.gpu)
        results = evaluator()
        print(results)

    # if you wanna see some output arrays for debugging
    if FLAGS.do_print_test:
        short_eval_examples = processor.get_dev_examples(FLAGS.data_dir)[:3]
        short_eval_examples = short_eval_examples[:FLAGS.eval_batch_size]
        short_test_iter = chainer.iterators.SerialIterator(
            short_eval_examples,
            FLAGS.eval_batch_size,
            repeat=False,
            shuffle=False)
        converter = Converter(label_list, FLAGS.max_seq_length, tokenizer)
        evaluator = extensions.Evaluator(test_iter,
                                         model,
                                         converter=converter,
                                         device=FLAGS.gpu)

        with chainer.using_config('train', False):
            with chainer.no_backprop_mode():
                data = short_test_iter.__next__()
                out = model.bert.get_pooled_output(
                    *converter(data, FLAGS.gpu)[:-1])
                print(out)
                print(out.shape)
            print(converter(data, -1))
Esempio n. 23
0
def main():
    if not FLAGS.do_train and not FLAGS.do_predict and not FLAGS.do_print_test:
        raise ValueError(
            "At least one of `do_train` or `do_predict` must be True.")

    if FLAGS.do_train:
        if not FLAGS.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if FLAGS.do_predict:
        if not FLAGS.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified."
            )

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    if not os.path.isdir(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    if FLAGS.do_train:
        train_examples = read_squad_examples(input_file=FLAGS.train_file,
                                             is_training=True)
        train_features = convert_examples_to_features(train_examples,
                                                      tokenizer,
                                                      FLAGS.max_seq_length,
                                                      FLAGS.doc_stride,
                                                      FLAGS.max_query_length,
                                                      is_training=True)
        num_train_steps = int(
            len(train_examples) / FLAGS.train_batch_size *
            FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    bert = modeling.BertModel(config=bert_config)
    model = modeling.BertSQuAD(bert)
    if FLAGS.do_train:
        # If training, load BERT parameters only.
        ignore_names = ['output/W', 'output/b']
    else:
        # If only do_predict, load all parameters.
        ignore_names = None
    chainer.serializers.load_npz(FLAGS.init_checkpoint,
                                 model,
                                 ignore_names=ignore_names)

    if FLAGS.gpu >= 0:
        chainer.backends.cuda.get_device_from_id(FLAGS.gpu).use()
        model.to_gpu()

    if FLAGS.do_train:
        # Adam with weight decay only for 2D matrices
        optimizer = optimization.WeightDecayForMatrixAdam(
            alpha=1.,  # ignore alpha. instead, use eta as actual lr
            eps=1e-6,
            weight_decay_rate=0.01)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.GradientClipping(1.))

        train_iter = chainer.iterators.SerialIterator(train_features,
                                                      FLAGS.train_batch_size)
        converter = Converter(is_training=True)
        updater = training.updaters.StandardUpdater(
            train_iter,
            optimizer,
            converter=converter,
            device=FLAGS.gpu,
            loss_func=model.compute_loss)
        trainer = training.Trainer(updater, (num_train_steps, 'iteration'),
                                   out=FLAGS.output_dir)

        # learning rate (eta) scheduling in Adam
        lr_decay_init = FLAGS.learning_rate * \
            (num_train_steps - num_warmup_steps) / num_train_steps
        trainer.extend(
            extensions.LinearShift(  # decay
                'eta', (lr_decay_init, 0.),
                (num_warmup_steps, num_train_steps)))
        trainer.extend(
            extensions.WarmupShift(  # warmup
                'eta', 0., num_warmup_steps, FLAGS.learning_rate))
        trainer.extend(extensions.observe_value(
            'eta', lambda trainer: trainer.updater.get_optimizer('main').eta),
                       trigger=(100, 'iteration'))  # logging

        trainer.extend(extensions.snapshot_object(
            model, 'model_snapshot_iter_{.updater.iteration}.npz'),
                       trigger=(num_train_steps // 2, 'iteration'))  # TODO
        trainer.extend(extensions.LogReport(trigger=(100, 'iteration')))
        trainer.extend(
            extensions.PrintReport([
                'iteration', 'main/loss', 'main/accuracy', 'elapsed_time',
                'eta'
            ]))
        trainer.extend(extensions.ProgressBar(update_interval=10))

        trainer.run()

    if FLAGS.do_predict:
        eval_examples = read_squad_examples(input_file=FLAGS.predict_file,
                                            is_training=False)
        eval_features = convert_examples_to_features(eval_examples,
                                                     tokenizer,
                                                     FLAGS.max_seq_length,
                                                     FLAGS.doc_stride,
                                                     FLAGS.max_query_length,
                                                     is_training=False)
        test_iter = chainer.iterators.SerialIterator(eval_features,
                                                     FLAGS.predict_batch_size,
                                                     repeat=False,
                                                     shuffle=False)
        converter = Converter(is_training=False)

        print('Evaluating ...')
        evaluate(eval_examples,
                 test_iter,
                 model,
                 converter=converter,
                 device=FLAGS.gpu,
                 predict_func=model.predict)
        print('Finished.')
Esempio n. 24
0
def main():
    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if not os.path.isdir(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    def _get_text_file(text_dir):
        import glob
        #file_list = glob.glob(f'{text_dir}/**/*')
        # seqが512
        #file_list = ['/nfs/ai16storage01/sec/akp2/1706nasubi/inatomi/benchmark/bert-chainer/data/wiki_data_pickle/all']
        # seqが128
        file_list = ['/nfs/ai16storage01/sec/akp2/1706nasubi/inatomi/benchmark/bert-chainer/data/wiki_data_pickle/all_seq128']
        # debug
        #file_list = ['/nfs/ai16storage01/sec/akp2/1706nasubi/inatomi/benchmark/bert-chainer/data/wiki_data_pickle/AA/wiki_00']
        files = ",".join(file_list)
        return files
    input_files = _get_text_file(FLAGS.input_file).split(',')

   #  model_fn = model_fn_builder(
   #      bert_config=bert_config,
   #      init_checkpoint=FLAGS.init_checkpoint,
   #      learning_rate=FLAGS.learning_rate,
   #      num_train_steps=FLAGS.num_train_steps,
   #      num_warmup_steps=FLAGS.num_warmup_steps,
   #      use_tpu=FLAGS.use_tpu,
   #      use_one_hot_embeddings=FLAGS.use_tpu)

    if FLAGS.do_train:
        input_files = input_files
    bert = modeling.BertModel(config=bert_config)
    model = modeling.BertPretrainer(bert)
    if FLAGS.init_checkpoint:
        serializers.load_npz(FLAGS.init_checkpoint, model)
        model = modeling.BertPretrainer(model.bert)
    if FLAGS.gpu >= 0:
        pass
        #chainer.backends.cuda.get_device_from_id(FLAGS.gpu).use()
        #model.to_gpu()

    if FLAGS.do_train:
        """chainerでのpretrainを記述。BERTClassificationに変わるものを作成し、BERTの出力をこねこねしてmodel_fnが返すものと同じものを返すようにすれば良いか?"""
        # Adam with weight decay only for 2D matrices
        optimizer = optimization.WeightDecayForMatrixAdam(
            alpha=1.,  # ignore alpha. instead, use eta as actual lr
            eps=1e-6, weight_decay_rate=0.01)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.GradientClipping(1.))

        """ ConcatenatedDatasetはon memolyなため、巨大データセットのPickleを扱えない
        input_files = sorted(input_files)[:len(input_files) // 2]
        input_files = sorted(input_files)[:200]
        import concurrent.futures
        train_examples = []
        with concurrent.futures.ThreadPoolExecutor() as executor:
            for train_exapmle in executor.map(_load_data_using_dataset_api, input_files):
                train_examples.append(train_exapmle)
        train_examples = ConcatenatedDataset(*train_examples)
        """
        train_examples = _load_data_using_dataset_api(input_files[0])

        train_iter = chainer.iterators.SerialIterator(
            train_examples, FLAGS.train_batch_size)
        converter = Converter()
        if False:
            updater = training.updaters.StandardUpdater(
                train_iter, optimizer,
                converter=converter,
                device=FLAGS.gpu)
        else:
            updater = training.updaters.ParallelUpdater(
                iterator=train_iter,
                optimizer=optimizer,
                converter=converter,
                # The device of the name 'main' is used as a "master", while others are
                # used as slaves. Names other than 'main' are arbitrary.
                devices={'main': 0,
                         '1': 1,
                         '2': 2,
                         '3': 3,
                         '4': 4,
                         '5': 5,
                         '6': 6,
                         '7': 7,
                         },
            )
        # learning rate (eta) scheduling in Adam
        num_warmup_steps = FLAGS.num_warmup_steps
        num_train_steps = FLAGS.num_train_steps
        trainer = training.Trainer(
            updater, (num_train_steps, 'iteration'), out=FLAGS.output_dir)
        lr_decay_init = FLAGS.learning_rate * \
            (num_train_steps - num_warmup_steps) / num_train_steps
        trainer.extend(extensions.LinearShift(  # decay
            'eta', (lr_decay_init, 0.), (num_warmup_steps, num_train_steps)))
        trainer.extend(extensions.WarmupShift(  # warmup
            'eta', 0., num_warmup_steps, FLAGS.learning_rate))
        trainer.extend(extensions.observe_value(
            'eta', lambda trainer: trainer.updater.get_optimizer('main').eta),
            trigger=(50, 'iteration'))  # logging

        trainer.extend(extensions.snapshot_object(
            model, 'seq_128_model_snapshot_iter_{.updater.iteration}.npz'),
            trigger=(1000, 'iteration'))
        trainer.extend(extensions.LogReport(
            trigger=(1, 'iteration')))
        #trainer.extend(extensions.PlotReport(
        #    [
        #        'main/next_sentence_loss',
        #        'main/next_sentence_accuracy',
        #     ], (3, 'iteration'), file_name='next_sentence.png'))
        #trainer.extend(extensions.PlotReport(
        #    [
        #        'main/masked_lm_loss',
        #        'main/masked_lm_accuracy',
        #     ], (3, 'iteration'), file_name='masked_lm.png'))
        trainer.extend(extensions.PlotReport(
            y_keys=[
                'main/loss',
                'main/next_sentence_loss',
                'main/next_sentence_accuracy',
                'main/masked_lm_loss',
                'main/masked_lm_accuracy',
             ], x_key='iteration', trigger=(100, 'iteration'), file_name='loss.png'))
        trainer.extend(extensions.PrintReport(
            ['iteration',
             'main/loss',
             'main/masked_lm_loss', 'main/masked_lm_accuracy',
             'main/next_sentence_loss', 'main/next_sentence_accuracy',
             'elapsed_time']))
        trainer.extend(extensions.ProgressBar(update_interval=20))

        trainer.run()

    if FLAGS.do_eval:
        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        eval_input_fn = input_fn_builder(
            input_files=input_files,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=False)

        result = estimator.evaluate(
            input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.gfile.GFile(output_eval_file, "w") as writer:
            tf.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Esempio n. 25
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, default='configs/base.yml', help='path to config file')
    parser.add_argument('--results_dir', type=str, default='./results',
                        help='directory to save the results to')
    parser.add_argument('--snapshot', type=str, default='',
                        help='path to the snapshot')
    parser.add_argument('--loaderjob', type=int,
                        help='number of parallel data loading processes')
    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    device = 0
    chainer.cuda.get_device_from_id(device).use()
    print("init")
    multiprocessing.set_start_method('forkserver')

    # Model
    gen, dis = load_models(config)
    gen.to_gpu()
    dis.to_gpu()
    models = {"gen": gen, "dis": dis}

    # Optimizer
    opt_gen = make_optimizer(gen, alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2'])
    opt_dis = make_optimizer(dis, alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2'])
    opts = {"opt_gen": opt_gen, "opt_dis": opt_dis}

    # Dataset
    train, test = yaml_utils.load_dataset(config)

    # Iterator
    train_iter = chainer.iterators.MultiprocessIterator(train, config.batchsize, n_processes=args.loaderjob)
    test_iter = chainer.iterators.SerialIterator(test, config.batchsize_test, repeat=False, shuffle=False)

    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': train_iter,
        'optimizer': opts,
        'device': device,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'), out=out)
    report_keys = ["loss_dis", "loss_gen", "loss_l1", "psnr", "ssim"]
    eval_func = yaml_utils.load_eval_func(config)
        # Set up logging
    trainer.extend(extensions.snapshot(), trigger=(config.snapshot_interval, 'iteration'))
    for m in models.values():
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(config.snapshot_interval, 'iteration'))
    trainer.extend(extensions.LogReport(keys=report_keys,
                                        trigger=(config.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys), trigger=(config.display_interval, 'iteration'))
    trainer.extend(eval_func(test_iter, gen, dst=out), trigger=(config.evaluation_interval, 'iteration'), priority=chainer.training.extension.PRIORITY_WRITER)
    for key in report_keys:
        trainer.extend(extensions.PlotReport(key, trigger=(config.evaluation_interval, 'iteration'), file_name='{}.png'.format(key)))
    trainer.extend(extensions.ProgressBar(update_interval=config.progressbar_interval))
    ext_opt_gen = extensions.LinearShift('alpha', (config.adam['alpha'], 0.),
                                         (config.iteration_decay_start, config.iteration), opt_gen)
    ext_opt_dis = extensions.LinearShift('alpha', (config.adam['alpha'], 0.),
                                         (config.iteration_decay_start, config.iteration), opt_dis)
    trainer.extend(ext_opt_gen)
    trainer.extend(ext_opt_dis)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()
Esempio n. 26
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/base.yml',
                        help='path to config file')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='index of gpu to be used')
    parser.add_argument('--input_dir', type=str, default='./data/imagenet')
    parser.add_argument('--truth_dir', type=str, default='./data/imagenet')
    parser.add_argument('--results_dir',
                        type=str,
                        default='./results/gans',
                        help='directory to save the results to')
    parser.add_argument('--snapshot',
                        type=str,
                        default='',
                        help='path to the snapshot')
    parser.add_argument('--gen_model',
                        type=str,
                        default='',
                        help='path to the generator .npz file')
    parser.add_argument('--dis_model',
                        type=str,
                        default='',
                        help='path to the discriminator .npz file')
    parser.add_argument('--loaderjob',
                        type=int,
                        help='number of parallel data loading processes')

    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))
    chainer.cuda.get_device_from_id(args.gpu).use()
    gen, dis, enc = load_models(config)

    chainer.serializers.load_npz(args.gen_model, gen)
    chainer.serializers.load_npz(args.dis_model, dis)

    gen.to_gpu(device=args.gpu)
    dis.to_gpu(device=args.gpu)
    enc.to_gpu(device=args.gpu)
    models = {"gen": gen, "dis": dis, "enc": enc}

    opt_enc = make_optimizer(enc,
                             alpha=config.adam['alpha'],
                             beta1=config.adam['beta1'],
                             beta2=config.adam['beta2'])
    opts = {"opt_enc": opt_enc}
    # Dataset
    config['dataset']['args']['root_input'] = args.input_dir
    config['dataset']['args']['root_truth'] = args.truth_dir
    dataset = yaml_utils.load_dataset(config)
    # Iterator
    iterator = chainer.iterators.MultiprocessIterator(
        dataset, config.batchsize, n_processes=args.loaderjob)
    kwargs = config.updater['args'] if 'args' in config.updater else {}
    kwargs.update({
        'models': models,
        'iterator': iterator,
        'optimizer': opts,
    })
    updater = yaml_utils.load_updater_class(config)
    updater = updater(**kwargs)
    out = args.results_dir
    create_result_dir(out, args.config_path, config)
    trainer = training.Trainer(updater, (config.iteration, 'iteration'),
                               out=out)
    report_keys = ["loss", "min_slope", "max_slope", "min_z", "max_z"]
    # Set up logging
    trainer.extend(extensions.snapshot(),
                   trigger=(config.snapshot_interval, 'iteration'))
    for m in models.values():
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(config.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(config.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(config.display_interval, 'iteration'))

    trainer.extend(sample_reconstruction(enc,
                                         gen,
                                         out,
                                         n_classes=gen.n_classes),
                   trigger=(config.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(
        extensions.ProgressBar(update_interval=config.progressbar_interval))
    ext_opt_enc = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_enc)
    trainer.extend(ext_opt_enc)
    if args.snapshot:
        print("Resume training with snapshot:{}".format(args.snapshot))
        chainer.serializers.load_npz(args.snapshot, trainer)

    # Run the training
    print("start training")
    trainer.run()
 def setUp(self):
     self.optimizer = mock.MagicMock()
     self.trainer = mock.MagicMock()
     self.extension = extensions.LinearShift('x', self.value_range,
                                             self.time_range,
                                             self.optimizer)
Esempio n. 28
0
def main():
    args = arguments()
    out = os.path.join(args.out, dt.now().strftime('%m%d_%H%M'))
    print(args)
    print("\nresults are saved under: ", out)
    save_args(args, out)

    if args.imgtype == "dcm":
        from dataset_dicom import Dataset as Dataset
    else:
        from dataset_jpg import DatasetOutMem as Dataset

    # CUDA
    if not chainer.cuda.available:
        print("CUDA required")
        exit()
    if len(args.gpu) == 1 and args.gpu[0] >= 0:
        chainer.cuda.get_device_from_id(args.gpu[0]).use()
#        cuda.cupy.cuda.set_allocator(cuda.cupy.cuda.MemoryPool().malloc)

# Enable autotuner of cuDNN
    chainer.config.autotune = True
    chainer.config.dtype = dtypes[args.dtype]
    chainer.print_runtime_info()
    # Turn off type check
    #    chainer.config.type_check = False
    #    print('Chainer version: ', chainer.__version__)
    #    print('GPU availability:', chainer.cuda.available)
    #    print('cuDNN availablility:', chainer.cuda.cudnn_enabled)

    ## dataset iterator
    print("Setting up data iterators...")
    train_A_dataset = Dataset(path=os.path.join(args.root, 'trainA'),
                              args=args,
                              random=args.random_translate,
                              forceSpacing=0)
    train_B_dataset = Dataset(path=os.path.join(args.root, 'trainB'),
                              args=args,
                              random=args.random_translate,
                              forceSpacing=args.forceSpacing)
    test_A_dataset = Dataset(path=os.path.join(args.root, 'testA'),
                             args=args,
                             random=0,
                             forceSpacing=0)
    test_B_dataset = Dataset(path=os.path.join(args.root, 'testB'),
                             args=args,
                             random=0,
                             forceSpacing=args.forceSpacing)

    args.ch = train_A_dataset.ch
    args.out_ch = train_B_dataset.ch
    print("channels in A {}, channels in B {}".format(args.ch, args.out_ch))

    test_A_iter = chainer.iterators.SerialIterator(test_A_dataset,
                                                   args.nvis_A,
                                                   shuffle=False)
    test_B_iter = chainer.iterators.SerialIterator(test_B_dataset,
                                                   args.nvis_B,
                                                   shuffle=False)

    if args.batch_size > 1:
        train_A_iter = chainer.iterators.MultiprocessIterator(train_A_dataset,
                                                              args.batch_size,
                                                              n_processes=3)
        train_B_iter = chainer.iterators.MultiprocessIterator(train_B_dataset,
                                                              args.batch_size,
                                                              n_processes=3)
    else:
        train_A_iter = chainer.iterators.SerialIterator(
            train_A_dataset, args.batch_size)
        train_B_iter = chainer.iterators.SerialIterator(
            train_B_dataset, args.batch_size)

    # setup models
    enc_x = net.Encoder(args)
    enc_y = enc_x if args.single_encoder else net.Encoder(args)
    dec_x = net.Decoder(args)
    dec_y = net.Decoder(args)
    dis_x = net.Discriminator(args)
    dis_y = net.Discriminator(args)
    dis_z = net.Discriminator(
        args) if args.lambda_dis_z > 0 else chainer.links.Linear(1, 1)
    models = {
        'enc_x': enc_x,
        'dec_x': dec_x,
        'enc_y': enc_y,
        'dec_y': dec_y,
        'dis_x': dis_x,
        'dis_y': dis_y,
        'dis_z': dis_z
    }

    ## load learnt models
    if args.load_models:
        for e in models:
            m = args.load_models.replace('enc_x', e)
            try:
                serializers.load_npz(m, models[e])
                print('model loaded: {}'.format(m))
            except:
                print("couldn't load {}".format(m))
                pass

    # select GPU
    if len(args.gpu) == 1:
        for e in models:
            models[e].to_gpu()
        print('using gpu {}, cuDNN {}'.format(args.gpu,
                                              chainer.cuda.cudnn_enabled))
    else:
        print("mandatory GPU use: currently only a single GPU can be used")
        exit()

    # Setup optimisers
    def make_optimizer(model, lr, opttype='Adam'):
        #        eps = 1e-5 if args.dtype==np.float16 else 1e-8
        optimizer = optim[opttype](lr)
        #from profiled_optimizer import create_marked_profile_optimizer
        #        optimizer = create_marked_profile_optimizer(optim[opttype](lr), sync=True, sync_level=2)
        if args.weight_decay > 0:
            if opttype in ['Adam', 'AdaBound', 'Eve']:
                optimizer.weight_decay_rate = args.weight_decay
            else:
                if args.weight_decay_norm == 'l2':
                    optimizer.add_hook(
                        chainer.optimizer.WeightDecay(args.weight_decay))
                else:
                    optimizer.add_hook(
                        chainer.optimizer_hooks.Lasso(args.weight_decay))
        optimizer.setup(model)
        return optimizer

    opt_enc_x = make_optimizer(enc_x, args.learning_rate_g, args.optimizer)
    opt_dec_x = make_optimizer(dec_x, args.learning_rate_g, args.optimizer)
    opt_enc_y = make_optimizer(enc_y, args.learning_rate_g, args.optimizer)
    opt_dec_y = make_optimizer(dec_y, args.learning_rate_g, args.optimizer)
    opt_x = make_optimizer(dis_x, args.learning_rate_d, args.optimizer)
    opt_y = make_optimizer(dis_y, args.learning_rate_d, args.optimizer)
    opt_z = make_optimizer(dis_z, args.learning_rate_d, args.optimizer)
    optimizers = {
        'opt_enc_x': opt_enc_x,
        'opt_dec_x': opt_dec_x,
        'opt_enc_y': opt_enc_y,
        'opt_dec_y': opt_dec_y,
        'opt_x': opt_x,
        'opt_y': opt_y,
        'opt_z': opt_z
    }
    if args.load_optimizer:
        for e in optimizers:
            try:
                m = args.load_models.replace('enc_x', e)
                serializers.load_npz(m, optimizers[e])
                print('optimiser loaded: {}'.format(m))
            except:
                print("couldn't load {}".format(m))
                pass

    # Set up an updater: TODO: multi gpu updater
    print("Preparing updater...")
    updater = Updater(
        models=(enc_x, dec_x, enc_y, dec_y, dis_x, dis_y, dis_z),
        iterator={
            'main': train_A_iter,
            'train_B': train_B_iter,
        },
        optimizer=optimizers,
        #        converter=convert.ConcatWithAsyncTransfer(),
        device=args.gpu[0],
        params={'args': args})

    if args.snapinterval < 0:
        args.snapinterval = args.lrdecay_start + args.lrdecay_period
    log_interval = (200, 'iteration')
    model_save_interval = (args.snapinterval, 'epoch')
    plot_interval = (500, 'iteration')

    # Set up a trainer
    print("Preparing trainer...")
    if args.iteration:
        stop_trigger = (args.iteration, 'iteration')
    else:
        stop_trigger = (args.lrdecay_start + args.lrdecay_period, 'epoch')
    trainer = training.Trainer(updater, stop_trigger, out=out)
    for e in models:
        trainer.extend(extensions.snapshot_object(models[e],
                                                  e + '{.updater.epoch}.npz'),
                       trigger=model_save_interval)
#        trainer.extend(extensions.ParameterStatistics(models[e]))   ## very slow
    for e in optimizers:
        trainer.extend(extensions.snapshot_object(optimizers[e],
                                                  e + '{.updater.epoch}.npz'),
                       trigger=model_save_interval)

    log_keys = ['epoch', 'iteration', 'lr']
    log_keys_cycle = [
        'opt_enc_x/loss_cycle', 'opt_enc_y/loss_cycle', 'opt_dec_x/loss_cycle',
        'opt_dec_y/loss_cycle', 'myval/cycle_x_l1', 'myval/cycle_y_l1'
    ]
    log_keys_d = [
        'opt_x/loss_real', 'opt_x/loss_fake', 'opt_y/loss_real',
        'opt_y/loss_fake', 'opt_z/loss_x', 'opt_z/loss_y'
    ]
    log_keys_adv = [
        'opt_enc_y/loss_adv', 'opt_dec_y/loss_adv', 'opt_enc_x/loss_adv',
        'opt_dec_x/loss_adv'
    ]
    log_keys.extend(
        ['opt_enc_x/loss_reg', 'opt_enc_y/loss_reg', 'opt_dec_y/loss_tv'])
    if args.lambda_air > 0:
        log_keys.extend(['opt_dec_x/loss_air', 'opt_dec_y/loss_air'])
    if args.lambda_grad > 0:
        log_keys.extend(['opt_dec_x/loss_grad', 'opt_dec_y/loss_grad'])
    if args.lambda_identity_x > 0:
        log_keys.extend(['opt_dec_x/loss_id', 'opt_dec_y/loss_id'])
    if args.dis_reg_weighting > 0:
        log_keys_d.extend(
            ['opt_x/loss_reg', 'opt_y/loss_reg', 'opt_z/loss_reg'])
    if args.dis_wgan:
        log_keys_d.extend(['opt_x/loss_gp', 'opt_y/loss_gp', 'opt_z/loss_gp'])

    log_keys_all = log_keys + log_keys_d + log_keys_adv + log_keys_cycle
    trainer.extend(
        extensions.LogReport(keys=log_keys_all, trigger=log_interval))
    trainer.extend(extensions.PrintReport(log_keys_all), trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=20))
    trainer.extend(extensions.observe_lr(optimizer_name='opt_enc_x'),
                   trigger=log_interval)
    # learning rate scheduling
    decay_start_iter = len(train_A_dataset) * args.lrdecay_start
    decay_end_iter = len(train_A_dataset) * (args.lrdecay_start +
                                             args.lrdecay_period)
    for e in [opt_enc_x, opt_enc_y, opt_dec_x, opt_dec_y]:
        trainer.extend(
            extensions.LinearShift('alpha', (args.learning_rate_g, 0),
                                   (decay_start_iter, decay_end_iter),
                                   optimizer=e))
    for e in [opt_x, opt_y, opt_z]:
        trainer.extend(
            extensions.LinearShift('alpha', (args.learning_rate_d, 0),
                                   (decay_start_iter, decay_end_iter),
                                   optimizer=e))
    ## dump graph
    if args.report_start < 1:
        if args.lambda_tv > 0:
            trainer.extend(
                extensions.dump_graph('opt_dec_y/loss_tv', out_name='dec.dot'))
        if args.lambda_reg > 0:
            trainer.extend(
                extensions.dump_graph('opt_enc_x/loss_reg',
                                      out_name='enc.dot'))
        trainer.extend(
            extensions.dump_graph('opt_x/loss_fake', out_name='dis.dot'))

    # ChainerUI


#    trainer.extend(CommandsExtension())

    if extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(log_keys[3:],
                                  'iteration',
                                  trigger=plot_interval,
                                  file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(log_keys_d,
                                  'iteration',
                                  trigger=plot_interval,
                                  file_name='loss_d.png'))
        trainer.extend(
            extensions.PlotReport(log_keys_adv,
                                  'iteration',
                                  trigger=plot_interval,
                                  file_name='loss_adv.png'))
        trainer.extend(
            extensions.PlotReport(log_keys_cycle,
                                  'iteration',
                                  trigger=plot_interval,
                                  file_name='loss_cyc.png'))

    ## visualisation
    vis_folder = os.path.join(out, "vis")
    os.makedirs(vis_folder, exist_ok=True)
    if not args.vis_freq:
        args.vis_freq = len(train_A_dataset) // 2
    s = [k for k in range(args.num_slices)
         ] if args.num_slices > 0 and args.imgtype == "dcm" else None
    trainer.extend(VisEvaluator({
        "testA": test_A_iter,
        "testB": test_B_iter
    }, {
        "enc_x": enc_x,
        "enc_y": enc_y,
        "dec_x": dec_x,
        "dec_y": dec_y
    },
                                params={
                                    'vis_out': vis_folder,
                                    'slice': s
                                },
                                device=args.gpu[0]),
                   trigger=(args.vis_freq, 'iteration'))

    ## output filenames of training dataset
    with open(os.path.join(out, 'trainA.txt'), 'w') as output:
        for f in train_A_dataset.names:
            output.writelines("\n".join(f))
            output.writelines("\n")
    with open(os.path.join(out, 'trainB.txt'), 'w') as output:
        for f in train_B_dataset.names:
            output.writelines("\n".join(f))
            output.writelines("\n")

    # archive the scripts
    rundir = os.path.dirname(os.path.realpath(__file__))
    import zipfile
    with zipfile.ZipFile(os.path.join(out, 'script.zip'),
                         'w',
                         compression=zipfile.ZIP_DEFLATED) as new_zip:
        for f in [
                'train.py', 'net.py', 'updater.py', 'consts.py', 'losses.py',
                'arguments.py', 'convert.py'
        ]:
            new_zip.write(os.path.join(rundir, f), arcname=f)

    # Run the training
    trainer.run()
Esempio n. 29
0
def main(hpt):
    logger.info('load New Data From Matlab')
    if hpt.dataset.type == 'synthetic':
        treeFile = loadmat(hpt.dataset['matrixFile_struct'])
        matrixForData = treeFile.get('matrixForHS')
        hpt.dataset.__setitem__('depth', treeFile.get('depthToSave')[0, 0])
        hpt.dataset.__setitem__('depthReal',
                                treeFile.get('depthToSaveReal')[0, 0])
        hpt.training.__setitem__('batch_size', treeFile.get('baches')[0, 0])
        print(treeFile.get('matrixForHS'))
        print(hpt.dataset.depth)
        print(hpt.training.batch_size)
    elif hpt.dataset.type == 'mnist':
        activityFile = loadmat(hpt.dataset['matrixFile_activity'])
        matrixForData = np.transpose(activityFile.get('roiActivity'))
        hpt.dataset.__setitem__('mnistShape', len(matrixForData[1, :]))
        hpt.training.__setitem__('batch_size', len(matrixForData[:, 1]))

    logger.info('build model')
    avg_elbo_loss = get_model(hpt)
    if hpt.general.gpu >= 0:
        avg_elbo_loss.to_gpu(hpt.general.gpu)

    logger.info('setup optimizer')
    if hpt.optimizer.type == 'adam':
        optimizer = chainer.optimizers.Adam(alpha=hpt.optimizer.lr)
    optimizer.setup(avg_elbo_loss)

    logger.info('load dataset')
    train, valid, test = dataset.get_dataset(hpt.dataset.type, matrixForData,
                                             **hpt.dataset)

    if hpt.general.test:
        train, _ = chainer.datasets.split_dataset(train, 100)
        valid, _ = chainer.datasets.split_dataset(valid, 100)
        test, _ = chainer.datasets.split_dataset(test, 100)

    train_iter = chainer.iterators.SerialIterator(train,
                                                  hpt.training.batch_size)
    valid_iter = chainer.iterators.SerialIterator(valid,
                                                  hpt.training.batch_size,
                                                  repeat=False,
                                                  shuffle=False)

    logger.info('setup updater/trainer')
    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                device=hpt.general.gpu,
                                                loss_func=avg_elbo_loss)

    if not hpt.training.early_stopping:
        trainer = training.Trainer(updater,
                                   (hpt.training.iteration, 'iteration'),
                                   out=po.namedir(output='str'))
    else:
        trainer = training.Trainer(updater,
                                   triggers.EarlyStoppingTrigger(
                                       monitor='validation/main/loss',
                                       patients=5,
                                       max_trigger=(hpt.training.iteration,
                                                    'iteration')),
                                   out=po.namedir(output='str'))

    if hpt.training.warm_up != -1:
        time_range = (0, hpt.training.warm_up)
        trainer.extend(
            extensions.LinearShift('beta',
                                   value_range=(0.1, hpt.loss.beta),
                                   time_range=time_range,
                                   optimizer=avg_elbo_loss))

    trainer.extend(
        extensions.Evaluator(valid_iter, avg_elbo_loss,
                             device=hpt.general.gpu))
    # trainer.extend(extensions.DumpGraph('main/loss'))
    trainer.extend(extensions.snapshot_object(
        avg_elbo_loss, 'avg_elbo_loss_snapshot_iter_{.updater.iteration}'),
                   trigger=(int(hpt.training.iteration / 5), 'iteration'))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.observe_lr())
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'iteration', 'main/loss', 'validation/main/loss',
            'main/reconstr', 'main/kl_penalty', 'main/beta', 'lr',
            'elapsed_time'
        ]))
    trainer.extend(extensions.ProgressBar())

    logger.info('run training')
    trainer.run()

    logger.info('save last model')
    extensions.snapshot_object(
        avg_elbo_loss,
        'avg_elbo_loss_snapshot_iter_{.updater.iteration}')(trainer)

    logger.info('evaluate')
    metrics = evaluate(hpt, train, test, avg_elbo_loss)
    for metric_name, metric in metrics.items():
        logger.info('{}: {:.4f}'.format(metric_name, metric))

    if hpt.general.noplot:
        return metrics

    logger.info('visualize images')
    visualize(hpt, train, test, avg_elbo_loss, treeFile)

    return metrics
Esempio n. 30
0
def train(train_file,
          test_file=None,
          format='tree',
          embed_file=None,
          n_epoch=20,
          batch_size=20,
          lr=0.001,
          limit=-1,
          l2_lambda=0.0,
          grad_clip=5.0,
          encoder_input=('char', 'postag'),
          model_config=None,
          device=-1,
          save_dir=None,
          seed=None,
          cache_dir='',
          refresh_cache=False,
          bert_model=0,
          bert_dir=''):
    if seed is not None:
        utils.set_random_seed(seed, device)
    logger = logging.getLogger()
    # logger.configure(filename='log.txt', logdir=save_dir)
    assert isinstance(logger, logging.AppLogger)
    if model_config is None:
        model_config = {}
    model_config['bert_model'] = bert_model
    model_config['bert_dir'] = bert_dir

    os.makedirs(save_dir, exist_ok=True)

    read_genia = format == 'genia'
    loader = dataset.DataLoader.build(
        postag_embed_size=model_config.get('postag_embed_size', 50),
        char_embed_size=model_config.get('char_embed_size', 10),
        word_embed_file=embed_file,
        filter_coord=(not read_genia),
        refresh_cache=refresh_cache,
        format=format,
        cache_options=dict(dir=cache_dir, mkdir=True, logger=logger),
        extra_ids=(git.hash(), ))

    use_external_postags = not read_genia
    cont_embed_file_ext = _get_cont_embed_file_ext(encoder_input)
    use_cont_embed = cont_embed_file_ext is not None

    train_dataset = loader.load_with_external_resources(
        train_file,
        train=True,
        bucketing=False,
        size=None if limit < 0 else limit,
        refresh_cache=refresh_cache,
        use_external_postags=use_external_postags,
        use_contextualized_embed=use_cont_embed,
        contextualized_embed_file_ext=cont_embed_file_ext)
    logging.info('{} samples loaded for training'.format(len(train_dataset)))
    test_dataset = None
    if test_file is not None:
        test_dataset = loader.load_with_external_resources(
            test_file,
            train=False,
            bucketing=False,
            size=None if limit < 0 else limit // 10,
            refresh_cache=refresh_cache,
            use_external_postags=use_external_postags,
            use_contextualized_embed=use_cont_embed,
            contextualized_embed_file_ext=cont_embed_file_ext)
        logging.info('{} samples loaded for validation'.format(
            len(test_dataset)))

    builder = models.CoordSolverBuilder(loader,
                                        inputs=encoder_input,
                                        **model_config)
    logger.info("{}".format(builder))
    model = builder.build()
    logger.trace("Model: {}".format(model))
    if device >= 0:
        chainer.cuda.get_device_from_id(device).use()
        model.to_gpu(device)

    if bert_model == 1:
        optimizer = chainer.optimizers.AdamW(alpha=lr)
        optimizer.setup(model)
        # optimizer.add_hook(chainer.optimizer.GradientClipping(1.))
    else:
        optimizer = chainer.optimizers.AdamW(alpha=lr,
                                             beta1=0.9,
                                             beta2=0.999,
                                             eps=1e-08)
        optimizer.setup(model)
        if l2_lambda > 0.0:
            optimizer.add_hook(chainer.optimizer.WeightDecay(l2_lambda))
        if grad_clip > 0.0:
            optimizer.add_hook(chainer.optimizer.GradientClipping(grad_clip))

    def _report(y, t):
        values = {}
        model.compute_accuracy(y, t)
        for k, v in model.result.items():
            if 'loss' in k:
                values[k] = float(chainer.cuda.to_cpu(v.data))
            elif 'accuracy' in k:
                values[k] = v
        training.report(values)

    trainer = training.Trainer(optimizer, model, loss_func=model.compute_loss)
    trainer.configure(utils.training_config)
    trainer.add_listener(
        training.listeners.ProgressBar(lambda n: tqdm(total=n)), priority=200)
    trainer.add_hook(training.BATCH_END,
                     lambda data: _report(data['ys'], data['ts']))
    if test_dataset:
        parser = parsers.build_parser(loader, model)
        evaluator = eval_module.Evaluator(parser,
                                          logger=logging,
                                          report_details=False)
        trainer.add_listener(evaluator)

    if bert_model == 2:
        num_train_steps = 20000 * 5 / 20
        num_warmup_steps = 10000 / 20
        learning_rate = 2e-5
        # learning rate (eta) scheduling in Adam
        lr_decay_init = learning_rate * \
            (num_train_steps - num_warmup_steps) / num_train_steps
        trainer.add_hook(
            training.BATCH_END,
            extensions.LinearShift(  # decay
                'eta', (lr_decay_init, 0.),
                (num_warmup_steps, num_train_steps),
                optimizer=optimizer))
        trainer.add_hook(
            training.BATCH_END,
            extensions.WarmupShift(  # warmup
                'eta',
                0.,
                num_warmup_steps,
                learning_rate,
                optimizer=optimizer))

    if save_dir is not None:
        accessid = logging.getLogger().accessid
        date = logging.getLogger().accesstime.strftime('%Y%m%d')
        # metric = 'whole' if isinstance(model, models.Teranishi17) else 'inner'
        metric = 'exact'
        trainer.add_listener(
            utils.Saver(
                model,
                basename="{}-{}".format(date, accessid),
                context=dict(App.context, builder=builder),
                directory=save_dir,
                logger=logger,
                save_best=True,
                evaluate=(lambda _: evaluator.get_overall_score(metric))))

    trainer.fit(train_dataset, test_dataset, n_epoch, batch_size)