def load(self, path, shape_dict=None): # pylint: disable=C0415 import tflite.Model as model with open(path, "rb") as tf_graph: content = tf_graph.read() # tflite.Model.Model is tflite.Model in 1.14 and 2.1.0 try: tflite_model = model.Model.GetRootAsModel(content, 0) except AttributeError: tflite_model = model.GetRootAsModel(content, 0) try: version = tflite_model.Version() logger.debug("tflite version %s", version) except Exception: raise TVMCException("input file not tflite") if version != 3: raise TVMCException("input file not tflite version 3") logger.debug("tflite_input_type") input_shapes, dtype_dict = TFLiteFrontend._input_type(tflite_model) if shape_dict is not None: input_shapes.update(shape_dict) logger.debug( "parse TFLite model and convert into Relay computation graph") mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=input_shapes, dtype_dict=dtype_dict) return mod, params
def guess_frontend(path): """ This function will try to imply which framework is being used, based on the extension of the file provided in the path parameter. Parameters ---------- path : str The path to the model file. Returns ------- frontend : tvm.driver.tvmc.Frontend An instance of the frontend that matches with the file extension provided in `path`. """ suffix = Path(path).suffix.lower() if suffix.startswith("."): suffix = suffix[1:] for frontend in ALL_FRONTENDS: if suffix in frontend.suffixes(): return frontend() raise TVMCException( "failed to infer the model format. Please specify --model-format")
def get_frontend_by_name(name: str): """ This function will try to get a frontend instance, based on the name provided. Parameters ---------- name : str the name of a given frontend Returns ------- frontend : tvm.driver.tvmc.Frontend An instance of the frontend that matches with the file extension provided in `path`. """ for frontend in ALL_FRONTENDS: if name == frontend.name(): return frontend() raise TVMCException( "unrecognized frontend '{0}'. Choose from: {1}".format(name, get_frontend_names()) )
def load(self, path, shape_dict=None, **kwargs): # pylint: disable=C0103 tf, keras = import_keras() # tvm build currently imports keras directly instead of tensorflow.keras try: model = keras.models.load_model(path) except ValueError as err: raise TVMCException(str(err)) # There are two flavours of keras model, sequential and # functional, TVM expects a functional model, so convert # if required: if self.is_sequential_p(model): model = self.sequential_to_functional(model) in_shapes = [] for layer in model._input_layers: if tf.executing_eagerly(): in_shapes.append(tuple(dim if dim is not None else 1 for dim in layer.input.shape)) else: in_shapes.append( tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape) ) inputs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes] input_shapes = {name: x.shape for (name, x) in zip(model.input_names, inputs)} if shape_dict is not None: input_shapes.update(shape_dict) kwargs.setdefault("layout", "NHWC") return relay.frontend.from_keras(model, input_shapes, **kwargs)
def load(self, path, shape_dict=None, **kwargs): # pylint: disable=C0415 import torch if shape_dict is None: raise TVMCException("--input-shapes must be specified for %s" % self.name()) traced_model = torch.jit.load(path) traced_model.eval() # Switch to inference mode # Convert shape dictionary to list for Pytorch frontend compatibility input_shapes = list(shape_dict.items()) logger.debug("parse Torch model and convert into Relay computation graph") return relay.frontend.from_pytorch(traced_model, input_shapes, **kwargs)
def load(self, path, shape_dict=None, **kwargs): model = lazy_import("tflite.Model") with open(path, "rb") as tf_graph: content = tf_graph.read() # tflite.Model.Model is tflite.Model in 1.14 and 2.1.0 try: tflite_model = model.Model.GetRootAsModel(content, 0) except AttributeError: tflite_model = model.GetRootAsModel(content, 0) try: version = tflite_model.Version() logger.debug("tflite version %s", version) except Exception: raise TVMCException("input file not tflite") if version != 3: raise TVMCException("input file not tflite version 3") logger.debug("parse TFLite model and convert into Relay computation graph") mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, **kwargs) return mod, params