def test_load_config(repo, equal_pipelines): p1 = random.choice(repo.get_pipelines()) pipeline_config = p1.load_config() p2 = BasePipeline.from_config(pipeline_config) assert equal_pipelines(p1, p2, loaded=True)
def get_pipelines(self) -> List: """Gets list of all pipelines.""" from zenml.pipelines import BasePipeline pipelines = [] for file_path in self.get_pipeline_file_paths(): c = yaml_utils.read_yaml(file_path) pipelines.append(BasePipeline.from_config(c)) return pipelines
def test_to_from_config(equal_pipelines): p1: BasePipeline = BasePipeline(name="my_pipeline") ds = ImageDatasource(name="my_datasource") p1.add_datasource(ds) p2 = BasePipeline.from_config(p1.to_config()) assert equal_pipelines(p1, p2, loaded=True)
def get_tfx_pipeline(config: Dict[Text, Any]) -> pipeline.Pipeline: """ Converts ZenML config dict to TFX pipeline. Args: config: A ZenML config dict Returns: tfx_pipeline: A TFX pipeline object. """ from zenml.pipelines import BasePipeline zen_pipeline: BasePipeline = BasePipeline.from_config(config) # Get component list component_list = zen_pipeline.get_tfx_component_list(config) # Get pipeline metadata pipeline_name = zen_pipeline.pipeline_name metadata_connection_config = \ zen_pipeline.metadata_store.get_tfx_metadata_config() artifact_store = zen_pipeline.artifact_store # Pipeline settings pipeline_root = os.path.join( artifact_store.path, artifact_store.unique_id) pipeline_log = os.path.join(pipeline_root, 'logs', pipeline_name) # Resolve execution backend execution = ProcessingBaseBackend() # default for e in zen_pipeline.steps_dict.values(): # find out the processing backends, take the first one which is # not a ProcessingBaseBackend if e.backend and issubclass( e.backend.__class__, ProcessingBaseBackend) and \ e.backend.__class__ != ProcessingBaseBackend: execution = e.backend break beam_args = execution.get_beam_args(pipeline_name, pipeline_root) tfx_pipeline = pipeline.Pipeline( components=component_list, beam_pipeline_args=beam_args, metadata_connection_config=metadata_connection_config, pipeline_name=zen_pipeline.artifact_store.unique_id, # for caching pipeline_root=pipeline_root, log_root=pipeline_log, enable_cache=zen_pipeline.enable_cache) # Ensure that the run_id is ZenML pipeline_name tfx_pipeline.pipeline_info.run_id = zen_pipeline.pipeline_name return tfx_pipeline
def get_pipeline_by_name(self, pipeline_name: Text = None): """ Loads a pipeline just by its name. Args: pipeline_name (str): Name of pipeline. """ from zenml.pipelines import BasePipeline yamls = self.get_pipeline_file_paths() for y in yamls: n = BasePipeline.get_name_from_pipeline_name(os.path.basename(y)) if n == pipeline_name: c = yaml_utils.read_yaml(y) return BasePipeline.from_config(c)
def test_get_pipeline_by_name(repo, equal_pipelines): p_names = repo.get_pipeline_names() random_name = random.choice(p_names) cfg_list = [y for y in repo.get_pipeline_file_paths() if random_name in y] cfg = yaml_utils.read_yaml(cfg_list[0]) p1 = repo.get_pipeline_by_name(random_name) p2 = BasePipeline.from_config(cfg) assert equal_pipelines(p1, p2, loaded=True)
def get_pipelines_by_datasource(self, datasource): """ Gets list of pipelines associated with datasource. Args: datasource (BaseDatasource): object of type BaseDatasource. """ from zenml.pipelines import BasePipeline pipelines = [] for file_path in self.get_pipeline_file_paths(): c = yaml_utils.read_yaml(file_path) if keys.DatasourceKeys.ID in c[keys.GlobalKeys.PIPELINE][ keys.PipelineKeys.DATASOURCE]: if c[keys.GlobalKeys.PIPELINE][keys.PipelineKeys.DATASOURCE][ keys.DatasourceKeys.ID] == datasource._id: pipelines.append(BasePipeline.from_config(c)) return pipelines