Ejemplo n.º 1
0
def read_onnx_from_protobuf(filename):
    # type: (str)->ONNXGraph
    model_proto = onnx_pb2.ModelProto()

    with open(filename, 'rb') as f:
        model_proto.ParseFromString(f.read())

    if not model_proto.HasField('ir_version'):
        print('Warning: ModelProto has no ir_version!')
    elif model_proto.ir_version > LAST_SUPPORTED_IR_VERSION:
        print(
            'Warning: ModelProto has newer ir_version than what we support. ({} > {})'
            .format(model_proto.ir_version, LAST_SUPPORTED_IR_VERSION))

    if len(model_proto.opset_import) == 0:
        print('Warning: ModelProto has no opset import!')
    else:
        for opset in model_proto.opset_import:
            if opset.domain != '':
                print(
                    'Warning: ModelProto has an unsupported opset domain: {}'.
                    format(opset.domain))
            elif opset.version > LAST_SUPPORTED_OPSET_VERSION:
                print(
                    'Warning: ModelProto has newer opset than what we support. ({} > {})'
                    .format(opset.version, LAST_SUPPORTED_OPSET_VERSION))

    return _get_graph(model_proto.graph)
Ejemplo n.º 2
0
def build_model(graph):
    # type: (ONNXGraph)->onnx_pb2.ModelProto

    graph.generate_missing_names()

    model_proto = onnx_pb2.ModelProto()
    model_proto.ir_version = OUTPUT_IR_VERSION

    opset = model_proto.opset_import.add()
    opset.domain = ''
    opset.version = OUTPUT_OPSET_VERSION

    model_proto.producer_name = PRODUCER_NAME
    model_proto.producer_version = PRODUCER_VERSION

    if graph.domain is not None:
        model_proto.domain = graph.domain
    if graph.version is not None:
        model_proto.model_version = graph.version

    build_graph(graph, model_proto.graph)

    return model_proto