示例#1
0
def test_merging():
    a = Config.parse([echo_step])
    b = Config.parse([complex_step])
    c = a.merge_with(b)
    assert len(c.steps) == 2
    for step in a.steps.keys() & b.steps.keys():
        assert step in c.steps
示例#2
0
def lint(yaml) -> LintResult:
    lr = LintResult()

    data = read_yaml(yaml)
    validator = get_validator()
    errors = sorted(
        validator.iter_errors(data),
        key=lambda error: (relevance(error), repr(error.path)),
    )
    for error in errors:
        simplified_schema_path = [
            el for el in list(error.relative_schema_path)[:-1]
            if el not in ('properties', 'items')
        ]
        obj_path = [str(el) for el in error.path]
        lr.add_error(
            '  {validator} validation on {schema_path}: {message} ({path})'.
            format(
                validator=style(error.validator.title(), bold=True),
                schema_path=style('.'.join(simplified_schema_path), bold=True),
                message=style(error.message, fg='red'),
                path=style('.'.join(obj_path), bold=True),
            ))

    if len(errors) > 0:
        return lr

    try:
        config = Config.parse(data)
    except ValidationError as err:  # Could happen before we get to linting things
        lr.add_error(str(err), exception=err)
    else:
        config.lint(lr, context={})
    return lr
示例#3
0
def get_pipeline_from_source(source_path: str, old_config: Config) -> Config:
    """Gets the pipeline definition by executing the main() in the source Python file.

    The file is expected to contain:

    def main(config) -> Pipeline:
        pipeline = Pipeline(name="foo", config=config)
        # ... Pipeline definition ...
        return pipeline

    :param source_path: Path of the Python source code file containing the pipeline definition
    :param old_config: Currently config for the Valohai project (used for validation)

    """
    import importlib.util

    spec = importlib.util.spec_from_file_location(
        name="pipeline_source", location=source_path
    )
    if not spec:
        raise ValueError(f"Could not find import spec from {source_path}")
    module = importlib.util.module_from_spec(spec)
    loader: Optional[Loader] = spec.loader  # type: ignore
    if not loader:
        raise ValueError("Spec has no loader")
    loader.exec_module(module)
    main: Optional[Callable[[Config], Papi]] = getattr(module, "main", None)
    if not main:
        raise AttributeError(f"{source_path} is missing main() method!")
    pipe = main(old_config)
    return Config(pipelines=[pipe.to_yaml()])
示例#4
0
def test_pipeline(pipeline_config: Config):
    lr = pipeline_config.lint()
    assert lr.is_valid()
    assert any(
        (edge.source_node == "batch1" and edge.source_type == "parameter" and
         edge.source_key == "aspect-ratio" and edge.target_node == "batch2" and
         edge.target_type == "parameter" and edge.target_key == "aspect-ratio")
        for edge in pipeline_config.pipelines["My little pipeline"].edges)

    assert any(
        (edge.source_node == "train" and edge.source_type == "output" and
         edge.source_key == "model" and edge.target_node == "deploy-predictor"
         and edge.target_type == "file"
         and edge.target_key == "predict-digit.model")
        for edge in pipeline_config.pipelines["My deployment pipeline"].edges)
    dp = pipeline_config.pipelines["My deployment pipeline"]

    dn_predict = [
        node for node in dp.nodes
        if node.type == 'deployment' and node.name == 'deploy-predictor'
    ][0]
    assert "predictor-staging" in dn_predict.aliases
    assert "predict-digit" in dn_predict.endpoints

    dn_no_preset = [
        node for node in dp.nodes
        if node.type == 'deployment' and node.name == 'deploy-no-presets'
    ][0]
    assert dn_no_preset.aliases == []
    assert dn_no_preset.endpoints == []

    assert any(
        (edge.source_type == "output" and edge.source_key == "model.pb")
        for edge in pipeline_config.pipelines["My medium pipeline"].edges)
示例#5
0
def config_to_yaml(config: Config):
    """Serialize Valohai Config to YAML

    :param config: valohai_yaml.objs.Config object
    """

    return yaml.dump(config.serialize(), default_flow_style=False)
示例#6
0
def parse(yaml, validate=True):
    """
    Parse the given YAML data into a `Config` object, optionally validating it first.

    :param yaml: YAML data (either a string, a stream, or pre-parsed Python dict/list)
    :type yaml: list|dict|str|file
    :param validate: Whether to validate the data before attempting to parse it.
    :type validate: bool
    :return: Config object
    :rtype: valohai_yaml.objs.Config
    """
    data = read_yaml(yaml)
    if validate:  # pragma: no branch
        from .validation import validate
        validate(data, raise_exc=True)
    return Config.parse(data)
示例#7
0
def generate_config(
    *,
    relative_source_path: str,
    step: str,
    image: str,
    parameters: ParameterDict,
    inputs: InputDict,
) -> Config:
    step_obj = generate_step(
        relative_source_path=relative_source_path,
        step=step,
        image=image,
        parameters=parameters,
        inputs=inputs,
    )
    config = Config()
    config.steps[step_obj.name] = step_obj
    return config
示例#8
0
def test_get_step_by_non_existing_attribute():
    config = Config.parse([echo_step, list_step])
    assert not config.get_step_by(gorilla='greeting')
示例#9
0
def test_get_step_by_command():
    config = Config.parse([echo_step, list_step])
    assert echo_step['step'] == config.get_step_by(command='echo HELLO WORLD').serialize()
    assert list_step['step'] == config.get_step_by(command='ls').serialize()
示例#10
0
def test_get_step_by_name_doesnt_exist():
    config = Config.parse([echo_step, list_step])
    assert not config.get_step_by(name='not found')
示例#11
0
def test_get_step_by_index():
    config = Config.parse([echo_step, list_step])
    assert list_step['step'] == config.get_step_by(index=1).serialize()
示例#12
0
def test_merging_conflict():
    a = Config.parse([complex_step])
    b = Config.parse([complex_step_alt])
    c = a.merge_with(b)
    expected = Config.parse([complex_steps_merged])
    assert c.serialize() == expected.serialize()
示例#13
0
def test_get_step_by_simple_name():
    config = Config.parse([echo_step, list_step])
    assert list_step['step'] == config.get_step_by(name='list files').serialize()
示例#14
0
def test_get_step_from_empty_config():
    config = Config.parse([])
    assert not config.get_step_by(name='greeting')
示例#15
0
def test_get_step_by_nothing_returns_none():
    config = Config.parse([echo_step, list_step])
    assert not config.get_step_by()
示例#16
0
def test_get_step_by_too_big_index():
    config = Config.parse([echo_step, list_step])
    assert not config.get_step_by(index=2)
示例#17
0
def test_get_step_by_name_and_command():
    config = Config.parse([echo_step, list_step])
    assert not config.get_step_by(name='greeting', command='echo HELLO MORDOR')
    assert not config.get_step_by(name='farewell', command='echo HELLO WORLD')
    assert echo_step['step'] == config.get_step_by(name='greeting', command='echo HELLO WORLD').serialize()