示例#1
0
    def test_savefun_and_writer_exclusive(self):
        # savefun and writer arguments cannot be specified together.
        def savefun(*args, **kwargs):
            assert False

        writer = extensions.snapshot_writers.SimpleWriter()
        with pytest.raises(TypeError):
            extensions.snapshot(savefun=savefun, writer=writer)

        trainer = mock.MagicMock()
        with pytest.raises(TypeError):
            extensions.snapshot_object(trainer, savefun=savefun, writer=writer)
示例#2
0
    def test_clean_up_tempdir(self):
        snapshot = extensions.snapshot_object(self.trainer, 'myfile.dat')
        snapshot(self.trainer)

        left_tmps = [
            fn for fn in os.listdir('.') if fn.startswith('tmpmyfile.dat')
        ]
        self.assertEqual(len(left_tmps), 0)
示例#3
0
    def test_save_file(self):
        w = extensions.snapshot_writers.SimpleWriter()
        snapshot = extensions.snapshot_object(self.trainer,
                                              'myfile.dat',
                                              writer=w)
        snapshot(self.trainer)

        self.assertTrue(os.path.exists('myfile.dat'))
示例#4
0
    def test_on_error(self):
        class TheOnlyError(Exception):
            pass

        @training.make_extension(trigger=(1, 'iteration'), priority=100)
        def exception_raiser(trainer):
            raise TheOnlyError()

        self.trainer.extend(exception_raiser)

        snapshot = extensions.snapshot_object(self.trainer,
                                              self.filename,
                                              snapshot_on_error=True)
        self.trainer.extend(snapshot)

        self.assertFalse(os.path.exists(self.filename))

        with self.assertRaises(TheOnlyError):
            self.trainer.run()

        self.assertTrue(os.path.exists(self.filename))
示例#5
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()
    assert_config(config)

    output.mkdir(exist_ok=True, parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    predictor = create_predictor(config.network)
    model = Model(
        loss_config=config.loss,
        predictor=predictor,
        local_padding_size=config.dataset.local_padding_size,
    )
    if config.train.weight_initializer is not None:
        init_weights(model, name=config.train.weight_initializer)

    device = torch.device("cuda")
    model.to(device)

    # dataset
    _create_iterator = partial(
        create_iterator,
        batch_size=config.train.batchsize,
        eval_batch_size=config.train.eval_batchsize,
        num_processes=config.train.num_processes,
        use_multithread=config.train.use_multithread,
    )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"],
                                  for_train=True,
                                  for_eval=False)
    test_iter = _create_iterator(datasets["test"],
                                 for_train=False,
                                 for_eval=False)
    eval_iter = _create_iterator(datasets["eval"],
                                 for_train=False,
                                 for_eval=True)

    valid_iter = None
    if datasets["valid"] is not None:
        valid_iter = _create_iterator(datasets["valid"],
                                      for_train=False,
                                      for_eval=True)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    cp: Dict[str, Any] = copy(config.train.optimizer)
    n = cp.pop("name").lower()

    optimizer: Optimizer
    if n == "adam":
        optimizer = optim.Adam(model.parameters(), **cp)
    elif n == "sgd":
        optimizer = optim.SGD(model.parameters(), **cp)
    else:
        raise ValueError(n)

    # updater
    if not config.train.use_amp:
        updater = StandardUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )
    else:
        updater = AmpUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_eval = (config.train.eval_iteration, "iteration")
    trigger_snapshot = (config.train.snapshot_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)

    shift_ext = None
    if config.train.linear_shift is not None:
        shift_ext = extensions.LinearShift(**config.train.linear_shift)
    if config.train.step_shift is not None:
        shift_ext = extensions.StepShift(**config.train.step_shift)
    if shift_ext is not None:
        trainer.extend(shift_ext)

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)

    generator = Generator(
        config=config,
        predictor=predictor,
        use_gpu=True,
        max_batch_size=(config.train.eval_batchsize
                        if config.train.eval_batchsize is not None else
                        config.train.batchsize),
        use_fast_inference=False,
    )
    generate_evaluator = GenerateEvaluator(
        generator=generator,
        time_length=config.dataset.time_length_evaluate,
        local_padding_time_length=config.dataset.
        local_padding_time_length_evaluate,
    )
    ext = extensions.Evaluator(eval_iter, generate_evaluator, device=device)
    trainer.extend(ext, name="eval", trigger=trigger_eval)
    if valid_iter is not None:
        ext = extensions.Evaluator(valid_iter,
                                   generate_evaluator,
                                   device=device)
        trainer.extend(ext, name="valid", trigger=trigger_eval)

    if config.train.stop_iteration is not None:
        saving_model_num = int(config.train.stop_iteration /
                               config.train.eval_iteration / 10)
    else:
        saving_model_num = 10
    ext = extensions.snapshot_object(
        predictor,
        filename="predictor_{.updater.iteration}.pth",
        n_retains=saving_model_num,
    )
    trainer.extend(
        ext,
        trigger=LowValueTrigger("eval/main/mcd", trigger=trigger_eval),
    )

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.observe_lr(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )
    trainer.extend(TensorboardReport(writer=SummaryWriter(Path(output))),
                   trigger=trigger_log)

    if config.project.category is not None:
        ext = WandbReport(
            config_dict=config.to_dict(),
            project_category=config.project.category,
            project_name=config.project.name,
            output_dir=output.joinpath("wandb"),
        )
        trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    ext = extensions.snapshot_object(
        trainer,
        filename="trainer_{.updater.iteration}.pth",
        n_retains=1,
        autoload=True,
    )
    trainer.extend(ext, trigger=trigger_snapshot)

    return trainer
示例#6
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()

    output.mkdir(exist_ok=True, parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    device = torch.device("cuda")
    predictor = create_predictor(config.network)
    model = Model(
        model_config=config.model,
        predictor=predictor,
        local_padding_length=config.dataset.local_padding_length,
    )
    init_weights(model, "orthogonal")
    model.to(device)

    # dataset
    _create_iterator = partial(
        create_iterator,
        batch_size=config.train.batchsize,
        eval_batch_size=config.train.eval_batchsize,
        num_processes=config.train.num_processes,
        use_multithread=config.train.use_multithread,
    )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"],
                                  for_train=True,
                                  for_eval=False)
    test_iter = _create_iterator(datasets["test"],
                                 for_train=False,
                                 for_eval=False)
    eval_iter = _create_iterator(datasets["eval"],
                                 for_train=False,
                                 for_eval=True)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    cp: Dict[str, Any] = copy(config.train.optimizer)
    n = cp.pop("name").lower()

    optimizer: Optimizer
    if n == "adam":
        optimizer = optim.Adam(model.parameters(), **cp)
    elif n == "sgd":
        optimizer = optim.SGD(model.parameters(), **cp)
    else:
        raise ValueError(n)

    # updater
    use_amp = config.train.use_amp if config.train.use_amp is not None else amp_exist
    if use_amp:
        updater = AmpUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )
    else:
        updater = StandardUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_eval = (config.train.eval_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)
    writer = SummaryWriter(Path(output))

    # # error at randint
    # sample_data = datasets["train"][0]
    # writer.add_graph(
    #     model,
    #     input_to_model=(
    #         sample_data["wave"].unsqueeze(0).to(device),
    #         sample_data["local"].unsqueeze(0).to(device),
    #         sample_data["speaker_id"].unsqueeze(0).to(device)
    #         if predictor.with_speaker
    #         else None,
    #     ),
    # )

    if config.train.multistep_shift is not None:
        trainer.extend(
            extensions.MultistepShift(**config.train.multistep_shift))
    if config.train.step_shift is not None:
        trainer.extend(extensions.StepShift(**config.train.step_shift))

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)

    generator = Generator(
        config=config,
        noise_schedule_config=NoiseScheduleModelConfig(start=1e-4,
                                                       stop=0.05,
                                                       num=50),
        predictor=predictor,
        sampling_rate=config.dataset.sampling_rate,
        use_gpu=True,
    )
    generate_evaluator = GenerateEvaluator(
        generator=generator,
        local_padding_time_second=config.dataset.
        evaluate_local_padding_time_second,
    )
    ext = extensions.Evaluator(eval_iter, generate_evaluator, device=device)
    trainer.extend(ext, name="eval", trigger=trigger_eval)

    if config.train.stop_iteration is not None:
        saving_model_num = int(config.train.stop_iteration /
                               config.train.eval_iteration / 10)
    else:
        saving_model_num = 10

    ext = extensions.snapshot_object(
        predictor,
        filename="predictor_{.updater.iteration}.pth",
        n_retains=saving_model_num,
    )
    trainer.extend(
        ext,
        trigger=LowValueTrigger("eval/main/mcd", trigger=trigger_eval),
    )

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.observe_lr(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )

    trainer.extend(ext, trigger=TensorboardReport(writer=writer))

    if config.project.category is not None:
        ext = WandbReport(
            config_dict=config.to_dict(),
            project_category=config.project.category,
            project_name=config.project.name,
            output_dir=output.joinpath("wandb"),
        )
        trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    ext = extensions.snapshot_object(
        trainer,
        filename="trainer_{.updater.iteration}.pth",
        n_retains=1,
        autoload=True,
    )
    trainer.extend(ext, trigger=trigger_eval)

    return trainer
示例#7
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()

    output.mkdir(exist_ok=True, parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    networks = create_network(config.network)
    model = Model(model_config=config.model, networks=networks)
    if config.train.weight_initializer is not None:
        init_weights(model, name=config.train.weight_initializer)

    device = torch.device("cuda") if config.train.use_gpu else torch.device(
        "cpu")
    model.to(device)

    # dataset
    _create_iterator = partial(
        create_iterator,
        batch_size=config.train.batch_size,
        eval_batch_size=config.train.eval_batch_size,
        num_processes=config.train.num_processes,
        use_multithread=config.train.use_multithread,
    )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"], for_train=True)
    test_iter = _create_iterator(datasets["test"], for_train=False)
    eval_iter = _create_iterator(datasets["eval"], for_train=False)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    optimizer = make_optimizer(config_dict=config.train.optimizer, model=model)

    # updater
    if not config.train.use_amp:
        updater = StandardUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )
    else:
        updater = AmpUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_eval = (config.train.eval_iteration, "iteration")
    trigger_snapshot = (config.train.snapshot_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)

    if config.train.stop_iteration is not None:
        saving_model_num = int(config.train.stop_iteration /
                               config.train.eval_iteration / 10)
    else:
        saving_model_num = 10
    ext = extensions.snapshot_object(
        networks.predictor,
        filename="predictor_{.updater.iteration}.pth",
        n_retains=saving_model_num,
    )
    trainer.extend(
        ext,
        trigger=LowValueTrigger("test/main/loss", trigger=trigger_eval),
    )

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.observe_lr(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )

    ext = TensorboardReport(writer=SummaryWriter(Path(output)))
    trainer.extend(ext, trigger=trigger_log)

    if config.project.category is not None:
        ext = WandbReport(
            config_dict=config.to_dict(),
            project_category=config.project.category,
            project_name=config.project.name,
            output_dir=output.joinpath("wandb"),
        )
        trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    ext = extensions.snapshot_object(
        trainer,
        filename="trainer_{.updater.iteration}.pth",
        n_retains=1,
        autoload=True,
    )
    trainer.extend(ext, trigger=trigger_snapshot)

    return trainer
示例#8
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()

    output.mkdir(exist_ok=True, parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    predictor = create_predictor(config.network)
    model = Model(model_config=config.model, predictor=predictor)
    if config.train.weight_initializer is not None:
        init_weights(model, name=config.train.weight_initializer)

    device = torch.device("cuda")
    model.to(device)

    # dataset
    _create_iterator = partial(
        create_iterator,
        batch_size=config.train.batch_size,
        num_processes=config.train.num_processes,
        use_multithread=config.train.use_multithread,
    )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"], for_train=True)
    test_iter = _create_iterator(datasets["test"], for_train=False)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    cp: Dict[str, Any] = copy(config.train.optimizer)
    n = cp.pop("name").lower()

    optimizer: Optimizer
    if n == "adam":
        optimizer = optim.Adam(model.parameters(), **cp)
    elif n == "sgd":
        optimizer = optim.SGD(model.parameters(), **cp)
    else:
        raise ValueError(n)

    # updater
    updater = StandardUpdater(
        iterator=train_iter,
        optimizer=optimizer,
        model=model,
        converter=list_concat,
        device=device,
    )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_eval = (config.train.snapshot_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)
    writer = SummaryWriter(Path(output))

    sample_data = datasets["train"][0]
    writer.add_graph(
        model,
        input_to_model=(
            [sample_data["f0"].to(device)],
            [sample_data["phoneme"].to(device)],
            [sample_data["phoneme_list"].to(device)],
            ([sample_data["speaker_id"].to(device)]
             if predictor.with_speaker else None),
        ),
    )

    ext = extensions.Evaluator(test_iter,
                               model,
                               converter=list_concat,
                               device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)

    if config.train.stop_iteration is not None:
        saving_model_num = int(config.train.stop_iteration /
                               config.train.snapshot_iteration / 10)
    else:
        saving_model_num = 10
    ext = extensions.snapshot_object(
        predictor,
        filename="predictor_{.updater.iteration}.pth",
        n_retains=saving_model_num,
    )
    trainer.extend(
        ext,
        trigger=LowValueTrigger("test/main/loss", trigger=trigger_eval),
    )

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.observe_lr(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )

    ext = TensorboardReport(writer=writer)
    trainer.extend(ext, trigger=trigger_log)

    if config.project.category is not None:
        ext = WandbReport(
            config_dict=config.to_dict(),
            project_category=config.project.category,
            project_name=config.project.name,
            output_dir=output.joinpath("wandb"),
        )
        trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    ext = extensions.snapshot_object(
        trainer,
        filename="trainer_{.updater.iteration}.pth",
        n_retains=1,
        autoload=True,
    )
    trainer.extend(ext, trigger=trigger_eval)

    return trainer
示例#9
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()

    output.mkdir(parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    device = torch.device("cuda")

    networks = create_network(config.network)
    model = Model(
        model_config=config.model,
        networks=networks,
        local_padding_length=config.dataset.local_padding_length,
    )
    model.to(device)

    if config.model.discriminator_input_type is not None:
        discriminator_model = DiscriminatorModel(
            model_config=config.model,
            networks=networks,
            local_padding_length=config.dataset.local_padding_length,
        )
        discriminator_model.to(device)
    else:
        discriminator_model = None

    # dataset
    def _create_iterator(dataset, for_train: bool):
        return MultiprocessIterator(
            dataset,
            config.train.batchsize,
            repeat=for_train,
            shuffle=for_train,
            n_processes=config.train.num_processes,
            dataset_timeout=300,
        )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"], for_train=True)
    test_iter = _create_iterator(datasets["test"], for_train=False)
    test_eval_iter = _create_iterator(datasets["test_eval"], for_train=False)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    optimizer = create_optimizer(config.train.optimizer, model)
    if config.train.discriminator_optimizer is not None:
        discriminator_optimizer = create_optimizer(
            config.train.discriminator_optimizer, discriminator_model)
    else:
        discriminator_optimizer = None

    # updater
    updater = Updater(
        iterator=train_iter,
        optimizer=optimizer,
        discriminator_model=discriminator_model,
        model=model,
        discriminator_optimizer=discriminator_optimizer,
        device=device,
    )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_snapshot = (config.train.snapshot_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)

    if config.train.step_shift is not None:
        trainer.extend(extensions.StepShift(**config.train.step_shift))

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)
    if discriminator_model is not None:
        ext = extensions.Evaluator(test_iter,
                                   discriminator_model,
                                   device=device)
        trainer.extend(ext, name="test", trigger=trigger_log)

    generator = Generator(config=config,
                          predictor=networks.predictor,
                          use_gpu=True)
    generate_evaluator = GenerateEvaluator(
        generator=generator,
        time_length=config.dataset.evaluate_time_second,
        local_padding_time_length=config.dataset.
        evaluate_local_padding_time_second,
    )
    ext = extensions.Evaluator(test_eval_iter,
                               generate_evaluator,
                               device=device)
    trainer.extend(ext, name="eval", trigger=trigger_snapshot)

    ext = extensions.snapshot_object(
        networks.predictor, filename="predictor_{.updater.iteration}.pth")
    trainer.extend(ext, trigger=trigger_snapshot)
    # ext = extensions.snapshot_object(
    #     trainer, filename="trainer_{.updater.iteration}.pth"
    # )
    # trainer.extend(ext, trigger=trigger_snapshot)
    # if networks.discriminator is not None:
    #     ext = extensions.snapshot_object(
    #         networks.discriminator, filename="discriminator_{.updater.iteration}.pth"
    #     )
    #     trainer.extend(ext, trigger=trigger_snapshot)

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )

    ext = TensorboardReport(writer=SummaryWriter(Path(output)))
    trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))
    if discriminator_model is not None:
        (output / "discriminator_struct.txt").write_text(
            repr(discriminator_model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    return trainer
def train_phase(predictor, train, valid, args):

    print('# classes:', train.n_classes)
    print('# samples:')
    print('-- train:', len(train))
    print('-- valid:', len(valid))

    # setup dataset iterators
    train_iter = iterators.MultiprocessIterator(train, args.batchsize)
    valid_iter = iterators.SerialIterator(valid,
                                          args.batchsize,
                                          repeat=False,
                                          shuffle=True)

    # setup a model
    class_weight = None  # NOTE: please set if you have..

    lossfun = partial(softmax_cross_entropy,
                      normalize=False,
                      class_weight=class_weight)

    device = torch.device(args.gpu)

    model = Classifier(predictor, lossfun=lossfun)
    model.to(device)

    # setup an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=max(args.decay, 0))

    # setup a trainer
    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                model,
                                                device=device)
    trainer = training.Trainer(updater, (args.iteration, 'iteration'),
                               out=args.out)

    frequency = max(args.iteration //
                    20, 1) if args.frequency == -1 else max(1, args.frequency)

    stop_trigger = triggers.EarlyStoppingTrigger(
        monitor='validation/main/loss',
        max_trigger=(args.iteration, 'iteration'),
        check_trigger=(frequency, 'iteration'),
        patients=np.inf if args.pinfall == -1 else max(1, args.pinfall))

    trainer = training.Trainer(updater, stop_trigger, out=args.out)

    # setup a visualizer
    transforms = {
        'x': lambda x: x,
        'y': lambda x: np.argmax(x, axis=0),
        't': lambda x: x
    }

    cmap = np.array([[0, 0, 0], [0, 0, 1]])
    cmaps = {'x': None, 'y': cmap, 't': cmap}

    clims = {'x': 'minmax', 'y': None, 't': None}

    visualizer = ImageVisualizer(transforms=transforms,
                                 cmaps=cmaps,
                                 clims=clims)

    # setup a validator
    valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png')
    trainer.extend(Validator(valid_iter,
                             model,
                             valid_file,
                             visualizer=visualizer,
                             n_vis=20,
                             device=args.gpu),
                   trigger=(frequency, 'iteration'))

    # trainer.extend(DumpGraph(model, 'main/loss'))

    trainer.extend(extensions.snapshot(
        filename='snapshot_iter_{.updater.iteration:08}.pth'),
                   trigger=(frequency, 'iteration'))
    trainer.extend(extensions.snapshot_object(
        predictor, 'predictor_iter_{.updater.iteration:08}.pth'),
                   trigger=(frequency, 'iteration'))

    log_keys = [
        'main/loss', 'validation/main/loss', 'main/accuracy',
        'validation/main/accuracy'
    ]

    trainer.extend(LogReport(keys=log_keys))

    # setup log ploter
    if extensions.PlotReport.available():
        for plot_key in ['loss', 'accuracy']:
            plot_keys = [
                key for key in log_keys
                if key.split('/')[-1].startswith(plot_key)
            ]
            trainer.extend(
                extensions.PlotReport(plot_keys,
                                      'iteration',
                                      file_name=plot_key + '.png',
                                      trigger=(frequency, 'iteration')))

    trainer.extend(
        PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=100))

    trainer.extend(extensions.ProgressBar())

    if args.resume:
        trainer.load_state_dict(torch.load(args.resume))

    # train
    trainer.run()
示例#11
0
 def add_snapshot_object(target, name):
     ext = extensions.snapshot_object(target,
                                      filename=name +
                                      "_{.updater.iteration}.pth")
     trainer.extend(ext, trigger=trigger_snapshot)
示例#12
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()

    output.mkdir(parents=True)
    with (output / 'config.yaml').open(mode='w') as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    networks = create_network(config.network)
    model = Model(model_config=config.model, networks=networks)

    device = torch.device('cuda')
    model.to(device)

    # dataset
    def _create_iterator(dataset, for_train: bool):
        return MultiprocessIterator(
            dataset,
            config.train.batchsize,
            repeat=for_train,
            shuffle=for_train,
            n_processes=config.train.num_processes,
            dataset_timeout=60,
        )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets['train'], for_train=True)
    test_iter = _create_iterator(datasets['test'], for_train=False)
    train_test_iter = _create_iterator(datasets['train_test'], for_train=False)

    warnings.simplefilter('error', MultiprocessIterator.TimeoutWarning)

    # optimizer
    cp: Dict[str, Any] = copy(config.train.optimizer)
    n = cp.pop('name').lower()

    if n == 'adam':
        optimizer = optim.Adam(model.parameters(), **cp)
    elif n == 'sgd':
        optimizer = optim.SGD(model.parameters(), **cp)
    else:
        raise ValueError(n)

    # updater
    updater = StandardUpdater(
        iterator=train_iter,
        optimizer=optimizer,
        model=model,
        device=device,
    )

    # trainer
    trigger_log = (config.train.log_iteration, 'iteration')
    trigger_snapshot = (config.train.snapshot_iteration, 'iteration')
    trigger_stop = (
        config.train.stop_iteration,
        'iteration') if config.train.stop_iteration is not None else None

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name='test', trigger=trigger_log)
    ext = extensions.Evaluator(train_test_iter, model, device=device)
    trainer.extend(ext, name='train', trigger=trigger_log)

    ext = extensions.snapshot_object(
        networks.predictor, filename='predictor_{.updater.iteration}.npz')
    trainer.extend(ext, trigger=trigger_snapshot)

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(extensions.PrintReport(
        ['iteration', 'main/loss', 'test/main/loss']),
                   trigger=trigger_log)

    ext = TensorboardReport(writer=SummaryWriter(Path(output)))
    trainer.extend(ext, trigger=trigger_log)

    (output / 'struct.txt').write_text(repr(model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    return trainer
示例#13
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()

    output.mkdir(exist_ok=True, parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    networks = create_network(config.network)
    model = Model(config=config.model, networks=networks)
    init_orthogonal(model)

    device = torch.device("cuda")
    model.to(device)

    # dataset
    _create_iterator = partial(
        create_iterator,
        batch_size=config.train.batch_size,
        eval_batch_size=config.train.eval_batch_size,
        num_processes=config.train.num_processes,
        use_multithread=config.train.use_multithread,
    )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"],
                                  for_train=True,
                                  for_eval=False)
    test_iter = _create_iterator(datasets["test"],
                                 for_train=False,
                                 for_eval=False)

    valid_iter = None
    if datasets["valid"] is not None:
        valid_iter = _create_iterator(datasets["valid"],
                                      for_train=False,
                                      for_eval=True)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    cp: Dict[str, Any] = copy(config.train.optimizer)
    n = cp.pop("name").lower()

    optimizer: Optimizer
    if n == "adam":
        optimizer = optim.Adam(model.parameters(), **cp)
    elif n == "sgd":
        optimizer = optim.SGD(model.parameters(), **cp)
    elif n == "ranger":
        optimizer = Ranger(model.parameters(), **cp)
    else:
        raise ValueError(n)

    # updater
    if not config.train.use_amp:
        updater = StandardUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )
    else:
        updater = AmpUpdater(
            iterator=train_iter,
            optimizer=optimizer,
            model=model,
            device=device,
        )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_eval = (config.train.eval_iteration, "iteration")
    trigger_stop = ((config.train.stop_iteration, "iteration")
                    if config.train.stop_iteration is not None else None)

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)

    if valid_iter is not None:
        ext = extensions.Evaluator(valid_iter, model, device=device)
        trainer.extend(ext, name="valid", trigger=trigger_eval)

    if config.train.stop_iteration is not None:
        saving_model_num = int(config.train.stop_iteration /
                               config.train.eval_iteration / 10)
    else:
        saving_model_num = 10
    for field in dataclasses.fields(Networks):
        ext = extensions.snapshot_object(
            getattr(networks, field.name),
            filename=field.name + "_{.updater.iteration}.pth",
            n_retains=saving_model_num,
        )
        trainer.extend(
            ext,
            trigger=HighValueTrigger(
                ("valid/main/phoneme_accuracy"
                 if valid_iter is not None else "test/main/phoneme_accuracy"),
                trigger=trigger_eval,
            ),
        )

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )

    ext = TensorboardReport(writer=SummaryWriter(Path(output)))
    trainer.extend(ext, trigger=trigger_log)

    if config.project.category is not None:
        ext = WandbReport(
            config_dict=config.to_dict(),
            project_category=config.project.category,
            project_name=config.project.name,
            output_dir=output.joinpath("wandb"),
        )
        trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    ext = extensions.snapshot_object(
        trainer,
        filename="trainer_{.updater.iteration}.pth",
        n_retains=1,
        autoload=True,
    )
    trainer.extend(ext, trigger=trigger_eval)

    return trainer
def train_phase(generator, train, valid, args):

    print('# samples:')
    print('-- train:', len(train))
    print('-- valid:', len(valid))

    # setup dataset iterators
    train_iter = iterators.SerialIterator(train, args.batchsize)
    valid_iter = iterators.SerialIterator(valid, args.batchsize,
                                                repeat=False, shuffle=True)

    # setup a model
    model = Regressor(generator,
                      activation=torch.tanh,
                      lossfun=F.l1_loss,
                      accfun=F.l1_loss)

    discriminator = build_discriminator()
    discriminator.save_args(os.path.join(args.out, 'discriminator.json'))

    device = torch.device(args.gpu)

    model.to(device)
    discriminator.to(device)

    # setup an optimizer
    optimizer_G = torch.optim.Adam(model.parameters(),
                                   lr=args.lr,
                                   betas=(args.beta, 0.999),
                                   weight_decay=max(args.decay, 0))

    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=args.lr,
                                   betas=(args.beta, 0.999),
                                   weight_decay=max(args.decay, 0))

    # setup a trainer
    updater = DCGANUpdater(
        iterator=train_iter,
        optimizer={
            'gen': optimizer_G,
            'dis': optimizer_D,
        },
        model={
            'gen': model,
            'dis': discriminator,
        },
        alpha=args.alpha,
        device=args.gpu,
    )

    frequency = max(args.iteration//80, 1) if args.frequency == -1 else max(1, args.frequency)

    stop_trigger = triggers.EarlyStoppingTrigger(monitor='validation/main/loss',
                        max_trigger=(args.iteration, 'iteration'),
                        check_trigger=(frequency, 'iteration'),
                        patients=np.inf if args.pinfall == -1 else max(1, args.pinfall))

    trainer = training.Trainer(updater, stop_trigger, out=args.out)

    # shift lr
    trainer.extend(
        extensions.LinearShift('lr', (args.lr, 0.0),
                        (args.iteration//2, args.iteration),
                        optimizer=optimizer_G))
    trainer.extend(
        extensions.LinearShift('lr', (args.lr, 0.0),
                        (args.iteration//2, args.iteration),
                        optimizer=optimizer_D))

    # setup a visualizer

    transforms = {'x': lambda x: x, 'y': lambda x: x, 't': lambda x: x}
    clims = {'x': (-1., 1.), 'y': (-1., 1.), 't': (-1., 1.)}

    visualizer = ImageVisualizer(transforms=transforms,
                                 cmaps=None,
                                 clims=clims)

    # setup a validator
    valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png')
    trainer.extend(Validator(valid_iter, model, valid_file,
                             visualizer=visualizer, n_vis=20,
                             device=args.gpu),
                             trigger=(frequency, 'iteration'))

    # trainer.extend(DumpGraph('loss_gen', filename='generative_loss.dot'))
    # trainer.extend(DumpGraph('loss_cond', filename='conditional_loss.dot'))
    # trainer.extend(DumpGraph('loss_dis', filename='discriminative_loss.dot'))

    trainer.extend(extensions.snapshot(filename='snapshot_iter_{.updater.iteration:08}.pth'),
                                       trigger=(frequency, 'iteration'))
    trainer.extend(extensions.snapshot_object(generator, 'generator_iter_{.updater.iteration:08}.pth'),
                                              trigger=(frequency, 'iteration'))
    trainer.extend(extensions.snapshot_object(discriminator, 'discriminator_iter_{.updater.iteration:08}.pth'),
                                              trigger=(frequency, 'iteration'))

    log_keys = ['loss_gen', 'loss_cond', 'loss_dis',
                'validation/main/accuracy']

    trainer.extend(LogReport(keys=log_keys, trigger=(100, 'iteration')))

    # setup log ploter
    if extensions.PlotReport.available():
        for plot_key in ['loss', 'accuracy']:
            plot_keys = [key for key in log_keys if key.split('/')[-1].startswith(plot_key)]
            trainer.extend(
                extensions.PlotReport(plot_keys,
                                     'iteration', file_name=plot_key + '.png',
                                     trigger=(frequency, 'iteration')) )

    trainer.extend(PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=1))

    trainer.extend(extensions.ProgressBar())

    if args.resume:
        trainer.load_state_dict(torch.load(args.resume))


    # train
    trainer.run()