def test_to_from_config(equal_zenml_configs): # TODO: This is messed up cfg1 = ZenMLConfig(repo_path=config_root) cfg2 = cfg1.from_config(cfg1.to_config(path=config_root)) assert equal_zenml_configs(cfg1, cfg2, loaded=True)
def __init__(self, path: Text = None): """ Construct reference a ZenML repository Args: path (str): Path to root of repository """ if Repository.__instance__ is None: if path is None: try: # Start from cwd and traverse up until find zenml config. path = Repository.get_zenml_dir(os.getcwd()) except Exception: # If there isnt a zenml.config, use the cwd path = os.getcwd() if not path_utils.is_dir(path): raise Exception(f'{path} does not exist or is not a dir!') self.path = path # Hook up git, path needs to have a git folder. self.git_wrapper = GitWrapper(self.path) # Load the ZenML config try: self.zenml_config = ZenMLConfig(self.path) except InitializationException: # We allow this because we of the GCP orchestrator for now self.zenml_config = None Repository.__instance__ = self else: raise Exception("You cannot create another Repository class!")
def init_repo(repo_path: Text, artifact_store_path: Text = None, metadata_store: Optional[ZenMLMetadataStore] = None, pipelines_dir: Text = None, analytics_opt_in: bool = None): """ Initializes a git repo with zenml. Args: repo_path (str): path to root of a git repo metadata_store: metadata store definition artifact_store_path (str): path where to store artifacts pipelines_dir (str): path where to store pipeline configs. analytics_opt_in (str): opt-in flag for analytics code. Raises: InvalidGitRepositoryError: If repository is not a git repository. NoSuchPathError: If the repo_path does not exist. """ # check whether its a git repo by initializing GitWrapper git_wrapper = GitWrapper(repo_path) # Do proper checks and add to .gitignore git_wrapper.add_gitignore([ZENML_DIR_NAME + '/']) # use the underlying ZenMLConfig class to create the config ZenMLConfig.to_config( repo_path, artifact_store_path, metadata_store, pipelines_dir) # create global config global_config = GlobalConfig.get_instance() if analytics_opt_in is not None: global_config.set_analytics_opt_in(analytics_opt_in)
def test_is_zenml_dir(): ok_path = config_root not_ok_path = TEST_ROOT assert ZenMLConfig.is_zenml_dir(ok_path) assert not ZenMLConfig.is_zenml_dir(not_ok_path)
def test_zenml_config_init(): # in the root initialization should work _ = ZenMLConfig(config_root) # outside of an initialized repo path with pytest.raises(InitializationException): _ = ZenMLConfig(TEST_ROOT)
def test_is_zenml_dir(): ok_path = TEST_ROOT not_ok_path = pipelines_dir assert ZenMLConfig.is_zenml_dir(ok_path) assert not ZenMLConfig.is_zenml_dir(not_ok_path)
def get_zenml_dir(path: Text): """ Recursive function to find the zenml config starting from path. Args: path (str): starting path """ if ZenMLConfig.is_zenml_dir(path): return path if path_utils.is_root(path): raise Exception( 'Looks like you used ZenML outside of a ZenML repo. ' 'Please init a ZenML repo first before you using ' 'the framework.') return Repository.get_zenml_dir(str(Path(path).parent))
class Repository: """ZenML repository definition. This is a Singleton class. Every ZenML project exists inside a ZenML repository. """ __instance__ = None def __init__(self, path: Text = None): """ Construct reference a ZenML repository Args: path (str): Path to root of repository """ if Repository.__instance__ is None: if path is None: try: # Start from cwd and traverse up until find zenml config. path = Repository.get_zenml_dir(os.getcwd()) except Exception: # If there isnt a zenml.config, use the cwd path = os.getcwd() if not path_utils.is_dir(path): raise Exception(f'{path} does not exist or is not a dir!') self.path = path # Hook up git, path needs to have a git folder. self.git_wrapper = GitWrapper(self.path) # Load the ZenML config try: self.zenml_config = ZenMLConfig(self.path) except InitializationException: # We allow this because we of the GCP orchestrator for now self.zenml_config = None Repository.__instance__ = self else: raise Exception("You cannot create another Repository class!") @staticmethod def get_zenml_dir(path: Text): """ Recursive function to find the zenml config starting from path. Args: path (str): starting path """ if ZenMLConfig.is_zenml_dir(path): return path if path_utils.is_root(path): raise Exception( 'Looks like you used ZenML outside of a ZenML repo. ' 'Please init a ZenML repo first before you using ' 'the framework.') return Repository.get_zenml_dir(str(Path(path).parent)) @staticmethod def get_instance(path: Text = None): """ Static method to fetch the current instance.""" logger.debug('Repository instance fetched.') if not Repository.__instance__: Repository(path) return Repository.__instance__ @staticmethod @track(event=CREATE_REPO) def init_repo(repo_path: Text, artifact_store_path: Text = None, metadata_store: Optional[ZenMLMetadataStore] = None, pipelines_dir: Text = None, analytics_opt_in: bool = None): """ Initializes a git repo with zenml. Args: repo_path (str): path to root of a git repo metadata_store: metadata store definition artifact_store_path (str): path where to store artifacts pipelines_dir (str): path where to store pipeline configs. analytics_opt_in (str): opt-in flag for analytics code. Raises: InvalidGitRepositoryError: If repository is not a git repository. NoSuchPathError: If the repo_path does not exist. """ # check whether its a git repo by initializing GitWrapper git_wrapper = GitWrapper(repo_path) # Do proper checks and add to .gitignore git_wrapper.add_gitignore([ZENML_DIR_NAME + '/']) # use the underlying ZenMLConfig class to create the config ZenMLConfig.to_config( repo_path, artifact_store_path, metadata_store, pipelines_dir) # create global config global_config = GlobalConfig.get_instance() if analytics_opt_in is not None: global_config.set_analytics_opt_in(analytics_opt_in) def get_default_artifact_store(self) -> Optional[ArtifactStore]: self._check_if_initialized() return self.zenml_config.get_artifact_store() def get_default_metadata_store(self): self._check_if_initialized() return self.zenml_config.get_metadata_store() def get_default_pipelines_dir(self) -> Text: self._check_if_initialized() return self.zenml_config.get_pipelines_dir() def get_git_wrapper(self) -> GitWrapper: self._check_if_initialized() return self.git_wrapper @track(event=GET_STEP_VERSION) def get_step_by_version(self, step_type: Union[Type, Text], version: Text): """ Gets a Step object by version. There might be many objects of a particular Step registered in many pipelines. This function just returns the first configuration that it matches. Args: step_type: either a string specifying full path to the step or a class path. version: either sha pin or standard ZenML version pin. """ from zenml.utils import source_utils from zenml.core.steps.base_step import BaseStep type_str = source_utils.get_module_path_from_class(step_type) for file_path in self.get_pipeline_file_paths(): c = yaml_utils.read_yaml(file_path) for step_name, step_config in c[keys.GlobalKeys.PIPELINE][ keys.PipelineKeys.STEPS].items(): # Get version from source class_ = source_utils.get_class_path_from_source( step_config[keys.StepKeys.SOURCE]) source_version = source_utils.get_version_from_source( step_config[keys.StepKeys.SOURCE]) if class_ == type_str and version == source_version: return BaseStep.from_config(step_config) def get_step_versions_by_type(self, step_type: Union[Type, Text]): """ List all registered steps in repository by step_type. Args: step_type: either a string specifying full path to the step or a class path. """ from zenml.utils import source_utils type_str = source_utils.get_module_path_from_class(step_type) steps_dict = self.get_step_versions() if type_str not in steps_dict: logger.warning(f'Type {type_str} not available. Available types: ' f'{list(steps_dict.keys())}') return return steps_dict[type_str] @track(event=GET_STEPS_VERSIONS) def get_step_versions(self): """List all registered steps in repository""" from zenml.utils import source_utils steps_dict = {} for file_path in self.get_pipeline_file_paths(): c = yaml_utils.read_yaml(file_path) for step_name, step_config in c[keys.GlobalKeys.PIPELINE][ keys.PipelineKeys.STEPS].items(): # Get version from source version = source_utils.get_version_from_source( step_config[keys.StepKeys.SOURCE]) class_ = source_utils.get_class_path_from_source( step_config[keys.StepKeys.SOURCE]) # Add to set of versions if class_ in steps_dict: steps_dict[class_].add(version) else: steps_dict[class_] = {version} return steps_dict def get_datasource_by_name(self, name: Text): """ Get all datasources in this repo. Returns: list of datasources used in this repo """ all_datasources = self.get_datasources() for d in all_datasources: if name == d.name: return d def get_datasource_id_by_name(self, name: Text) -> List: """ Get ID of a datasource by just its name. Returns: ID of datasource. """ for file_path in self.get_pipeline_file_paths(): c = yaml_utils.read_yaml(file_path) src = c[keys.GlobalKeys.PIPELINE][keys.PipelineKeys.DATASOURCE] if keys.DatasourceKeys.NAME in src: if name == src[keys.DatasourceKeys.NAME]: return src[keys.DatasourceKeys.ID] def get_datasource_names(self) -> List: """ Get all datasources in this repo. Returns: List of datasource names used in this repo. """ n = [] for file_path in self.get_pipeline_file_paths(): c = yaml_utils.read_yaml(file_path) src = c[keys.GlobalKeys.PIPELINE][keys.PipelineKeys.DATASOURCE] if keys.DatasourceKeys.NAME in src: n.append(src[keys.DatasourceKeys.NAME]) return list(set(n)) @track(event=GET_DATASOURCES) def get_datasources(self) -> List: """ Get all datasources in this repo. Returns: list of datasources used in this repo """ from zenml.core.datasources.base_datasource import BaseDatasource datasources = [] datasources_name = set() for file_path in self.get_pipeline_file_paths(): c = yaml_utils.read_yaml(file_path) ds = BaseDatasource.from_config(c[keys.GlobalKeys.PIPELINE]) if ds and ds.name not in datasources_name: datasources.append(ds) datasources_name.add(ds.name) return datasources def get_pipeline_by_name(self, pipeline_name: Text = None): """ Loads a pipeline just by its name. Args: pipeline_name (str): Name of pipeline. """ from zenml.core.pipelines.base_pipeline import BasePipeline yamls = self.get_pipeline_file_paths() for y in yamls: n = BasePipeline.get_name_from_pipeline_name(os.path.basename(y)) if n == pipeline_name: c = yaml_utils.read_yaml(y) return BasePipeline.from_config(c) def get_pipelines_by_type(self, type_filter: List[Text]) -> List: """ Gets list of pipelines filtered by type. Args: type_filter (list): list of types to filter by. """ pipelines = self.get_pipelines() return [p for p in pipelines if p.PIPELINE_TYPE in type_filter] def get_pipeline_names(self) -> Optional[List[Text]]: """Gets list of pipeline (unique) names""" from zenml.core.pipelines.base_pipeline import BasePipeline yamls = self.get_pipeline_file_paths(only_file_names=True) return [BasePipeline.get_name_from_pipeline_name(p) for p in yamls] def get_pipeline_file_paths(self, only_file_names: bool = False) -> \ Optional[List[Text]]: """Gets list of pipeline file path""" self._check_if_initialized() pipelines_dir = self.zenml_config.get_pipelines_dir() if not path_utils.is_dir(pipelines_dir): return [] all_files = path_utils.list_dir(pipelines_dir, only_file_names) return [x for x in all_files if yaml_utils.is_yaml(x)] def get_pipelines_by_datasource(self, datasource): """ Gets list of pipelines associated with datasource. Args: datasource (BaseDatasource): object of type BaseDatasource. """ from zenml.core.pipelines.base_pipeline import BasePipeline pipelines = [] for file_path in self.get_pipeline_file_paths(): c = yaml_utils.read_yaml(file_path) if keys.DatasourceKeys.ID in c[keys.GlobalKeys.PIPELINE][ keys.PipelineKeys.DATASOURCE]: if c[keys.GlobalKeys.PIPELINE][keys.PipelineKeys.DATASOURCE][ keys.DatasourceKeys.ID] == datasource._id: pipelines.append(BasePipeline.from_config(c)) return pipelines @track(event=GET_PIPELINES) def get_pipelines(self) -> List: """Gets list of all pipelines.""" from zenml.core.pipelines.base_pipeline import BasePipeline pipelines = [] for file_path in self.get_pipeline_file_paths(): c = yaml_utils.read_yaml(file_path) pipelines.append(BasePipeline.from_config(c)) return pipelines @track(event=REGISTER_PIPELINE) def register_pipeline(self, file_name: Text, config: Dict[Text, Any]): """ Registers a pipeline in the artifact store as a YAML file. Args: file_name (str): file name of pipeline config (dict): dict representation of ZenML config """ self._check_if_initialized() pipelines_dir = self.zenml_config.get_pipelines_dir() # Create dir path_utils.create_dir_if_not_exists(pipelines_dir) # Write yaml_utils.write_yaml(os.path.join(pipelines_dir, file_name), config) def load_pipeline_config(self, file_name: Text) -> Dict[Text, Any]: """ Loads a ZenML config from YAML. Args: file_name (str): file name of pipeline """ self._check_if_initialized() pipelines_dir = self.zenml_config.get_pipelines_dir() return yaml_utils.read_yaml(os.path.join(pipelines_dir, file_name)) def compare_training_pipelines(self): """Launch the compare app for all training pipelines in repo""" from zenml.utils.post_training.post_training_utils import \ launch_compare_tool launch_compare_tool() def clean(self): """Deletes associated metadata store, pipelines dir and artifacts""" raise NotImplementedError def _check_if_initialized(self): if self.zenml_config is None: raise InitializationException
def test_zenml_config_setters(equal_md_stores): cfg1 = ZenMLConfig(repo_path=config_root) old_store_path = artifact_store_path old_pipelines_dir = pipelines_dir old_sqlite = cfg1.get_metadata_store() new_store_path = os.getcwd() new_pipelines_dir = "awfkoeghelk" new_mdstore = MockMetadataStore() cfg1.set_artifact_store(new_store_path) cfg1.set_pipelines_dir(new_pipelines_dir) cfg1.set_metadata_store(new_mdstore) updated_cfg = yaml_utils.read_yaml(cfg1.config_path) loaded_md_store = MockMetadataStore.from_config( updated_cfg[keys.GlobalKeys.METADATA_STORE]) assert updated_cfg[keys.GlobalKeys.ARTIFACT_STORE] == new_store_path assert updated_cfg[PIPELINES_DIR_KEY] == new_pipelines_dir assert equal_md_stores(new_mdstore, loaded_md_store) # revert changes cfg1.set_artifact_store(old_store_path) cfg1.set_pipelines_dir(old_pipelines_dir) cfg1.set_metadata_store(old_sqlite) shutil.rmtree(new_pipelines_dir, ignore_errors=False)
def test_zenml_config_getters(): cfg1 = ZenMLConfig(repo_path=config_root) assert cfg1.get_pipelines_dir() assert cfg1.get_artifact_store() assert cfg1.get_metadata_store()