Exemple #1
0
def test_naming():
    name = "my_pipeline"
    p = BasePipeline(name=name)
    # file_name = p.file_name
    pipeline_name = p.pipeline_name

    assert p.get_name_from_pipeline_name(pipeline_name) == name
Exemple #2
0
def test_to_from_config(equal_pipelines):
    p1: BasePipeline = BasePipeline(name="my_pipeline")

    ds = ImageDatasource(name="my_datasource")

    p1.add_datasource(ds)

    p2 = BasePipeline.from_config(p1.to_config())

    assert equal_pipelines(p1, p2, loaded=True)
Exemple #3
0
def test_get_status(repo):
    name = "my_pipeline"
    p = BasePipeline(name=name)

    # passing this test means that if the pipeline is not in the metadata
    # store, it is (semantically) not started yet
    assert p.get_status() == PipelineStatusTypes.NotStarted.name

    run_pipeline = random.choice(repo.get_pipelines())

    assert run_pipeline.get_status() == PipelineStatusTypes.Succeeded.name
Exemple #4
0
    def get_pipeline_by_name(self, pipeline_name: Text = None):
        """
        Loads a pipeline just by its name.

        Args:
            pipeline_name (str): Name of pipeline.
        """
        from zenml.pipelines import BasePipeline
        yamls = self.get_pipeline_file_paths()
        for y in yamls:
            n = BasePipeline.get_name_from_pipeline_name(os.path.basename(y))
            if n == pipeline_name:
                c = yaml_utils.read_yaml(y)
                return BasePipeline.from_config(c)
Exemple #5
0
def test_run_base(delete_config):
    # Test of pipeline.run(), without artifact / metadata store change
    p = BasePipeline(name="my_pipeline")

    class MockBackend(OrchestratorBaseBackend):
        def run(self, config):
            return {"message": "Run triggered!"}

    backend = MockBackend()

    p.run(backend=backend)

    assert p._immutable

    delete_config(p.file_name)
Exemple #6
0
def test_executed(repo):
    name = "my_pipeline"
    p = BasePipeline(name=name)
    assert not p.is_executed_in_metadata_store

    random_run_pipeline = random.choice(repo.get_pipelines())

    assert random_run_pipeline.is_executed_in_metadata_store
Exemple #7
0
def test_load_config(repo, equal_pipelines):
    p1 = random.choice(repo.get_pipelines())

    pipeline_config = p1.load_config()

    p2 = BasePipeline.from_config(pipeline_config)

    assert equal_pipelines(p1, p2, loaded=True)
Exemple #8
0
 def get_pipelines(self) -> List:
     """Gets list of all pipelines."""
     from zenml.pipelines import BasePipeline
     pipelines = []
     for file_path in self.get_pipeline_file_paths():
         c = yaml_utils.read_yaml(file_path)
         pipelines.append(BasePipeline.from_config(c))
     return pipelines
Exemple #9
0
def test_add_datasource():
    name = "my_pipeline"
    p: BasePipeline = BasePipeline(name=name)

    p.add_datasource(BaseDatasource(name="my_datasource"))

    assert isinstance(p.datasource, BaseDatasource)

    assert not p.steps_dict[keys.TrainingSteps.DATA]
Exemple #10
0
def test_get_data_file_paths(repo):
    first_ds = BaseDatasource(name="my_datasource")

    first_pipeline = BasePipeline(name="my_pipeline")

    first_pipeline.add_datasource(first_ds)

    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE])

    with pytest.raises(AssertionError):
        _ = second_ds._get_data_file_paths(first_pipeline)

    real_pipeline = second_ds._get_one_pipeline()
    paths = second_ds._get_data_file_paths(real_pipeline)

    # TODO: Find a better way of asserting TFRecords
    assert all(os.path.splitext(p)[-1] == ".gz" for p in paths)
    def get_tfx_pipeline(config: Dict[Text, Any]) -> pipeline.Pipeline:
        """
        Converts ZenML config dict to TFX pipeline.

        Args:
            config: A ZenML config dict

        Returns:
            tfx_pipeline: A TFX pipeline object.
        """
        from zenml.pipelines import BasePipeline
        zen_pipeline: BasePipeline = BasePipeline.from_config(config)

        # Get component list
        component_list = zen_pipeline.get_tfx_component_list(config)

        # Get pipeline metadata
        pipeline_name = zen_pipeline.pipeline_name
        metadata_connection_config = \
            zen_pipeline.metadata_store.get_tfx_metadata_config()
        artifact_store = zen_pipeline.artifact_store

        # Pipeline settings
        pipeline_root = os.path.join(
            artifact_store.path, artifact_store.unique_id)
        pipeline_log = os.path.join(pipeline_root, 'logs', pipeline_name)

        # Resolve execution backend
        execution = ProcessingBaseBackend()  # default
        for e in zen_pipeline.steps_dict.values():
            # find out the processing backends, take the first one which is
            # not a ProcessingBaseBackend
            if e.backend and issubclass(
                    e.backend.__class__, ProcessingBaseBackend) and \
                    e.backend.__class__ != ProcessingBaseBackend:
                execution = e.backend
                break

        beam_args = execution.get_beam_args(pipeline_name, pipeline_root)

        tfx_pipeline = pipeline.Pipeline(
            components=component_list,
            beam_pipeline_args=beam_args,
            metadata_connection_config=metadata_connection_config,
            pipeline_name=zen_pipeline.artifact_store.unique_id,  # for caching
            pipeline_root=pipeline_root,
            log_root=pipeline_log,
            enable_cache=zen_pipeline.enable_cache)

        # Ensure that the run_id is ZenML pipeline_name
        tfx_pipeline.pipeline_info.run_id = zen_pipeline.pipeline_name
        return tfx_pipeline
Exemple #12
0
def test_register_pipeline(repo, delete_config):
    name = "my_pipeline"
    p: BasePipeline = BasePipeline(name=name)

    p._check_registered()

    random_run_pipeline = random.choice(repo.get_pipelines())
    with pytest.raises(exceptions.AlreadyExistsException):
        random_run_pipeline._check_registered()

    p.register_pipeline({"name": name})

    delete_config(p.file_name)
Exemple #13
0
def test_get_pipeline_by_name(repo, equal_pipelines):
    p_names = repo.get_pipeline_names()

    random_name = random.choice(p_names)
    cfg_list = [y for y in repo.get_pipeline_file_paths()
                if random_name in y]

    cfg = yaml_utils.read_yaml(cfg_list[0])

    p1 = repo.get_pipeline_by_name(random_name)

    p2 = BasePipeline.from_config(cfg)

    assert equal_pipelines(p1, p2, loaded=True)
Exemple #14
0
def test_get_steps_config():
    # TODO: Expand this to more steps
    name = "my_pipeline"
    p: BasePipeline = BasePipeline(name=name)

    kwargs = {"number": 1, "description": "abcdefg"}
    step = BaseStep(**kwargs)

    p.steps_dict["test"] = step

    cfg = p.get_steps_config()

    steps_cfg = cfg[keys.PipelineKeys.STEPS]

    # avoid missing args / type inconsistencies
    assert steps_cfg["test"] == step.to_config()
Exemple #15
0
def test_get_pipeline_config():
    # test just the pipeline config block added to the steps_config
    name = "my_pipeline"
    p: BasePipeline = BasePipeline(name=name)

    config = p.get_pipeline_config()

    p_name = p.pipeline_name

    p_args = config[keys.PipelineKeys.ARGS]

    assert p_args[keys.PipelineDetailKeys.NAME] == p_name
    # assert p_args[keys.PipelineDetailKeys.TYPE] == "base"
    assert p_args[keys.PipelineDetailKeys.ENABLE_CACHE] is True
    assert config[keys.PipelineKeys.DATASOURCE] == {}
    assert config[keys.PipelineKeys.SOURCE].split("@")[0] == \
           "zenml.pipelines.base_pipeline.BasePipeline"
Exemple #16
0
    def get_pipelines_by_datasource(self, datasource):
        """
        Gets list of pipelines associated with datasource.

        Args:
            datasource (BaseDatasource): object of type BaseDatasource.
        """
        from zenml.pipelines import BasePipeline
        pipelines = []
        for file_path in self.get_pipeline_file_paths():
            c = yaml_utils.read_yaml(file_path)
            if keys.DatasourceKeys.ID in c[keys.GlobalKeys.PIPELINE][
                    keys.PipelineKeys.DATASOURCE]:
                if c[keys.GlobalKeys.PIPELINE][keys.PipelineKeys.DATASOURCE][
                        keys.DatasourceKeys.ID] == datasource._id:
                    pipelines.append(BasePipeline.from_config(c))
        return pipelines
Exemple #17
0
 def get_pipeline_names(self) -> Optional[List[Text]]:
     """Gets list of pipeline (unique) names"""
     from zenml.pipelines import BasePipeline
     yamls = self.get_pipeline_file_paths(only_file_names=True)
     return [BasePipeline.get_name_from_pipeline_name(p) for p in yamls]
Exemple #18
0
def test_run_config():
    p = BasePipeline(name="my_pipeline")

    class MockBackend(OrchestratorBaseBackend):
        def run(self, config):
            return {"message": "Run triggered!"}

    # run config without a specified backend
    p.run_config(p.to_config())

    p.backend = "123"

    with pytest.raises(Exception):
        # not a backend subclass error
        p.run_config(p.to_config())

    p.backend = MockBackend()

    assert not p.run_config(p.to_config())