예제 #1
0
def config_file(tmp_path: Path):
    # Write default configuration file
    EnhancementTask.main(cmd=[
        "--dry_run",
        "true",
        "--output_dir",
        str(tmp_path),
    ])
    return tmp_path / "config.yaml"
예제 #2
0
def main(cmd=None):
    r"""Enhancemnet frontend training.

    Example:

        % python enh_train.py asr --print_config --optim adadelta \
                > conf/train_enh.yaml
        % python enh_train.py --config conf/train_enh.yaml
    """
    EnhancementTask.main(cmd=cmd)
예제 #3
0
def config_file(tmp_path: Path):
    # Write default configuration file
    EnhancementTask.main(
        cmd=[
            "--dry_run",
            "true",
            "--output_dir",
            str(tmp_path / "enh"),
        ]
    )

    with open(tmp_path / "enh" / "config.yaml", "r") as f:
        args = yaml.safe_load(f)

    if args["encoder"] == "stft" and len(args["encoder_conf"]) == 0:
        args["encoder_conf"] = get_default_kwargs(STFTEncoder)

    with open(tmp_path / "enh" / "config.yaml", "w") as f:
        yaml_no_alias_safe_dump(args, f, indent=4, sort_keys=False)

    return tmp_path / "enh" / "config.yaml"
예제 #4
0
def test_main_with_no_args():
    with pytest.raises(SystemExit):
        EnhancementTask.main(cmd=[])
예제 #5
0
def test_main_print_config():
    with pytest.raises(SystemExit):
        EnhancementTask.main(cmd=["--print_config"])
예제 #6
0
def test_main_help():
    with pytest.raises(SystemExit):
        EnhancementTask.main(cmd=["--help"])