Пример #1
0
 def test_without_init(self):
     self.optimizer.x = self.init
     extension = extensions.StepShift('x',
                                      self.gamma,
                                      self.step,
                                      init=self.init,
                                      target=self.target)
     self._run_trainer(extension, self.expect)
Пример #2
0
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.extension = extensions.StepShift('x', self.gamma, self.step,
                                              self.init, self.target,
                                              self.optimizer)

        self.interval = 1
        self.expect = [e for e in self.expect for _ in range(self.interval)]
        self.trigger = training.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
Пример #3
0
    def test_resume(self):
        new_optimizer = mock.Mock()
        new_extension = extensions.StepShift('x', self.gamma, self.step,
                                             self.init, self.target,
                                             new_optimizer)

        self.trainer.extend(self.extension)
        self.trainer.run()

        new_trainer = testing.get_trainer_with_mock_updater((5, 'iteration'))
        new_trainer.extend(new_extension)
        testing.save_and_load_npz(self.trainer, new_trainer)

        new_extension.initialize(new_trainer)
        self.assertEqual(new_optimizer.x, self.optimizer.x)
        self.assertIsInstance(new_optimizer.x, float)
Пример #4
0
 def test_with_optimizer(self):
     optimizer = mock.Mock()
     optimizer.x = 0
     extension = extensions.StepShift('x', self.gamma, self.step, self.init,
                                      self.target, optimizer)
     self._run_trainer(extension, self.expect, optimizer)
Пример #5
0
def create_trainer(
    config: Config,
    output: Path,
):
    assert_config(config)
    if output.exists():
        raise Exception(f"output directory {output} already exists.")

    # model
    predictor = create_predictor(config.model)
    if config.train.trained_model is not None:
        chainer.serializers.load_npz(
            config.train.trained_model["predictor_path"], predictor)
    model = Model(
        loss_config=config.loss,
        predictor=predictor,
        local_padding_size=config.dataset.local_padding_size,
    )

    model.to_gpu(config.train.gpu[0])
    cuda.get_device_from_id(config.train.gpu[0]).use()

    # dataset
    dataset = create_dataset(config.dataset)
    batchsize_devided = config.train.batchsize // len(config.train.gpu)
    train_iter = MultiprocessIterator(dataset["train"], config.train.batchsize)
    test_iter = MultiprocessIterator(dataset["test"],
                                     batchsize_devided,
                                     repeat=False,
                                     shuffle=True)
    train_test_iter = MultiprocessIterator(dataset["train_test"],
                                           batchsize_devided,
                                           repeat=False,
                                           shuffle=True)

    if dataset["test_eval"] is not None:
        test_eval_iter = MultiprocessIterator(dataset["test_eval"],
                                              batchsize_devided,
                                              repeat=False,
                                              shuffle=True)
    else:
        test_eval_iter = None

    # optimizer
    def create_optimizer(model):
        cp: Dict[str, Any] = copy(config.train.optimizer)
        n = cp.pop("name").lower()

        if n == "adam":
            optimizer = optimizers.Adam(**cp)
        elif n == "sgd":
            optimizer = optimizers.SGD(**cp)
        else:
            raise ValueError(n)

        optimizer.setup(model)

        if config.train.optimizer_gradient_clipping is not None:
            optimizer.add_hook(
                optimizer_hooks.GradientClipping(
                    config.train.optimizer_gradient_clipping))

        return optimizer

    optimizer = create_optimizer(model)
    if config.train.trained_model is not None:
        chainer.serializers.load_npz(
            config.train.trained_model["optimizer_path"], optimizer)

    # updater
    if len(config.train.gpu) <= 1:
        updater = StandardUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            converter=concat_optional,
            device=config.train.gpu[0],
        )
    else:
        updater = ParallelUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            converter=concat_optional,
            devices={
                "main" if i == 0 else f"gpu{gpu}": gpu
                for i, gpu in enumerate(config.train.gpu)
            },
        )
    if config.train.trained_model is not None:
        updater.iteration = optimizer.t

    # trainer
    output.mkdir()
    config.save_as_json((output / "config.json").absolute())

    trigger_log = (config.train.log_iteration, "iteration")
    trigger_snapshot = (config.train.snapshot_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = training.Trainer(updater, stop_trigger=trigger_stop, out=output)
    tb_writer = SummaryWriter(Path(output))

    shift_ext = None
    if config.train.linear_shift is not None:
        shift_ext = extensions.LinearShift(**config.train.linear_shift)
    if config.train.step_shift is not None:
        shift_ext = extensions.StepShift(**config.train.step_shift)
    if shift_ext is not None:
        if config.train.trained_model is not None:
            shift_ext._t = optimizer.t
        trainer.extend(shift_ext)

    if config.train.ema_decay is not None:
        train_predictor = predictor
        predictor = deepcopy(predictor)
        ext = ExponentialMovingAverage(target=train_predictor,
                                       ema_target=predictor,
                                       decay=config.train.ema_decay)
        trainer.extend(ext, trigger=(1, "iteration"))

    ext = extensions.Evaluator(test_iter,
                               model,
                               concat_optional,
                               device=config.train.gpu[0])
    trainer.extend(ext, name="test", trigger=trigger_log)
    ext = extensions.Evaluator(train_test_iter,
                               model,
                               concat_optional,
                               device=config.train.gpu[0])
    trainer.extend(ext, name="train", trigger=trigger_log)

    if test_eval_iter is not None:
        generator = Generator(config=config,
                              model=predictor,
                              max_batch_size=config.train.batchsize)
        generate_evaluator = GenerateEvaluator(
            generator=generator,
            time_length=config.dataset.time_length_evaluate,
            local_padding_time_length=config.dataset.
            local_padding_time_length_evaluate,
        )
        ext = extensions.Evaluator(
            test_eval_iter,
            generate_evaluator,
            concat_optional,
            device=config.train.gpu[0],
        )
        trainer.extend(ext, name="eval", trigger=trigger_snapshot)

    ext = extensions.snapshot_object(predictor,
                                     filename="main_{.updater.iteration}.npz")
    trainer.extend(ext, trigger=trigger_snapshot)
    # ext = extensions.snapshot_object(
    #     optimizer, filename="optimizer_{.updater.iteration}.npz"
    # )
    # trainer.extend(ext, trigger=trigger_snapshot)

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.observe_lr(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )
    trainer.extend(TensorBoardReport(writer=tb_writer), trigger=trigger_log)

    trainer.extend(extensions.dump_graph(root_name="main/loss"))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    return trainer