Пример #1
0
    def test_attr_type_proto(self):  # type: () -> None
        # type_proto
        type = TypeProto()
        attr = helper.make_attribute("type_proto", type)
        self.assertEqual(attr.name, "type_proto")
        self.assertEqual(attr.tp, type)
        self.assertEqual(attr.type, AttributeProto.TYPE_PROTO)
        # type_protos
        types = [TypeProto(), TypeProto()]
        attr = helper.make_attribute("type_protos", types)

        self.assertEqual(attr.name, "type_protos")
        self.assertEqual(list(attr.type_protos), types)
        self.assertEqual(attr.type, AttributeProto.TYPE_PROTOS)
Пример #2
0
def make_sequence_type_proto(
        inner_type_proto: TypeProto,
) -> TypeProto:
    """Makes a sequence TypeProto."""
    type_proto = TypeProto()
    type_proto.sequence_type.elem_type.CopyFrom(inner_type_proto)
    return type_proto
Пример #3
0
def make_optional_type_proto(
        inner_type_proto: TypeProto,
) -> TypeProto:
    """Makes an optional TypeProto."""
    type_proto = TypeProto()
    type_proto.optional_type.elem_type.CopyFrom(inner_type_proto)
    return type_proto
Пример #4
0
def make_sparse_tensor_type_proto(
        elem_type: int,
        shape: Optional[Sequence[Union[Text, int, None]]],
        shape_denotation: Optional[List[Text]] = None,
) -> TypeProto:
    """Makes a SparseTensor TypeProto based on the data type and shape."""

    type_proto = TypeProto()
    sparse_tensor_type_proto = type_proto.sparse_tensor_type
    sparse_tensor_type_proto.elem_type = elem_type
    sparse_tensor_shape_proto = sparse_tensor_type_proto.shape

    if shape is not None:
        # You might think this is a no-op (extending a normal Python
        # list by [] certainly is), but protobuf lists work a little
        # differently; if a field is never set, it is omitted from the
        # resulting protobuf; a list that is explicitly set to be
        # empty will get an (empty) entry in the protobuf. This
        # difference is visible to our consumers, so make sure we emit
        # an empty shape!
        sparse_tensor_shape_proto.dim.extend([])

        if shape_denotation:
            if len(shape_denotation) != len(shape):
                raise ValueError(
                    'Invalid shape_denotation. '
                    'Must be of the same length as shape.')

        for i, d in enumerate(shape):
            dim = sparse_tensor_shape_proto.dim.add()
            if d is None:
                pass
            elif isinstance(d, int):
                dim.dim_value = d
            elif isinstance(d, str):
                dim.dim_param = d
            else:
                raise ValueError(
                    'Invalid item in shape: {}. '
                    'Needs to be of int or text.'.format(d))

            if shape_denotation:
                dim.denotation = shape_denotation[i]

    return type_proto