示例#1
0
    def with_config(
        self: T, config_file: str, overwrite_step_parameters: bool = False
    ) -> T:
        """Configures this pipeline using a yaml file.

        Args:
            config_file: Path to a yaml file which contains configuration
                options for running this pipeline. See
                https://docs.zenml.io/guides/pipeline-configuration for details
                regarding the specification of this file.
            overwrite_step_parameters: If set to `True`, values from the
                configuration file will overwrite configuration parameters
                passed in code.

        Returns:
            The pipeline object that this method was called on.
        """
        config_yaml = yaml_utils.read_yaml(config_file)

        if PipelineConfigurationKeys.STEPS in config_yaml:
            self._read_config_steps(
                config_yaml[PipelineConfigurationKeys.STEPS],
                overwrite=overwrite_step_parameters,
            )

        return self
示例#2
0
def test_zenml_config_setters(equal_md_stores):
    cfg1 = ZenMLConfig(repo_path=config_root)

    old_store_path = artifact_store_path
    old_pipelines_dir = pipelines_dir
    old_sqlite = cfg1.get_metadata_store()

    new_store_path = os.getcwd()
    new_pipelines_dir = "awfkoeghelk"
    new_mdstore = MockMetadataStore()

    cfg1.set_artifact_store(new_store_path)
    cfg1.set_pipelines_dir(new_pipelines_dir)
    cfg1.set_metadata_store(new_mdstore)

    updated_cfg = yaml_utils.read_yaml(cfg1.config_path)

    loaded_md_store = MockMetadataStore.from_config(
        updated_cfg[keys.GlobalKeys.METADATA_STORE])

    assert updated_cfg[keys.GlobalKeys.ARTIFACT_STORE] == new_store_path
    assert updated_cfg[PIPELINES_DIR_KEY] == new_pipelines_dir

    assert equal_md_stores(new_mdstore, loaded_md_store)

    # revert changes
    cfg1.set_artifact_store(old_store_path)
    cfg1.set_pipelines_dir(old_pipelines_dir)
    cfg1.set_metadata_store(old_sqlite)

    shutil.rmtree(new_pipelines_dir, ignore_errors=False)
示例#3
0
    def get_step_by_version(self, step_type: Union[Type, Text], version: Text):
        """
        Gets a Step object by version. There might be many objects of a
        particular Step registered in many pipelines. This function just
        returns the first configuration that it matches.

        Args:
            step_type: either a string specifying full path to the step or a
            class path.
            version: either sha pin or standard ZenML version pin.
        """
        from zenml.utils import source_utils
        from zenml.steps.base_step import BaseStep

        type_str = source_utils.get_module_path_from_class(step_type)

        for file_path in self.get_pipeline_file_paths():
            c = yaml_utils.read_yaml(file_path)
            for step_name, step_config in c[keys.GlobalKeys.PIPELINE][
                    keys.PipelineKeys.STEPS].items():
                # Get version from source
                class_ = source_utils.get_class_path_from_source(
                    step_config[keys.StepKeys.SOURCE])
                source_version = source_utils.get_version_from_source(
                    step_config[keys.StepKeys.SOURCE])

                if class_ == type_str and version == source_version:
                    return BaseStep.from_config(step_config)
示例#4
0
    def __init__(self, repo_path: Text):
        """
        Construct class instance for ZenML Config.

        Args:
            repo_path (str): path to root of repository.
        """
        self.repo_path = repo_path
        if not ZenMLConfig.is_zenml_dir(self.repo_path):
            raise AssertionError(f'This is not a ZenML repository, as it does '
                                 f'not contain the {ZENML_CONFIG_NAME} '
                                 f'config file. Please initialize your repo '
                                 f'with `zenml init` with the ZenML CLI.')

        self.config_dir = os.path.join(repo_path, ZENML_DIR_NAME)
        self.config_path = os.path.join(self.config_dir, ZENML_CONFIG_NAME)

        self.raw_config = yaml_utils.read_yaml(self.config_path)

        # Load self vars in init to be clean
        self.metadata_store: Optional[ZenMLMetadataStore] = None
        self.artifact_store: Optional[ArtifactStore] = None
        self.pipelines_dir: Text = ''

        # Override these using load_config
        self.load_config(self.raw_config)
示例#5
0
 def get_pipelines(self) -> List:
     """Gets list of all pipelines."""
     from zenml.pipelines import BasePipeline
     pipelines = []
     for file_path in self.get_pipeline_file_paths():
         c = yaml_utils.read_yaml(file_path)
         pipelines.append(BasePipeline.from_config(c))
     return pipelines
示例#6
0
文件: conftest.py 项目: syllogy/zenml
    def wrapper():
        repo: Repository = Repository.get_instance()
        repo.zenml_config.set_pipelines_dir(pipeline_root)

        for p_config in path_utils.list_dir(pipeline_root):
            y = yaml_utils.read_yaml(p_config)
            p: TrainingPipeline = TrainingPipeline.from_config(y)
            p.run()
示例#7
0
    def load_pipeline_config(self, file_name: Text) -> Dict[Text, Any]:
        """
        Loads a ZenML config from YAML.

        Args:
            file_name (str): file name of pipeline
        """
        self._check_if_initialized()
        pipelines_dir = self.zenml_config.get_pipelines_dir()
        return yaml_utils.read_yaml(os.path.join(pipelines_dir, file_name))
示例#8
0
    def get_datasource_names(self) -> List:
        """
        Get all datasources in this repo.

        Returns: List of datasource names used in this repo.
        """
        n = []
        for file_path in self.get_pipeline_file_paths():
            c = yaml_utils.read_yaml(file_path)
            n.append(c[keys.GlobalKeys.DATASOURCE][keys.DatasourceKeys.NAME])
        return list(set(n))
示例#9
0
def test_get_datapoints(repo):
    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    ds: BaseDatasource = BaseDatasource.from_config(
        cfg[keys.GlobalKeys.PIPELINE])

    csv_df = pd.read_csv(
        os.path.join(TEST_ROOT, "test_data", "my_dataframe.csv"))

    assert ds.get_datapoints() == len(csv_df.index)
示例#10
0
def test_datasource_create(repo):
    name = "my_datasource"
    first_ds = BaseDatasource(name=name)

    assert not first_ds._immutable

    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE])

    assert second_ds._immutable
示例#11
0
    def get_datasource_id_by_name(self, name: Text) -> List:
        """
        Get ID of a datasource by just its name.

        Returns: ID of datasource.
        """
        for file_path in self.get_pipeline_file_paths():
            c = yaml_utils.read_yaml(file_path)
            src = c[keys.GlobalKeys.PIPELINE][keys.PipelineKeys.DATASOURCE]
            if keys.DatasourceKeys.NAME in src:
                if name == src[keys.DatasourceKeys.NAME]:
                    return src[keys.DatasourceKeys.ID]
示例#12
0
def test_get_pipeline_by_name(equal_pipelines):
    p_names = repo.get_pipeline_names()

    random_name = random.choice(p_names)
    cfg_list = [y for y in repo.get_pipeline_file_paths() if random_name in y]

    cfg = yaml_utils.read_yaml(cfg_list[0])

    p1 = repo.get_pipeline_by_name(random_name)

    p2 = BasePipeline.from_config(cfg)

    assert equal_pipelines(p1, p2, loaded=True)
示例#13
0
def test_get_one_pipeline(repo):
    name = "my_datasource"
    first_ds = BaseDatasource(name=name)

    with pytest.raises(exceptions.EmptyDatasourceException):
        _ = first_ds._get_one_pipeline()

    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE])

    assert second_ds._get_one_pipeline()
示例#14
0
    def get_pipelines(self) -> List:
        """
        Gets list of all pipelines.

        Args:
            type_filter (list): list of types to filter by.
        """
        from zenml.core.pipelines.base_pipeline import BasePipeline
        pipelines = []
        for file_path in self.get_pipeline_file_paths():
            c = yaml_utils.read_yaml(file_path)
            pipelines.append(BasePipeline.from_config(c))
        return pipelines
示例#15
0
def test_sample_data():
    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    ds: BaseDatasource = BaseDatasource.from_config(
        cfg[keys.GlobalKeys.PIPELINE])

    sample_df = ds.sample_data()
    csv_df = pd.read_csv(
        os.path.join(TEST_ROOT, "test_data", "my_dataframe.csv"))

    # TODO: This fails on the test csv because the age gets typed as
    #  a float in datasource.sample_data() method
    assert sample_df.equals(csv_df)
示例#16
0
    def get_pipeline_by_name(self, pipeline_name: Text = None):
        """
        Loads a pipeline just by its name.

        Args:
            pipeline_name (str): Name of pipeline.
        """
        from zenml.pipelines import BasePipeline
        yamls = self.get_pipeline_file_paths()
        for y in yamls:
            n = BasePipeline.get_name_from_pipeline_name(os.path.basename(y))
            if n == pipeline_name:
                c = yaml_utils.read_yaml(y)
                return BasePipeline.from_config(c)
示例#17
0
def test_sample_data(repo):
    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    ds: BaseDatasource = BaseDatasource.from_config(
        cfg[keys.GlobalKeys.PIPELINE])

    sample_df = ds.sample_data()
    csv_df = pd.read_csv(
        os.path.join(TEST_ROOT, "test_data", "my_dataframe.csv"))

    # TODO: This fails for floating point values other than 2.5 in GPA.
    #   Pandas floating point comp might be too strict
    assert sample_df.equals(csv_df)
示例#18
0
def run_pipeline(python_file: str, config_path: str) -> None:
    """Runs pipeline specified by the given config YAML object.

    Args:
        python_file: Path to the python file that defines the pipeline.
        config_path: Path to configuration YAML file.
    """
    module = source_utils.import_python_file(python_file)
    config = yaml_utils.read_yaml(config_path)
    PipelineConfigurationKeys.key_check(config)

    pipeline_name = config[PipelineConfigurationKeys.NAME]
    pipeline_class = _get_module_attribute(module, pipeline_name)

    steps = {}
    for step_name, step_config in config[
            PipelineConfigurationKeys.STEPS].items():
        StepConfigurationKeys.key_check(step_config)
        step_class = _get_module_attribute(
            module, step_config[StepConfigurationKeys.SOURCE_])
        step_instance = step_class()
        materializers_config = step_config.get(
            StepConfigurationKeys.MATERIALIZERS_, None)
        if materializers_config:
            # We need to differentiate whether it's a single materializer
            # or a dictionary mapping output names to materializers
            if isinstance(materializers_config, str):
                materializers = _get_module_attribute(module,
                                                      materializers_config)
            elif isinstance(materializers_config, dict):
                materializers = {
                    output_name: _get_module_attribute(module, source)
                    for output_name, source in materializers_config.items()
                }
            else:
                raise PipelineConfigurationError(
                    f"Only `str` and `dict` values are allowed for "
                    f"'materializers' attribute of a step configuration. You "
                    f"tried to pass in `{materializers_config}` (type: "
                    f"`{type(materializers_config).__name__}`).")
            step_instance = step_instance.with_return_materializers(
                materializers)

        steps[step_name] = step_instance

    pipeline_instance = pipeline_class(**steps).with_config(
        config_path, overwrite_step_parameters=True)
    logger.debug("Finished setting up pipeline '%s' from CLI", pipeline_name)
    pipeline_instance.run()
示例#19
0
    def get_pipeline_by_name(self, pipeline_name: Text = None):
        """
        Loads a pipeline just by its name.

        Args:
            pipeline_name (str): Name of pipeline.
        """
        from zenml.core.pipelines.base_pipeline import BasePipeline
        yamls = self.get_pipeline_file_paths()
        for y in yamls:
            n = BasePipeline.get_name_from_pipeline_name(y)
            if n == pipeline_name:
                c = yaml_utils.read_yaml(y)
                return BasePipeline.from_config(c)
        raise Exception(f'No pipeline called {pipeline_name}')
示例#20
0
    def get_pipelines_by_datasource(self, datasource):
        """
        Gets list of pipelines associated with datasource.

        Args:
            datasource (BaseDatasource): object of type BaseDatasource.
        """
        from zenml.core.pipelines.base_pipeline import BasePipeline
        pipelines = []
        for file_path in self.get_pipeline_file_paths():
            c = yaml_utils.read_yaml(file_path)
            if c[keys.GlobalKeys.DATASOURCE][keys.DatasourceKeys.ID] == \
                    datasource._id:
                pipelines.append(BasePipeline.from_config(c))
        return pipelines
示例#21
0
def run_pipeline(path_to_config: Text):
    """
    Runs pipeline specified by the given config YAML object.

    Args:
        path_to_config: Path to config of the designated pipeline.
         Has to be matching the YAML file name.
    """
    # config has metadata store, backends and artifact store,
    # so no need to specify them
    try:
        config = read_yaml(path_to_config)
        p: TrainingPipeline = TrainingPipeline.from_config(config)
        p.run()
    except Exception as e:
        error(e)
示例#22
0
    def get_datasources(self) -> List:
        """
        Get all datasources in this repo.

        Returns: list of datasources used in this repo
        """
        from zenml.datasources import BaseDatasource

        datasources = []
        datasources_name = set()
        for file_path in self.get_pipeline_file_paths():
            c = yaml_utils.read_yaml(file_path)
            ds = BaseDatasource.from_config(c[keys.GlobalKeys.PIPELINE])
            if ds and ds.name not in datasources_name:
                datasources.append(ds)
                datasources_name.add(ds.name)
        return datasources
示例#23
0
    def get_step_versions(self):
        """List all registered steps in repository"""
        from zenml.utils import source_utils
        steps_dict = {}
        for file_path in self.get_pipeline_file_paths():
            c = yaml_utils.read_yaml(file_path)
            for step_name, step_config in c[keys.GlobalKeys.STEPS].items():
                # Get version from source
                version = source_utils.get_version_from_source(
                    step_config[keys.StepKeys.SOURCE])
                class_ = source_utils.get_class_path_from_source(
                    step_config[keys.StepKeys.SOURCE])

                # Add to set of versions
                if class_ in steps_dict:
                    steps_dict[class_].add(version)
                else:
                    steps_dict[class_] = {version}
        return steps_dict
示例#24
0
def test_get_data_file_paths(repo):
    first_ds = BaseDatasource(name="my_datasource")

    first_pipeline = BasePipeline(name="my_pipeline")

    first_pipeline.add_datasource(first_ds)

    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE])

    with pytest.raises(AssertionError):
        _ = second_ds._get_data_file_paths(first_pipeline)

    real_pipeline = second_ds._get_one_pipeline()
    paths = second_ds._get_data_file_paths(real_pipeline)

    # TODO: Find a better way of asserting TFRecords
    assert all(os.path.splitext(p)[-1] == ".gz" for p in paths)