def test_naming(): name = "my_pipeline" p = BasePipeline(name=name) # file_name = p.file_name pipeline_name = p.pipeline_name assert p.get_name_from_pipeline_name(pipeline_name) == name
def test_to_from_config(equal_pipelines): p1: BasePipeline = BasePipeline(name="my_pipeline") ds = ImageDatasource(name="my_datasource") p1.add_datasource(ds) p2 = BasePipeline.from_config(p1.to_config()) assert equal_pipelines(p1, p2, loaded=True)
def test_get_status(repo): name = "my_pipeline" p = BasePipeline(name=name) # passing this test means that if the pipeline is not in the metadata # store, it is (semantically) not started yet assert p.get_status() == PipelineStatusTypes.NotStarted.name run_pipeline = random.choice(repo.get_pipelines()) assert run_pipeline.get_status() == PipelineStatusTypes.Succeeded.name
def get_pipeline_by_name(self, pipeline_name: Text = None): """ Loads a pipeline just by its name. Args: pipeline_name (str): Name of pipeline. """ from zenml.pipelines import BasePipeline yamls = self.get_pipeline_file_paths() for y in yamls: n = BasePipeline.get_name_from_pipeline_name(os.path.basename(y)) if n == pipeline_name: c = yaml_utils.read_yaml(y) return BasePipeline.from_config(c)
def test_run_base(delete_config): # Test of pipeline.run(), without artifact / metadata store change p = BasePipeline(name="my_pipeline") class MockBackend(OrchestratorBaseBackend): def run(self, config): return {"message": "Run triggered!"} backend = MockBackend() p.run(backend=backend) assert p._immutable delete_config(p.file_name)
def test_executed(repo): name = "my_pipeline" p = BasePipeline(name=name) assert not p.is_executed_in_metadata_store random_run_pipeline = random.choice(repo.get_pipelines()) assert random_run_pipeline.is_executed_in_metadata_store
def test_load_config(repo, equal_pipelines): p1 = random.choice(repo.get_pipelines()) pipeline_config = p1.load_config() p2 = BasePipeline.from_config(pipeline_config) assert equal_pipelines(p1, p2, loaded=True)
def get_pipelines(self) -> List: """Gets list of all pipelines.""" from zenml.pipelines import BasePipeline pipelines = [] for file_path in self.get_pipeline_file_paths(): c = yaml_utils.read_yaml(file_path) pipelines.append(BasePipeline.from_config(c)) return pipelines
def test_add_datasource(): name = "my_pipeline" p: BasePipeline = BasePipeline(name=name) p.add_datasource(BaseDatasource(name="my_datasource")) assert isinstance(p.datasource, BaseDatasource) assert not p.steps_dict[keys.TrainingSteps.DATA]
def test_get_data_file_paths(repo): first_ds = BaseDatasource(name="my_datasource") first_pipeline = BasePipeline(name="my_pipeline") first_pipeline.add_datasource(first_ds) # reload a datasource from a saved config p_config = random.choice(repo.get_pipeline_file_paths()) cfg = yaml_utils.read_yaml(p_config) second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE]) with pytest.raises(AssertionError): _ = second_ds._get_data_file_paths(first_pipeline) real_pipeline = second_ds._get_one_pipeline() paths = second_ds._get_data_file_paths(real_pipeline) # TODO: Find a better way of asserting TFRecords assert all(os.path.splitext(p)[-1] == ".gz" for p in paths)
def get_tfx_pipeline(config: Dict[Text, Any]) -> pipeline.Pipeline: """ Converts ZenML config dict to TFX pipeline. Args: config: A ZenML config dict Returns: tfx_pipeline: A TFX pipeline object. """ from zenml.pipelines import BasePipeline zen_pipeline: BasePipeline = BasePipeline.from_config(config) # Get component list component_list = zen_pipeline.get_tfx_component_list(config) # Get pipeline metadata pipeline_name = zen_pipeline.pipeline_name metadata_connection_config = \ zen_pipeline.metadata_store.get_tfx_metadata_config() artifact_store = zen_pipeline.artifact_store # Pipeline settings pipeline_root = os.path.join( artifact_store.path, artifact_store.unique_id) pipeline_log = os.path.join(pipeline_root, 'logs', pipeline_name) # Resolve execution backend execution = ProcessingBaseBackend() # default for e in zen_pipeline.steps_dict.values(): # find out the processing backends, take the first one which is # not a ProcessingBaseBackend if e.backend and issubclass( e.backend.__class__, ProcessingBaseBackend) and \ e.backend.__class__ != ProcessingBaseBackend: execution = e.backend break beam_args = execution.get_beam_args(pipeline_name, pipeline_root) tfx_pipeline = pipeline.Pipeline( components=component_list, beam_pipeline_args=beam_args, metadata_connection_config=metadata_connection_config, pipeline_name=zen_pipeline.artifact_store.unique_id, # for caching pipeline_root=pipeline_root, log_root=pipeline_log, enable_cache=zen_pipeline.enable_cache) # Ensure that the run_id is ZenML pipeline_name tfx_pipeline.pipeline_info.run_id = zen_pipeline.pipeline_name return tfx_pipeline
def test_register_pipeline(repo, delete_config): name = "my_pipeline" p: BasePipeline = BasePipeline(name=name) p._check_registered() random_run_pipeline = random.choice(repo.get_pipelines()) with pytest.raises(exceptions.AlreadyExistsException): random_run_pipeline._check_registered() p.register_pipeline({"name": name}) delete_config(p.file_name)
def test_get_pipeline_by_name(repo, equal_pipelines): p_names = repo.get_pipeline_names() random_name = random.choice(p_names) cfg_list = [y for y in repo.get_pipeline_file_paths() if random_name in y] cfg = yaml_utils.read_yaml(cfg_list[0]) p1 = repo.get_pipeline_by_name(random_name) p2 = BasePipeline.from_config(cfg) assert equal_pipelines(p1, p2, loaded=True)
def test_get_steps_config(): # TODO: Expand this to more steps name = "my_pipeline" p: BasePipeline = BasePipeline(name=name) kwargs = {"number": 1, "description": "abcdefg"} step = BaseStep(**kwargs) p.steps_dict["test"] = step cfg = p.get_steps_config() steps_cfg = cfg[keys.PipelineKeys.STEPS] # avoid missing args / type inconsistencies assert steps_cfg["test"] == step.to_config()
def test_get_pipeline_config(): # test just the pipeline config block added to the steps_config name = "my_pipeline" p: BasePipeline = BasePipeline(name=name) config = p.get_pipeline_config() p_name = p.pipeline_name p_args = config[keys.PipelineKeys.ARGS] assert p_args[keys.PipelineDetailKeys.NAME] == p_name # assert p_args[keys.PipelineDetailKeys.TYPE] == "base" assert p_args[keys.PipelineDetailKeys.ENABLE_CACHE] is True assert config[keys.PipelineKeys.DATASOURCE] == {} assert config[keys.PipelineKeys.SOURCE].split("@")[0] == \ "zenml.pipelines.base_pipeline.BasePipeline"
def get_pipelines_by_datasource(self, datasource): """ Gets list of pipelines associated with datasource. Args: datasource (BaseDatasource): object of type BaseDatasource. """ from zenml.pipelines import BasePipeline pipelines = [] for file_path in self.get_pipeline_file_paths(): c = yaml_utils.read_yaml(file_path) if keys.DatasourceKeys.ID in c[keys.GlobalKeys.PIPELINE][ keys.PipelineKeys.DATASOURCE]: if c[keys.GlobalKeys.PIPELINE][keys.PipelineKeys.DATASOURCE][ keys.DatasourceKeys.ID] == datasource._id: pipelines.append(BasePipeline.from_config(c)) return pipelines
def get_pipeline_names(self) -> Optional[List[Text]]: """Gets list of pipeline (unique) names""" from zenml.pipelines import BasePipeline yamls = self.get_pipeline_file_paths(only_file_names=True) return [BasePipeline.get_name_from_pipeline_name(p) for p in yamls]
def test_run_config(): p = BasePipeline(name="my_pipeline") class MockBackend(OrchestratorBaseBackend): def run(self, config): return {"message": "Run triggered!"} # run config without a specified backend p.run_config(p.to_config()) p.backend = "123" with pytest.raises(Exception): # not a backend subclass error p.run_config(p.to_config()) p.backend = MockBackend() assert not p.run_config(p.to_config())