Exemple #1
0
def ray_train_cli(
    # fmt: off
    ctx: typer.Context,  # This is only used to read additional arguments
    config_path: Path = Arg(..., help="Path to config file", exists=True),
    code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
    output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory or remote storage URL for saving trained pipeline"),
    num_workers: int = Opt(1, "--n-workers", "-w", help="Number of workers"),
    ray_address: Optional[str] = Opt(None, "--address", "-a", help="Address of ray cluster"),
    use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
    verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
    # fmt: on
):
    """
    Train a spaCy pipeline using Ray for parallel training.
    """
    # TODO: wire up output path
    logger.setLevel(logging.DEBUG if verbose else logging.ERROR)
    setup_gpu(use_gpu)
    overrides = parse_config_overrides(ctx.args)
    with show_validation_error(config_path):
        config = load_config(config_path, overrides=overrides, interpolate=False)
    ray_train(
        config,
        ray_address=ray_address,
        num_workers=num_workers,
        use_gpu=use_gpu,
        code_path=code_path,
    )
Exemple #2
0
def test_parse_cli_overrides():
    overrides = "--x.foo bar --x.bar=12 --x.baz false --y.foo=hello"
    os.environ[ENV_VARS.CONFIG_OVERRIDES] = overrides
    result = parse_config_overrides([])
    assert len(result) == 4
    assert result["x.foo"] == "bar"
    assert result["x.bar"] == 12
    assert result["x.baz"] is False
    assert result["y.foo"] == "hello"
    os.environ[ENV_VARS.CONFIG_OVERRIDES] = "--x"
    assert parse_config_overrides([], env_var=None) == {}
    with pytest.raises(SystemExit):
        parse_config_overrides([])
    os.environ[ENV_VARS.CONFIG_OVERRIDES] = "hello world"
    with pytest.raises(SystemExit):
        parse_config_overrides([])
    del os.environ[ENV_VARS.CONFIG_OVERRIDES]
Exemple #3
0
def test_parse_config_overrides_invalid_2(args):
    with pytest.raises(SystemExit):
        parse_config_overrides(args)
Exemple #4
0
def test_parse_config_overrides_invalid(args):
    with pytest.raises(NoSuchOption):
        parse_config_overrides(args)
Exemple #5
0
def test_parse_config_overrides(args, expected):
    assert parse_config_overrides(args) == expected