def test_collect_files(self): """ Testing the collection of files from file tree. """ from aiida.tools.dbexporters.tcod import _collect_files from aiida.common.folders import SandboxFolder import StringIO sf = SandboxFolder() sf.get_subfolder('out', create=True) sf.get_subfolder('pseudo', create=True) sf.get_subfolder('save', create=True) sf.get_subfolder('save/1', create=True) sf.get_subfolder('save/2', create=True) f = StringIO.StringIO("test") sf.create_file_from_filelike(f, 'aiida.in') f = StringIO.StringIO("test") sf.create_file_from_filelike(f, 'aiida.out') f = StringIO.StringIO("test") sf.create_file_from_filelike(f, '_aiidasubmit.sh') f = StringIO.StringIO("test") sf.create_file_from_filelike(f, '_.out') f = StringIO.StringIO("test") sf.create_file_from_filelike(f, 'out/out') f = StringIO.StringIO("test") sf.create_file_from_filelike(f, 'save/1/log.log') md5 = '098f6bcd4621d373cade4e832627b4f6' sha1 = 'a94a8fe5ccb19ba61c4c0873d391e987982fbbd3' self.assertEquals( _collect_files(sf.abspath), [{'name': '_.out', 'contents': 'test', 'md5': md5, 'sha1': sha1, 'type': 'file'}, {'name': '_aiidasubmit.sh', 'contents': 'test', 'md5': md5, 'sha1': sha1, 'type': 'file'}, {'name': 'aiida.in', 'contents': 'test', 'md5': md5, 'sha1': sha1, 'type': 'file'}, {'name': 'aiida.out', 'contents': 'test', 'md5': md5, 'sha1': sha1, 'type': 'file'}, {'name': 'out/', 'type': 'folder'}, {'name': 'out/out', 'contents': 'test', 'md5': md5, 'sha1': sha1, 'type': 'file'}, {'name': 'pseudo/', 'type': 'folder'}, {'name': 'save/', 'type': 'folder'}, {'name': 'save/1/', 'type': 'folder'}, {'name': 'save/1/log.log', 'contents': 'test', 'md5': md5, 'sha1': sha1, 'type': 'file'}, {'name': 'save/2/', 'type': 'folder'}])
class ReaderJsonBase(ArchiveReaderAbstract): """A reader base for the JSON compressed formats.""" FILENAME_DATA = 'data.json' FILENAME_METADATA = 'metadata.json' REPO_FOLDER = NODES_EXPORT_SUBFOLDER def __init__(self, filename: str, sandbox_in_repo: bool = False, **kwargs: Any): """A reader for JSON compressed archives. :param filename: the filename (possibly including the absolute path) of the file on which to export. :param sandbox_in_repo: Create the temporary uncompressed folder within the aiida repository """ super().__init__(filename, **kwargs) self._metadata = None self._data = None # a temporary folder used to extract the file tree self._sandbox: Optional[SandboxFolder] = None self._sandbox_in_repo = sandbox_in_repo @property def file_format_verbose(self) -> str: raise NotImplementedError() @property def compatible_export_version(self) -> str: return EXPORT_VERSION def __enter__(self): super().__enter__() self._sandbox = SandboxFolder(self._sandbox_in_repo) return self def __exit__(self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType]): self._sandbox.erase() # type: ignore self._sandbox = None self._metadata = None self._data = None super().__exit__(exctype, excinst, exctb) def _get_metadata(self): """Retrieve the metadata JSON.""" raise NotImplementedError() def _get_data(self): """Retrieve the data JSON.""" raise NotImplementedError() def _extract(self, *, path_prefix: str, callback: Callable[[str, Any], None]): """Extract repository data to a temporary folder. :param path_prefix: Only extract paths starting with this prefix. :param callback: a callback to report on the process, ``callback(action, value)``, with the following callback signatures: - ``callback('init', {'total': <int>, 'description': <str>})``, to signal the start of a process, its total iterations and description - ``callback('update', <int>)``, to signal an update to the process and the number of iterations to progress :raises TypeError: if parameter types are not respected """ raise NotImplementedError() @property def export_version(self) -> str: metadata = self._get_metadata() if 'export_version' not in metadata: raise CorruptArchive('export_version missing from metadata.json') return metadata['export_version'] @property def metadata(self) -> ArchiveMetadata: metadata = self._get_metadata() export_parameters = metadata.get('export_parameters', {}) output = { 'export_version': metadata['export_version'], 'aiida_version': metadata['aiida_version'], 'all_fields_info': metadata['all_fields_info'], 'unique_identifiers': metadata['unique_identifiers'], 'graph_traversal_rules': export_parameters.get('graph_traversal_rules', None), 'entities_starting_set': export_parameters.get('entities_starting_set', None), 'include_comments': export_parameters.get('include_comments', None), 'include_logs': export_parameters.get('include_logs', None), 'conversion_info': metadata.get('conversion_info', []) } try: return ArchiveMetadata(**output) except TypeError as error: raise CorruptArchive(f'Metadata invalid: {error}') def entity_count(self, name: str) -> int: data = self._get_data().get('export_data', {}).get(name, {}) return len(data) @property def link_count(self) -> int: return len(self._get_data()['links_uuid']) def iter_entity_fields( self, name: str, fields: Optional[Tuple[str, ...]] = None ) -> Iterator[Tuple[int, Dict[str, Any]]]: if name not in self.entity_names: raise ValueError(f'Unknown entity name: {name}') data = self._get_data()['export_data'].get(name, {}) if name == NODE_ENTITY_NAME: # here we merge in the attributes and extras before yielding attributes = self._get_data().get('node_attributes', {}) extras = self._get_data().get('node_extras', {}) for pk, all_fields in data.items(): if pk not in attributes: raise CorruptArchive( f'Unable to find attributes info for Node with Pk={pk}' ) if pk not in extras: raise CorruptArchive( f'Unable to find extra info for Node with Pk={pk}') all_fields = { **all_fields, **{ 'attributes': attributes[pk], 'extras': extras[pk] } } if fields is not None: all_fields = { k: v for k, v in all_fields.items() if k in fields } yield int(pk), all_fields else: for pk, all_fields in data.items(): if fields is not None: all_fields = { k: v for k, v in all_fields.items() if k in fields } yield int(pk), all_fields def iter_node_uuids(self) -> Iterator[str]: for _, fields in self.iter_entity_fields(NODE_ENTITY_NAME, fields=('uuid', )): yield fields['uuid'] def iter_group_uuids(self) -> Iterator[Tuple[str, Set[str]]]: group_uuids = self._get_data()['groups_uuid'] for _, fields in self.iter_entity_fields(GROUP_ENTITY_NAME, fields=('uuid', )): key = fields['uuid'] yield key, set(group_uuids.get(key, set())) def iter_link_data(self) -> Iterator[dict]: for value in self._get_data()['links_uuid']: yield value def iter_node_repos( self, uuids: Iterable[str], callback: Callable[[str, Any], None] = null_callback, ) -> Iterator[Folder]: path_prefixes = [ os.path.join(self.REPO_FOLDER, export_shard_uuid(uuid)) for uuid in uuids ] if not path_prefixes: return self.assert_within_context() assert self._sandbox is not None # required by mypy # unarchive the common folder if it does not exist common_prefix = os.path.commonpath(path_prefixes) if not self._sandbox.get_subfolder(common_prefix).exists(): self._extract(path_prefix=common_prefix, callback=callback) callback( 'init', { 'total': len(path_prefixes), 'description': 'Iterating node repositories' }) for uuid, path_prefix in zip(uuids, path_prefixes): callback('update', 1) subfolder = self._sandbox.get_subfolder(path_prefix) if not subfolder.exists(): raise CorruptArchive( f'Unable to find the repository folder for Node with UUID={uuid} in the exported file' ) yield subfolder