예제 #1
0
def parse_message_level_ex(
    tensor_of_protos: tf.Tensor,
    desc: descriptor.Descriptor,
    field_names: Set[ProtoFieldName],
    message_format: str = "binary",
    backing_str_tensor: Optional[tf.Tensor] = None,
    honor_proto3_optional_semantics: bool = False
) -> Mapping[StrStep, struct2tensor_ops._ParsedField]:
    """Parses regular fields, extensions, any casts, and map protos."""
    raw_field_names = _get_field_names_to_parse(desc, field_names)
    regular_fields = list(
        struct2tensor_ops.parse_message_level(
            tensor_of_protos,
            desc,
            raw_field_names,
            message_format=message_format,
            backing_str_tensor=backing_str_tensor,
            honor_proto3_optional_semantics=honor_proto3_optional_semantics))
    regular_field_map = {x.field_name: x for x in regular_fields}

    any_fields = _get_any_parsed_fields(desc, regular_field_map, field_names)
    map_fields = _get_map_parsed_fields(desc, regular_field_map, field_names,
                                        backing_str_tensor)
    result = regular_field_map
    result.update(any_fields)
    result.update(map_fields)
    return result
 def test_proto2_optional_field_with_honor_proto3_optional_semantic(self):
     proto2_message1 = test_pb2.AllSimple()
     proto2_message2 = test_pb2.AllSimple(optional_string="a")
     tensor_of_protos = tf.constant([
         proto2_message1.SerializeToString(),
         proto2_message2.SerializeToString(),
         proto2_message1.SerializeToString()
     ])
     parsed_tuples = struct2tensor_ops.parse_message_level(
         tensor_of_protos,
         test_pb2.AllSimple.DESCRIPTOR, [
             "optional_string",
         ],
         honor_proto3_optional_semantics=True)
     indices = {
         parsed_tuple.field_name: parsed_tuple.index
         for parsed_tuple in parsed_tuples
     }
     values = {
         parsed_tuple.field_name: parsed_tuple.value
         for parsed_tuple in parsed_tuples
     }
     # Only the second proto has value. No default value should be inserted.
     for idx in indices.values():
         self.assertAllEqual([1], idx)
     for value in values.values():
         self.assertAllEqual([b"a"], value)
예제 #3
0
    def _parse_map_entry(self, messages_with_map, map_field_name, keys_needed):
        parsed_map_submessage = struct2tensor_ops.parse_message_level(
            tf.constant([m.SerializeToString() for m in messages_with_map]),
            test_map_pb2.MessageWithMap.DESCRIPTOR, [map_field_name])[0]

        return struct2tensor_ops.parse_proto_map(
            parsed_map_submessage.value, parsed_map_submessage.index,
            parsed_map_submessage.field_descriptor.message_type, keys_needed)
예제 #4
0
 def test_parse_message_level(self):
     action = test_pb2.Action()
     action.doc_id = "3"
     action.number_of_views = 3
     tensor_of_protos = tf.constant([action.SerializeToString()])
     [field_tuple] = struct2tensor_ops.parse_message_level(
         tensor_of_protos,
         test_pb2.Action().DESCRIPTOR, ["number_of_views"])
     values = field_tuple.value
     indices = field_tuple.index
     self.assertAllEqual(indices, [0])
     self.assertAllEqual(values, [3])
예제 #5
0
 def test_parse_external_extension(self):
   user_info = test_pb2.UserInfo()
   user_info.Extensions[
       test_extension_pb2.MyExternalExtension.ext].special = "shhh"
   expected_value = test_extension_pb2.MyExternalExtension()
   expected_value.special = "shhh"
   tensor_of_protos = tf.constant([user_info.SerializeToString()])
   [field_tuple] = struct2tensor_ops.parse_message_level(
       tensor_of_protos,
       test_pb2.UserInfo().DESCRIPTOR,
       ["(struct2tensor.test.MyExternalExtension.ext)"])
   self.assertAllEqual(field_tuple.index, [0])
   self.assertAllEqual(field_tuple.value, [expected_value.SerializeToString()])
예제 #6
0
def parse_message_level_ex(tensor_of_protos, desc, field_names):
    """Parses regular fields, extensions, any casts, and map protos."""
    raw_field_names = _get_field_names_to_parse(desc, field_names)
    regular_fields = list(
        struct2tensor_ops.parse_message_level(tensor_of_protos, desc,
                                              raw_field_names))
    regular_field_map = {x.field_name: x for x in regular_fields}

    any_fields = _get_any_parsed_fields(desc, regular_field_map, field_names)
    map_fields = _get_map_parsed_fields(desc, regular_field_map, field_names)
    result = regular_field_map
    result.update(any_fields)
    result.update(map_fields)
    return result
예제 #7
0
    def test_out_of_order_repeated_fields_1(self):
        # This is a 2-1-2 wire number pattern.
        proto = (test_pb2.Event(query_token=["aaa"]).SerializeToString() +
                 test_pb2.Event(event_id="abc").SerializeToString() +
                 test_pb2.Event(query_token=["bbb"]).SerializeToString())
        expected_field_value = {
            "query_token": [b"aaa", b"bbb"],
            "event_id": [b"abc"]
        }

        for fields_to_parse in [["query_token"], ["event_id"],
                                ["query_token", "event_id"]]:
            parsed_fields = struct2tensor_ops.parse_message_level(
                [proto], test_pb2.Event.DESCRIPTOR, fields_to_parse)
            for f in parsed_fields:
                self.assertAllEqual(expected_field_value[f.field_name],
                                    f.value)
예제 #8
0
    def test_out_of_order_fields(self):
        fragments = [
            test_pb2.Event(query_token=["aaa"]).SerializeToString(),
            test_pb2.Event(query_token=["bbb"]).SerializeToString(),
            test_pb2.Event(event_id="abc").SerializeToString(),
            test_pb2.Event(action_mask=[False, True]).SerializeToString(),
        ]
        # Test against all 4! permutations of fragments, and for each permutation
        # test parsing all possible combination of 4 fields.
        for indices in itertools.permutations(range(len(fragments))):
            proto = b"".join([fragments[i] for i in indices])

            for i in indices:
                if i == 0:
                    expected_query_tokens = [b"aaa", b"bbb"]
                    break
                if i == 1:
                    expected_query_tokens = [b"bbb", b"aaa"]
                    break

            # "query" is not on wire at all.
            all_fields_to_parse = [
                "query_token", "event_id", "action_mask", "query"
            ]
            expected_field_value = {
                "action_mask": [False, True],
                "query_token": expected_query_tokens,
                "event_id": [b"abc"],
                "query": np.array([], dtype=np.object),
            }

            for num_fields_to_parse in range(len(all_fields_to_parse)):
                for comb in itertools.combinations(all_fields_to_parse,
                                                   num_fields_to_parse):
                    parsed_fields = struct2tensor_ops.parse_message_level(
                        [proto], test_pb2.Event.DESCRIPTOR, comb)
                    self.assertLen(parsed_fields, len(comb))
                    for f in parsed_fields:
                        self.assertAllEqual(
                            expected_field_value[f.field_name], f.value,
                            "field: {}, permutation: {}, field_to_parse: {}".
                            format(f.field_name, indices, comb))
예제 #9
0
 def test_out_of_order_repeated_fields_2(self):
     # This follows a 3-5-3 wire number pattern, where 3 and 4 parsed fields.
     proto = (test_pb2.Event(query_token=["aaa"]).SerializeToString() +
              test_pb2.Event(action_mask=[True]).SerializeToString() +
              test_pb2.Event(query_token=["bbb"]).SerializeToString())
     expected_field_value = {
         "query_token": [b"aaa", b"bbb"],
         "action_mask": [True],
         "action": []
     }
     for fields_to_parse in [["query_token"], ["action_mask"],
                             ["query_token", "action_mask"],
                             ["query_token", "action"],
                             ["query_token", "action_mask", "action"]]:
         parsed_fields = struct2tensor_ops.parse_message_level(
             [proto], test_pb2.Event.DESCRIPTOR, fields_to_parse)
         for f in parsed_fields:
             expected_value = expected_field_value[f.field_name]
             if expected_value:
                 self.assertAllEqual(expected_value, f.value)
예제 #10
0
def parse_message_level_ex(
    tensor_of_protos: tf.Tensor,
    desc: descriptor.Descriptor,
    field_names: Set[ProtoFieldName],
    message_format: str = "binary"
) -> Mapping[StrStep, struct2tensor_ops._ParsedField]:
    """Parses regular fields, extensions, any casts, and map protos."""
    raw_field_names = _get_field_names_to_parse(desc, field_names)
    regular_fields = list(
        struct2tensor_ops.parse_message_level(tensor_of_protos,
                                              desc,
                                              raw_field_names,
                                              message_format=message_format))
    regular_field_map = {x.field_name: x for x in regular_fields}

    any_fields = _get_any_parsed_fields(desc, regular_field_map, field_names)
    map_fields = _get_map_parsed_fields(desc, regular_field_map, field_names)
    result = regular_field_map
    result.update(any_fields)
    result.update(map_fields)
    return result
예제 #11
0
  def test_out_of_order_repeated_fields_3(self):
    # This follows a 3-5-3 wire number pattern, where 3 and 4 parsed fields.
    proto = (
        test_pb2.AllSimple(repeated_string=["aaa"]).SerializeToString() +
        test_pb2.AllSimple(repeated_int64=[12345]).SerializeToString() +
        test_pb2.AllSimple(repeated_string=["bbb"]).SerializeToString())
    expected_field_value = {
        "repeated_string": [b"aaa", b"bbb"],
        "repeated_int64": [12345],
        "repeated_int32": [],
        "repeated_uint32": []
    }

    for fields_to_parse in [["repeated_int64"], ["repeated_string"],
                            [
                                "repeated_string",
                                "repeated_uint32", "repeated_int32"
                            ]]:
      parsed_fields = struct2tensor_ops.parse_message_level(
          [proto], test_pb2.AllSimple.DESCRIPTOR, fields_to_parse)
      for f in parsed_fields:
        self.assertAllEqual(expected_field_value[f.field_name], f.value)
예제 #12
0
    def test_parse_packed_fields(self):
        message_with_packed_fields = test_pb2.HasPackedFields(
            packed_int32=[-1, -2, -3],
            packed_uint32=[100000, 200000, 300000],
            packed_int64=[-400000, -500000, -600000],
            packed_uint64=[4, 5, 6],
            packed_float=[7.0, 8.0, 9.0],
            packed_double=[10.0, 11.0, 12.0],
        )
        tensor_of_protos = tf.constant(
            [message_with_packed_fields.SerializeToString()] * 2)

        parsed_tuples = struct2tensor_ops.parse_message_level(
            tensor_of_protos, test_pb2.HasPackedFields.DESCRIPTOR, [
                "packed_int32",
                "packed_uint32",
                "packed_int64",
                "packed_uint64",
                "packed_float",
                "packed_double",
            ])
        indices = {
            parsed_tuple.field_name: parsed_tuple.index
            for parsed_tuple in parsed_tuples
        }
        values = {
            parsed_tuple.field_name: parsed_tuple.value
            for parsed_tuple in parsed_tuples
        }

        for index in indices.values():
            self.assertAllEqual(index, [0, 0, 0, 1, 1, 1])
        for field_name, value in values.items():
            self.assertAllEqual(
                value,
                list(getattr(message_with_packed_fields, field_name)) * 2)