Пример #1
0
def test_to_from_config(equal_datasources):
    first_ds = BaseDatasource(name="my_datasource")

    config = dict({keys.PipelineKeys.STEPS: {}})
    config[keys.PipelineKeys.STEPS][keys.DataSteps.DATA] = {"args": {}}
    config[keys.PipelineKeys.DATASOURCE] = first_ds.to_config()
    second_ds = BaseDatasource.from_config(config)

    assert equal_datasources(first_ds, second_ds, loaded=True)
Пример #2
0
def test_datasource_create(repo):
    name = "my_datasource"
    first_ds = BaseDatasource(name=name)

    assert not first_ds._immutable

    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE])

    assert second_ds._immutable
Пример #3
0
def test_get_one_pipeline(repo):
    name = "my_datasource"
    first_ds = BaseDatasource(name=name)

    with pytest.raises(exceptions.EmptyDatasourceException):
        _ = first_ds._get_one_pipeline()

    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE])

    assert second_ds._get_one_pipeline()
Пример #4
0
    def add_datasource(self, datasource: BaseDatasource):
        """
        Add datasource to pipeline.

        Args:
            datasource: class of type BaseDatasource
        """
        self.datasource = datasource
        self.steps_dict[keys.TrainingSteps.DATA] = datasource.get_data_step()
Пример #5
0
def test_add_datasource():
    name = "my_pipeline"
    p: BasePipeline = BasePipeline(name=name)

    p.add_datasource(BaseDatasource(name="my_datasource"))

    assert isinstance(p.datasource, BaseDatasource)

    assert not p.steps_dict[keys.TrainingSteps.DATA]
Пример #6
0
    def wrapper(ds1: BaseDatasource, ds2: BaseDatasource, loaded=True):
        # There can be a "None" datasource in a pipeline
        if ds1 is None and ds2 is None:
            return True
        if sum(d is None for d in [ds1, ds2]) == 1:
            return False

        equal = False
        equal |= ds1.name == ds2.name
        equal |= ds1.schema == ds2.schema
        equal |= ds1._id == ds2._id
        equal |= ds1._source == ds2._source
        equal |= equal_steps(ds1.get_data_step(), ds2.get_data_step(),
                             loaded=loaded)
        if loaded:
            equal |= ds1._immutable != ds2._immutable
        else:
            equal |= ds1._immutable == ds2._immutable

        return equal
Пример #7
0
def test_get_datapoints(repo):
    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    ds: BaseDatasource = BaseDatasource.from_config(
        cfg[keys.GlobalKeys.PIPELINE])

    csv_df = pd.read_csv(
        os.path.join(TEST_ROOT, "test_data", "my_dataframe.csv"))

    assert ds.get_datapoints() == len(csv_df.index)
Пример #8
0
def test_get_data_file_paths(repo):
    first_ds = BaseDatasource(name="my_datasource")

    first_pipeline = BasePipeline(name="my_pipeline")

    first_pipeline.add_datasource(first_ds)

    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE])

    with pytest.raises(AssertionError):
        _ = second_ds._get_data_file_paths(first_pipeline)

    real_pipeline = second_ds._get_one_pipeline()
    paths = second_ds._get_data_file_paths(real_pipeline)

    # TODO: Find a better way of asserting TFRecords
    assert all(os.path.splitext(p)[-1] == ".gz" for p in paths)
Пример #9
0
    def wrapper(ds1: BaseDatasource, ds2: BaseDatasource, loaded=True):
        # There can be a "None" datasource in a pipeline
        if ds1 is None and ds2 is None:
            return True
        if sum(d is None for d in [ds1, ds2]) == 1:
            return False

        equal = False
        equal |= ds1.name == ds2.name
        equal |= ds1._id == ds2._id
        equal |= ds1._source == ds2._source
        equal |= equal_steps(ds1.get_data_step(), ds2.get_data_step(),
                             loaded=loaded)
        # TODO[LOW]: Add more checks for constructor kwargs, __dict__ etc.
        if loaded:
            equal |= ds1._immutable != ds2._immutable
        else:
            equal |= ds1._immutable == ds2._immutable

        return equal
Пример #10
0
    def from_config(cls, config: Dict):
        """
        Convert from pipeline config to ZenML Pipeline object.

        All steps are also populated and configuration set to parameters set
        in the config file.

        Args:
            config: a ZenML config in dict-form (probably loaded from YAML).
        """
        # start with artifact store
        artifact_store = ArtifactStore(config[keys.GlobalKeys.ARTIFACT_STORE])

        # metadata store
        metadata_store = ZenMLMetadataStore.from_config(
            config=config[keys.GlobalKeys.METADATA_STORE]
        )

        # orchestration backend
        backend = OrchestratorBaseBackend.from_config(
            config[keys.GlobalKeys.BACKEND])

        # pipeline configuration
        p_config = config[keys.GlobalKeys.PIPELINE]
        pipeline_name = p_config[keys.PipelineKeys.NAME]
        pipeline_source = p_config[keys.PipelineKeys.SOURCE]

        # populate steps
        steps_dict: Dict = {}
        for step_key, step_config in p_config[keys.PipelineKeys.STEPS].items():
            steps_dict[step_key] = BaseStep.from_config(step_config)

        # datasource
        datasource = BaseDatasource.from_config(
            config[keys.GlobalKeys.PIPELINE])

        # enable cache
        enable_cache = p_config[keys.PipelineKeys.ENABLE_CACHE]

        class_ = source_utils.load_source_path_class(pipeline_source)

        obj = class_(
            name=cls.get_name_from_pipeline_name(pipeline_name),
            pipeline_name=pipeline_name,
            enable_cache=enable_cache,
            steps_dict=steps_dict,
            backend=backend,
            artifact_store=artifact_store,
            metadata_store=metadata_store,
            datasource=datasource)
        obj._immutable = True
        logger.debug(f'Pipeline {pipeline_name} loaded and and is immutable.')
        return obj
Пример #11
0
def test_sample_data(repo):
    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    ds: BaseDatasource = BaseDatasource.from_config(
        cfg[keys.GlobalKeys.PIPELINE])

    sample_df = ds.sample_data()
    csv_df = pd.read_csv(
        os.path.join(TEST_ROOT, "test_data", "my_dataframe.csv"))

    # TODO: This fails for floating point values other than 2.5 in GPA.
    #   Pandas floating point comp might be too strict
    assert sample_df.equals(csv_df)
Пример #12
0
def test_sample_data():
    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    ds: BaseDatasource = BaseDatasource.from_config(
        cfg[keys.GlobalKeys.PIPELINE])

    sample_df = ds.sample_data()
    csv_df = pd.read_csv(
        os.path.join(TEST_ROOT, "test_data", "my_dataframe.csv"))

    # TODO: This fails on the test csv because the age gets typed as
    #  a float in datasource.sample_data() method
    assert sample_df.equals(csv_df)
Пример #13
0
def test_get_pipelines_by_datasource():
    # asserted in an earlier test
    ds = repo.get_datasource_by_name("my_csv_datasource")

    p_names = repo.get_pipeline_names()

    ds2 = BaseDatasource(name="ds_12254757")

    pipelines = repo.get_pipelines_by_datasource(ds)

    pipelines_2 = repo.get_pipelines_by_datasource(ds2)

    assert len(pipelines) == len(p_names)

    assert not pipelines_2
Пример #14
0
    def from_config(cls, config: Dict):
        """
        Convert from pipeline config to ZenML Pipeline object.

        All steps are also populated and configuration set to parameters set
        in the config file.

        Args:
            config: a ZenML config in dict-form (probably loaded from YAML).
        """
        # populate steps
        steps_dict: Dict = {}
        for step_key, step_config in config[keys.GlobalKeys.STEPS].items():
            steps_dict[step_key] = BaseStep.from_config(step_config)

        env = config[keys.GlobalKeys.ENV]
        pipeline_name = env[keys.EnvironmentKeys.EXPERIMENT_NAME]
        name = BasePipeline.get_name_from_pipeline_name(
            pipeline_name=pipeline_name)

        backends_dict: Dict = {}
        for backend_key, backend_config in env[
                keys.EnvironmentKeys.BACKENDS].items():
            backends_dict[backend_key] = BaseBackend.from_config(
                backend_key, backend_config)

        artifact_store = ArtifactStore(
            env[keys.EnvironmentKeys.ARTIFACT_STORE])
        metadata_store = ZenMLMetadataStore.from_config(
            config=env[METADATA_KEY])

        datasource = BaseDatasource.from_config(config)

        from zenml.core.pipelines.pipeline_factory import pipeline_factory
        pipeline_type = BasePipeline.get_type_from_pipeline_name(pipeline_name)
        class_ = pipeline_factory.get_pipeline_by_type(pipeline_type)

        # TODO: [MEDIUM] Perhaps move some of the logic in the init block here
        #  especially regarding inferring immutability.

        return class_(name=name,
                      pipeline_name=pipeline_name,
                      enable_cache=env[keys.EnvironmentKeys.ENABLE_CACHE],
                      steps_dict=steps_dict,
                      backends_dict=backends_dict,
                      artifact_store=artifact_store,
                      metadata_store=metadata_store,
                      datasource=datasource)
Пример #15
0
    def get_datasources(self) -> List:
        """
        Get all datasources in this repo.

        Returns: list of datasources used in this repo
        """
        from zenml.core.datasources.base_datasource import BaseDatasource

        datasources = []
        datasources_name = set()
        for file_path in self.get_pipeline_file_paths():
            c = yaml_utils.read_yaml(file_path)
            ds = BaseDatasource.from_config(c[keys.GlobalKeys.PIPELINE])
            if ds and ds.name not in datasources_name:
                datasources.append(ds)
                datasources_name.add(ds.name)
        return datasources
Пример #16
0
def test_get_datastep():
    first_ds = BaseDatasource(name="my_datasource")

    assert not first_ds.get_data_step()