Esempio n. 1
0
def test_import():
    cur_dir = os.path.dirname(os.path.abspath(__file__))
    tflm_dir = os.path.abspath(cur_dir + '/../assets/tests')
    tflm_name = 'mobilenet_v2_1.0_224.tflite'
    path = os.path.join(tflm_dir, tflm_name)
    with open(path, 'rb') as f:
        buf = f.read()
        model = Model.GetRootAsModel(buf, 0)
    assert (model.Version() == 3)
    def __init_op_info(self, model_path):
        self.__tflite_ops = []
        self.__tflite_op_types = []

        data = open(model_path, "rb").read()
        raw_model = Model.GetRootAsModel(bytearray(data), 0)

        tflite_graph = raw_model.Subgraphs(0)

        for idx in range(tflite_graph.OperatorsLength()):
            op = tflite_graph.Operators(idx)
            op_type = raw_model.OperatorCodes(op.OpcodeIndex()).BuiltinCode()

            self.__tflite_ops.append(op)
            self.__tflite_op_types.append(op_type)
Esempio n. 3
0
def __analyze_tflite():
    print('vistflite...')
    with open(sys.argv[1], 'rb') as df:
        content = df.read()
        model = Model.GetRootAsModel(content, 0)
        print(model.Version())
        print(model.Description())
        # OperatorCodes
        for i in range(model.OperatorCodesLength()):
            op = model.OperatorCodes(i)
            if op.CustomCode():
                print('Custom OP:', op.Version(), op.CustomCode())
            else:
                print('Builtin OP:', op.Version(), op.BuiltinCode(), getBuiltinOperatorStringName(op.BuiltinCode()))
        # SubGraphs
        print('SubGraphs:')
        for i in range(model.SubgraphsLength()):
            subGraph = model.Subgraphs(i)
            print('Name:', subGraph.Name())
            print('Tensors:')
            for j in range(subGraph.TensorsLength()):
                print(subGraph.Tensors(j).Name())
Esempio n. 4
0
# we download model tar file and extract, finally get tflite file
model_path = download_testdata(model_url,
                               "mobilenet_v1_1.0_224_quant.tgz",
                               module=['tf', 'official'])
model_dir = os.path.dirname(model_path)
extract(model_path)

# now we have mobilenet_v1_1.0_224_quant.tflite on disk and open it
tflite_model_file = os.path.join(model_dir,
                                 "mobilenet_v1_1.0_224_quant.tflite")
tflite_model_buf = open(tflite_model_file, "rb").read()

# get TFLite model from buffer
from tflite.Model import Model  # edit here

tflite_model = Model.GetRootAsModel(tflite_model_buf, 0)

#######################################################################
# Generic run functions for TVM & TFLite
# --------------------------------------
target = tvm.target.riscv_cpu("spike")
input_tensor = "input"
input_shape = (1, 224, 224, 3)
input_dtype = "uint8"

# Parse TFLite model and convert it to a Relay module
mod, params = relay.frontend.from_tflite(
    tflite_model,
    shape_dict={input_tensor: input_shape},
    dtype_dict={input_tensor: input_dtype})
""" tensorize flow
Esempio n. 5
0
def read_tflite_model(path):
    data = open(path, "rb").read()
    model = Model.GetRootAsModel(bytearray(data), 0)
    return model
Esempio n. 6
0
def read_model(key):
    with open(util_for_test.getPath(key), 'rb') as f:
        buf = f.read()
        model = Model.GetRootAsModel(buf, 0)
    return model
Esempio n. 7
0
def read_tflite_model(file):
    buf = open(file, "rb").read()
    buf = bytearray(buf)
    model = Model.GetRootAsModel(buf, 0)
    return model
Esempio n. 8
0
def load_model(model_path):
    # Load a model from a .tflite file at `model_path`
    with open(model_path, 'rb') as f:
        buf = f.read()
    buf = bytearray(buf)
    return Model.GetRootAsModel(buf, 0)
Esempio n. 9
0
def test_import():
    path = shrub.testing.download('mobilenet_v2_1.0_224.tflite')
    with open(path, 'rb') as f:
        buf = f.read()
        model = Model.GetRootAsModel(buf, 0)
    assert(model.Version() == 3)