def infer_outputs(op_type, inputs, outputs=None, initializer=None, target_opset=None, **atts): """ Infers outputs type and shapes given an ONNX operator. """ if isinstance(op_type, str): required_outputs = [] if outputs: for o in outputs: if hasattr(o, 'onnx_name'): required_outputs.append(o.onnx_name) elif isinstance(o, str): required_outputs.append(o) else: raise TypeError("Unable to require output {}.".format(o)) node = make_node(op_type, [i.onnx_name for i in inputs], required_outputs, **atts) node = [node] elif hasattr(op_type, 'nodes'): node = op_type.nodes else: raise RuntimeError("Unable to build ONNX nodes from type {}.".format( type(op_type))) input_init = inputs.copy() if initializer: input_init.extend(initializer) onnx_inputs = [] for input in input_init: if isinstance(input, Variable): onnx_type = input.type.to_onnx_type() tensor_type = onnx_type.tensor_type shape = [ tensor_type.shape.dim[i].dim_value for i in range(len(tensor_type.shape.dim)) ] inp = make_tensor_value_info(input.onnx_name, tensor_type.elem_type, tuple(shape)) onnx_inputs.append(inp) elif isinstance(input, onnx.TensorProto): v = make_tensor_value_info(input.name, input.data_type.real, list(d for d in input.dims)) onnx_inputs.append(v) elif isinstance(input, onnx.AttributeProto): value_info = ValueInfoProto() value_info.name = input.name onnx_type = onnx_proto.TypeProto() onnx_type.tensor_type.elem_type = input.type value_info.type.CopyFrom(onnx_type) onnx_inputs.append(value_info) else: onnx_inputs.append(input) graph = make_graph(node, 'infer_shapes', onnx_inputs, []) original_model = make_model(graph, producer_name='skl2onnx') domains = {} for n in node: domains[n.domain] = max(domains.get(n.domain, 1), getattr(n, 'op_version', 1)) for i, (k, v) in enumerate(domains.items()): if i == 0 and len(original_model.opset_import) == 1: op_set = original_model.opset_import[0] else: op_set = original_model.opset_import.add() op_set.domain = k if target_opset: if isinstance(target_opset, dict): op_set.version = target_opset.get( k, get_latest_tested_opset_version()) else: op_set.version = target_opset else: op_set.version = get_latest_tested_opset_version() inferred_model = shape_inference.infer_shapes(original_model) shapes = Variable.from_pb(inferred_model.graph.value_info) if len(shapes) == 0: raise RuntimeError("Shape inference fails.\n" "*Inputs*\n{}\n*Model*\n{}'".format( onnx_inputs, original_model)) return shapes
def to_onnx_type(self): onnx_type = onnx_proto.TypeProto() onnx_type.tensor_type.elem_type = onnx_proto.TensorProto.STRING s = onnx_type.tensor_type.shape.dim.add() s.dim_value = 1 return onnx_type
def to_onnx_type(self): onnx_type = onnx_proto.TypeProto() onnx_type.sequence_type.elem_type.CopyFrom( self.element_type.to_onnx_type()) return onnx_type