def parse_proto_map(
    map_entries,
    map_entry_parent_indices,
    map_entry_descriptor: descriptor.Descriptor,
    keys_needed: Sequence[str],
    backing_str_tensor: Optional[tf.Tensor] = None
) -> Sequence[Tuple[tf.Tensor, tf.Tensor]]:
    """A custom op to parse serialized Protobuf map entries.

  Args:
    map_entries: a 1D string tensor that contains serialized map entry
      sub-messages.
    map_entry_parent_indices: a 1D int64 tensor of the same length as
      map_entries. map_entry_parent_indices[i] == j means map_entries[i] belongs
      to the j-th map.
    map_entry_descriptor: the proto descriptor of the map entry sub-message.
    keys_needed: keys that are needed to be looked up in the map. If the map's
      keys are integers, then these strings will be parsed as integers in
      decimal. If the map's keys are booleans, then only "0" and "1" are
      expected.
    backing_str_tensor: a possible string tensor backing the string_view for
      intermediate serialized protos.

  Returns:
    A list of tuples one for each key in `keys_needed`. In each tuple, the first
    term contains decoded values; the second term contains the parent indices
    for the values.
  """
    keys_needed_as_list = list(keys_needed)
    value_fd = map_entry_descriptor.fields_by_name["value"]

    if tf.is_tensor(backing_str_tensor):
        backing_str_tensor = [backing_str_tensor]
    else:
        backing_str_tensor = []

    # TODO(b/172576749): Once we allow sufficient bake in time for the kernel
    # change, switch to using V2 only.
    if backing_str_tensor:
        values, parent_indices = gen_decode_proto_map_op.decode_proto_map_v2(
            map_entries, map_entry_parent_indices, backing_str_tensor,
            map_entry_descriptor.full_name, keys_needed_as_list,
            len(keys_needed_as_list),
            _get_dtype_from_cpp_type(value_fd.cpp_type),
            file_descriptor_set.get_file_descriptor_set_proto(
                map_entry_descriptor, ["key", "value"]).SerializeToString())
        return list(zip(values, parent_indices))
    else:
        values, parent_indices = gen_decode_proto_map_op.decode_proto_map(
            map_entries, map_entry_parent_indices,
            map_entry_descriptor.full_name, keys_needed_as_list,
            len(keys_needed_as_list),
            _get_dtype_from_cpp_type(value_fd.cpp_type),
            file_descriptor_set.get_file_descriptor_set_proto(
                map_entry_descriptor, ["key", "value"]).SerializeToString())
        return list(zip(values, parent_indices))
示例#2
0
def parse_proto_map(
        map_entries, map_entry_parent_indices,
        map_entry_descriptor: descriptor.Descriptor,
        keys_needed: Sequence[str]) -> Sequence[Tuple[tf.Tensor, tf.Tensor]]:
    """A custom op to parse serialized Protobuf map entries.

  Args:
    map_entries: a 1D string tensor that contains serialized map entry
      sub-messages.
    map_entry_parent_indices: a 1D int64 tensor of the same length as
      map_entries. map_entry_parent_indices[i] == j means map_entries[i] belongs
      to the j-th map.
    map_entry_descriptor: the proto descriptor of the map entry sub-message.
    keys_needed: keys that are needed to be looked up in the map. If the map's
      keys are integers, then these strings will be parsed as integers in
      decimal. If the map's keys are booleans, then only "0" and "1" are
      expected.

  Returns:
    A list of tuples one for each key in `keys_needed`. In each tuple, the first
    term contains decoded values; the second term contains the parent indices
    for the values.
  """
    keys_needed_as_list = list(keys_needed)
    value_fd = map_entry_descriptor.fields_by_name["value"]
    values, parent_indices = gen_decode_proto_map_op.decode_proto_map(
        map_entries, map_entry_parent_indices, map_entry_descriptor.full_name,
        keys_needed_as_list, len(keys_needed_as_list),
        _get_dtype_from_cpp_type(value_fd.cpp_type),
        file_descriptor_set.get_file_descriptor_set_proto(
            map_entry_descriptor, ["key", "value"]).SerializeToString())
    return list(zip(values, parent_indices))
 def test_get_file_descriptor_set_proto_simple_test_map(self):
     file_set_proto = file_descriptor_set.get_file_descriptor_set_proto(
         test_map_pb2.SubMessage.DESCRIPTOR, [])
     self.assertLen(file_set_proto.file, 1)
     self.assertEqual(
         file_set_proto.file[0].name,
         _get_base_directory() + "struct2tensor/test/test_map.proto")
示例#4
0
def parse_message_level(tensor_of_protos, descriptor_type, field_names):
    """Parses a subset of the fields at a level of a message.

  If there is a field with a message type, it is parsed as a string. Then, the
  function can be applied recursively.

  Args:
    tensor_of_protos: a 1-D tensor of strings of protocol buffers.
    descriptor_type: a descriptor for the protocol buffer to parse. See
      https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/descriptor.py
      field_names: the names of the fields to parse.

  Returns:
    list of named _ParsedField, one per field_name in field_names:
    field_name: the string from field_names.
    field_descriptor: descriptor_type.fields_by_name[field_name]
    value: a 1-D tensor of the values from the field field_name.
    index: an index, such that for all i, tensor_of_protos[index[i]] has a
      value value[i]. Note that sometimes index[i]=index[i+1], implying a
      repeated field field_name.

  """
    if not field_names:
        return []
    message_type = descriptor_type.full_name
    descriptor_set = file_descriptor_set.get_file_descriptor_set_proto(
        descriptor_type, field_names)
    descriptor_literal = descriptor_set.SerializeToString()
    # TODO(martinz): catch KeyError and give a better error.
    field_descriptors = [
        _get_field_descriptor(descriptor_type, field_name)
        for field_name in field_names
    ]
    output_types = [
        _get_dtype_from_cpp_type(field_descriptor.cpp_type)
        for field_descriptor in field_descriptors
    ]
    values, indices = gen_decode_proto_sparse.decode_proto_sparse_v2(
        tensor_of_protos,
        descriptor_literal=descriptor_literal,
        message_type=message_type,
        num_fields=len(field_names),
        field_names=list(field_names),
        output_types=output_types)
    return [
        _ParsedField(  # pylint: disable=g-complex-comprehension
            field_name=field_name,
            field_descriptor=field_descriptor,
            value=value,
            index=index) for field_name, field_descriptor, value, index in zip(
                field_names, field_descriptors, values, indices)
    ]
示例#5
0
def parse_message_level(
        tensor_of_protos: tf.Tensor,
        descriptor_type: descriptor.Descriptor,
        field_names: Sequence[str],
        message_format: Text = "binary") -> Sequence[_ParsedField]:
    """Parses a subset of the fields at a level of a message.

  If there is a field with a message type, it is parsed as a string. Then, the
  function can be applied recursively.

  Args:
    tensor_of_protos: a 1-D tensor of strings of protocol buffers.
    descriptor_type: a descriptor for the protocol buffer to parse. See
      https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/descriptor.py
    field_names: the names of the fields to parse.
    message_format: Indicates the format of the protocol buffer: is one of
      'text' or 'binary'.
  Returns:
    list of named _ParsedField, one per field_name in field_names:
    field_name: the string from field_names.
    field_descriptor: descriptor_type.fields_by_name[field_name]
    value: a 1-D tensor of the values from the field field_name.
    index: an index, such that for all i, tensor_of_protos[index[i]] has a
      value value[i]. Note that sometimes index[i]=index[i+1], implying a
      repeated field field_name.

  """
    if not field_names:
        return []
    # We sort the field names so that the input attr to DecodeProtoSparseV2 op
    # is deterministic.
    field_names = sorted(field_names)
    message_type = descriptor_type.full_name
    descriptor_set = file_descriptor_set.get_file_descriptor_set_proto(
        descriptor_type, field_names)
    descriptor_literal = descriptor_set.SerializeToString()
    # TODO(martinz): catch KeyError and give a better error.
    field_descriptors = [
        _get_field_descriptor(descriptor_type, field_name)
        for field_name in field_names
    ]
    output_types = [
        _get_dtype_from_cpp_type(field_descriptor.cpp_type)
        for field_descriptor in field_descriptors
    ]
    values, indices = gen_decode_proto_sparse.decode_proto_sparse_v2(
        tensor_of_protos,
        descriptor_literal=descriptor_literal,
        message_type=message_type,
        num_fields=len(field_names),
        field_names=list(field_names),
        output_types=output_types,
        message_format=message_format)

    result = []
    for field_name, field_descriptor, value, index in zip(
            field_names, field_descriptors, values, indices):
        result.append(
            _ParsedField(field_name=field_name,
                         field_descriptor=field_descriptor,
                         value=value,
                         index=index))

    return result
def parse_message_level(
        tensor_of_protos: tf.Tensor,
        descriptor_type: descriptor.Descriptor,
        field_names: Sequence[str],
        message_format: str = "binary",
        backing_str_tensor: Optional[tf.Tensor] = None,
        honor_proto3_optional_semantics: bool = False
) -> Sequence[_ParsedField]:
    """Parses a subset of the fields at a level of a message.

  If there is a field with a message type, it is parsed as a string. Then, the
  function can be applied recursively.

  Args:
    tensor_of_protos: a 1-D tensor of strings of protocol buffers.
    descriptor_type: a descriptor for the protocol buffer to parse. See
      https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/descriptor.py
    field_names: the names of the fields to parse.
    message_format: Indicates the format of the protocol buffer: is one of
      'text' or 'binary'.
    backing_str_tensor: a possible string tensor backing the string_view for
      intermediate serialized protos.
    honor_proto3_optional_semantics: if True, and if a proto3 primitive optional
      field without the presence semantic (i.e. the field is without the
      "optional" or "repeated" label) is requested to be parsed, it will always
      have a value for each input parent message. If a value is not present on
      wire, the default value (0 or "") will be used.
  Returns:
    list of named _ParsedField, one per field_name in field_names:
    field_name: the string from field_names.
    field_descriptor: descriptor_type.fields_by_name[field_name]
    value: a 1-D tensor of the values from the field field_name.
    index: an index, such that for all i, tensor_of_protos[index[i]] has a
      value value[i]. Note that sometimes index[i]=index[i+1], implying a
      repeated field field_name.

  """
    if not field_names:
        return []
    # We sort the field names so that the input attr to DecodeProtoSparseV2 op
    # is deterministic.
    field_names = sorted(field_names)
    message_type = descriptor_type.full_name
    descriptor_set = file_descriptor_set.get_file_descriptor_set_proto(
        descriptor_type, field_names)
    descriptor_literal = descriptor_set.SerializeToString()
    # TODO(martinz): catch KeyError and give a better error.
    field_descriptors = [
        _get_field_descriptor(descriptor_type, field_name)
        for field_name in field_names
    ]
    output_types = [
        _get_dtype_from_cpp_type(field_descriptor.cpp_type)
        for field_descriptor in field_descriptors
    ]
    if tf.is_tensor(backing_str_tensor):
        assert message_format == "binary", (
            "message_format must be 'binary' if a backing_str_tensor is provided"
        )
        backing_str_tensor = [backing_str_tensor]
    else:
        backing_str_tensor = []
    # TODO(b/185908025): Once we allow sufficient bake in time for v4, switch to
    # v4 only. v4 supports both proto3 optional semantics and string view.
    if honor_proto3_optional_semantics:
        values, indices = gen_decode_proto_sparse.decode_proto_sparse_v4(
            tensor_of_protos,
            backing_str_tensor,
            descriptor_literal=descriptor_literal,
            message_type=message_type,
            num_fields=len(field_names),
            field_names=list(field_names),
            output_types=output_types,
            message_format=message_format,
            honor_proto3_optional_semantics=honor_proto3_optional_semantics)
    else:
        values, indices = gen_decode_proto_sparse.decode_proto_sparse_v3(
            tensor_of_protos,
            backing_str_tensor,
            descriptor_literal=descriptor_literal,
            message_type=message_type,
            num_fields=len(field_names),
            field_names=list(field_names),
            output_types=output_types,
            message_format=message_format)

    result = []
    for field_name, field_descriptor, value, index in zip(
            field_names, field_descriptors, values, indices):
        result.append(
            _ParsedField(field_name=field_name,
                         field_descriptor=field_descriptor,
                         value=value,
                         index=index))

    return result
def parse_message_level(
        tensor_of_protos: tf.Tensor,
        descriptor_type: descriptor.Descriptor,
        field_names: Sequence[str],
        message_format: str = "binary",
        backing_str_tensor: Optional[tf.Tensor] = None
) -> Sequence[_ParsedField]:
    """Parses a subset of the fields at a level of a message.

  If there is a field with a message type, it is parsed as a string. Then, the
  function can be applied recursively.

  Args:
    tensor_of_protos: a 1-D tensor of strings of protocol buffers.
    descriptor_type: a descriptor for the protocol buffer to parse. See
      https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/descriptor.py
    field_names: the names of the fields to parse.
    message_format: Indicates the format of the protocol buffer: is one of
      'text' or 'binary'.
    backing_str_tensor: a possible string tensor backing the string_view for
      intermediate serialized protos.
  Returns:
    list of named _ParsedField, one per field_name in field_names:
    field_name: the string from field_names.
    field_descriptor: descriptor_type.fields_by_name[field_name]
    value: a 1-D tensor of the values from the field field_name.
    index: an index, such that for all i, tensor_of_protos[index[i]] has a
      value value[i]. Note that sometimes index[i]=index[i+1], implying a
      repeated field field_name.

  """
    if not field_names:
        return []
    # We sort the field names so that the input attr to DecodeProtoSparseV2 op
    # is deterministic.
    field_names = sorted(field_names)
    message_type = descriptor_type.full_name
    descriptor_set = file_descriptor_set.get_file_descriptor_set_proto(
        descriptor_type, field_names)
    descriptor_literal = descriptor_set.SerializeToString()
    # TODO(martinz): catch KeyError and give a better error.
    field_descriptors = [
        _get_field_descriptor(descriptor_type, field_name)
        for field_name in field_names
    ]
    output_types = [
        _get_dtype_from_cpp_type(field_descriptor.cpp_type)
        for field_descriptor in field_descriptors
    ]
    if tf.is_tensor(backing_str_tensor):
        backing_str_tensor = [backing_str_tensor]
    else:
        backing_str_tensor = []
    # TODO(b/172576749): Once we allow sufficient bake in time for the kernel
    # change, switch to using V3 only.
    if backing_str_tensor:
        assert message_format == "binary", (
            "message_format must be 'binary' if a backing_str_tensor is provided"
        )
        values, indices = gen_decode_proto_sparse.decode_proto_sparse_v3(
            tensor_of_protos,
            backing_str_tensor,
            descriptor_literal=descriptor_literal,
            message_type=message_type,
            num_fields=len(field_names),
            field_names=list(field_names),
            output_types=output_types,
            message_format=message_format)
    else:
        values, indices = gen_decode_proto_sparse.decode_proto_sparse_v2(
            tensor_of_protos,
            descriptor_literal=descriptor_literal,
            message_type=message_type,
            num_fields=len(field_names),
            field_names=list(field_names),
            output_types=output_types,
            message_format=message_format)

    result = []
    for field_name, field_descriptor, value, index in zip(
            field_names, field_descriptors, values, indices):
        result.append(
            _ParsedField(field_name=field_name,
                         field_descriptor=field_descriptor,
                         value=value,
                         index=index))

    return result