示例#1
0
def test_to_from_config(equal_datasources):
    first_ds = BaseDatasource(name="my_datasource")

    config = dict({keys.PipelineKeys.STEPS: {}})
    config[keys.PipelineKeys.STEPS][keys.DataSteps.DATA] = {"args": {}}
    config[keys.PipelineKeys.DATASOURCE] = first_ds.to_config()
    second_ds = BaseDatasource.from_config(config)

    assert equal_datasources(first_ds, second_ds, loaded=True)
示例#2
0
def test_datasource_create(repo):
    name = "my_datasource"
    first_ds = BaseDatasource(name=name)

    assert not first_ds._immutable

    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE])

    assert second_ds._immutable
示例#3
0
def test_get_one_pipeline(repo):
    name = "my_datasource"
    first_ds = BaseDatasource(name=name)

    with pytest.raises(exceptions.EmptyDatasourceException):
        _ = first_ds._get_one_pipeline()

    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE])

    assert second_ds._get_one_pipeline()
示例#4
0
    def add_datasource(self, datasource: BaseDatasource):
        """
        Add datasource to pipeline.

        Args:
            datasource: class of type BaseDatasource
        """
        self.datasource = datasource
        self.steps_dict[keys.TrainingSteps.DATA] = datasource.get_data_step()
示例#5
0
def test_add_datasource():
    name = "my_pipeline"
    p: BasePipeline = BasePipeline(name=name)

    p.add_datasource(BaseDatasource(name="my_datasource"))

    assert isinstance(p.datasource, BaseDatasource)

    assert not p.steps_dict[keys.TrainingSteps.DATA]
示例#6
0
def test_get_datapoints(repo):
    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    ds: BaseDatasource = BaseDatasource.from_config(
        cfg[keys.GlobalKeys.PIPELINE])

    csv_df = pd.read_csv(
        os.path.join(TEST_ROOT, "test_data", "my_dataframe.csv"))

    assert ds.get_datapoints() == len(csv_df.index)
示例#7
0
def test_get_data_file_paths(repo):
    first_ds = BaseDatasource(name="my_datasource")

    first_pipeline = BasePipeline(name="my_pipeline")

    first_pipeline.add_datasource(first_ds)

    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE])

    with pytest.raises(AssertionError):
        _ = second_ds._get_data_file_paths(first_pipeline)

    real_pipeline = second_ds._get_one_pipeline()
    paths = second_ds._get_data_file_paths(real_pipeline)

    # TODO: Find a better way of asserting TFRecords
    assert all(os.path.splitext(p)[-1] == ".gz" for p in paths)
示例#8
0
    def wrapper(ds1: BaseDatasource, ds2: BaseDatasource, loaded=True):
        # There can be a "None" datasource in a pipeline
        if ds1 is None and ds2 is None:
            return True
        if sum(d is None for d in [ds1, ds2]) == 1:
            return False

        equal = False
        equal |= ds1.name == ds2.name
        equal |= ds1._id == ds2._id
        equal |= ds1._source == ds2._source
        equal |= equal_steps(ds1.get_data_step(),
                             ds2.get_data_step(),
                             loaded=loaded)
        # TODO[LOW]: Add more checks for constructor kwargs, __dict__ etc.
        if loaded:
            equal |= ds1._immutable != ds2._immutable
        else:
            equal |= ds1._immutable == ds2._immutable

        return equal
示例#9
0
def test_sample_data(repo):
    # reload a datasource from a saved config
    p_config = random.choice(repo.get_pipeline_file_paths())
    cfg = yaml_utils.read_yaml(p_config)
    ds: BaseDatasource = BaseDatasource.from_config(
        cfg[keys.GlobalKeys.PIPELINE])

    sample_df = ds.sample_data()
    csv_df = pd.read_csv(
        os.path.join(TEST_ROOT, "test_data", "my_dataframe.csv"))

    # TODO: This fails for floating point values other than 2.5 in GPA.
    #   Pandas floating point comp might be too strict
    assert sample_df.equals(csv_df)
示例#10
0
def test_get_pipelines_by_datasource(repo):
    # asserted in an earlier test
    ds = repo.get_datasource_by_name("my_csv_datasource")

    p_names = repo.get_pipeline_names()

    ds2 = BaseDatasource(name="ds_12254757")

    pipelines = repo.get_pipelines_by_datasource(ds)

    pipelines_2 = repo.get_pipelines_by_datasource(ds2)

    assert len(pipelines) == len(p_names)

    assert not pipelines_2
示例#11
0
    def from_config(cls, config: Dict):
        """
        Convert from pipeline config to ZenML Pipeline object.

        All steps are also populated and configuration set to parameters set
        in the config file.

        Args:
            config: a ZenML config in dict-form (probably loaded from YAML).
        """
        # start with artifact store
        artifact_store = ArtifactStore(config[keys.GlobalKeys.ARTIFACT_STORE])

        # metadata store
        metadata_store = ZenMLMetadataStore.from_config(
            config=config[keys.GlobalKeys.METADATA_STORE])

        # orchestration backend
        backend = OrchestratorBaseBackend.from_config(
            config[keys.GlobalKeys.BACKEND])

        # pipeline configuration
        p_config = config[keys.GlobalKeys.PIPELINE]
        kwargs = p_config[keys.PipelineKeys.ARGS]
        pipeline_name = kwargs.pop(keys.PipelineDetailKeys.NAME)
        pipeline_source = p_config[keys.PipelineKeys.SOURCE]

        # populate steps
        steps_dict: Dict = {}
        for step_key, step_config in p_config[keys.PipelineKeys.STEPS].items():
            steps_dict[step_key] = BaseStep.from_config(step_config)

        # datasource
        datasource = BaseDatasource.from_config(
            config[keys.GlobalKeys.PIPELINE])

        class_ = source_utils.load_source_path_class(pipeline_source)

        obj = class_(steps_dict=steps_dict,
                     backend=backend,
                     artifact_store=artifact_store,
                     metadata_store=metadata_store,
                     datasource=datasource,
                     pipeline_name=pipeline_name,
                     name=cls.get_name_from_pipeline_name(pipeline_name),
                     **kwargs)
        obj._immutable = True
        return obj
示例#12
0
    def get_datasources(self) -> List:
        """
        Get all datasources in this repo.

        Returns: list of datasources used in this repo
        """
        from zenml.datasources import BaseDatasource

        datasources = []
        datasources_name = set()
        for file_path in self.get_pipeline_file_paths():
            c = yaml_utils.read_yaml(file_path)
            ds = BaseDatasource.from_config(c[keys.GlobalKeys.PIPELINE])
            if ds and ds.name not in datasources_name:
                datasources.append(ds)
                datasources_name.add(ds.name)
        return datasources
示例#13
0
def test_get_datastep():
    first_ds = BaseDatasource(name="my_datasource")

    assert not first_ds.get_data_step()