def test_to_from_config(equal_datasources): first_ds = BaseDatasource(name="my_datasource") config = dict({keys.PipelineKeys.STEPS: {}}) config[keys.PipelineKeys.STEPS][keys.DataSteps.DATA] = {"args": {}} config[keys.PipelineKeys.DATASOURCE] = first_ds.to_config() second_ds = BaseDatasource.from_config(config) assert equal_datasources(first_ds, second_ds, loaded=True)
def test_get_datapoints(repo): # reload a datasource from a saved config p_config = random.choice(repo.get_pipeline_file_paths()) cfg = yaml_utils.read_yaml(p_config) ds: BaseDatasource = BaseDatasource.from_config( cfg[keys.GlobalKeys.PIPELINE]) csv_df = pd.read_csv( os.path.join(TEST_ROOT, "test_data", "my_dataframe.csv")) assert ds.get_datapoints() == len(csv_df.index)
def test_datasource_create(repo): name = "my_datasource" first_ds = BaseDatasource(name=name) assert not first_ds._immutable # reload a datasource from a saved config p_config = random.choice(repo.get_pipeline_file_paths()) cfg = yaml_utils.read_yaml(p_config) second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE]) assert second_ds._immutable
def from_config(cls, config: Dict): """ Convert from pipeline config to ZenML Pipeline object. All steps are also populated and configuration set to parameters set in the config file. Args: config: a ZenML config in dict-form (probably loaded from YAML). """ # start with artifact store artifact_store = ArtifactStore(config[keys.GlobalKeys.ARTIFACT_STORE]) # metadata store metadata_store = ZenMLMetadataStore.from_config( config=config[keys.GlobalKeys.METADATA_STORE] ) # orchestration backend backend = OrchestratorBaseBackend.from_config( config[keys.GlobalKeys.BACKEND]) # pipeline configuration p_config = config[keys.GlobalKeys.PIPELINE] pipeline_name = p_config[keys.PipelineKeys.NAME] pipeline_source = p_config[keys.PipelineKeys.SOURCE] # populate steps steps_dict: Dict = {} for step_key, step_config in p_config[keys.PipelineKeys.STEPS].items(): steps_dict[step_key] = BaseStep.from_config(step_config) # datasource datasource = BaseDatasource.from_config( config[keys.GlobalKeys.PIPELINE]) # enable cache enable_cache = p_config[keys.PipelineKeys.ENABLE_CACHE] class_ = source_utils.load_source_path_class(pipeline_source) obj = class_( name=cls.get_name_from_pipeline_name(pipeline_name), pipeline_name=pipeline_name, enable_cache=enable_cache, steps_dict=steps_dict, backend=backend, artifact_store=artifact_store, metadata_store=metadata_store, datasource=datasource) obj._immutable = True logger.debug(f'Pipeline {pipeline_name} loaded and and is immutable.') return obj
def test_get_one_pipeline(repo): name = "my_datasource" first_ds = BaseDatasource(name=name) with pytest.raises(exceptions.EmptyDatasourceException): _ = first_ds._get_one_pipeline() # reload a datasource from a saved config p_config = random.choice(repo.get_pipeline_file_paths()) cfg = yaml_utils.read_yaml(p_config) second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE]) assert second_ds._get_one_pipeline()
def test_sample_data(repo): # reload a datasource from a saved config p_config = random.choice(repo.get_pipeline_file_paths()) cfg = yaml_utils.read_yaml(p_config) ds: BaseDatasource = BaseDatasource.from_config( cfg[keys.GlobalKeys.PIPELINE]) sample_df = ds.sample_data() csv_df = pd.read_csv( os.path.join(TEST_ROOT, "test_data", "my_dataframe.csv")) # TODO: This fails for floating point values other than 2.5 in GPA. # Pandas floating point comp might be too strict assert sample_df.equals(csv_df)
def test_sample_data(): # reload a datasource from a saved config p_config = random.choice(repo.get_pipeline_file_paths()) cfg = yaml_utils.read_yaml(p_config) ds: BaseDatasource = BaseDatasource.from_config( cfg[keys.GlobalKeys.PIPELINE]) sample_df = ds.sample_data() csv_df = pd.read_csv( os.path.join(TEST_ROOT, "test_data", "my_dataframe.csv")) # TODO: This fails on the test csv because the age gets typed as # a float in datasource.sample_data() method assert sample_df.equals(csv_df)
def from_config(cls, config: Dict): """ Convert from pipeline config to ZenML Pipeline object. All steps are also populated and configuration set to parameters set in the config file. Args: config: a ZenML config in dict-form (probably loaded from YAML). """ # populate steps steps_dict: Dict = {} for step_key, step_config in config[keys.GlobalKeys.STEPS].items(): steps_dict[step_key] = BaseStep.from_config(step_config) env = config[keys.GlobalKeys.ENV] pipeline_name = env[keys.EnvironmentKeys.EXPERIMENT_NAME] name = BasePipeline.get_name_from_pipeline_name( pipeline_name=pipeline_name) backends_dict: Dict = {} for backend_key, backend_config in env[ keys.EnvironmentKeys.BACKENDS].items(): backends_dict[backend_key] = BaseBackend.from_config( backend_key, backend_config) artifact_store = ArtifactStore( env[keys.EnvironmentKeys.ARTIFACT_STORE]) metadata_store = ZenMLMetadataStore.from_config( config=env[METADATA_KEY]) datasource = BaseDatasource.from_config(config) from zenml.core.pipelines.pipeline_factory import pipeline_factory pipeline_type = BasePipeline.get_type_from_pipeline_name(pipeline_name) class_ = pipeline_factory.get_pipeline_by_type(pipeline_type) # TODO: [MEDIUM] Perhaps move some of the logic in the init block here # especially regarding inferring immutability. return class_(name=name, pipeline_name=pipeline_name, enable_cache=env[keys.EnvironmentKeys.ENABLE_CACHE], steps_dict=steps_dict, backends_dict=backends_dict, artifact_store=artifact_store, metadata_store=metadata_store, datasource=datasource)
def get_datasources(self) -> List: """ Get all datasources in this repo. Returns: list of datasources used in this repo """ from zenml.core.datasources.base_datasource import BaseDatasource datasources = [] datasources_name = set() for file_path in self.get_pipeline_file_paths(): c = yaml_utils.read_yaml(file_path) ds = BaseDatasource.from_config(c[keys.GlobalKeys.PIPELINE]) if ds and ds.name not in datasources_name: datasources.append(ds) datasources_name.add(ds.name) return datasources
def test_get_data_file_paths(repo): first_ds = BaseDatasource(name="my_datasource") first_pipeline = BasePipeline(name="my_pipeline") first_pipeline.add_datasource(first_ds) # reload a datasource from a saved config p_config = random.choice(repo.get_pipeline_file_paths()) cfg = yaml_utils.read_yaml(p_config) second_ds = BaseDatasource.from_config(cfg[keys.GlobalKeys.PIPELINE]) with pytest.raises(AssertionError): _ = second_ds._get_data_file_paths(first_pipeline) real_pipeline = second_ds._get_one_pipeline() paths = second_ds._get_data_file_paths(real_pipeline) # TODO: Find a better way of asserting TFRecords assert all(os.path.splitext(p)[-1] == ".gz" for p in paths)