Ejemplo n.º 1
0
    def test_collect_files(self):
        """
        Testing the collection of files from file tree.
        """
        from aiida.tools.dbexporters.tcod import _collect_files
        from aiida.common.folders import SandboxFolder
        import StringIO

        sf = SandboxFolder()
        sf.get_subfolder('out', create=True)
        sf.get_subfolder('pseudo', create=True)
        sf.get_subfolder('save', create=True)
        sf.get_subfolder('save/1', create=True)
        sf.get_subfolder('save/2', create=True)

        f = StringIO.StringIO("test")
        sf.create_file_from_filelike(f, 'aiida.in')
        f = StringIO.StringIO("test")
        sf.create_file_from_filelike(f, 'aiida.out')
        f = StringIO.StringIO("test")
        sf.create_file_from_filelike(f, '_aiidasubmit.sh')
        f = StringIO.StringIO("test")
        sf.create_file_from_filelike(f, '_.out')
        f = StringIO.StringIO("test")
        sf.create_file_from_filelike(f, 'out/out')
        f = StringIO.StringIO("test")
        sf.create_file_from_filelike(f, 'save/1/log.log')

        md5 = '098f6bcd4621d373cade4e832627b4f6'
        sha1 = 'a94a8fe5ccb19ba61c4c0873d391e987982fbbd3'
        self.assertEquals(
            _collect_files(sf.abspath),
            [{'name': '_.out', 'contents': 'test', 'md5': md5,
              'sha1': sha1, 'type': 'file'},
             {'name': '_aiidasubmit.sh', 'contents': 'test', 'md5': md5,
              'sha1': sha1, 'type': 'file'},
             {'name': 'aiida.in', 'contents': 'test', 'md5': md5,
              'sha1': sha1, 'type': 'file'},
             {'name': 'aiida.out', 'contents': 'test', 'md5': md5,
              'sha1': sha1, 'type': 'file'},
             {'name': 'out/', 'type': 'folder'},
             {'name': 'out/out', 'contents': 'test', 'md5': md5,
              'sha1': sha1, 'type': 'file'},
             {'name': 'pseudo/', 'type': 'folder'},
             {'name': 'save/', 'type': 'folder'},
             {'name': 'save/1/', 'type': 'folder'},
             {'name': 'save/1/log.log', 'contents': 'test', 'md5': md5,
              'sha1': sha1, 'type': 'file'},
             {'name': 'save/2/', 'type': 'folder'}])
Ejemplo n.º 2
0
class ReaderJsonBase(ArchiveReaderAbstract):
    """A reader base for the JSON compressed formats."""

    FILENAME_DATA = 'data.json'
    FILENAME_METADATA = 'metadata.json'
    REPO_FOLDER = NODES_EXPORT_SUBFOLDER

    def __init__(self,
                 filename: str,
                 sandbox_in_repo: bool = False,
                 **kwargs: Any):
        """A reader for JSON compressed archives.

        :param filename: the filename (possibly including the absolute path)
            of the file on which to export.
        :param sandbox_in_repo: Create the temporary uncompressed folder within the aiida repository

        """
        super().__init__(filename, **kwargs)
        self._metadata = None
        self._data = None
        # a temporary folder used to extract the file tree
        self._sandbox: Optional[SandboxFolder] = None
        self._sandbox_in_repo = sandbox_in_repo

    @property
    def file_format_verbose(self) -> str:
        raise NotImplementedError()

    @property
    def compatible_export_version(self) -> str:
        return EXPORT_VERSION

    def __enter__(self):
        super().__enter__()
        self._sandbox = SandboxFolder(self._sandbox_in_repo)
        return self

    def __exit__(self, exctype: Optional[Type[BaseException]],
                 excinst: Optional[BaseException],
                 exctb: Optional[TracebackType]):
        self._sandbox.erase()  # type: ignore
        self._sandbox = None
        self._metadata = None
        self._data = None
        super().__exit__(exctype, excinst, exctb)

    def _get_metadata(self):
        """Retrieve the metadata JSON."""
        raise NotImplementedError()

    def _get_data(self):
        """Retrieve the data JSON."""
        raise NotImplementedError()

    def _extract(self, *, path_prefix: str, callback: Callable[[str, Any],
                                                               None]):
        """Extract repository data to a temporary folder.

        :param path_prefix: Only extract paths starting with this prefix.
        :param callback: a callback to report on the process, ``callback(action, value)``,
            with the following callback signatures:

            - ``callback('init', {'total': <int>, 'description': <str>})``,
               to signal the start of a process, its total iterations and description
            - ``callback('update', <int>)``,
               to signal an update to the process and the number of iterations to progress

        :raises TypeError: if parameter types are not respected
        """
        raise NotImplementedError()

    @property
    def export_version(self) -> str:
        metadata = self._get_metadata()
        if 'export_version' not in metadata:
            raise CorruptArchive('export_version missing from metadata.json')
        return metadata['export_version']

    @property
    def metadata(self) -> ArchiveMetadata:
        metadata = self._get_metadata()
        export_parameters = metadata.get('export_parameters', {})
        output = {
            'export_version':
            metadata['export_version'],
            'aiida_version':
            metadata['aiida_version'],
            'all_fields_info':
            metadata['all_fields_info'],
            'unique_identifiers':
            metadata['unique_identifiers'],
            'graph_traversal_rules':
            export_parameters.get('graph_traversal_rules', None),
            'entities_starting_set':
            export_parameters.get('entities_starting_set', None),
            'include_comments':
            export_parameters.get('include_comments', None),
            'include_logs':
            export_parameters.get('include_logs', None),
            'conversion_info':
            metadata.get('conversion_info', [])
        }
        try:
            return ArchiveMetadata(**output)
        except TypeError as error:
            raise CorruptArchive(f'Metadata invalid: {error}')

    def entity_count(self, name: str) -> int:
        data = self._get_data().get('export_data', {}).get(name, {})
        return len(data)

    @property
    def link_count(self) -> int:
        return len(self._get_data()['links_uuid'])

    def iter_entity_fields(
        self,
        name: str,
        fields: Optional[Tuple[str, ...]] = None
    ) -> Iterator[Tuple[int, Dict[str, Any]]]:
        if name not in self.entity_names:
            raise ValueError(f'Unknown entity name: {name}')
        data = self._get_data()['export_data'].get(name, {})
        if name == NODE_ENTITY_NAME:
            # here we merge in the attributes and extras before yielding
            attributes = self._get_data().get('node_attributes', {})
            extras = self._get_data().get('node_extras', {})
            for pk, all_fields in data.items():
                if pk not in attributes:
                    raise CorruptArchive(
                        f'Unable to find attributes info for Node with Pk={pk}'
                    )
                if pk not in extras:
                    raise CorruptArchive(
                        f'Unable to find extra info for Node with Pk={pk}')
                all_fields = {
                    **all_fields,
                    **{
                        'attributes': attributes[pk],
                        'extras': extras[pk]
                    }
                }
                if fields is not None:
                    all_fields = {
                        k: v
                        for k, v in all_fields.items() if k in fields
                    }
                yield int(pk), all_fields
        else:
            for pk, all_fields in data.items():
                if fields is not None:
                    all_fields = {
                        k: v
                        for k, v in all_fields.items() if k in fields
                    }
                yield int(pk), all_fields

    def iter_node_uuids(self) -> Iterator[str]:
        for _, fields in self.iter_entity_fields(NODE_ENTITY_NAME,
                                                 fields=('uuid', )):
            yield fields['uuid']

    def iter_group_uuids(self) -> Iterator[Tuple[str, Set[str]]]:
        group_uuids = self._get_data()['groups_uuid']
        for _, fields in self.iter_entity_fields(GROUP_ENTITY_NAME,
                                                 fields=('uuid', )):
            key = fields['uuid']
            yield key, set(group_uuids.get(key, set()))

    def iter_link_data(self) -> Iterator[dict]:
        for value in self._get_data()['links_uuid']:
            yield value

    def iter_node_repos(
        self,
        uuids: Iterable[str],
        callback: Callable[[str, Any], None] = null_callback,
    ) -> Iterator[Folder]:
        path_prefixes = [
            os.path.join(self.REPO_FOLDER, export_shard_uuid(uuid))
            for uuid in uuids
        ]

        if not path_prefixes:
            return
        self.assert_within_context()
        assert self._sandbox is not None  # required by mypy

        # unarchive the common folder if it does not exist
        common_prefix = os.path.commonpath(path_prefixes)
        if not self._sandbox.get_subfolder(common_prefix).exists():
            self._extract(path_prefix=common_prefix, callback=callback)

        callback(
            'init', {
                'total': len(path_prefixes),
                'description': 'Iterating node repositories'
            })
        for uuid, path_prefix in zip(uuids, path_prefixes):
            callback('update', 1)
            subfolder = self._sandbox.get_subfolder(path_prefix)
            if not subfolder.exists():
                raise CorruptArchive(
                    f'Unable to find the repository folder for Node with UUID={uuid} in the exported file'
                )
            yield subfolder