def parse_message_level_ex( tensor_of_protos: tf.Tensor, desc: descriptor.Descriptor, field_names: Set[ProtoFieldName], message_format: str = "binary", backing_str_tensor: Optional[tf.Tensor] = None, honor_proto3_optional_semantics: bool = False ) -> Mapping[StrStep, struct2tensor_ops._ParsedField]: """Parses regular fields, extensions, any casts, and map protos.""" raw_field_names = _get_field_names_to_parse(desc, field_names) regular_fields = list( struct2tensor_ops.parse_message_level( tensor_of_protos, desc, raw_field_names, message_format=message_format, backing_str_tensor=backing_str_tensor, honor_proto3_optional_semantics=honor_proto3_optional_semantics)) regular_field_map = {x.field_name: x for x in regular_fields} any_fields = _get_any_parsed_fields(desc, regular_field_map, field_names) map_fields = _get_map_parsed_fields(desc, regular_field_map, field_names, backing_str_tensor) result = regular_field_map result.update(any_fields) result.update(map_fields) return result
def test_proto2_optional_field_with_honor_proto3_optional_semantic(self): proto2_message1 = test_pb2.AllSimple() proto2_message2 = test_pb2.AllSimple(optional_string="a") tensor_of_protos = tf.constant([ proto2_message1.SerializeToString(), proto2_message2.SerializeToString(), proto2_message1.SerializeToString() ]) parsed_tuples = struct2tensor_ops.parse_message_level( tensor_of_protos, test_pb2.AllSimple.DESCRIPTOR, [ "optional_string", ], honor_proto3_optional_semantics=True) indices = { parsed_tuple.field_name: parsed_tuple.index for parsed_tuple in parsed_tuples } values = { parsed_tuple.field_name: parsed_tuple.value for parsed_tuple in parsed_tuples } # Only the second proto has value. No default value should be inserted. for idx in indices.values(): self.assertAllEqual([1], idx) for value in values.values(): self.assertAllEqual([b"a"], value)
def _parse_map_entry(self, messages_with_map, map_field_name, keys_needed): parsed_map_submessage = struct2tensor_ops.parse_message_level( tf.constant([m.SerializeToString() for m in messages_with_map]), test_map_pb2.MessageWithMap.DESCRIPTOR, [map_field_name])[0] return struct2tensor_ops.parse_proto_map( parsed_map_submessage.value, parsed_map_submessage.index, parsed_map_submessage.field_descriptor.message_type, keys_needed)
def test_parse_message_level(self): action = test_pb2.Action() action.doc_id = "3" action.number_of_views = 3 tensor_of_protos = tf.constant([action.SerializeToString()]) [field_tuple] = struct2tensor_ops.parse_message_level( tensor_of_protos, test_pb2.Action().DESCRIPTOR, ["number_of_views"]) values = field_tuple.value indices = field_tuple.index self.assertAllEqual(indices, [0]) self.assertAllEqual(values, [3])
def test_parse_external_extension(self): user_info = test_pb2.UserInfo() user_info.Extensions[ test_extension_pb2.MyExternalExtension.ext].special = "shhh" expected_value = test_extension_pb2.MyExternalExtension() expected_value.special = "shhh" tensor_of_protos = tf.constant([user_info.SerializeToString()]) [field_tuple] = struct2tensor_ops.parse_message_level( tensor_of_protos, test_pb2.UserInfo().DESCRIPTOR, ["(struct2tensor.test.MyExternalExtension.ext)"]) self.assertAllEqual(field_tuple.index, [0]) self.assertAllEqual(field_tuple.value, [expected_value.SerializeToString()])
def parse_message_level_ex(tensor_of_protos, desc, field_names): """Parses regular fields, extensions, any casts, and map protos.""" raw_field_names = _get_field_names_to_parse(desc, field_names) regular_fields = list( struct2tensor_ops.parse_message_level(tensor_of_protos, desc, raw_field_names)) regular_field_map = {x.field_name: x for x in regular_fields} any_fields = _get_any_parsed_fields(desc, regular_field_map, field_names) map_fields = _get_map_parsed_fields(desc, regular_field_map, field_names) result = regular_field_map result.update(any_fields) result.update(map_fields) return result
def test_out_of_order_repeated_fields_1(self): # This is a 2-1-2 wire number pattern. proto = (test_pb2.Event(query_token=["aaa"]).SerializeToString() + test_pb2.Event(event_id="abc").SerializeToString() + test_pb2.Event(query_token=["bbb"]).SerializeToString()) expected_field_value = { "query_token": [b"aaa", b"bbb"], "event_id": [b"abc"] } for fields_to_parse in [["query_token"], ["event_id"], ["query_token", "event_id"]]: parsed_fields = struct2tensor_ops.parse_message_level( [proto], test_pb2.Event.DESCRIPTOR, fields_to_parse) for f in parsed_fields: self.assertAllEqual(expected_field_value[f.field_name], f.value)
def test_out_of_order_fields(self): fragments = [ test_pb2.Event(query_token=["aaa"]).SerializeToString(), test_pb2.Event(query_token=["bbb"]).SerializeToString(), test_pb2.Event(event_id="abc").SerializeToString(), test_pb2.Event(action_mask=[False, True]).SerializeToString(), ] # Test against all 4! permutations of fragments, and for each permutation # test parsing all possible combination of 4 fields. for indices in itertools.permutations(range(len(fragments))): proto = b"".join([fragments[i] for i in indices]) for i in indices: if i == 0: expected_query_tokens = [b"aaa", b"bbb"] break if i == 1: expected_query_tokens = [b"bbb", b"aaa"] break # "query" is not on wire at all. all_fields_to_parse = [ "query_token", "event_id", "action_mask", "query" ] expected_field_value = { "action_mask": [False, True], "query_token": expected_query_tokens, "event_id": [b"abc"], "query": np.array([], dtype=np.object), } for num_fields_to_parse in range(len(all_fields_to_parse)): for comb in itertools.combinations(all_fields_to_parse, num_fields_to_parse): parsed_fields = struct2tensor_ops.parse_message_level( [proto], test_pb2.Event.DESCRIPTOR, comb) self.assertLen(parsed_fields, len(comb)) for f in parsed_fields: self.assertAllEqual( expected_field_value[f.field_name], f.value, "field: {}, permutation: {}, field_to_parse: {}". format(f.field_name, indices, comb))
def test_out_of_order_repeated_fields_2(self): # This follows a 3-5-3 wire number pattern, where 3 and 4 parsed fields. proto = (test_pb2.Event(query_token=["aaa"]).SerializeToString() + test_pb2.Event(action_mask=[True]).SerializeToString() + test_pb2.Event(query_token=["bbb"]).SerializeToString()) expected_field_value = { "query_token": [b"aaa", b"bbb"], "action_mask": [True], "action": [] } for fields_to_parse in [["query_token"], ["action_mask"], ["query_token", "action_mask"], ["query_token", "action"], ["query_token", "action_mask", "action"]]: parsed_fields = struct2tensor_ops.parse_message_level( [proto], test_pb2.Event.DESCRIPTOR, fields_to_parse) for f in parsed_fields: expected_value = expected_field_value[f.field_name] if expected_value: self.assertAllEqual(expected_value, f.value)
def parse_message_level_ex( tensor_of_protos: tf.Tensor, desc: descriptor.Descriptor, field_names: Set[ProtoFieldName], message_format: str = "binary" ) -> Mapping[StrStep, struct2tensor_ops._ParsedField]: """Parses regular fields, extensions, any casts, and map protos.""" raw_field_names = _get_field_names_to_parse(desc, field_names) regular_fields = list( struct2tensor_ops.parse_message_level(tensor_of_protos, desc, raw_field_names, message_format=message_format)) regular_field_map = {x.field_name: x for x in regular_fields} any_fields = _get_any_parsed_fields(desc, regular_field_map, field_names) map_fields = _get_map_parsed_fields(desc, regular_field_map, field_names) result = regular_field_map result.update(any_fields) result.update(map_fields) return result
def test_out_of_order_repeated_fields_3(self): # This follows a 3-5-3 wire number pattern, where 3 and 4 parsed fields. proto = ( test_pb2.AllSimple(repeated_string=["aaa"]).SerializeToString() + test_pb2.AllSimple(repeated_int64=[12345]).SerializeToString() + test_pb2.AllSimple(repeated_string=["bbb"]).SerializeToString()) expected_field_value = { "repeated_string": [b"aaa", b"bbb"], "repeated_int64": [12345], "repeated_int32": [], "repeated_uint32": [] } for fields_to_parse in [["repeated_int64"], ["repeated_string"], [ "repeated_string", "repeated_uint32", "repeated_int32" ]]: parsed_fields = struct2tensor_ops.parse_message_level( [proto], test_pb2.AllSimple.DESCRIPTOR, fields_to_parse) for f in parsed_fields: self.assertAllEqual(expected_field_value[f.field_name], f.value)
def test_parse_packed_fields(self): message_with_packed_fields = test_pb2.HasPackedFields( packed_int32=[-1, -2, -3], packed_uint32=[100000, 200000, 300000], packed_int64=[-400000, -500000, -600000], packed_uint64=[4, 5, 6], packed_float=[7.0, 8.0, 9.0], packed_double=[10.0, 11.0, 12.0], ) tensor_of_protos = tf.constant( [message_with_packed_fields.SerializeToString()] * 2) parsed_tuples = struct2tensor_ops.parse_message_level( tensor_of_protos, test_pb2.HasPackedFields.DESCRIPTOR, [ "packed_int32", "packed_uint32", "packed_int64", "packed_uint64", "packed_float", "packed_double", ]) indices = { parsed_tuple.field_name: parsed_tuple.index for parsed_tuple in parsed_tuples } values = { parsed_tuple.field_name: parsed_tuple.value for parsed_tuple in parsed_tuples } for index in indices.values(): self.assertAllEqual(index, [0, 0, 0, 1, 1, 1]) for field_name, value in values.items(): self.assertAllEqual( value, list(getattr(message_with_packed_fields, field_name)) * 2)