def train_phase(predictor, train, valid, args):

    # setup iterators
    train_iter = iterators.SerialIterator(train, args.batchsize)
    valid_iter = iterators.SerialIterator(valid,
                                          args.batchsize,
                                          repeat=False,
                                          shuffle=False)

    # setup a model
    device = torch.device(args.gpu)

    model = Classifier(predictor)
    model.to(device)

    # setup an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 weight_decay=max(args.decay, 0))

    # setup a trainer
    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                model,
                                                device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu))

    # trainer.extend(DumpGraph(model, 'main/loss'))

    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    trainer.extend(extensions.LogReport())

    if args.plot and extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch',
                                  file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch',
                file_name='accuracy.png'))

    trainer.extend(
        extensions.PrintReport([
            'epoch', 'iteration', 'main/loss', 'validation/main/loss',
            'main/accuracy', 'validation/main/accuracy', 'elapsed_time'
        ]))

    trainer.extend(extensions.ProgressBar())

    if args.resume:
        trainer.load_state_dict(torch.load(args.resume))

    trainer.run()

    torch.save(predictor.state_dict(), os.path.join(args.out, 'predictor.pth'))
    def _setup(self, stream=None, delete_flush=False):
        self.logreport = mock.MagicMock(spec=extensions.LogReport(
            ['epoch'], trigger=(1, 'iteration'), log_name=None))
        if stream is None:
            self.stream = mock.MagicMock()
            if delete_flush:
                del self.stream.flush
        else:
            self.stream = stream
        self.report = extensions.PrintReport(['epoch'],
                                             log_report=self.logreport,
                                             out=self.stream)

        self.trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=(1, 'iteration'))
        self.trainer.extend(self.logreport)
        self.trainer.extend(self.report)
        self.logreport.log = [{'epoch': 0}]
示例#3
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()
    assert_config(config)

    output.mkdir(exist_ok=True, parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    predictor = create_predictor(config.network)
    model = Model(
        loss_config=config.loss,
        predictor=predictor,
        local_padding_size=config.dataset.local_padding_size,
    )
    if config.train.weight_initializer is not None:
        init_weights(model, name=config.train.weight_initializer)

    device = torch.device("cuda")
    model.to(device)

    # dataset
    _create_iterator = partial(
        create_iterator,
        batch_size=config.train.batchsize,
        eval_batch_size=config.train.eval_batchsize,
        num_processes=config.train.num_processes,
        use_multithread=config.train.use_multithread,
    )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"],
                                  for_train=True,
                                  for_eval=False)
    test_iter = _create_iterator(datasets["test"],
                                 for_train=False,
                                 for_eval=False)
    eval_iter = _create_iterator(datasets["eval"],
                                 for_train=False,
                                 for_eval=True)

    valid_iter = None
    if datasets["valid"] is not None:
        valid_iter = _create_iterator(datasets["valid"],
                                      for_train=False,
                                      for_eval=True)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    cp: Dict[str, Any] = copy(config.train.optimizer)
    n = cp.pop("name").lower()

    optimizer: Optimizer
    if n == "adam":
        optimizer = optim.Adam(model.parameters(), **cp)
    elif n == "sgd":
        optimizer = optim.SGD(model.parameters(), **cp)
    else:
        raise ValueError(n)

    # updater
    if not config.train.use_amp:
        updater = StandardUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )
    else:
        updater = AmpUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_eval = (config.train.eval_iteration, "iteration")
    trigger_snapshot = (config.train.snapshot_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)

    shift_ext = None
    if config.train.linear_shift is not None:
        shift_ext = extensions.LinearShift(**config.train.linear_shift)
    if config.train.step_shift is not None:
        shift_ext = extensions.StepShift(**config.train.step_shift)
    if shift_ext is not None:
        trainer.extend(shift_ext)

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)

    generator = Generator(
        config=config,
        predictor=predictor,
        use_gpu=True,
        max_batch_size=(config.train.eval_batchsize
                        if config.train.eval_batchsize is not None else
                        config.train.batchsize),
        use_fast_inference=False,
    )
    generate_evaluator = GenerateEvaluator(
        generator=generator,
        time_length=config.dataset.time_length_evaluate,
        local_padding_time_length=config.dataset.
        local_padding_time_length_evaluate,
    )
    ext = extensions.Evaluator(eval_iter, generate_evaluator, device=device)
    trainer.extend(ext, name="eval", trigger=trigger_eval)
    if valid_iter is not None:
        ext = extensions.Evaluator(valid_iter,
                                   generate_evaluator,
                                   device=device)
        trainer.extend(ext, name="valid", trigger=trigger_eval)

    if config.train.stop_iteration is not None:
        saving_model_num = int(config.train.stop_iteration /
                               config.train.eval_iteration / 10)
    else:
        saving_model_num = 10
    ext = extensions.snapshot_object(
        predictor,
        filename="predictor_{.updater.iteration}.pth",
        n_retains=saving_model_num,
    )
    trainer.extend(
        ext,
        trigger=LowValueTrigger("eval/main/mcd", trigger=trigger_eval),
    )

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.observe_lr(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )
    trainer.extend(TensorboardReport(writer=SummaryWriter(Path(output))),
                   trigger=trigger_log)

    if config.project.category is not None:
        ext = WandbReport(
            config_dict=config.to_dict(),
            project_category=config.project.category,
            project_name=config.project.name,
            output_dir=output.joinpath("wandb"),
        )
        trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    ext = extensions.snapshot_object(
        trainer,
        filename="trainer_{.updater.iteration}.pth",
        n_retains=1,
        autoload=True,
    )
    trainer.extend(ext, trigger=trigger_snapshot)

    return trainer
示例#4
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()

    output.mkdir(exist_ok=True, parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    device = torch.device("cuda")
    predictor = create_predictor(config.network)
    model = Model(
        model_config=config.model,
        predictor=predictor,
        local_padding_length=config.dataset.local_padding_length,
    )
    init_weights(model, "orthogonal")
    model.to(device)

    # dataset
    _create_iterator = partial(
        create_iterator,
        batch_size=config.train.batchsize,
        eval_batch_size=config.train.eval_batchsize,
        num_processes=config.train.num_processes,
        use_multithread=config.train.use_multithread,
    )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"],
                                  for_train=True,
                                  for_eval=False)
    test_iter = _create_iterator(datasets["test"],
                                 for_train=False,
                                 for_eval=False)
    eval_iter = _create_iterator(datasets["eval"],
                                 for_train=False,
                                 for_eval=True)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    cp: Dict[str, Any] = copy(config.train.optimizer)
    n = cp.pop("name").lower()

    optimizer: Optimizer
    if n == "adam":
        optimizer = optim.Adam(model.parameters(), **cp)
    elif n == "sgd":
        optimizer = optim.SGD(model.parameters(), **cp)
    else:
        raise ValueError(n)

    # updater
    use_amp = config.train.use_amp if config.train.use_amp is not None else amp_exist
    if use_amp:
        updater = AmpUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )
    else:
        updater = StandardUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_eval = (config.train.eval_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)
    writer = SummaryWriter(Path(output))

    # # error at randint
    # sample_data = datasets["train"][0]
    # writer.add_graph(
    #     model,
    #     input_to_model=(
    #         sample_data["wave"].unsqueeze(0).to(device),
    #         sample_data["local"].unsqueeze(0).to(device),
    #         sample_data["speaker_id"].unsqueeze(0).to(device)
    #         if predictor.with_speaker
    #         else None,
    #     ),
    # )

    if config.train.multistep_shift is not None:
        trainer.extend(
            extensions.MultistepShift(**config.train.multistep_shift))
    if config.train.step_shift is not None:
        trainer.extend(extensions.StepShift(**config.train.step_shift))

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)

    generator = Generator(
        config=config,
        noise_schedule_config=NoiseScheduleModelConfig(start=1e-4,
                                                       stop=0.05,
                                                       num=50),
        predictor=predictor,
        sampling_rate=config.dataset.sampling_rate,
        use_gpu=True,
    )
    generate_evaluator = GenerateEvaluator(
        generator=generator,
        local_padding_time_second=config.dataset.
        evaluate_local_padding_time_second,
    )
    ext = extensions.Evaluator(eval_iter, generate_evaluator, device=device)
    trainer.extend(ext, name="eval", trigger=trigger_eval)

    if config.train.stop_iteration is not None:
        saving_model_num = int(config.train.stop_iteration /
                               config.train.eval_iteration / 10)
    else:
        saving_model_num = 10

    ext = extensions.snapshot_object(
        predictor,
        filename="predictor_{.updater.iteration}.pth",
        n_retains=saving_model_num,
    )
    trainer.extend(
        ext,
        trigger=LowValueTrigger("eval/main/mcd", trigger=trigger_eval),
    )

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.observe_lr(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )

    trainer.extend(ext, trigger=TensorboardReport(writer=writer))

    if config.project.category is not None:
        ext = WandbReport(
            config_dict=config.to_dict(),
            project_category=config.project.category,
            project_name=config.project.name,
            output_dir=output.joinpath("wandb"),
        )
        trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    ext = extensions.snapshot_object(
        trainer,
        filename="trainer_{.updater.iteration}.pth",
        n_retains=1,
        autoload=True,
    )
    trainer.extend(ext, trigger=trigger_eval)

    return trainer
示例#5
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()

    output.mkdir(exist_ok=True, parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    networks = create_network(config.network)
    model = Model(model_config=config.model, networks=networks)
    if config.train.weight_initializer is not None:
        init_weights(model, name=config.train.weight_initializer)

    device = torch.device("cuda") if config.train.use_gpu else torch.device(
        "cpu")
    model.to(device)

    # dataset
    _create_iterator = partial(
        create_iterator,
        batch_size=config.train.batch_size,
        eval_batch_size=config.train.eval_batch_size,
        num_processes=config.train.num_processes,
        use_multithread=config.train.use_multithread,
    )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"], for_train=True)
    test_iter = _create_iterator(datasets["test"], for_train=False)
    eval_iter = _create_iterator(datasets["eval"], for_train=False)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    optimizer = make_optimizer(config_dict=config.train.optimizer, model=model)

    # updater
    if not config.train.use_amp:
        updater = StandardUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )
    else:
        updater = AmpUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_eval = (config.train.eval_iteration, "iteration")
    trigger_snapshot = (config.train.snapshot_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)

    if config.train.stop_iteration is not None:
        saving_model_num = int(config.train.stop_iteration /
                               config.train.eval_iteration / 10)
    else:
        saving_model_num = 10
    ext = extensions.snapshot_object(
        networks.predictor,
        filename="predictor_{.updater.iteration}.pth",
        n_retains=saving_model_num,
    )
    trainer.extend(
        ext,
        trigger=LowValueTrigger("test/main/loss", trigger=trigger_eval),
    )

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.observe_lr(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )

    ext = TensorboardReport(writer=SummaryWriter(Path(output)))
    trainer.extend(ext, trigger=trigger_log)

    if config.project.category is not None:
        ext = WandbReport(
            config_dict=config.to_dict(),
            project_category=config.project.category,
            project_name=config.project.name,
            output_dir=output.joinpath("wandb"),
        )
        trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    ext = extensions.snapshot_object(
        trainer,
        filename="trainer_{.updater.iteration}.pth",
        n_retains=1,
        autoload=True,
    )
    trainer.extend(ext, trigger=trigger_snapshot)

    return trainer
示例#6
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()

    output.mkdir(exist_ok=True, parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    predictor = create_predictor(config.network)
    model = Model(model_config=config.model, predictor=predictor)
    if config.train.weight_initializer is not None:
        init_weights(model, name=config.train.weight_initializer)

    device = torch.device("cuda")
    model.to(device)

    # dataset
    _create_iterator = partial(
        create_iterator,
        batch_size=config.train.batch_size,
        num_processes=config.train.num_processes,
        use_multithread=config.train.use_multithread,
    )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"], for_train=True)
    test_iter = _create_iterator(datasets["test"], for_train=False)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    cp: Dict[str, Any] = copy(config.train.optimizer)
    n = cp.pop("name").lower()

    optimizer: Optimizer
    if n == "adam":
        optimizer = optim.Adam(model.parameters(), **cp)
    elif n == "sgd":
        optimizer = optim.SGD(model.parameters(), **cp)
    else:
        raise ValueError(n)

    # updater
    updater = StandardUpdater(
        iterator=train_iter,
        optimizer=optimizer,
        model=model,
        converter=list_concat,
        device=device,
    )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_eval = (config.train.snapshot_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)
    writer = SummaryWriter(Path(output))

    sample_data = datasets["train"][0]
    writer.add_graph(
        model,
        input_to_model=(
            [sample_data["f0"].to(device)],
            [sample_data["phoneme"].to(device)],
            [sample_data["phoneme_list"].to(device)],
            ([sample_data["speaker_id"].to(device)]
             if predictor.with_speaker else None),
        ),
    )

    ext = extensions.Evaluator(test_iter,
                               model,
                               converter=list_concat,
                               device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)

    if config.train.stop_iteration is not None:
        saving_model_num = int(config.train.stop_iteration /
                               config.train.snapshot_iteration / 10)
    else:
        saving_model_num = 10
    ext = extensions.snapshot_object(
        predictor,
        filename="predictor_{.updater.iteration}.pth",
        n_retains=saving_model_num,
    )
    trainer.extend(
        ext,
        trigger=LowValueTrigger("test/main/loss", trigger=trigger_eval),
    )

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.observe_lr(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )

    ext = TensorboardReport(writer=writer)
    trainer.extend(ext, trigger=trigger_log)

    if config.project.category is not None:
        ext = WandbReport(
            config_dict=config.to_dict(),
            project_category=config.project.category,
            project_name=config.project.name,
            output_dir=output.joinpath("wandb"),
        )
        trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    ext = extensions.snapshot_object(
        trainer,
        filename="trainer_{.updater.iteration}.pth",
        n_retains=1,
        autoload=True,
    )
    trainer.extend(ext, trigger=trigger_eval)

    return trainer
示例#7
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()

    output.mkdir(parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    device = torch.device("cuda")

    networks = create_network(config.network)
    model = Model(
        model_config=config.model,
        networks=networks,
        local_padding_length=config.dataset.local_padding_length,
    )
    model.to(device)

    if config.model.discriminator_input_type is not None:
        discriminator_model = DiscriminatorModel(
            model_config=config.model,
            networks=networks,
            local_padding_length=config.dataset.local_padding_length,
        )
        discriminator_model.to(device)
    else:
        discriminator_model = None

    # dataset
    def _create_iterator(dataset, for_train: bool):
        return MultiprocessIterator(
            dataset,
            config.train.batchsize,
            repeat=for_train,
            shuffle=for_train,
            n_processes=config.train.num_processes,
            dataset_timeout=300,
        )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"], for_train=True)
    test_iter = _create_iterator(datasets["test"], for_train=False)
    test_eval_iter = _create_iterator(datasets["test_eval"], for_train=False)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    optimizer = create_optimizer(config.train.optimizer, model)
    if config.train.discriminator_optimizer is not None:
        discriminator_optimizer = create_optimizer(
            config.train.discriminator_optimizer, discriminator_model)
    else:
        discriminator_optimizer = None

    # updater
    updater = Updater(
        iterator=train_iter,
        optimizer=optimizer,
        discriminator_model=discriminator_model,
        model=model,
        discriminator_optimizer=discriminator_optimizer,
        device=device,
    )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_snapshot = (config.train.snapshot_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)

    if config.train.step_shift is not None:
        trainer.extend(extensions.StepShift(**config.train.step_shift))

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)
    if discriminator_model is not None:
        ext = extensions.Evaluator(test_iter,
                                   discriminator_model,
                                   device=device)
        trainer.extend(ext, name="test", trigger=trigger_log)

    generator = Generator(config=config,
                          predictor=networks.predictor,
                          use_gpu=True)
    generate_evaluator = GenerateEvaluator(
        generator=generator,
        time_length=config.dataset.evaluate_time_second,
        local_padding_time_length=config.dataset.
        evaluate_local_padding_time_second,
    )
    ext = extensions.Evaluator(test_eval_iter,
                               generate_evaluator,
                               device=device)
    trainer.extend(ext, name="eval", trigger=trigger_snapshot)

    ext = extensions.snapshot_object(
        networks.predictor, filename="predictor_{.updater.iteration}.pth")
    trainer.extend(ext, trigger=trigger_snapshot)
    # ext = extensions.snapshot_object(
    #     trainer, filename="trainer_{.updater.iteration}.pth"
    # )
    # trainer.extend(ext, trigger=trigger_snapshot)
    # if networks.discriminator is not None:
    #     ext = extensions.snapshot_object(
    #         networks.discriminator, filename="discriminator_{.updater.iteration}.pth"
    #     )
    #     trainer.extend(ext, trigger=trigger_snapshot)

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )

    ext = TensorboardReport(writer=SummaryWriter(Path(output)))
    trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))
    if discriminator_model is not None:
        (output / "discriminator_struct.txt").write_text(
            repr(discriminator_model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    return trainer
示例#8
0
def train_phase(predictor, train, valid, args):

    # visualize
    plt.rcParams['font.size'] = 18
    plt.figure(figsize=(13, 5))
    ax = sns.scatterplot(x=train.x.ravel(),
                         y=train.y.ravel(),
                         color='blue',
                         s=55,
                         alpha=0.3)
    ax.plot(train.x.ravel(), train.t.ravel(), color='red', linewidth=2)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_xlim(-10, 10)
    ax.set_ylim(-15, 15)
    plt.legend(['Ground-truth', 'Observation'])
    plt.title('Training data set')
    plt.tight_layout()
    plt.savefig(os.path.join(args.out, 'train_dataset.png'))
    plt.close()

    # setup iterators
    train_iter = iterators.SerialIterator(train, args.batchsize, shuffle=True)
    valid_iter = iterators.SerialIterator(valid,
                                          args.batchsize,
                                          repeat=False,
                                          shuffle=False)

    # setup a model
    device = torch.device(args.gpu)

    lossfun = noised_mean_squared_error
    accfun = lambda y, t: F.l1_loss(y[0], t)

    model = Regressor(predictor, lossfun=lossfun, accfun=accfun)
    model.to(device)

    # setup an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 weight_decay=max(args.decay, 0))

    # setup a trainer
    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                model,
                                                device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu))

    # trainer.extend(DumpGraph(model, 'main/loss'))

    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    trainer.extend(extensions.LogReport())

    if args.plot and extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch',
                                  file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch',
                file_name='accuracy.png'))

        trainer.extend(
            extensions.PlotReport(
                ['main/predictor/sigma', 'validation/main/predictor/sigma'],
                'epoch',
                file_name='sigma.png'))

    trainer.extend(
        extensions.PrintReport([
            'epoch', 'iteration', 'main/loss', 'validation/main/loss',
            'main/accuracy', 'validation/main/accuracy',
            'main/predictor/sigma', 'validation/main/predictor/sigma',
            'elapsed_time'
        ]))

    trainer.extend(extensions.ProgressBar())

    if args.resume:
        trainer.load_state_dict(torch.load(args.resume))

    trainer.run()

    torch.save(predictor.state_dict(), os.path.join(args.out, 'predictor.pth'))
示例#9
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
    dataset_dir: Optional[Path],
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()

    output.mkdir(parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    device = torch.device("cuda")
    networks = create_network(config.network)
    generator_model = GeneratorModel(model_config=config.model,
                                     networks=networks).to(device)
    moving_generator_model = deepcopy(generator_model).to(device)
    discriminator_model = DiscriminatorModel(model_config=config.model,
                                             networks=networks).to(device)

    # dataset
    def _create_iterator(dataset, for_train: bool):
        return MultiprocessIterator(
            dataset,
            config.train.batchsize,
            repeat=for_train,
            shuffle=for_train,
            n_processes=config.train.num_processes,
            dataset_timeout=60 * 15,
        )

    datasets = create_dataset(config.dataset, dataset_dir=dataset_dir)
    train_iter = _create_iterator(datasets["train"], for_train=True)
    test_iter = _create_iterator(datasets["test"], for_train=False)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    style_transfer_optimizer = create_optimizer(
        config=config.train.style_transfer_optimizer,
        model=networks.style_transfer)
    mapping_network_optimizer = create_optimizer(
        config=config.train.mapping_network_optimizer,
        model=networks.mapping_network)
    style_encoder_optimizer = create_optimizer(
        config=config.train.style_encoder_optimizer,
        model=networks.style_encoder)
    discriminator_optimizer = create_optimizer(
        config=config.train.discriminator_optimizer,
        model=networks.discriminator)

    # updater
    updater = Updater(
        iterator=train_iter,
        optimizer=dict(
            style_transfer=style_transfer_optimizer,
            mapping_network=mapping_network_optimizer,
            style_encoder=style_encoder_optimizer,
            discriminator=discriminator_optimizer,
        ),
        model=dict(
            generator=generator_model,
            discriminator=discriminator_model,
            moving_generator=moving_generator_model,
        ),
        moving_average_rate=config.train.moving_average_rate,
        device=device,
    )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_snapshot = (config.train.snapshot_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)

    def eval_func(**kwargs):
        generator_model.forward_with_latent(**kwargs)
        generator_model.forward_with_reference(**kwargs)
        discriminator_model.forward_with_latent(**kwargs)
        discriminator_model.forward_with_reference(**kwargs)
        moving_generator_model.forward_with_latent(**kwargs)
        moving_generator_model.forward_with_reference(**kwargs)

    ext = extensions.Evaluator(
        test_iter,
        target=dict(
            generator=generator_model,
            discriminator=discriminator_model,
            moving_generator=moving_generator_model,
        ),
        eval_func=eval_func,
        device=device,
    )
    trainer.extend(ext, name="test", trigger=trigger_log)

    def add_snapshot_object(target, name):
        ext = extensions.snapshot_object(target,
                                         filename=name +
                                         "_{.updater.iteration}.pth")
        trainer.extend(ext, trigger=trigger_snapshot)

    add_snapshot_object(networks.style_transfer, "style_transfer")
    add_snapshot_object(networks.mapping_network, "mapping_network")
    add_snapshot_object(networks.style_encoder, "style_encoder")

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport([
            "iteration", "generator/latent/loss", "test/generator/latent/loss"
        ]),
        trigger=trigger_log,
    )

    if config.train.model_config_linear_shift is not None:
        ext = ObjectLinearShift(target=config.model,
                                **config.train.model_config_linear_shift)
        trainer.extend(
            ext,
            trigger=(1, "iteration"),
        )

    ext = TensorboardReport(writer=SummaryWriter(Path(output)))
    trainer.extend(ext, trigger=trigger_log)

    (output / "generator_struct.txt").write_text(repr(generator_model))
    (output / "discriminator_struct.txt").write_text(repr(discriminator_model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    return trainer
示例#10
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()

    output.mkdir(parents=True)
    with (output / 'config.yaml').open(mode='w') as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    networks = create_network(config.network)
    model = Model(model_config=config.model, networks=networks)

    device = torch.device('cuda')
    model.to(device)

    # dataset
    def _create_iterator(dataset, for_train: bool):
        return MultiprocessIterator(
            dataset,
            config.train.batchsize,
            repeat=for_train,
            shuffle=for_train,
            n_processes=config.train.num_processes,
            dataset_timeout=60,
        )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets['train'], for_train=True)
    test_iter = _create_iterator(datasets['test'], for_train=False)
    train_test_iter = _create_iterator(datasets['train_test'], for_train=False)

    warnings.simplefilter('error', MultiprocessIterator.TimeoutWarning)

    # optimizer
    cp: Dict[str, Any] = copy(config.train.optimizer)
    n = cp.pop('name').lower()

    if n == 'adam':
        optimizer = optim.Adam(model.parameters(), **cp)
    elif n == 'sgd':
        optimizer = optim.SGD(model.parameters(), **cp)
    else:
        raise ValueError(n)

    # updater
    updater = StandardUpdater(
        iterator=train_iter,
        optimizer=optimizer,
        model=model,
        device=device,
    )

    # trainer
    trigger_log = (config.train.log_iteration, 'iteration')
    trigger_snapshot = (config.train.snapshot_iteration, 'iteration')
    trigger_stop = (
        config.train.stop_iteration,
        'iteration') if config.train.stop_iteration is not None else None

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name='test', trigger=trigger_log)
    ext = extensions.Evaluator(train_test_iter, model, device=device)
    trainer.extend(ext, name='train', trigger=trigger_log)

    ext = extensions.snapshot_object(
        networks.predictor, filename='predictor_{.updater.iteration}.npz')
    trainer.extend(ext, trigger=trigger_snapshot)

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(extensions.PrintReport(
        ['iteration', 'main/loss', 'test/main/loss']),
                   trigger=trigger_log)

    ext = TensorboardReport(writer=SummaryWriter(Path(output)))
    trainer.extend(ext, trigger=trigger_log)

    (output / 'struct.txt').write_text(repr(model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    return trainer
示例#11
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()

    output.mkdir(exist_ok=True, parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    networks = create_network(config.network)
    model = Model(config=config.model, networks=networks)
    init_orthogonal(model)

    device = torch.device("cuda")
    model.to(device)

    # dataset
    _create_iterator = partial(
        create_iterator,
        batch_size=config.train.batch_size,
        eval_batch_size=config.train.eval_batch_size,
        num_processes=config.train.num_processes,
        use_multithread=config.train.use_multithread,
    )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"],
                                  for_train=True,
                                  for_eval=False)
    test_iter = _create_iterator(datasets["test"],
                                 for_train=False,
                                 for_eval=False)

    valid_iter = None
    if datasets["valid"] is not None:
        valid_iter = _create_iterator(datasets["valid"],
                                      for_train=False,
                                      for_eval=True)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    cp: Dict[str, Any] = copy(config.train.optimizer)
    n = cp.pop("name").lower()

    optimizer: Optimizer
    if n == "adam":
        optimizer = optim.Adam(model.parameters(), **cp)
    elif n == "sgd":
        optimizer = optim.SGD(model.parameters(), **cp)
    elif n == "ranger":
        optimizer = Ranger(model.parameters(), **cp)
    else:
        raise ValueError(n)

    # updater
    if not config.train.use_amp:
        updater = StandardUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )
    else:
        updater = AmpUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_eval = (config.train.eval_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)

    if valid_iter is not None:
        ext = extensions.Evaluator(valid_iter, model, device=device)
        trainer.extend(ext, name="valid", trigger=trigger_eval)

    if config.train.stop_iteration is not None:
        saving_model_num = int(config.train.stop_iteration /
                               config.train.eval_iteration / 10)
    else:
        saving_model_num = 10
    for field in dataclasses.fields(Networks):
        ext = extensions.snapshot_object(
            getattr(networks, field.name),
            filename=field.name + "_{.updater.iteration}.pth",
            n_retains=saving_model_num,
        )
        trainer.extend(
            ext,
            trigger=HighValueTrigger(
                ("valid/main/phoneme_accuracy"
                 if valid_iter is not None else "test/main/phoneme_accuracy"),
                trigger=trigger_eval,
            ),
        )

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )

    ext = TensorboardReport(writer=SummaryWriter(Path(output)))
    trainer.extend(ext, trigger=trigger_log)

    if config.project.category is not None:
        ext = WandbReport(
            config_dict=config.to_dict(),
            project_category=config.project.category,
            project_name=config.project.name,
            output_dir=output.joinpath("wandb"),
        )
        trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    ext = extensions.snapshot_object(
        trainer,
        filename="trainer_{.updater.iteration}.pth",
        n_retains=1,
        autoload=True,
    )
    trainer.extend(ext, trigger=trigger_eval)

    return trainer
示例#12
0
def main():
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=100,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--frequency',
                        '-f',
                        type=int,
                        default=-1,
                        help='Frequency of taking a snapshot')
    parser.add_argument('--device',
                        '-d',
                        type=str,
                        default='-1',
                        help='Device specifier. Either ChainerX device '
                        'specifier or an integer. If non-negative integer, '
                        'CuPy arrays with specified device id are used. If '
                        'negative integer, NumPy arrays are used')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume',
                        '-r',
                        type=str,
                        help='Resume the training from snapshot')
    parser.add_argument('--autoload',
                        action='store_true',
                        help='Automatically load trainer snapshots in case'
                        ' of preemption or other temporary system failure')
    parser.add_argument('--unit',
                        '-u',
                        type=int,
                        default=1000,
                        help='Number of units')
    group = parser.add_argument_group('deprecated arguments')
    group.add_argument('--gpu',
                       '-g',
                       dest='device',
                       type=int,
                       nargs='?',
                       const=0,
                       help='GPU ID (negative value indicates CPU)')
    args = parser.parse_args()

    device = torch.device(args.device)

    print('Device: {}'.format(device))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    # Set up a neural network to train
    # Classifier reports softmax cross entropy loss and accuracy at every
    # iteration, which will be used by the PrintReport extension below.
    model = Classifier(MLP(784, args.unit, 10))
    model.to(device)

    # Setup an optimizer
    optimizer = torch.optim.Adam(model.parameters())

    # Load the MNIST dataset
    transform = transforms.ToTensor()
    train = datasets.MNIST('data',
                           train=True,
                           download=True,
                           transform=transform)
    test = datasets.MNIST('data', train=False, transform=transform)

    train_iter = pytorch_trainer.iterators.SerialIterator(
        train, args.batchsize)
    test_iter = pytorch_trainer.iterators.SerialIterator(test,
                                                         args.batchsize,
                                                         repeat=False,
                                                         shuffle=False)

    # Set up a trainer
    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                model,
                                                device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(test_iter, model, device=device),
                   call_before_training=True)

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    # trainer.extend(extensions.DumpGraph('main/loss'))

    # Take a snapshot for each specified epoch
    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    # Take a snapshot each ``frequency`` epoch, delete old stale
    # snapshots and automatically load from snapshot files if any
    # files are already resident at result directory.
    trainer.extend(extensions.snapshot(n_retains=1, autoload=args.autoload),
                   trigger=(frequency, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(), call_before_training=True)

    # Save two plot images to the result dir
    trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                         'epoch',
                                         file_name='loss.png'),
                   call_before_training=True)
    trainer.extend(extensions.PlotReport(
        ['main/accuracy', 'validation/main/accuracy'],
        'epoch',
        file_name='accuracy.png'),
                   call_before_training=True)

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(extensions.PrintReport([
        'epoch', 'main/loss', 'validation/main/loss', 'main/accuracy',
        'validation/main/accuracy', 'elapsed_time'
    ]),
                   call_before_training=True)

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())

    if args.resume is not None:
        # Resume from a snapshot (Note: this loaded model is to be
        # overwritten by --autoload option, autoloading snapshots, if
        # any snapshots exist in output directory)
        trainer.load_state_dict(torch.load(args.resume))

    # Run the training
    trainer.run()