Ejemplo n.º 1
0
    def test_out_of_order_repeated_fields_1(self):
        # This is a 2-1-2 wire number pattern.
        proto = (test_pb2.Event(query_token=["aaa"]).SerializeToString() +
                 test_pb2.Event(event_id="abc").SerializeToString() +
                 test_pb2.Event(query_token=["bbb"]).SerializeToString())
        expected_field_value = {
            "query_token": [b"aaa", b"bbb"],
            "event_id": [b"abc"]
        }

        for fields_to_parse in [["query_token"], ["event_id"],
                                ["query_token", "event_id"]]:
            parsed_fields = struct2tensor_ops.parse_message_level(
                [proto], test_pb2.Event.DESCRIPTOR, fields_to_parse)
            for f in parsed_fields:
                self.assertAllEqual(expected_field_value[f.field_name],
                                    f.value)
Ejemplo n.º 2
0
    def test_out_of_order_fields(self):
        fragments = [
            test_pb2.Event(query_token=["aaa"]).SerializeToString(),
            test_pb2.Event(query_token=["bbb"]).SerializeToString(),
            test_pb2.Event(event_id="abc").SerializeToString(),
            test_pb2.Event(action_mask=[False, True]).SerializeToString(),
        ]
        # Test against all 4! permutations of fragments, and for each permutation
        # test parsing all possible combination of 4 fields.
        for indices in itertools.permutations(range(len(fragments))):
            proto = b"".join([fragments[i] for i in indices])

            for i in indices:
                if i == 0:
                    expected_query_tokens = [b"aaa", b"bbb"]
                    break
                if i == 1:
                    expected_query_tokens = [b"bbb", b"aaa"]
                    break

            # "query" is not on wire at all.
            all_fields_to_parse = [
                "query_token", "event_id", "action_mask", "query"
            ]
            expected_field_value = {
                "action_mask": [False, True],
                "query_token": expected_query_tokens,
                "event_id": [b"abc"],
                "query": np.array([], dtype=np.object),
            }

            for num_fields_to_parse in range(len(all_fields_to_parse)):
                for comb in itertools.combinations(all_fields_to_parse,
                                                   num_fields_to_parse):
                    parsed_fields = struct2tensor_ops.parse_message_level(
                        [proto], test_pb2.Event.DESCRIPTOR, comb)
                    self.assertLen(parsed_fields, len(comb))
                    for f in parsed_fields:
                        self.assertAllEqual(
                            expected_field_value[f.field_name], f.value,
                            "field: {}, permutation: {}, field_to_parse: {}".
                            format(f.field_name, indices, comb))
Ejemplo n.º 3
0
 def test_out_of_order_repeated_fields_2(self):
     # This follows a 3-5-3 wire number pattern, where 3 and 4 parsed fields.
     proto = (test_pb2.Event(query_token=["aaa"]).SerializeToString() +
              test_pb2.Event(action_mask=[True]).SerializeToString() +
              test_pb2.Event(query_token=["bbb"]).SerializeToString())
     expected_field_value = {
         "query_token": [b"aaa", b"bbb"],
         "action_mask": [True],
         "action": []
     }
     for fields_to_parse in [["query_token"], ["action_mask"],
                             ["query_token", "action_mask"],
                             ["query_token", "action"],
                             ["query_token", "action_mask", "action"]]:
         parsed_fields = struct2tensor_ops.parse_message_level(
             [proto], test_pb2.Event.DESCRIPTOR, fields_to_parse)
         for f in parsed_fields:
             expected_value = expected_field_value[f.field_name]
             if expected_value:
                 self.assertAllEqual(expected_value, f.value)
Ejemplo n.º 4
0
    def test_parse_full_message_level_for_event(self):
        event = test_pb2.Event()
        event.event_id = "foo"
        event.query = "query"
        event.query_token.append("a")
        event.query_token.append("b")
        action0 = event.action.add()
        action0.doc_id = "abc"
        action1 = event.action.add()
        event.user_info.age_in_years = 38
        event2 = test_pb2.Event()
        action2 = event2.action.add()
        action2.doc_id = "def"
        parsed_field_dict = _parse_full_message_level_as_dict([event, event2])
        doc_id = parsed_field_dict["action"]
        serialized_actions = [
            proto.SerializeToString() for proto in [action0, action1, action2]
        ]

        self.assertAllEqual(doc_id.index, [0, 0, 1])
        self.assertAllEqual(doc_id.value, serialized_actions)