Beispiel #1
0
def _type_to_proto(t: torch._C.TensorType) -> onnx.TypeProto:
    if t.kind() == "NoneType":
        return onnx.TypeProto()

    ret: onnx.TypeProto = onnx.TypeProto()
    ret.denotation = repr(t)

    if t.kind() == "ListType":
        ret.sequence_type.elem_type.CopyFrom(_type_to_proto(cast(torch._C.TensorType, t.getElementType())))
        return ret

    if t.kind() == "IntType":
        ret.tensor_type.elem_type = onnx.TensorProto.DataType.INT64
        ret.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
        return ret

    assert t.kind() == "TensorType", f"Not Tensor type(actual: {t.kind()}): {t}"

    if t.scalarType() is None:
        ret.tensor_type.elem_type = onnx.TensorProto.DataType.UNDEFINED
    else:
        ret.tensor_type.elem_type = int(  # type: ignore
            sym_hel.cast_pytorch_to_onnx[t.scalarType()]  # type: ignore[index]
        )

    ret.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
    if t.sizes() is not None:
        for s in t.sizes():  # type: ignore
            d = ret.tensor_type.shape.dim.add()
            d.dim_value = s

    assert ret.tensor_type.HasField("shape")

    return ret
Beispiel #2
0
def add_initializers_into_inputs(model: onnx.ModelProto) -> onnx.ModelProto:
    for x in model.graph.initializer:
        input_names = [x.name for x in model.graph.input]
        if x.name not in input_names:
            shape = onnx.TensorShapeProto()
            for dim in x.dims:
                shape.dim.extend([onnx.TensorShapeProto.Dimension(dim_value=dim)])
            model.graph.input.extend(
                [onnx.ValueInfoProto(name=x.name,
                                     type=onnx.TypeProto(tensor_type=onnx.TypeProto.Tensor(elem_type=x.data_type,
                                                                                           shape=shape)))])
    return model
def add_initializers_into_inputs(model: onnx.ModelProto) -> onnx.ModelProto:
    # Due to a onnx bug, https://github.com/onnx/onnx/issues/2417, we need to add missing initializers into inputs
    for x in model.graph.initializer:
        input_names = [x.name for x in model.graph.input]
        if x.name not in input_names:
            shape = onnx.TensorShapeProto()
            for dim in x.dims:
                shape.dim.extend(
                    [onnx.TensorShapeProto.Dimension(dim_value=dim)])
            model.graph.input.extend([
                onnx.ValueInfoProto(
                    name=x.name,
                    type=onnx.TypeProto(tensor_type=onnx.TypeProto.Tensor(
                        elem_type=x.data_type, shape=shape)))
            ])
    return model