def __get_entities_from_config(self, stream): """Get entities found in config.yaml Args: stream (bytes): The stream of config.yaml file. Returns: list of objects of class Entity. """ config_yaml = yaml_load_str(stream) config = Config(config_yaml) entities = [] for remote in config.remotes.values(): if not remote: continue repository = self._manager.find_repository(remote) for spec_path in self._manager.search_file(repository, SPEC_EXTENSION): spec_yaml = yaml_load_str( self._manager.get_file_content(repository, spec_path)) entity = Entity(repository, spec_yaml) entities.append(entity) self._manager.alert_rate_limits() return entities
def _get_linked_entities(self, name, version, spec_path, repository): """Get a list of linked entities found for an entity version. Args: name (str): The name of the entity you want to get the linked entities. version (str): The version of the entity you want to get the linked entities. spec_path (str): The path where the entity spec is located. repository (Repository): The instance of github.Repository.Repository. Returns: list of LinkedEntity. """ for tag in repository.get_tags(): if tag.name.split('__')[-2] != name or tag.name.split( '__')[-1] != str(version): continue content = self._manager.get_file_content(repository, spec_path, tag.name) if not content: continue spec_tag_yaml = yaml_load_str(content) entity = SpecVersion(spec_tag_yaml) return entity.get_related_entities_info()
def test_yaml_load_str(self): obj = yaml_load_str(self.yaml_str_sample) self.assertEqual( obj[STORAGE_CONFIG_KEY][S3H]['bucket_test']['aws-credentials'] ['profile'], 'profile_test') self.assertEqual(obj[STORAGE_CONFIG_KEY][S3H]['bucket_test']['region'], 'region_test')
def _get_entity_versions(self, name, spec_path, repository): """Get a list of spec versions found for a specific entity. Args: name (str): The name of the entity you want to get the versions. spec_path (str): The path where the entity spec is located. repository (Repository): The instance of github.Repository.Repository. Returns: list of class SpecVersion. """ versions = [] for tag in repository.get_tags(): if tag.name.split('__')[-2] != name: continue content = self._manager.get_file_content(repository, spec_path, tag.name) if not content: continue spec_tag_yaml = yaml_load_str(content) spec_version = SpecVersion(spec_tag_yaml) versions.append(spec_version) return versions
def get_entities(self): """Get a list of entities found in config.yaml. Returns: list of class Entity. """ entities = [] metadata_repository = namedtuple( 'Repository', ['private', 'full_name', 'ssh_url', 'html_url', 'owner']) metadata_owner = namedtuple('Owner', ['email', 'name']) try: for type_entity in EntityType: self.__init_manager(type_entity.value) if not self._manager: continue repository = metadata_repository(False, '', '', '', metadata_owner('', '')) for obj in Repo( self._manager.path).head.commit.tree.traverse(): if SPEC_EXTENSION in obj.name: entity_spec = yaml_load_str( io.BytesIO(obj.data_stream.read())) entity = Entity(repository, entity_spec) if entity.type in type_entity.value and entity not in entities: entities.append(entity) except Exception as error: log.debug( output_messages['DEBUG_ENTITIES_RELATIONSHIP'].format(error), class_name=LocalEntityManager.__name__) return entities
def _get_spec_content(self, spec, sha): entity_dir = get_entity_dir(self.__repo_type, spec, root_path=self.__path) spec_path = '/'.join([posix_path(entity_dir), spec + SPEC_EXTENSION]) return yaml_load_str(self._get_spec_content_from_ref(sha, spec_path))
def test_yaml_load_str(self): obj = yaml_load_str(self.yaml_str_sample) self.assertEqual( obj['store']['s3h']['bucket_test']['aws-credentials']['profile'], 'profile_test') self.assertEqual(obj['store']['s3h']['bucket_test']['region'], 'region_test')
def get_sample_config_spec(bucket, profile, region): doc = ''' %s: s3h: %s: aws-credentials: profile: %s region: %s ''' % (STORAGE_CONFIG_KEY, bucket, profile, region) c = yaml_load_str(doc) return c
def get_sample_config_spec(bucket, profile, region): doc = ''' store: s3h: %s: aws-credentials: profile: %s region: %s ''' % (bucket, profile, region) c = yaml_load_str(doc) return c
def get_specs_to_compare(self, spec): entity = self.__repo_type spec_manifest_key = 'manifest' tags = self.list_tags(spec, True) entity_dir = get_entity_dir(entity, spec, root_path=self.__path) spec_path = '/'.join([posix_path(entity_dir), spec + SPEC_EXTENSION]) for tag in tags: current_ref = tag.commit parents = current_ref.parents base_spec = {entity: {spec_manifest_key: {}}} if parents: base_ref = parents[0] base_spec = yaml_load_str( self._get_spec_content_from_ref(base_ref, spec_path)) current_spec = yaml_load_str( self._get_spec_content_from_ref(current_ref, spec_path)) yield current_spec[entity][spec_manifest_key], base_spec[entity][ spec_manifest_key]
def test_get_spec_content_from_ref(self): mdpath = os.path.join(self.test_dir, 'mdata', DATASETS, 'metadata') specpath = 'dataset-ex' m = Metadata(specpath, self.test_dir, config, DATASETS) m.init() ensure_path_exists(os.path.join(mdpath, specpath)) spec_metadata_path = os.path.join(mdpath, specpath) + '/dataset-ex.spec' shutil.copy('hdata/dataset-ex.spec', spec_metadata_path) sha = m.commit(spec_metadata_path, specpath) tag = m.tag_add(sha) path = 'dataset-ex/dataset-ex.spec' content = yaml_load_str(m._get_spec_content_from_ref(tag.commit, path)) spec_file = yaml_load(spec_metadata_path) self.assertEqual(content, spec_file)
def get_linked_entities(self, name, version, type_entity): """Get a list of linked entities found for an entity version. Args: name (str): The name of the entity you want to get the linked entities. version (str): The version of the entity you want to get the linked entities. type_entity (str): The type of the ml-entity (datasets, models, labels). Returns: list of LinkedEntity. """ self.__init_manager(type_entity) if not self._manager: return for content in self.__get_spec_each_tag(name, version): spec_tag_yaml = yaml_load_str(content) entity = SpecVersion(spec_tag_yaml) return entity.get_related_entities_info()
def get_project_entities_relationships(self, config_repo_name, export_type=FileType.JSON.value, export_path=None): """Get a list of relationships for all project entities. Args: config_repo_name (str): The repository name where the config is located in GitHub. export_type (str): Set the format of the return (json, csv, dot) [default: json]. export_path (str): Set the path to export metrics to a file. Returns: list of EntityVersionRelationships. """ project_entities = self.get_entities(config_repo_name=config_repo_name) config_repo = self._manager.find_repository(config_repo_name) if config_repo is None or self.__is_config_repo(config_repo) is False: return config_bytes = self._manager.get_file_content(config_repo, self.MLGIT_CONFIG_FILE) config_yaml = yaml_load_str(config_bytes) config = Config(config_yaml) all_relationships = {} for entity in project_entities: entity_relationships = self.get_entity_relationships( entity.name, config.get_entity_type_remote(entity.type)) all_relationships[entity.name] = entity_relationships[entity.name] if export_type == FileType.CSV.value: all_relationships = export_relationships_to_csv( project_entities, all_relationships, export_path) elif export_type == FileType.DOT.value: all_relationships = export_relationships_to_dot( project_entities, all_relationships, export_path) self._manager.alert_rate_limits() return all_relationships
def get_entity_versions(self, name, type_entity): """Get a list of spec versions found for an specific entity. Args: name (str): The name of the entity you want to get the versions. type_entity (str): The type of the ml-entity (datasets, models, labels). Returns: list of class SpecVersion. """ self.__init_manager(type_entity) if not self._manager: return versions = [] for content in self.__get_spec_each_tag(name): spec_tag_yaml = yaml_load_str(content) spec_version = SpecVersion(spec_tag_yaml) versions.append(spec_version) return sorted(versions, key=lambda k: k.version, reverse=True)
def get_sample_spec(bucket, repotype=DATASET_SPEC_KEY): c = yaml_load_str(get_sample_spec_doc(bucket, repotype)) return c
def get_sample_spec(bucket, repotype='dataset'): c = yaml_load_str(get_sample_spec_doc(bucket, repotype)) return c