def ray_train_cli( # fmt: off ctx: typer.Context, # This is only used to read additional arguments config_path: Path = Arg(..., help="Path to config file", exists=True), code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"), output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory or remote storage URL for saving trained pipeline"), num_workers: int = Opt(1, "--n-workers", "-w", help="Number of workers"), ray_address: Optional[str] = Opt(None, "--address", "-a", help="Address of ray cluster"), use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"), verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"), # fmt: on ): """ Train a spaCy pipeline using Ray for parallel training. """ # TODO: wire up output path logger.setLevel(logging.DEBUG if verbose else logging.ERROR) setup_gpu(use_gpu) overrides = parse_config_overrides(ctx.args) with show_validation_error(config_path): config = load_config(config_path, overrides=overrides, interpolate=False) ray_train( config, ray_address=ray_address, num_workers=num_workers, use_gpu=use_gpu, code_path=code_path, )
def test_parse_cli_overrides(): overrides = "--x.foo bar --x.bar=12 --x.baz false --y.foo=hello" os.environ[ENV_VARS.CONFIG_OVERRIDES] = overrides result = parse_config_overrides([]) assert len(result) == 4 assert result["x.foo"] == "bar" assert result["x.bar"] == 12 assert result["x.baz"] is False assert result["y.foo"] == "hello" os.environ[ENV_VARS.CONFIG_OVERRIDES] = "--x" assert parse_config_overrides([], env_var=None) == {} with pytest.raises(SystemExit): parse_config_overrides([]) os.environ[ENV_VARS.CONFIG_OVERRIDES] = "hello world" with pytest.raises(SystemExit): parse_config_overrides([]) del os.environ[ENV_VARS.CONFIG_OVERRIDES]
def test_parse_config_overrides_invalid_2(args): with pytest.raises(SystemExit): parse_config_overrides(args)
def test_parse_config_overrides_invalid(args): with pytest.raises(NoSuchOption): parse_config_overrides(args)
def test_parse_config_overrides(args, expected): assert parse_config_overrides(args) == expected