예제 #1
0
    def load(self, path, shape_dict=None):
        # pylint: disable=C0415
        import tflite.Model as model

        with open(path, "rb") as tf_graph:
            content = tf_graph.read()

        # tflite.Model.Model is tflite.Model in 1.14 and 2.1.0
        try:
            tflite_model = model.Model.GetRootAsModel(content, 0)
        except AttributeError:
            tflite_model = model.GetRootAsModel(content, 0)

        try:
            version = tflite_model.Version()
            logger.debug("tflite version %s", version)
        except Exception:
            raise TVMCException("input file not tflite")

        if version != 3:
            raise TVMCException("input file not tflite version 3")

        logger.debug("tflite_input_type")
        input_shapes, dtype_dict = TFLiteFrontend._input_type(tflite_model)
        if shape_dict is not None:
            input_shapes.update(shape_dict)

        logger.debug(
            "parse TFLite model and convert into Relay computation graph")
        mod, params = relay.frontend.from_tflite(tflite_model,
                                                 shape_dict=input_shapes,
                                                 dtype_dict=dtype_dict)
        return mod, params
예제 #2
0
def guess_frontend(path):
    """
    This function will try to imply which framework is being used,
    based on the extension of the file provided in the path parameter.

    Parameters
    ----------
    path : str
        The path to the model file.

    Returns
    -------
    frontend : tvm.driver.tvmc.Frontend
        An instance of the frontend that matches with
        the file extension provided in `path`.

    """

    suffix = Path(path).suffix.lower()
    if suffix.startswith("."):
        suffix = suffix[1:]

    for frontend in ALL_FRONTENDS:
        if suffix in frontend.suffixes():
            return frontend()

    raise TVMCException(
        "failed to infer the model format. Please specify --model-format")
예제 #3
0
def get_frontend_by_name(name: str):
    """
    This function will try to get a frontend instance, based
    on the name provided.

    Parameters
    ----------
    name : str
        the name of a given frontend

    Returns
    -------
    frontend : tvm.driver.tvmc.Frontend
        An instance of the frontend that matches with
        the file extension provided in `path`.

    """

    for frontend in ALL_FRONTENDS:
        if name == frontend.name():
            return frontend()

    raise TVMCException(
        "unrecognized frontend '{0}'. Choose from: {1}".format(name, get_frontend_names())
    )
예제 #4
0
    def load(self, path, shape_dict=None, **kwargs):
        # pylint: disable=C0103
        tf, keras = import_keras()

        # tvm build currently imports keras directly instead of tensorflow.keras
        try:
            model = keras.models.load_model(path)
        except ValueError as err:
            raise TVMCException(str(err))

        # There are two flavours of keras model, sequential and
        # functional, TVM expects a functional model, so convert
        # if required:
        if self.is_sequential_p(model):
            model = self.sequential_to_functional(model)

        in_shapes = []
        for layer in model._input_layers:
            if tf.executing_eagerly():
                in_shapes.append(tuple(dim if dim is not None else 1 for dim in layer.input.shape))
            else:
                in_shapes.append(
                    tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape)
                )

        inputs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes]
        input_shapes = {name: x.shape for (name, x) in zip(model.input_names, inputs)}
        if shape_dict is not None:
            input_shapes.update(shape_dict)
        kwargs.setdefault("layout", "NHWC")
        return relay.frontend.from_keras(model, input_shapes, **kwargs)
예제 #5
0
    def load(self, path, shape_dict=None, **kwargs):
        # pylint: disable=C0415
        import torch

        if shape_dict is None:
            raise TVMCException("--input-shapes must be specified for %s" % self.name())

        traced_model = torch.jit.load(path)
        traced_model.eval()  # Switch to inference mode

        # Convert shape dictionary to list for Pytorch frontend compatibility
        input_shapes = list(shape_dict.items())

        logger.debug("parse Torch model and convert into Relay computation graph")
        return relay.frontend.from_pytorch(traced_model, input_shapes, **kwargs)
예제 #6
0
    def load(self, path, shape_dict=None, **kwargs):
        model = lazy_import("tflite.Model")

        with open(path, "rb") as tf_graph:
            content = tf_graph.read()

        # tflite.Model.Model is tflite.Model in 1.14 and 2.1.0
        try:
            tflite_model = model.Model.GetRootAsModel(content, 0)
        except AttributeError:
            tflite_model = model.GetRootAsModel(content, 0)

        try:
            version = tflite_model.Version()
            logger.debug("tflite version %s", version)
        except Exception:
            raise TVMCException("input file not tflite")

        if version != 3:
            raise TVMCException("input file not tflite version 3")

        logger.debug("parse TFLite model and convert into Relay computation graph")
        mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, **kwargs)
        return mod, params