def test_merging(): a = Config.parse([echo_step]) b = Config.parse([complex_step]) c = a.merge_with(b) assert len(c.steps) == 2 for step in a.steps.keys() & b.steps.keys(): assert step in c.steps
def lint(yaml) -> LintResult: lr = LintResult() data = read_yaml(yaml) validator = get_validator() errors = sorted( validator.iter_errors(data), key=lambda error: (relevance(error), repr(error.path)), ) for error in errors: simplified_schema_path = [ el for el in list(error.relative_schema_path)[:-1] if el not in ('properties', 'items') ] obj_path = [str(el) for el in error.path] lr.add_error( ' {validator} validation on {schema_path}: {message} ({path})'. format( validator=style(error.validator.title(), bold=True), schema_path=style('.'.join(simplified_schema_path), bold=True), message=style(error.message, fg='red'), path=style('.'.join(obj_path), bold=True), )) if len(errors) > 0: return lr try: config = Config.parse(data) except ValidationError as err: # Could happen before we get to linting things lr.add_error(str(err), exception=err) else: config.lint(lr, context={}) return lr
def get_pipeline_from_source(source_path: str, old_config: Config) -> Config: """Gets the pipeline definition by executing the main() in the source Python file. The file is expected to contain: def main(config) -> Pipeline: pipeline = Pipeline(name="foo", config=config) # ... Pipeline definition ... return pipeline :param source_path: Path of the Python source code file containing the pipeline definition :param old_config: Currently config for the Valohai project (used for validation) """ import importlib.util spec = importlib.util.spec_from_file_location( name="pipeline_source", location=source_path ) if not spec: raise ValueError(f"Could not find import spec from {source_path}") module = importlib.util.module_from_spec(spec) loader: Optional[Loader] = spec.loader # type: ignore if not loader: raise ValueError("Spec has no loader") loader.exec_module(module) main: Optional[Callable[[Config], Papi]] = getattr(module, "main", None) if not main: raise AttributeError(f"{source_path} is missing main() method!") pipe = main(old_config) return Config(pipelines=[pipe.to_yaml()])
def test_pipeline(pipeline_config: Config): lr = pipeline_config.lint() assert lr.is_valid() assert any( (edge.source_node == "batch1" and edge.source_type == "parameter" and edge.source_key == "aspect-ratio" and edge.target_node == "batch2" and edge.target_type == "parameter" and edge.target_key == "aspect-ratio") for edge in pipeline_config.pipelines["My little pipeline"].edges) assert any( (edge.source_node == "train" and edge.source_type == "output" and edge.source_key == "model" and edge.target_node == "deploy-predictor" and edge.target_type == "file" and edge.target_key == "predict-digit.model") for edge in pipeline_config.pipelines["My deployment pipeline"].edges) dp = pipeline_config.pipelines["My deployment pipeline"] dn_predict = [ node for node in dp.nodes if node.type == 'deployment' and node.name == 'deploy-predictor' ][0] assert "predictor-staging" in dn_predict.aliases assert "predict-digit" in dn_predict.endpoints dn_no_preset = [ node for node in dp.nodes if node.type == 'deployment' and node.name == 'deploy-no-presets' ][0] assert dn_no_preset.aliases == [] assert dn_no_preset.endpoints == [] assert any( (edge.source_type == "output" and edge.source_key == "model.pb") for edge in pipeline_config.pipelines["My medium pipeline"].edges)
def config_to_yaml(config: Config): """Serialize Valohai Config to YAML :param config: valohai_yaml.objs.Config object """ return yaml.dump(config.serialize(), default_flow_style=False)
def parse(yaml, validate=True): """ Parse the given YAML data into a `Config` object, optionally validating it first. :param yaml: YAML data (either a string, a stream, or pre-parsed Python dict/list) :type yaml: list|dict|str|file :param validate: Whether to validate the data before attempting to parse it. :type validate: bool :return: Config object :rtype: valohai_yaml.objs.Config """ data = read_yaml(yaml) if validate: # pragma: no branch from .validation import validate validate(data, raise_exc=True) return Config.parse(data)
def generate_config( *, relative_source_path: str, step: str, image: str, parameters: ParameterDict, inputs: InputDict, ) -> Config: step_obj = generate_step( relative_source_path=relative_source_path, step=step, image=image, parameters=parameters, inputs=inputs, ) config = Config() config.steps[step_obj.name] = step_obj return config
def test_get_step_by_non_existing_attribute(): config = Config.parse([echo_step, list_step]) assert not config.get_step_by(gorilla='greeting')
def test_get_step_by_command(): config = Config.parse([echo_step, list_step]) assert echo_step['step'] == config.get_step_by(command='echo HELLO WORLD').serialize() assert list_step['step'] == config.get_step_by(command='ls').serialize()
def test_get_step_by_name_doesnt_exist(): config = Config.parse([echo_step, list_step]) assert not config.get_step_by(name='not found')
def test_get_step_by_index(): config = Config.parse([echo_step, list_step]) assert list_step['step'] == config.get_step_by(index=1).serialize()
def test_merging_conflict(): a = Config.parse([complex_step]) b = Config.parse([complex_step_alt]) c = a.merge_with(b) expected = Config.parse([complex_steps_merged]) assert c.serialize() == expected.serialize()
def test_get_step_by_simple_name(): config = Config.parse([echo_step, list_step]) assert list_step['step'] == config.get_step_by(name='list files').serialize()
def test_get_step_from_empty_config(): config = Config.parse([]) assert not config.get_step_by(name='greeting')
def test_get_step_by_nothing_returns_none(): config = Config.parse([echo_step, list_step]) assert not config.get_step_by()
def test_get_step_by_too_big_index(): config = Config.parse([echo_step, list_step]) assert not config.get_step_by(index=2)
def test_get_step_by_name_and_command(): config = Config.parse([echo_step, list_step]) assert not config.get_step_by(name='greeting', command='echo HELLO MORDOR') assert not config.get_step_by(name='farewell', command='echo HELLO WORLD') assert echo_step['step'] == config.get_step_by(name='greeting', command='echo HELLO WORLD').serialize()