def load(self, path): if os.path.isfile(self._keras_module_name_path(path)): with open(self._keras_module_name_path(path), "rb") as text_file: keras_module_name = text_file.read().decode("utf-8") try: keras_module = importlib.import_module(keras_module_name) except ImportError: raise BentoMLArtifactLoadingException( "Failed to import '{}' module when loading saved " "KerasModelArtifact".format(keras_module_name)) self.creat_session() if self.custom_objects is None and os.path.isfile( self._custom_objects_path(path)): self.custom_objects = cloudpickle.load( open(self._custom_objects_path(path), 'rb')) with self.graph.as_default(): with self.sess.as_default(): # load keras model via json and weights if requested if self._store_as_json_and_weights: with open(self._model_json_path(path), 'r') as json_file: model_json = json_file.read() model = keras_module.models.model_from_json( model_json, custom_objects=self.custom_objects) model.load_weights(self._model_weights_path(path)) # otherwise, load keras model via standard load_model else: model = keras_module.models.load_model( self._model_file_path(path), custom_objects=self.custom_objects) return self.pack(model)
def load(self, path): try: import torch except ImportError: raise ImportError( "torch package is required to use PytorchModelArtifact") model = cloudpickle.load(open(self._file_path(path), 'rb')) if not isinstance(model, torch.nn.Module): raise TypeError( "Expecting PytorchModelArtifact loaded object type to be " "'torch.nn.Module' but actually it is {}".format(type(model))) return self.pack(model)
def load(self, path): if tf is None: raise ImportError( "Tensorflow package is required to use KerasModelArtifact.") self.creat_session() if self.custom_objects is None and os.path.isfile( self._custom_objects_path(path)): self.custom_objects = cloudpickle.load( open(self._custom_objects_path(path), 'rb')) with self.graph.as_default(): with self.sess.as_default(): model = keras.models.load_model( self._model_file_path(path), custom_objects=self.custom_objects) return self.pack(model)
def load(self, path): try: import torch except ImportError: raise MissingDependencyException( "torch package is required to use PytorchModelArtifact") # TorchScript Models are saved as zip files if zipfile.is_zipfile(self._file_path(path)): model = torch.jit.load(self._file_path(path)) else: model = cloudpickle.load(open(self._file_path(path), 'rb')) if not isinstance(model, torch.nn.Module): raise InvalidArgument( "Expecting PytorchModelArtifact loaded object type to be " "'torch.nn.Module' or 'torch.jit.ScriptModule' \ but actually it is {}".format(type(model))) return self.pack(model)