예제 #1
0
    def load(self,
             source: Union[str, BinaryIO, "Model"] = None,
             cache_dir: str = None,
             backend: StorageBackend = None,
             lazy=False) -> "Model":
        """
        Build a new Model instance.

        :param source: UUID, file system path, file object or an URL; None means auto.
        :param cache_dir: The directory where to store the downloaded model.
        :param backend: Remote storage backend to use if ``source`` is a UUID or a URL.
        :param lazy: Do not really load numpy arrays into memory. Instead, mmap() them. \
                     User is expected to call Model.close() when the tree is no longer needed.
        """
        if isinstance(source, Model):
            if not isinstance(source, type(self)):
                raise TypeError("Incompatible model instance: %s <> %s" %
                                (type(source), type(self)))
            self.__dict__ = source.__dict__
            return self

        if backend is not None and not isinstance(backend, StorageBackend):
            raise TypeError("backend must be an instance of "
                            "modelforge.storage_backend.StorageBackend")
        self._source = str(source)
        generic = self.NAME == self.GENERIC_NAME
        try:
            if source is None or (isinstance(source, str)
                                  and not os.path.isfile(source)):
                if cache_dir is None:
                    if not generic:
                        cache_dir = os.path.join(self.cache_dir(), self.NAME)
                    else:
                        cache_dir = tempfile.mkdtemp(prefix="modelforge-")
                try:
                    uuid.UUID(source)
                    is_uuid = True
                except (TypeError, ValueError):
                    is_uuid = False
                model_id = self.DEFAULT_NAME if not is_uuid else source
                file_name = model_id + self.DEFAULT_FILE_EXT
                file_name = os.path.join(os.path.expanduser(cache_dir),
                                         file_name)
                if os.path.exists(file_name) and (not source or
                                                  not os.path.exists(source)):
                    source = file_name
                elif source is None or is_uuid:
                    if backend is None:
                        raise ValueError(
                            "The backend must be set to load a UUID or the default "
                            "model.")
                    index = backend.index.contents
                    config = index["models"]
                    if not generic:
                        if not is_uuid:
                            model_id = index["meta"][self.NAME][model_id]
                        source = config[self.NAME][model_id]
                    else:
                        if not is_uuid:
                            raise ValueError(
                                "File path, URL or UUID is needed.")
                        for models in config.values():
                            if source in models:
                                source = models[source]
                                break
                        else:
                            raise FileNotFoundError("Model %s not found." %
                                                    source)
                    source = source["url"]
                if re.match(r"\w+://", source):
                    if backend is None:
                        raise ValueError(
                            "The backend must be set to load a URL.")
                    backend.fetch_model(source, file_name)
                    self._source = source
                    source = file_name
            if isinstance(source, str):
                size = os.stat(source).st_size
            else:
                self._source = "<file object>"
                pos = source.tell()
                size = source.seek(0, os.SEEK_END) - pos
                source.seek(pos, os.SEEK_SET)
            self._log.info("Reading %s (%s)...", source,
                           humanize.naturalsize(size))
            model = asdf.open(source, copy_arrays=not lazy, lazy_load=lazy)
            try:
                tree = model.tree
                self._meta = tree["meta"]
                self._initial_version = list(self.version)
                if not generic:
                    meta_name = self._meta["model"]
                    matched = self.NAME == meta_name
                    if not matched:
                        needed = {self.NAME}
                        for child in type(self).__subclasses__():
                            needed.add(child.NAME)
                            matched |= child.NAME == meta_name
                        if not matched:
                            raise ValueError(
                                "The supplied model is of the wrong type: needed "
                                "%s, got %s." % (needed, meta_name))
                self._load_tree(tree)
            finally:
                if not lazy:
                    model.close()
                else:
                    self._asdf = model
        finally:
            if generic and cache_dir is not None:
                shutil.rmtree(cache_dir)
        self._size = size
        return self
예제 #2
0
파일: model.py 프로젝트: afcarl/modelforge
    def load(self,
             source: Union[str, BinaryIO, "Model"] = None,
             cache_dir: str = None,
             backend: StorageBackend = None) -> "Model":
        """
        Initializes a new Model instance.
        :param source: UUID, file system path, file object or an URL; None means auto.
        :param cache_dir: The directory where to store the downloaded model.
        :param backend: Remote storage backend to use if ``source`` is a UUID or a URL.
        """
        if isinstance(source, Model):
            if not isinstance(source, type(self)):
                raise TypeError("Incompatible model instance: %s <> %s" %
                                (type(source), type(self)))
            self.__dict__ = source.__dict__
            return self

        if backend is not None and not isinstance(backend, StorageBackend):
            raise TypeError("backend must be an instance of "
                            "modelforge.storage_backend.StorageBackend")
        self._source = str(source)
        try:
            if source is None or (isinstance(source, str)
                                  and not os.path.isfile(source)):
                if cache_dir is None:
                    if self.NAME is not None:
                        cache_dir = os.path.join(self.cache_dir(), self.NAME)
                    else:
                        cache_dir = tempfile.mkdtemp(prefix="modelforge-")
                try:
                    uuid.UUID(source)
                    is_uuid = True
                except (TypeError, ValueError):
                    is_uuid = False
                model_id = self.DEFAULT_NAME if not is_uuid else source
                file_name = model_id + self.DEFAULT_FILE_EXT
                file_name = os.path.join(os.path.expanduser(cache_dir),
                                         file_name)
                if os.path.exists(file_name) and (not source or
                                                  not os.path.exists(source)):
                    source = file_name
                elif source is None or is_uuid:
                    if backend is None:
                        raise ValueError(
                            "The backend must be set to load a UUID or the default "
                            "model.")
                    index = backend.index.contents
                    config = index["models"]
                    if self.NAME is not None:
                        if not is_uuid:
                            model_id = index["meta"][self.NAME][model_id]
                        source = config[self.NAME][model_id]
                    else:
                        if not is_uuid:
                            raise ValueError(
                                "File path, URL or UUID is needed.")
                        for models in config.values():
                            if source in models:
                                source = models[source]
                                break
                        else:
                            raise FileNotFoundError("Model %s not found." %
                                                    source)
                    source = source["url"]
                if re.match(r"\w+://", source):
                    if backend is None:
                        raise ValueError(
                            "The backend must be set to load a URL.")
                    backend.fetch_model(source, file_name)
                    source = file_name
            self._log.info("Reading %s...", source)
            with asdf.open(source) as model:
                tree = model.tree
                self._meta = tree["meta"]
                if self.NAME is not None:
                    meta_name = self._meta["model"]
                    matched = self.NAME == meta_name
                    if not matched:
                        needed = {self.NAME}
                        for child in type(self).__subclasses__():
                            needed.add(child.NAME)
                            matched |= child.NAME == meta_name
                        if not matched:
                            raise ValueError(
                                "The supplied model is of the wrong type: needed "
                                "%s, got %s." % (needed, meta_name))
                self._load_tree(tree)
        finally:
            if self.NAME is None and cache_dir is not None:
                shutil.rmtree(cache_dir)
        return self