示例#1
0
def printable_type(t: TypeProto) -> Text:
    if t.WhichOneof('value') == "tensor_type":
        s = TensorProto.DataType.Name(t.tensor_type.elem_type)
        if t.tensor_type.HasField('shape'):
            if len(t.tensor_type.shape.dim):
                s += str(', ' + 'x'.join(map(printable_dim, t.tensor_type.shape.dim)))
            else:
                s += str(', scalar')
        return s
    if t.WhichOneof('value') is None:
        return ""
    return 'Unknown type {}'.format(t.WhichOneof('value'))
示例#2
0
    def test_attr_type_proto(self):  # type: () -> None
        # type_proto
        type = TypeProto()
        attr = helper.make_attribute("type_proto", type)
        self.assertEqual(attr.name, "type_proto")
        self.assertEqual(attr.tp, type)
        self.assertEqual(attr.type, AttributeProto.TYPE_PROTO)
        # type_protos
        types = [TypeProto(), TypeProto()]
        attr = helper.make_attribute("type_protos", types)

        self.assertEqual(attr.name, "type_protos")
        self.assertEqual(list(attr.type_protos), types)
        self.assertEqual(attr.type, AttributeProto.TYPE_PROTOS)
示例#3
0
def make_optional_type_proto(
        inner_type_proto: TypeProto,
) -> TypeProto:
    """Makes an optional TypeProto."""
    type_proto = TypeProto()
    type_proto.optional_type.elem_type.CopyFrom(inner_type_proto)
    return type_proto
示例#4
0
def make_sequence_type_proto(
        inner_type_proto: TypeProto,
) -> TypeProto:
    """Makes a sequence TypeProto."""
    type_proto = TypeProto()
    type_proto.sequence_type.elem_type.CopyFrom(inner_type_proto)
    return type_proto
示例#5
0
 def _load_proto(self, proto_filename: Text, target_list: List[Union[np.ndarray, List[Any]]], model_type_proto: TypeProto) -> None:
     with open(proto_filename, 'rb') as f:
         protobuf_content = f.read()
         if model_type_proto.HasField('sequence_type'):
             sequence = onnx.SequenceProto()
             sequence.ParseFromString(protobuf_content)
             target_list.append(numpy_helper.to_list(sequence))
         elif model_type_proto.HasField('tensor_type'):
             tensor = onnx.TensorProto()
             tensor.ParseFromString(protobuf_content)
             target_list.append(numpy_helper.to_array(tensor))
         elif model_type_proto.HasField('optional_type'):
             optional = onnx.OptionalProto()
             optional.ParseFromString(protobuf_content)
             target_list.append(numpy_helper.to_optional(optional))
         else:
             print('Loading proto of that specific type (Map/Sparse Tensor) is currently not supported')
def _extract_shape(type_proto: TypeProto):
    which_value = type_proto.WhichOneof('value')

    if which_value == 'tensor_type':
        tensor = type_proto.tensor_type
    else:
        raise ValueError

    return [
        dim.dim_value if dim.WhichOneof('value') == 'dim_value' else None
        for dim in tensor.shape.dim[1:]
    ]
示例#7
0
def make_sparse_tensor_type_proto(
        elem_type: int,
        shape: Optional[Sequence[Union[Text, int, None]]],
        shape_denotation: Optional[List[Text]] = None,
) -> TypeProto:
    """Makes a SparseTensor TypeProto based on the data type and shape."""

    type_proto = TypeProto()
    sparse_tensor_type_proto = type_proto.sparse_tensor_type
    sparse_tensor_type_proto.elem_type = elem_type
    sparse_tensor_shape_proto = sparse_tensor_type_proto.shape

    if shape is not None:
        # You might think this is a no-op (extending a normal Python
        # list by [] certainly is), but protobuf lists work a little
        # differently; if a field is never set, it is omitted from the
        # resulting protobuf; a list that is explicitly set to be
        # empty will get an (empty) entry in the protobuf. This
        # difference is visible to our consumers, so make sure we emit
        # an empty shape!
        sparse_tensor_shape_proto.dim.extend([])

        if shape_denotation:
            if len(shape_denotation) != len(shape):
                raise ValueError(
                    'Invalid shape_denotation. '
                    'Must be of the same length as shape.')

        for i, d in enumerate(shape):
            dim = sparse_tensor_shape_proto.dim.add()
            if d is None:
                pass
            elif isinstance(d, int):
                dim.dim_value = d
            elif isinstance(d, str):
                dim.dim_param = d
            else:
                raise ValueError(
                    'Invalid item in shape: {}. '
                    'Needs to be of int or text.'.format(d))

            if shape_denotation:
                dim.denotation = shape_denotation[i]

    return type_proto