Exemplo n.º 1
0
def infer_outputs(op_type,
                  inputs,
                  outputs=None,
                  initializer=None,
                  target_opset=None,
                  **atts):
    """
    Infers outputs type and shapes given an ONNX operator.
    """
    if isinstance(op_type, str):
        required_outputs = []
        if outputs:
            for o in outputs:
                if hasattr(o, 'onnx_name'):
                    required_outputs.append(o.onnx_name)
                elif isinstance(o, str):
                    required_outputs.append(o)
                else:
                    raise TypeError("Unable to require output {}.".format(o))
        node = make_node(op_type, [i.onnx_name for i in inputs],
                         required_outputs, **atts)
        node = [node]
    elif hasattr(op_type, 'nodes'):
        node = op_type.nodes
    else:
        raise RuntimeError("Unable to build ONNX nodes from type {}.".format(
            type(op_type)))

    input_init = inputs.copy()
    if initializer:
        input_init.extend(initializer)
    onnx_inputs = []
    for input in input_init:
        if isinstance(input, Variable):
            onnx_type = input.type.to_onnx_type()
            tensor_type = onnx_type.tensor_type
            shape = [
                tensor_type.shape.dim[i].dim_value
                for i in range(len(tensor_type.shape.dim))
            ]
            inp = make_tensor_value_info(input.onnx_name,
                                         tensor_type.elem_type, tuple(shape))
            onnx_inputs.append(inp)
        elif isinstance(input, onnx.TensorProto):
            v = make_tensor_value_info(input.name, input.data_type.real,
                                       list(d for d in input.dims))
            onnx_inputs.append(v)
        elif isinstance(input, onnx.AttributeProto):
            value_info = ValueInfoProto()
            value_info.name = input.name
            onnx_type = onnx_proto.TypeProto()
            onnx_type.tensor_type.elem_type = input.type
            value_info.type.CopyFrom(onnx_type)
            onnx_inputs.append(value_info)
        else:
            onnx_inputs.append(input)

    graph = make_graph(node, 'infer_shapes', onnx_inputs, [])
    original_model = make_model(graph, producer_name='skl2onnx')
    domains = {}
    for n in node:
        domains[n.domain] = max(domains.get(n.domain, 1),
                                getattr(n, 'op_version', 1))
    for i, (k, v) in enumerate(domains.items()):
        if i == 0 and len(original_model.opset_import) == 1:
            op_set = original_model.opset_import[0]
        else:
            op_set = original_model.opset_import.add()
        op_set.domain = k
        if target_opset:
            if isinstance(target_opset, dict):
                op_set.version = target_opset.get(
                    k, get_latest_tested_opset_version())
            else:
                op_set.version = target_opset
        else:
            op_set.version = get_latest_tested_opset_version()

    inferred_model = shape_inference.infer_shapes(original_model)
    shapes = Variable.from_pb(inferred_model.graph.value_info)
    if len(shapes) == 0:
        raise RuntimeError("Shape inference fails.\n"
                           "*Inputs*\n{}\n*Model*\n{}'".format(
                               onnx_inputs, original_model))
    return shapes
Exemplo n.º 2
0
 def to_onnx_type(self):
     onnx_type = onnx_proto.TypeProto()
     onnx_type.tensor_type.elem_type = onnx_proto.TensorProto.STRING
     s = onnx_type.tensor_type.shape.dim.add()
     s.dim_value = 1
     return onnx_type
Exemplo n.º 3
0
 def to_onnx_type(self):
     onnx_type = onnx_proto.TypeProto()
     onnx_type.sequence_type.elem_type.CopyFrom(
         self.element_type.to_onnx_type())
     return onnx_type