示例#1
0
def test_to_from_config(equal_steps):
    kwargs = {"a": 1, "name": "my_backend", "grade": 1.5}
    s1 = BaseStep(**kwargs)

    s2 = BaseStep.from_config(s1.to_config())

    assert equal_steps(s1, s2, loaded=True)
示例#2
0
def test_with_backend(equal_steps):
    kwargs = {"a": 1, "name": "my_backend", "grade": 1.5}

    b1 = BaseBackend(**kwargs)

    s1 = BaseStep(**kwargs).with_backend(b1)

    assert bool(s1.backend)

    s2 = BaseStep.from_config(s1.to_config())

    assert equal_steps(s1, s2, loaded=True)
示例#3
0
    def get_step_by_version(self, step_type: Union[Type, Text], version: Text):
        """
        Gets a Step object by version. There might be many objects of a
        particular Step registered in many pipelines. This function just
        returns the first configuration that it matches.

        Args:
            step_type: either a string specifying full path to the step or a
            class path.
            version: either sha pin or standard ZenML version pin.
        """
        from zenml.utils import source_utils
        from zenml.core.steps.base_step import BaseStep

        type_str = source_utils.get_module_path_from_class(step_type)

        for file_path in self.get_pipeline_file_paths():
            c = yaml_utils.read_yaml(file_path)
            for step_name, step_config in c[keys.GlobalKeys.PIPELINE][
                keys.PipelineKeys.STEPS].items():
                # Get version from source
                class_ = source_utils.get_class_path_from_source(
                    step_config[keys.StepKeys.SOURCE])
                source_version = source_utils.get_version_from_source(
                    step_config[keys.StepKeys.SOURCE])

                if class_ == type_str and version == source_version:
                    return BaseStep.from_config(step_config)
示例#4
0
def test_get_steps_config():
    # TODO: Expand this to more steps
    name = "my_pipeline"
    p: BasePipeline = BasePipeline(name=name)

    kwargs = {"number": 1, "description": "abcdefg"}
    step = BaseStep(**kwargs)

    p.steps_dict["test"] = step

    cfg = p.get_steps_config()

    steps_cfg = cfg[keys.PipelineKeys.STEPS]

    # avoid missing args / type inconsistencies
    assert steps_cfg["test"] == step.to_config()
示例#5
0
    def from_config(cls, config: Dict):
        """
        Convert from pipeline config to ZenML Pipeline object.

        All steps are also populated and configuration set to parameters set
        in the config file.

        Args:
            config: a ZenML config in dict-form (probably loaded from YAML).
        """
        # start with artifact store
        artifact_store = ArtifactStore(config[keys.GlobalKeys.ARTIFACT_STORE])

        # metadata store
        metadata_store = ZenMLMetadataStore.from_config(
            config=config[keys.GlobalKeys.METADATA_STORE]
        )

        # orchestration backend
        backend = OrchestratorBaseBackend.from_config(
            config[keys.GlobalKeys.BACKEND])

        # pipeline configuration
        p_config = config[keys.GlobalKeys.PIPELINE]
        pipeline_name = p_config[keys.PipelineKeys.NAME]
        pipeline_source = p_config[keys.PipelineKeys.SOURCE]

        # populate steps
        steps_dict: Dict = {}
        for step_key, step_config in p_config[keys.PipelineKeys.STEPS].items():
            steps_dict[step_key] = BaseStep.from_config(step_config)

        # datasource
        datasource = BaseDatasource.from_config(
            config[keys.GlobalKeys.PIPELINE])

        # enable cache
        enable_cache = p_config[keys.PipelineKeys.ENABLE_CACHE]

        class_ = source_utils.load_source_path_class(pipeline_source)

        obj = class_(
            name=cls.get_name_from_pipeline_name(pipeline_name),
            pipeline_name=pipeline_name,
            enable_cache=enable_cache,
            steps_dict=steps_dict,
            backend=backend,
            artifact_store=artifact_store,
            metadata_store=metadata_store,
            datasource=datasource)
        obj._immutable = True
        logger.debug(f'Pipeline {pipeline_name} loaded and and is immutable.')
        return obj
示例#6
0
    def from_config(cls, config: Dict):
        """
        Convert from pipeline config to ZenML Pipeline object.

        All steps are also populated and configuration set to parameters set
        in the config file.

        Args:
            config: a ZenML config in dict-form (probably loaded from YAML).
        """
        # populate steps
        steps_dict: Dict = {}
        for step_key, step_config in config[keys.GlobalKeys.STEPS].items():
            steps_dict[step_key] = BaseStep.from_config(step_config)

        env = config[keys.GlobalKeys.ENV]
        pipeline_name = env[keys.EnvironmentKeys.EXPERIMENT_NAME]
        name = BasePipeline.get_name_from_pipeline_name(
            pipeline_name=pipeline_name)

        backends_dict: Dict = {}
        for backend_key, backend_config in env[
                keys.EnvironmentKeys.BACKENDS].items():
            backends_dict[backend_key] = BaseBackend.from_config(
                backend_key, backend_config)

        artifact_store = ArtifactStore(
            env[keys.EnvironmentKeys.ARTIFACT_STORE])
        metadata_store = ZenMLMetadataStore.from_config(
            config=env[METADATA_KEY])

        datasource = BaseDatasource.from_config(config)

        from zenml.core.pipelines.pipeline_factory import pipeline_factory
        pipeline_type = BasePipeline.get_type_from_pipeline_name(pipeline_name)
        class_ = pipeline_factory.get_pipeline_by_type(pipeline_type)

        # TODO: [MEDIUM] Perhaps move some of the logic in the init block here
        #  especially regarding inferring immutability.

        return class_(name=name,
                      pipeline_name=pipeline_name,
                      enable_cache=env[keys.EnvironmentKeys.ENABLE_CACHE],
                      steps_dict=steps_dict,
                      backends_dict=backends_dict,
                      artifact_store=artifact_store,
                      metadata_store=metadata_store,
                      datasource=datasource)