def test_to_from_config(equal_datasources): first_ds = BaseDatasource(name="my_datasource") config = dict({keys.PipelineKeys.STEPS: {}}) config[keys.PipelineKeys.STEPS][keys.DataSteps.DATA] = {"args": {}} config[keys.PipelineKeys.DATASOURCE] = first_ds.to_config() second_ds = BaseDatasource.from_config(config) assert equal_datasources(first_ds, second_ds, loaded=True)
def test_datasource_create(repo): name = "my_datasource" first_ds = BaseDatasource(name=name) assert not first_ds._immutable # reload a datasource from a saved config p_config = random.choice(repo.get_pipeline_file_paths()) cfg = yaml_utils.read_yaml(p_config) second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE]) assert second_ds._immutable
def test_get_one_pipeline(repo): name = "my_datasource" first_ds = BaseDatasource(name=name) with pytest.raises(exceptions.EmptyDatasourceException): _ = first_ds._get_one_pipeline() # reload a datasource from a saved config p_config = random.choice(repo.get_pipeline_file_paths()) cfg = yaml_utils.read_yaml(p_config) second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE]) assert second_ds._get_one_pipeline()
def add_datasource(self, datasource: BaseDatasource): """ Add datasource to pipeline. Args: datasource: class of type BaseDatasource """ self.datasource = datasource self.steps_dict[keys.TrainingSteps.DATA] = datasource.get_data_step()
def test_add_datasource(): name = "my_pipeline" p: BasePipeline = BasePipeline(name=name) p.add_datasource(BaseDatasource(name="my_datasource")) assert isinstance(p.datasource, BaseDatasource) assert not p.steps_dict[keys.TrainingSteps.DATA]
def test_get_datapoints(repo): # reload a datasource from a saved config p_config = random.choice(repo.get_pipeline_file_paths()) cfg = yaml_utils.read_yaml(p_config) ds: BaseDatasource = BaseDatasource.from_config( cfg[keys.GlobalKeys.PIPELINE]) csv_df = pd.read_csv( os.path.join(TEST_ROOT, "test_data", "my_dataframe.csv")) assert ds.get_datapoints() == len(csv_df.index)
def test_get_data_file_paths(repo): first_ds = BaseDatasource(name="my_datasource") first_pipeline = BasePipeline(name="my_pipeline") first_pipeline.add_datasource(first_ds) # reload a datasource from a saved config p_config = random.choice(repo.get_pipeline_file_paths()) cfg = yaml_utils.read_yaml(p_config) second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE]) with pytest.raises(AssertionError): _ = second_ds._get_data_file_paths(first_pipeline) real_pipeline = second_ds._get_one_pipeline() paths = second_ds._get_data_file_paths(real_pipeline) # TODO: Find a better way of asserting TFRecords assert all(os.path.splitext(p)[-1] == ".gz" for p in paths)
def wrapper(ds1: BaseDatasource, ds2: BaseDatasource, loaded=True): # There can be a "None" datasource in a pipeline if ds1 is None and ds2 is None: return True if sum(d is None for d in [ds1, ds2]) == 1: return False equal = False equal |= ds1.name == ds2.name equal |= ds1._id == ds2._id equal |= ds1._source == ds2._source equal |= equal_steps(ds1.get_data_step(), ds2.get_data_step(), loaded=loaded) # TODO[LOW]: Add more checks for constructor kwargs, __dict__ etc. if loaded: equal |= ds1._immutable != ds2._immutable else: equal |= ds1._immutable == ds2._immutable return equal
def test_sample_data(repo): # reload a datasource from a saved config p_config = random.choice(repo.get_pipeline_file_paths()) cfg = yaml_utils.read_yaml(p_config) ds: BaseDatasource = BaseDatasource.from_config( cfg[keys.GlobalKeys.PIPELINE]) sample_df = ds.sample_data() csv_df = pd.read_csv( os.path.join(TEST_ROOT, "test_data", "my_dataframe.csv")) # TODO: This fails for floating point values other than 2.5 in GPA. # Pandas floating point comp might be too strict assert sample_df.equals(csv_df)
def test_get_pipelines_by_datasource(repo): # asserted in an earlier test ds = repo.get_datasource_by_name("my_csv_datasource") p_names = repo.get_pipeline_names() ds2 = BaseDatasource(name="ds_12254757") pipelines = repo.get_pipelines_by_datasource(ds) pipelines_2 = repo.get_pipelines_by_datasource(ds2) assert len(pipelines) == len(p_names) assert not pipelines_2
def from_config(cls, config: Dict): """ Convert from pipeline config to ZenML Pipeline object. All steps are also populated and configuration set to parameters set in the config file. Args: config: a ZenML config in dict-form (probably loaded from YAML). """ # start with artifact store artifact_store = ArtifactStore(config[keys.GlobalKeys.ARTIFACT_STORE]) # metadata store metadata_store = ZenMLMetadataStore.from_config( config=config[keys.GlobalKeys.METADATA_STORE]) # orchestration backend backend = OrchestratorBaseBackend.from_config( config[keys.GlobalKeys.BACKEND]) # pipeline configuration p_config = config[keys.GlobalKeys.PIPELINE] kwargs = p_config[keys.PipelineKeys.ARGS] pipeline_name = kwargs.pop(keys.PipelineDetailKeys.NAME) pipeline_source = p_config[keys.PipelineKeys.SOURCE] # populate steps steps_dict: Dict = {} for step_key, step_config in p_config[keys.PipelineKeys.STEPS].items(): steps_dict[step_key] = BaseStep.from_config(step_config) # datasource datasource = BaseDatasource.from_config( config[keys.GlobalKeys.PIPELINE]) class_ = source_utils.load_source_path_class(pipeline_source) obj = class_(steps_dict=steps_dict, backend=backend, artifact_store=artifact_store, metadata_store=metadata_store, datasource=datasource, pipeline_name=pipeline_name, name=cls.get_name_from_pipeline_name(pipeline_name), **kwargs) obj._immutable = True return obj
def get_datasources(self) -> List: """ Get all datasources in this repo. Returns: list of datasources used in this repo """ from zenml.datasources import BaseDatasource datasources = [] datasources_name = set() for file_path in self.get_pipeline_file_paths(): c = yaml_utils.read_yaml(file_path) ds = BaseDatasource.from_config(c[keys.GlobalKeys.PIPELINE]) if ds and ds.name not in datasources_name: datasources.append(ds) datasources_name.add(ds.name) return datasources
def test_get_datastep(): first_ds = BaseDatasource(name="my_datasource") assert not first_ds.get_data_step()