def write_tflite_graph_to_flatbuffers(graph, filename): graph.sort() builder = flatbuffers.Builder(0) tflite_fb.BufferStartDataVector(builder, 0) data = builder.EndVector(0) tflite_fb.BufferStart(builder) tflite_fb.BufferAddData(builder, data) buffer = tflite_fb.BufferEnd(builder) buffers = [buffer] for tensor in graph.tensors: if tensor.data is not None: tensor_data = tensor.data if isinstance(tensor_data, (list, tuple)): tensor_data = np.array( tensor_data, dtype=_TensorDtypeAsNumpy[_TensorTypeValueByName[ tensor.dtype]]) bytes = tensor_data.reshape([-1]).view(np.uint8) buffers.append(_build_buffer(builder, bytes)) #metadata buffer metadata_index = len(buffers) buffers.append( _build_buffer(builder, np.frombuffer(b'1.14.0', dtype=np.uint8))) tflite_fb.ModelStartBuffersVector(builder, len(buffers)) for buffer in reversed(buffers): builder.PrependUOffsetTRelative(buffer) buffers = builder.EndVector(len(buffers)) buffer_index = 1 tensors = [] tensor_index = {} for tensor in graph.tensors: tensor_index[tensor] = len(tensors) tensors.append(_build_tensor(builder, tensor, buffer_index)) if tensor.data is not None: buffer_index += 1 tflite_fb.SubGraphStartTensorsVector(builder, len(tensors)) for tensor in reversed(tensors): builder.PrependUOffsetTRelative(tensor) tensors = builder.EndVector(len(tensors)) op_codes = [] op_code_index = {} for operation in graph.operations: builtin_and_custom_codes = _builtin_code_and_custom_code(operation) if builtin_and_custom_codes not in op_code_index: op_code_index[builtin_and_custom_codes] = len(op_codes) op_codes.append( _build_operator_code(builder, *builtin_and_custom_codes)) tflite_fb.ModelStartOperatorCodesVector(builder, len(op_codes)) for op_code in reversed(op_codes): builder.PrependUOffsetTRelative(op_code) op_codes = builder.EndVector(len(op_codes)) operators = [] for operation in graph.operations: operators.append( _build_operator(builder, operation, op_code_index, tensor_index)) tflite_fb.SubGraphStartOperatorsVector(builder, len(operators)) for operator in reversed(operators): builder.PrependUOffsetTRelative(operator) operators = builder.EndVector(len(operators)) name = builder.CreateString(graph.name) if graph.name is not None else None inputs = graph.inputs tflite_fb.SubGraphStartInputsVector(builder, len(inputs)) for input in reversed(inputs): builder.PrependInt32(tensor_index[input]) inputs = builder.EndVector(len(inputs)) outputs = graph.outputs tflite_fb.SubGraphStartInputsVector(builder, len(outputs)) for output in reversed(outputs): builder.PrependInt32(tensor_index[output]) outputs = builder.EndVector(len(outputs)) tflite_fb.SubGraphStart(builder) if name is not None: tflite_fb.SubGraphAddName(builder, name) tflite_fb.SubGraphAddTensors(builder, tensors) tflite_fb.SubGraphAddOperators(builder, operators) tflite_fb.SubGraphAddInputs(builder, inputs) tflite_fb.SubGraphAddOutputs(builder, outputs) subgraph = tflite_fb.SubGraphEnd(builder) tflite_fb.ModelStartSubgraphsVector(builder, 1) builder.PrependUOffsetTRelative(subgraph) subgraphs = builder.EndVector(1) metadata_name = builder.CreateString("min_runtime_version") tflite_fb.MetadataStart(builder) tflite_fb.MetadataAddName(builder, metadata_name) tflite_fb.MetadataAddBuffer(builder, metadata_index) metadata = tflite_fb.MetadataEnd(builder) tflite_fb.ModelStartMetadataVector(builder, 1) builder.PrependUOffsetTRelative(metadata) metadata_vector = builder.EndVector(1) tflite_fb.ModelStart(builder) tflite_fb.ModelAddVersion(builder, OUTPUT_SCHEMA_VERSION) tflite_fb.ModelAddBuffers(builder, buffers) tflite_fb.ModelAddOperatorCodes(builder, op_codes) tflite_fb.ModelAddSubgraphs(builder, subgraphs) tflite_fb.ModelAddMetadata(builder, metadata_vector) model = tflite_fb.ModelEnd(builder) FinishWithFileIdentifier(builder, model, OUTPUT_FILE_IDENTIFIER) bytes = builder.Output() with open(filename, 'wb') as file: file.write(bytes)
def _build_buffer(builder, bytes): data = _CreateNumpyVector(builder, bytes) tflite_fb.BufferStart(builder) tflite_fb.BufferAddData(builder, data) return tflite_fb.BufferEnd(builder)