Exemplo n.º 1
0
    def load(self, path):
        if os.path.isfile(self._keras_module_name_path(path)):
            with open(self._keras_module_name_path(path), "rb") as text_file:
                keras_module_name = text_file.read().decode("utf-8")
                try:
                    keras_module = importlib.import_module(keras_module_name)
                except ImportError:
                    raise BentoMLArtifactLoadingException(
                        "Failed to import '{}' module when loading saved "
                        "KerasModelArtifact".format(keras_module_name))

        self.creat_session()

        if self.custom_objects is None and os.path.isfile(
                self._custom_objects_path(path)):
            self.custom_objects = cloudpickle.load(
                open(self._custom_objects_path(path), 'rb'))

        with self.graph.as_default():
            with self.sess.as_default():
                # load keras model via json and weights if requested
                if self._store_as_json_and_weights:
                    with open(self._model_json_path(path), 'r') as json_file:
                        model_json = json_file.read()
                    model = keras_module.models.model_from_json(
                        model_json, custom_objects=self.custom_objects)
                    model.load_weights(self._model_weights_path(path))
                # otherwise, load keras model via standard load_model
                else:
                    model = keras_module.models.load_model(
                        self._model_file_path(path),
                        custom_objects=self.custom_objects)
        return self.pack(model)
Exemplo n.º 2
0
    def load(self, path):
        try:
            import torch
        except ImportError:
            raise ImportError(
                "torch package is required to use PytorchModelArtifact")

        model = cloudpickle.load(open(self._file_path(path), 'rb'))

        if not isinstance(model, torch.nn.Module):
            raise TypeError(
                "Expecting PytorchModelArtifact loaded object type to be "
                "'torch.nn.Module' but actually it is {}".format(type(model)))

        return self.pack(model)
Exemplo n.º 3
0
    def load(self, path):
        if tf is None:
            raise ImportError(
                "Tensorflow package is required to use KerasModelArtifact.")

        self.creat_session()

        if self.custom_objects is None and os.path.isfile(
                self._custom_objects_path(path)):
            self.custom_objects = cloudpickle.load(
                open(self._custom_objects_path(path), 'rb'))

        with self.graph.as_default():
            with self.sess.as_default():
                model = keras.models.load_model(
                    self._model_file_path(path),
                    custom_objects=self.custom_objects)
        return self.pack(model)
Exemplo n.º 4
0
    def load(self, path):
        try:
            import torch
        except ImportError:
            raise MissingDependencyException(
                "torch package is required to use PytorchModelArtifact")

        # TorchScript Models are saved as zip files
        if zipfile.is_zipfile(self._file_path(path)):
            model = torch.jit.load(self._file_path(path))
        else:
            model = cloudpickle.load(open(self._file_path(path), 'rb'))

        if not isinstance(model, torch.nn.Module):
            raise InvalidArgument(
                "Expecting PytorchModelArtifact loaded object type to be "
                "'torch.nn.Module' or 'torch.jit.ScriptModule' \
                but actually it is {}".format(type(model)))

        return self.pack(model)