def test_out_of_order_repeated_fields_1(self): # This is a 2-1-2 wire number pattern. proto = (test_pb2.Event(query_token=["aaa"]).SerializeToString() + test_pb2.Event(event_id="abc").SerializeToString() + test_pb2.Event(query_token=["bbb"]).SerializeToString()) expected_field_value = { "query_token": [b"aaa", b"bbb"], "event_id": [b"abc"] } for fields_to_parse in [["query_token"], ["event_id"], ["query_token", "event_id"]]: parsed_fields = struct2tensor_ops.parse_message_level( [proto], test_pb2.Event.DESCRIPTOR, fields_to_parse) for f in parsed_fields: self.assertAllEqual(expected_field_value[f.field_name], f.value)
def test_out_of_order_fields(self): fragments = [ test_pb2.Event(query_token=["aaa"]).SerializeToString(), test_pb2.Event(query_token=["bbb"]).SerializeToString(), test_pb2.Event(event_id="abc").SerializeToString(), test_pb2.Event(action_mask=[False, True]).SerializeToString(), ] # Test against all 4! permutations of fragments, and for each permutation # test parsing all possible combination of 4 fields. for indices in itertools.permutations(range(len(fragments))): proto = b"".join([fragments[i] for i in indices]) for i in indices: if i == 0: expected_query_tokens = [b"aaa", b"bbb"] break if i == 1: expected_query_tokens = [b"bbb", b"aaa"] break # "query" is not on wire at all. all_fields_to_parse = [ "query_token", "event_id", "action_mask", "query" ] expected_field_value = { "action_mask": [False, True], "query_token": expected_query_tokens, "event_id": [b"abc"], "query": np.array([], dtype=np.object), } for num_fields_to_parse in range(len(all_fields_to_parse)): for comb in itertools.combinations(all_fields_to_parse, num_fields_to_parse): parsed_fields = struct2tensor_ops.parse_message_level( [proto], test_pb2.Event.DESCRIPTOR, comb) self.assertLen(parsed_fields, len(comb)) for f in parsed_fields: self.assertAllEqual( expected_field_value[f.field_name], f.value, "field: {}, permutation: {}, field_to_parse: {}". format(f.field_name, indices, comb))
def test_out_of_order_repeated_fields_2(self): # This follows a 3-5-3 wire number pattern, where 3 and 4 parsed fields. proto = (test_pb2.Event(query_token=["aaa"]).SerializeToString() + test_pb2.Event(action_mask=[True]).SerializeToString() + test_pb2.Event(query_token=["bbb"]).SerializeToString()) expected_field_value = { "query_token": [b"aaa", b"bbb"], "action_mask": [True], "action": [] } for fields_to_parse in [["query_token"], ["action_mask"], ["query_token", "action_mask"], ["query_token", "action"], ["query_token", "action_mask", "action"]]: parsed_fields = struct2tensor_ops.parse_message_level( [proto], test_pb2.Event.DESCRIPTOR, fields_to_parse) for f in parsed_fields: expected_value = expected_field_value[f.field_name] if expected_value: self.assertAllEqual(expected_value, f.value)
def test_parse_full_message_level_for_event(self): event = test_pb2.Event() event.event_id = "foo" event.query = "query" event.query_token.append("a") event.query_token.append("b") action0 = event.action.add() action0.doc_id = "abc" action1 = event.action.add() event.user_info.age_in_years = 38 event2 = test_pb2.Event() action2 = event2.action.add() action2.doc_id = "def" parsed_field_dict = _parse_full_message_level_as_dict([event, event2]) doc_id = parsed_field_dict["action"] serialized_actions = [ proto.SerializeToString() for proto in [action0, action1, action2] ] self.assertAllEqual(doc_id.index, [0, 0, 1]) self.assertAllEqual(doc_id.value, serialized_actions)