Пример #1
0
 def test_without_init(self):
     self.optimizer.param_groups[0]['x'] = self.init
     extension = extensions.StepShift('x',
                                      self.gamma,
                                      self.step,
                                      init=self.init,
                                      target=self.target)
     self._run_trainer(extension, self.expect)
Пример #2
0
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.optimizer.param_groups = [{'x': None}]
        self.extension = extensions.StepShift('x', self.gamma, self.step,
                                              self.init, self.target,
                                              self.optimizer)

        self.interval = 1
        self.expect = [e for e in self.expect for _ in range(self.interval)]
        self.trigger = training.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
Пример #3
0
    def test_resume(self):
        new_optimizer = mock.Mock()
        new_optimizer.param_groups = [{'x': None}]
        new_extension = extensions.StepShift('x', self.gamma, self.step,
                                             self.init, self.target,
                                             new_optimizer)

        self.trainer.extend(self.extension)
        self.trainer.run()

        new_trainer = testing.get_trainer_with_mock_updater((5, 'iteration'))
        new_trainer.extend(new_extension)
        testing.save_and_load_pth(self.trainer, new_trainer)

        new_extension.initialize(new_trainer)
        self.assertEqual(new_optimizer.param_groups[0]['x'],
                         self.optimizer.param_groups[0]['x'])
        self.assertIsInstance(new_optimizer.param_groups[0]['x'], float)
Пример #4
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()
    assert_config(config)

    output.mkdir(exist_ok=True, parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    predictor = create_predictor(config.network)
    model = Model(
        loss_config=config.loss,
        predictor=predictor,
        local_padding_size=config.dataset.local_padding_size,
    )
    if config.train.weight_initializer is not None:
        init_weights(model, name=config.train.weight_initializer)

    device = torch.device("cuda")
    model.to(device)

    # dataset
    _create_iterator = partial(
        create_iterator,
        batch_size=config.train.batchsize,
        eval_batch_size=config.train.eval_batchsize,
        num_processes=config.train.num_processes,
        use_multithread=config.train.use_multithread,
    )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"],
                                  for_train=True,
                                  for_eval=False)
    test_iter = _create_iterator(datasets["test"],
                                 for_train=False,
                                 for_eval=False)
    eval_iter = _create_iterator(datasets["eval"],
                                 for_train=False,
                                 for_eval=True)

    valid_iter = None
    if datasets["valid"] is not None:
        valid_iter = _create_iterator(datasets["valid"],
                                      for_train=False,
                                      for_eval=True)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    cp: Dict[str, Any] = copy(config.train.optimizer)
    n = cp.pop("name").lower()

    optimizer: Optimizer
    if n == "adam":
        optimizer = optim.Adam(model.parameters(), **cp)
    elif n == "sgd":
        optimizer = optim.SGD(model.parameters(), **cp)
    else:
        raise ValueError(n)

    # updater
    if not config.train.use_amp:
        updater = StandardUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )
    else:
        updater = AmpUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_eval = (config.train.eval_iteration, "iteration")
    trigger_snapshot = (config.train.snapshot_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)

    shift_ext = None
    if config.train.linear_shift is not None:
        shift_ext = extensions.LinearShift(**config.train.linear_shift)
    if config.train.step_shift is not None:
        shift_ext = extensions.StepShift(**config.train.step_shift)
    if shift_ext is not None:
        trainer.extend(shift_ext)

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)

    generator = Generator(
        config=config,
        predictor=predictor,
        use_gpu=True,
        max_batch_size=(config.train.eval_batchsize
                        if config.train.eval_batchsize is not None else
                        config.train.batchsize),
        use_fast_inference=False,
    )
    generate_evaluator = GenerateEvaluator(
        generator=generator,
        time_length=config.dataset.time_length_evaluate,
        local_padding_time_length=config.dataset.
        local_padding_time_length_evaluate,
    )
    ext = extensions.Evaluator(eval_iter, generate_evaluator, device=device)
    trainer.extend(ext, name="eval", trigger=trigger_eval)
    if valid_iter is not None:
        ext = extensions.Evaluator(valid_iter,
                                   generate_evaluator,
                                   device=device)
        trainer.extend(ext, name="valid", trigger=trigger_eval)

    if config.train.stop_iteration is not None:
        saving_model_num = int(config.train.stop_iteration /
                               config.train.eval_iteration / 10)
    else:
        saving_model_num = 10
    ext = extensions.snapshot_object(
        predictor,
        filename="predictor_{.updater.iteration}.pth",
        n_retains=saving_model_num,
    )
    trainer.extend(
        ext,
        trigger=LowValueTrigger("eval/main/mcd", trigger=trigger_eval),
    )

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.observe_lr(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )
    trainer.extend(TensorboardReport(writer=SummaryWriter(Path(output))),
                   trigger=trigger_log)

    if config.project.category is not None:
        ext = WandbReport(
            config_dict=config.to_dict(),
            project_category=config.project.category,
            project_name=config.project.name,
            output_dir=output.joinpath("wandb"),
        )
        trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    ext = extensions.snapshot_object(
        trainer,
        filename="trainer_{.updater.iteration}.pth",
        n_retains=1,
        autoload=True,
    )
    trainer.extend(ext, trigger=trigger_snapshot)

    return trainer
Пример #5
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()

    output.mkdir(exist_ok=True, parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    device = torch.device("cuda")
    predictor = create_predictor(config.network)
    model = Model(
        model_config=config.model,
        predictor=predictor,
        local_padding_length=config.dataset.local_padding_length,
    )
    init_weights(model, "orthogonal")
    model.to(device)

    # dataset
    _create_iterator = partial(
        create_iterator,
        batch_size=config.train.batchsize,
        eval_batch_size=config.train.eval_batchsize,
        num_processes=config.train.num_processes,
        use_multithread=config.train.use_multithread,
    )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"],
                                  for_train=True,
                                  for_eval=False)
    test_iter = _create_iterator(datasets["test"],
                                 for_train=False,
                                 for_eval=False)
    eval_iter = _create_iterator(datasets["eval"],
                                 for_train=False,
                                 for_eval=True)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    cp: Dict[str, Any] = copy(config.train.optimizer)
    n = cp.pop("name").lower()

    optimizer: Optimizer
    if n == "adam":
        optimizer = optim.Adam(model.parameters(), **cp)
    elif n == "sgd":
        optimizer = optim.SGD(model.parameters(), **cp)
    else:
        raise ValueError(n)

    # updater
    use_amp = config.train.use_amp if config.train.use_amp is not None else amp_exist
    if use_amp:
        updater = AmpUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )
    else:
        updater = StandardUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_eval = (config.train.eval_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)
    writer = SummaryWriter(Path(output))

    # # error at randint
    # sample_data = datasets["train"][0]
    # writer.add_graph(
    #     model,
    #     input_to_model=(
    #         sample_data["wave"].unsqueeze(0).to(device),
    #         sample_data["local"].unsqueeze(0).to(device),
    #         sample_data["speaker_id"].unsqueeze(0).to(device)
    #         if predictor.with_speaker
    #         else None,
    #     ),
    # )

    if config.train.multistep_shift is not None:
        trainer.extend(
            extensions.MultistepShift(**config.train.multistep_shift))
    if config.train.step_shift is not None:
        trainer.extend(extensions.StepShift(**config.train.step_shift))

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)

    generator = Generator(
        config=config,
        noise_schedule_config=NoiseScheduleModelConfig(start=1e-4,
                                                       stop=0.05,
                                                       num=50),
        predictor=predictor,
        sampling_rate=config.dataset.sampling_rate,
        use_gpu=True,
    )
    generate_evaluator = GenerateEvaluator(
        generator=generator,
        local_padding_time_second=config.dataset.
        evaluate_local_padding_time_second,
    )
    ext = extensions.Evaluator(eval_iter, generate_evaluator, device=device)
    trainer.extend(ext, name="eval", trigger=trigger_eval)

    if config.train.stop_iteration is not None:
        saving_model_num = int(config.train.stop_iteration /
                               config.train.eval_iteration / 10)
    else:
        saving_model_num = 10

    ext = extensions.snapshot_object(
        predictor,
        filename="predictor_{.updater.iteration}.pth",
        n_retains=saving_model_num,
    )
    trainer.extend(
        ext,
        trigger=LowValueTrigger("eval/main/mcd", trigger=trigger_eval),
    )

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.observe_lr(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )

    trainer.extend(ext, trigger=TensorboardReport(writer=writer))

    if config.project.category is not None:
        ext = WandbReport(
            config_dict=config.to_dict(),
            project_category=config.project.category,
            project_name=config.project.name,
            output_dir=output.joinpath("wandb"),
        )
        trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    ext = extensions.snapshot_object(
        trainer,
        filename="trainer_{.updater.iteration}.pth",
        n_retains=1,
        autoload=True,
    )
    trainer.extend(ext, trigger=trigger_eval)

    return trainer
Пример #6
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()

    output.mkdir(parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    device = torch.device("cuda")

    networks = create_network(config.network)
    model = Model(
        model_config=config.model,
        networks=networks,
        local_padding_length=config.dataset.local_padding_length,
    )
    model.to(device)

    if config.model.discriminator_input_type is not None:
        discriminator_model = DiscriminatorModel(
            model_config=config.model,
            networks=networks,
            local_padding_length=config.dataset.local_padding_length,
        )
        discriminator_model.to(device)
    else:
        discriminator_model = None

    # dataset
    def _create_iterator(dataset, for_train: bool):
        return MultiprocessIterator(
            dataset,
            config.train.batchsize,
            repeat=for_train,
            shuffle=for_train,
            n_processes=config.train.num_processes,
            dataset_timeout=300,
        )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"], for_train=True)
    test_iter = _create_iterator(datasets["test"], for_train=False)
    test_eval_iter = _create_iterator(datasets["test_eval"], for_train=False)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    optimizer = create_optimizer(config.train.optimizer, model)
    if config.train.discriminator_optimizer is not None:
        discriminator_optimizer = create_optimizer(
            config.train.discriminator_optimizer, discriminator_model)
    else:
        discriminator_optimizer = None

    # updater
    updater = Updater(
        iterator=train_iter,
        optimizer=optimizer,
        discriminator_model=discriminator_model,
        model=model,
        discriminator_optimizer=discriminator_optimizer,
        device=device,
    )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_snapshot = (config.train.snapshot_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)

    if config.train.step_shift is not None:
        trainer.extend(extensions.StepShift(**config.train.step_shift))

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)
    if discriminator_model is not None:
        ext = extensions.Evaluator(test_iter,
                                   discriminator_model,
                                   device=device)
        trainer.extend(ext, name="test", trigger=trigger_log)

    generator = Generator(config=config,
                          predictor=networks.predictor,
                          use_gpu=True)
    generate_evaluator = GenerateEvaluator(
        generator=generator,
        time_length=config.dataset.evaluate_time_second,
        local_padding_time_length=config.dataset.
        evaluate_local_padding_time_second,
    )
    ext = extensions.Evaluator(test_eval_iter,
                               generate_evaluator,
                               device=device)
    trainer.extend(ext, name="eval", trigger=trigger_snapshot)

    ext = extensions.snapshot_object(
        networks.predictor, filename="predictor_{.updater.iteration}.pth")
    trainer.extend(ext, trigger=trigger_snapshot)
    # ext = extensions.snapshot_object(
    #     trainer, filename="trainer_{.updater.iteration}.pth"
    # )
    # trainer.extend(ext, trigger=trigger_snapshot)
    # if networks.discriminator is not None:
    #     ext = extensions.snapshot_object(
    #         networks.discriminator, filename="discriminator_{.updater.iteration}.pth"
    #     )
    #     trainer.extend(ext, trigger=trigger_snapshot)

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )

    ext = TensorboardReport(writer=SummaryWriter(Path(output)))
    trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))
    if discriminator_model is not None:
        (output / "discriminator_struct.txt").write_text(
            repr(discriminator_model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    return trainer
Пример #7
0
 def test_with_optimizer(self):
     optimizer = mock.Mock()
     optimizer.param_groups = [{'x': 0}]
     extension = extensions.StepShift('x', self.gamma, self.step, self.init,
                                      self.target, optimizer)
     self._run_trainer(extension, self.expect, optimizer)
Пример #8
0
def create_trainer(
    config: Config,
    output: Path,
):
    # config
    config.add_git_info()

    output.mkdir(exist_ok=True, parents=True)
    with output.joinpath("config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    predictor = create_predictor(config.network)
    model = Model(model_config=config.model, predictor=predictor)
    if config.train.weight_initializer is not None:
        init_weights(model, name=config.train.weight_initializer)

    device = torch.device("cuda")
    model.to(device)

    # dataset
    _create_iterator = partial(
        create_iterator,
        batch_size=config.train.batch_size,
        num_processes=config.train.num_processes,
        use_multithread=config.train.use_multithread,
    )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"], for_train=True)
    test_iter = _create_iterator(datasets["test"], for_train=False)
    eval_iter = _create_iterator(datasets["test"],
                                 for_train=False,
                                 for_eval=True)

    valid_iter = None
    if datasets["valid"] is not None:
        valid_iter = _create_iterator(datasets["valid"],
                                      for_train=False,
                                      for_eval=True)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    optimizer = make_optimizer(config_dict=config.train.optimizer, model=model)

    # updater
    if not config.train.use_amp:
        updater = StandardUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )
    else:
        updater = AmpUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_eval = (config.train.eval_iteration, "iteration")
    trigger_snapshot = (config.train.snapshot_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)

    if config.train.step_shift is not None:
        ext = extensions.StepShift(**config.train.step_shift)
        trainer.extend(ext)

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)

    generator = Generator(config=config, predictor=predictor, use_gpu=True)
    generate_evaluator = GenerateEvaluator(generator=generator)
    ext = extensions.Evaluator(eval_iter, generate_evaluator, device=device)
    trainer.extend(ext, name="eval", trigger=trigger_eval)
    if valid_iter is not None:
        ext = extensions.Evaluator(valid_iter,
                                   generate_evaluator,
                                   device=device)
        trainer.extend(ext, name="valid", trigger=trigger_eval)

    ext = extensions.snapshot_object(
        predictor,
        filename="predictor_{.updater.iteration}.pth",
        n_retains=5,
    )
    trainer.extend(
        ext,
        trigger=LowValueTrigger("eval/main/mcd", trigger=trigger_eval),
    )

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.observe_lr(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )

    ext = TensorboardReport(writer=SummaryWriter(Path(output)))
    trainer.extend(ext, trigger=trigger_log)

    if config.project.category is not None:
        ext = WandbReport(
            config_dict=config.to_dict(),
            project_category=config.project.category,
            project_name=config.project.name,
            output_dir=output.joinpath("wandb"),
        )
        trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    ext = extensions.snapshot_object(
        trainer,
        filename="trainer_{.updater.iteration}.pth",
        n_retains=1,
        autoload=True,
    )
    trainer.extend(ext, trigger=trigger_snapshot)

    return trainer