def test_monophone_config(basic_corpus_dir, basic_dict_path, temp_dir):
    am_trainer = TrainableAligner(
        corpus_directory=basic_corpus_dir,
        dictionary_path=basic_dict_path,
        temporary_directory=temp_dir,
    )
    config = MonophoneTrainer(identifier="mono", worker=am_trainer)
    config.compute_calculated_properties()
    assert config.realignment_iterations == [
        0,
        1,
        2,
        3,
        4,
        5,
        6,
        7,
        8,
        9,
        10,
        12,
        14,
        16,
        18,
        20,
        23,
        26,
        29,
        32,
        35,
        38,
    ]
    am_trainer.cleanup()
def test_triphone_config(basic_corpus_dir, basic_dict_path, temp_dir):
    am_trainer = TrainableAligner(
        corpus_directory=basic_corpus_dir,
        dictionary_path=basic_dict_path,
        temporary_directory=temp_dir,
    )
    config = TriphoneTrainer(identifier="tri", worker=am_trainer)
    config.compute_calculated_properties()
    assert config.realignment_iterations == [10, 20, 30]
    am_trainer.cleanup()
def test_load_mono_train(basic_corpus_dir, basic_dict_path, temp_dir,
                         mono_train_config_path):
    params = TrainableAligner.parse_parameters(mono_train_config_path)
    am_trainer = TrainableAligner(corpus_directory=basic_corpus_dir,
                                  dictionary_path=basic_dict_path,
                                  temporary_directory=temp_dir,
                                  **params)
    for t in am_trainer.training_configs.values():
        assert not t.use_mp
        assert t.use_energy
    assert not am_trainer.use_mp
    assert am_trainer.use_energy
    am_trainer.cleanup()
def test_lda_mllt_config(basic_corpus_dir, basic_dict_path, temp_dir):
    am_trainer = TrainableAligner(
        corpus_directory=basic_corpus_dir,
        dictionary_path=basic_dict_path,
        temporary_directory=temp_dir,
    )

    assert am_trainer.beam == 10
    assert am_trainer.retry_beam == 40
    assert am_trainer.align_options["beam"] == 10
    assert am_trainer.align_options["retry_beam"] == 40
    config = LdaTrainer(identifier="lda", worker=am_trainer)

    config.compute_calculated_properties()
    assert config.mllt_iterations == [2, 4, 6, 12]
    am_trainer.cleanup()
def test_load_basic_train(basic_corpus_dir, basic_dict_path, temp_dir,
                          basic_train_config_path):
    params = TrainableAligner.parse_parameters(basic_train_config_path)
    am_trainer = TrainableAligner(corpus_directory=basic_corpus_dir,
                                  dictionary_path=basic_dict_path,
                                  temporary_directory=temp_dir,
                                  **params)

    assert am_trainer.beam == 100
    assert am_trainer.retry_beam == 400
    assert am_trainer.align_options["beam"] == 100
    assert am_trainer.align_options["retry_beam"] == 400

    for trainer in am_trainer.training_configs.values():
        assert trainer.beam == 100
        assert trainer.retry_beam == 400
        assert trainer.align_options["beam"] == 100
        assert trainer.align_options["retry_beam"] == 400
    am_trainer.cleanup()
示例#6
0
def test_typing(basic_corpus_dir, basic_dict_path, temp_dir):
    am_trainer = TrainableAligner(
        corpus_directory=basic_corpus_dir,
        dictionary_path=basic_dict_path,
        temporary_directory=temp_dir,
    )
    trainer = SatTrainer(identifier="sat", worker=am_trainer)
    assert type(trainer).__name__ == "SatTrainer"
    assert isinstance(trainer, TrainerMixin)
    assert isinstance(trainer, AlignMixin)
    assert isinstance(trainer, MfaWorker)
    assert isinstance(am_trainer, MfaWorker)
def test_load(basic_corpus_dir, basic_dict_path, temp_dir, config_directory):
    path = os.path.join(config_directory, "basic_train_config.yaml")
    params = TrainableAligner.parse_parameters(path)
    am_trainer = TrainableAligner(corpus_directory=basic_corpus_dir,
                                  dictionary_path=basic_dict_path,
                                  temporary_directory=temp_dir,
                                  **params)
    assert len(am_trainer.training_configs) == 4
    assert isinstance(am_trainer.training_configs["monophone"],
                      MonophoneTrainer)
    assert isinstance(am_trainer.training_configs["triphone"], TriphoneTrainer)
    assert isinstance(am_trainer.training_configs[am_trainer.final_identifier],
                      SatTrainer)

    path = os.path.join(config_directory, "out_of_order_config.yaml")
    with pytest.raises(ConfigError):
        params = TrainableAligner.parse_parameters(path)
    am_trainer.cleanup()
def train_acoustic_model(args: Namespace, unknown_args: Optional[List[str]] = None) -> None:
    """
    Run the acoustic model training

    Parameters
    ----------
    args: :class:`~argparse.Namespace`
        Command line arguments
    unknown_args: list[str]
        Optional arguments that will be passed to configuration objects
    """
    trainer = TrainableAligner(
        corpus_directory=args.corpus_directory,
        dictionary_path=args.dictionary_path,
        temporary_directory=args.temporary_directory,
        **TrainableAligner.parse_parameters(args.config_path, args, unknown_args),
    )
    try:
        trainer.train()
        if args.output_model_path is not None:
            trainer.export_model(args.output_model_path)

        if args.output_directory is not None:
            output_format = getattr(args, "output_format", None)
            trainer.export_files(
                args.output_directory,
                output_format,
                include_original_text=getattr(args, "include_original_text", False),
            )
    except Exception:
        trainer.dirty = True
        raise
    finally:
        trainer.cleanup()