示例#1
0
def test_load_config(repo, equal_pipelines):
    p1 = random.choice(repo.get_pipelines())

    pipeline_config = p1.load_config()

    p2 = BasePipeline.from_config(pipeline_config)

    assert equal_pipelines(p1, p2, loaded=True)
示例#2
0
 def get_pipelines(self) -> List:
     """Gets list of all pipelines."""
     from zenml.pipelines import BasePipeline
     pipelines = []
     for file_path in self.get_pipeline_file_paths():
         c = yaml_utils.read_yaml(file_path)
         pipelines.append(BasePipeline.from_config(c))
     return pipelines
示例#3
0
def test_to_from_config(equal_pipelines):
    p1: BasePipeline = BasePipeline(name="my_pipeline")

    ds = ImageDatasource(name="my_datasource")

    p1.add_datasource(ds)

    p2 = BasePipeline.from_config(p1.to_config())

    assert equal_pipelines(p1, p2, loaded=True)
    def get_tfx_pipeline(config: Dict[Text, Any]) -> pipeline.Pipeline:
        """
        Converts ZenML config dict to TFX pipeline.

        Args:
            config: A ZenML config dict

        Returns:
            tfx_pipeline: A TFX pipeline object.
        """
        from zenml.pipelines import BasePipeline
        zen_pipeline: BasePipeline = BasePipeline.from_config(config)

        # Get component list
        component_list = zen_pipeline.get_tfx_component_list(config)

        # Get pipeline metadata
        pipeline_name = zen_pipeline.pipeline_name
        metadata_connection_config = \
            zen_pipeline.metadata_store.get_tfx_metadata_config()
        artifact_store = zen_pipeline.artifact_store

        # Pipeline settings
        pipeline_root = os.path.join(
            artifact_store.path, artifact_store.unique_id)
        pipeline_log = os.path.join(pipeline_root, 'logs', pipeline_name)

        # Resolve execution backend
        execution = ProcessingBaseBackend()  # default
        for e in zen_pipeline.steps_dict.values():
            # find out the processing backends, take the first one which is
            # not a ProcessingBaseBackend
            if e.backend and issubclass(
                    e.backend.__class__, ProcessingBaseBackend) and \
                    e.backend.__class__ != ProcessingBaseBackend:
                execution = e.backend
                break

        beam_args = execution.get_beam_args(pipeline_name, pipeline_root)

        tfx_pipeline = pipeline.Pipeline(
            components=component_list,
            beam_pipeline_args=beam_args,
            metadata_connection_config=metadata_connection_config,
            pipeline_name=zen_pipeline.artifact_store.unique_id,  # for caching
            pipeline_root=pipeline_root,
            log_root=pipeline_log,
            enable_cache=zen_pipeline.enable_cache)

        # Ensure that the run_id is ZenML pipeline_name
        tfx_pipeline.pipeline_info.run_id = zen_pipeline.pipeline_name
        return tfx_pipeline
示例#5
0
    def get_pipeline_by_name(self, pipeline_name: Text = None):
        """
        Loads a pipeline just by its name.

        Args:
            pipeline_name (str): Name of pipeline.
        """
        from zenml.pipelines import BasePipeline
        yamls = self.get_pipeline_file_paths()
        for y in yamls:
            n = BasePipeline.get_name_from_pipeline_name(os.path.basename(y))
            if n == pipeline_name:
                c = yaml_utils.read_yaml(y)
                return BasePipeline.from_config(c)
示例#6
0
def test_get_pipeline_by_name(repo, equal_pipelines):
    p_names = repo.get_pipeline_names()

    random_name = random.choice(p_names)
    cfg_list = [y for y in repo.get_pipeline_file_paths()
                if random_name in y]

    cfg = yaml_utils.read_yaml(cfg_list[0])

    p1 = repo.get_pipeline_by_name(random_name)

    p2 = BasePipeline.from_config(cfg)

    assert equal_pipelines(p1, p2, loaded=True)
示例#7
0
    def get_pipelines_by_datasource(self, datasource):
        """
        Gets list of pipelines associated with datasource.

        Args:
            datasource (BaseDatasource): object of type BaseDatasource.
        """
        from zenml.pipelines import BasePipeline
        pipelines = []
        for file_path in self.get_pipeline_file_paths():
            c = yaml_utils.read_yaml(file_path)
            if keys.DatasourceKeys.ID in c[keys.GlobalKeys.PIPELINE][
                    keys.PipelineKeys.DATASOURCE]:
                if c[keys.GlobalKeys.PIPELINE][keys.PipelineKeys.DATASOURCE][
                        keys.DatasourceKeys.ID] == datasource._id:
                    pipelines.append(BasePipeline.from_config(c))
        return pipelines