예제 #1
0
def write_tflite_graph_to_flatbuffers(graph, filename):
    graph.sort()
    builder = flatbuffers.Builder(0)

    tflite_fb.BufferStartDataVector(builder, 0)
    data = builder.EndVector(0)
    tflite_fb.BufferStart(builder)
    tflite_fb.BufferAddData(builder, data)
    buffer = tflite_fb.BufferEnd(builder)

    buffers = [buffer]
    for tensor in graph.tensors:
        if tensor.data is not None:
            tensor_data = tensor.data
            if isinstance(tensor_data, (list, tuple)):
                tensor_data = np.array(
                    tensor_data,
                    dtype=_TensorDtypeAsNumpy[_TensorTypeValueByName[
                        tensor.dtype]])
            bytes = tensor_data.reshape([-1]).view(np.uint8)
            buffers.append(_build_buffer(builder, bytes))

    #metadata buffer
    metadata_index = len(buffers)
    buffers.append(
        _build_buffer(builder, np.frombuffer(b'1.14.0', dtype=np.uint8)))

    tflite_fb.ModelStartBuffersVector(builder, len(buffers))
    for buffer in reversed(buffers):
        builder.PrependUOffsetTRelative(buffer)
    buffers = builder.EndVector(len(buffers))

    buffer_index = 1

    tensors = []
    tensor_index = {}
    for tensor in graph.tensors:
        tensor_index[tensor] = len(tensors)
        tensors.append(_build_tensor(builder, tensor, buffer_index))
        if tensor.data is not None:
            buffer_index += 1

    tflite_fb.SubGraphStartTensorsVector(builder, len(tensors))
    for tensor in reversed(tensors):
        builder.PrependUOffsetTRelative(tensor)
    tensors = builder.EndVector(len(tensors))

    op_codes = []
    op_code_index = {}
    for operation in graph.operations:
        builtin_and_custom_codes = _builtin_code_and_custom_code(operation)
        if builtin_and_custom_codes not in op_code_index:
            op_code_index[builtin_and_custom_codes] = len(op_codes)
            op_codes.append(
                _build_operator_code(builder, *builtin_and_custom_codes))

    tflite_fb.ModelStartOperatorCodesVector(builder, len(op_codes))
    for op_code in reversed(op_codes):
        builder.PrependUOffsetTRelative(op_code)
    op_codes = builder.EndVector(len(op_codes))

    operators = []
    for operation in graph.operations:
        operators.append(
            _build_operator(builder, operation, op_code_index, tensor_index))

    tflite_fb.SubGraphStartOperatorsVector(builder, len(operators))
    for operator in reversed(operators):
        builder.PrependUOffsetTRelative(operator)
    operators = builder.EndVector(len(operators))

    name = builder.CreateString(graph.name) if graph.name is not None else None

    inputs = graph.inputs
    tflite_fb.SubGraphStartInputsVector(builder, len(inputs))
    for input in reversed(inputs):
        builder.PrependInt32(tensor_index[input])
    inputs = builder.EndVector(len(inputs))

    outputs = graph.outputs
    tflite_fb.SubGraphStartInputsVector(builder, len(outputs))
    for output in reversed(outputs):
        builder.PrependInt32(tensor_index[output])
    outputs = builder.EndVector(len(outputs))

    tflite_fb.SubGraphStart(builder)
    if name is not None:
        tflite_fb.SubGraphAddName(builder, name)
    tflite_fb.SubGraphAddTensors(builder, tensors)
    tflite_fb.SubGraphAddOperators(builder, operators)
    tflite_fb.SubGraphAddInputs(builder, inputs)
    tflite_fb.SubGraphAddOutputs(builder, outputs)
    subgraph = tflite_fb.SubGraphEnd(builder)

    tflite_fb.ModelStartSubgraphsVector(builder, 1)
    builder.PrependUOffsetTRelative(subgraph)
    subgraphs = builder.EndVector(1)

    metadata_name = builder.CreateString("min_runtime_version")
    tflite_fb.MetadataStart(builder)
    tflite_fb.MetadataAddName(builder, metadata_name)
    tflite_fb.MetadataAddBuffer(builder, metadata_index)
    metadata = tflite_fb.MetadataEnd(builder)
    tflite_fb.ModelStartMetadataVector(builder, 1)
    builder.PrependUOffsetTRelative(metadata)
    metadata_vector = builder.EndVector(1)

    tflite_fb.ModelStart(builder)
    tflite_fb.ModelAddVersion(builder, OUTPUT_SCHEMA_VERSION)
    tflite_fb.ModelAddBuffers(builder, buffers)
    tflite_fb.ModelAddOperatorCodes(builder, op_codes)
    tflite_fb.ModelAddSubgraphs(builder, subgraphs)
    tflite_fb.ModelAddMetadata(builder, metadata_vector)
    model = tflite_fb.ModelEnd(builder)

    FinishWithFileIdentifier(builder, model, OUTPUT_FILE_IDENTIFIER)

    bytes = builder.Output()

    with open(filename, 'wb') as file:
        file.write(bytes)
예제 #2
0
def _build_buffer(builder, bytes):
    data = _CreateNumpyVector(builder, bytes)
    tflite_fb.BufferStart(builder)
    tflite_fb.BufferAddData(builder, data)
    return tflite_fb.BufferEnd(builder)