def _create_any_protos(): my_value_0 = test_pb2.AllSimple() my_value_0.optional_int32 = 0 my_value_1 = test_pb2.UserInfo() my_value_2 = test_pb2.AllSimple() my_value_2.optional_int32 = 20 return [_create_any(x) for x in [my_value_0, my_value_1, my_value_2]]
def test_proto2_optional_field_with_honor_proto3_optional_semantic(self): proto2_message1 = test_pb2.AllSimple() proto2_message2 = test_pb2.AllSimple(optional_string="a") tensor_of_protos = tf.constant([ proto2_message1.SerializeToString(), proto2_message2.SerializeToString(), proto2_message1.SerializeToString() ]) parsed_tuples = struct2tensor_ops.parse_message_level( tensor_of_protos, test_pb2.AllSimple.DESCRIPTOR, [ "optional_string", ], honor_proto3_optional_semantics=True) indices = { parsed_tuple.field_name: parsed_tuple.index for parsed_tuple in parsed_tuples } values = { parsed_tuple.field_name: parsed_tuple.value for parsed_tuple in parsed_tuples } # Only the second proto has value. No default value should be inserted. for idx in indices.values(): self.assertAllEqual([1], idx) for value in values.values(): self.assertAllEqual([b"a"], value)
def test_parse_full_message_level_for_all_simple_repeated_repeated(self): """Test five messages with every possible repeated field repeated.""" all_simple = test_pb2.AllSimple() all_simple.repeated_string.append("foo") all_simple.repeated_string.append("foo2") all_simple.repeated_int32.append(32) all_simple.repeated_int32.append(322) all_simple.repeated_uint32.append(123) all_simple.repeated_uint32.append(1232) all_simple.repeated_int64.append(123456) all_simple.repeated_int64.append(1234562) all_simple.repeated_uint64.append(123) all_simple.repeated_uint64.append(1232) all_simple.repeated_float.append(1.0) all_simple.repeated_float.append(2.0) all_simple.repeated_double.append(1.5) all_simple.repeated_double.append(2.5) result = _get_full_message_level_runnable([ all_simple, test_pb2.AllSimple(), test_pb2.AllSimple(), all_simple, test_pb2.AllSimple(), all_simple, test_pb2.AllSimple() ]) self.assertAllEqual(result["repeated_string"][INDEX], [0, 0, 3, 3, 5, 5]) self.assertAllEqual( result["repeated_string"][VALUE], [b"foo", b"foo2", b"foo", b"foo2", b"foo", b"foo2"]) self.assertAllEqual(result["repeated_int32"][INDEX], [0, 0, 3, 3, 5, 5]) self.assertAllEqual(result["repeated_int32"][VALUE], [32, 322, 32, 322, 32, 322]) self.assertAllEqual(result["repeated_uint32"][INDEX], [0, 0, 3, 3, 5, 5]) self.assertAllEqual(result["repeated_uint32"][VALUE], [123, 1232, 123, 1232, 123, 1232]) self.assertAllEqual(result["repeated_int64"][INDEX], [0, 0, 3, 3, 5, 5]) self.assertAllEqual( result["repeated_int64"][VALUE], [123456, 1234562, 123456, 1234562, 123456, 1234562]) self.assertAllEqual(result["repeated_uint64"][INDEX], [0, 0, 3, 3, 5, 5]) self.assertAllEqual(result["repeated_uint64"][VALUE], [123, 1232, 123, 1232, 123, 1232]) self.assertAllEqual(result["repeated_float"][INDEX], [0, 0, 3, 3, 5, 5]) self.assertAllEqual(result["repeated_float"][VALUE], [1.0, 2.0, 1.0, 2.0, 1.0, 2.0]) self.assertAllEqual(result["repeated_double"][INDEX], [0, 0, 3, 3, 5, 5]) self.assertAllEqual(result["repeated_double"][VALUE], [1.5, 2.5, 1.5, 2.5, 1.5, 2.5])
def _get_expression_with_any(): my_any_0 = test_any_pb2.MessageWithAny() my_value_0 = test_pb2.AllSimple() my_value_0.optional_int32 = 0 my_any_0.my_any.Pack(my_value_0) my_any_1 = test_any_pb2.MessageWithAny() my_value_1 = test_pb2.UserInfo() my_any_1.my_any.Pack(my_value_1) my_any_2 = test_any_pb2.MessageWithAny() my_value_2 = test_pb2.AllSimple() my_value_2.optional_int32 = 20 my_any_2.my_any.Pack(my_value_2) serialized = [x.SerializeToString() for x in [my_any_0, my_any_1, my_any_2]] return proto.create_expression_from_proto( serialized, test_any_pb2.MessageWithAny.DESCRIPTOR)
def test_any_path(self): my_any_0 = test_any_pb2.MessageWithAny() my_value_0 = test_pb2.AllSimple() my_value_0.optional_int32 = 17 my_any_0.my_any.Pack(my_value_0) expr = proto.create_expression_from_proto( [my_any_0.SerializeToString()], test_any_pb2.MessageWithAny.DESCRIPTOR) new_root = promote.promote( expr, path.Path( ["my_any", "(struct2tensor.test.AllSimple)", "optional_int32"]), "new_int32") new_field = new_root.get_descendant_or_error( path.Path(["my_any", "new_int32"])) result = calculate_with_source_paths.calculate_prensors_with_source_paths( [new_field]) prensor_result, proto_summary_result = result self.assertLen(prensor_result, 1) self.assertLen(proto_summary_result, 1) leaf_node = prensor_result[0].node self.assertAllEqual(leaf_node.parent_index, [0]) self.assertAllEqual(leaf_node.values, [17]) list_of_paths = proto_summary_result[0].paths expected = [ path.Path( ["my_any", "(struct2tensor.test.AllSimple)", "optional_int32"]) ] self.equal_ignore_order(list_of_paths, expected)
def test_normal_field(self, message_format): """Test three messages with a repeated string.""" all_simple = test_pb2.AllSimple() all_simple.repeated_string.append("foo") all_simple.repeated_string.append("foo2") all_simple_empty = test_pb2.AllSimple() result = _run_parse_message_level_ex( [all_simple, all_simple_empty, all_simple, all_simple], {"repeated_string"}, message_format) self.assertNotIn("repeated_bool", result) self.assertAllEqual(result["repeated_string"][_INDEX], [0, 0, 2, 2, 3, 3]) self.assertAllEqual( result["repeated_string"][_VALUE], [b"foo", b"foo2", b"foo", b"foo2", b"foo", b"foo2"])
def test_parse_full_message_level_for_simple_action_multiple(self): """Test multiple messages.""" as1 = test_pb2.AllSimple() as1.optional_string = "a" as1.repeated_string.append("b") as1.repeated_string.append("c") as2 = test_pb2.AllSimple() as2.optional_string = "d" as2.optional_int32 = 123 as3 = test_pb2.AllSimple() as3.repeated_string.append("d") as3.repeated_string.append("e") as3.optional_int32 = 123 parsed_field_dict = _parse_full_message_level_as_dict([as1, as2, as3]) doc_id = parsed_field_dict["repeated_string"] self.assertAllEqual(doc_id.index, [0, 0, 2, 2]) self.assertAllEqual(doc_id.value, [b"b", b"c", b"d", b"e"])
def test_out_of_order_repeated_fields_3(self): # This follows a 3-5-3 wire number pattern, where 3 and 4 parsed fields. proto = ( test_pb2.AllSimple(repeated_string=["aaa"]).SerializeToString() + test_pb2.AllSimple(repeated_int64=[12345]).SerializeToString() + test_pb2.AllSimple(repeated_string=["bbb"]).SerializeToString()) expected_field_value = { "repeated_string": [b"aaa", b"bbb"], "repeated_int64": [12345], "repeated_int32": [], "repeated_uint32": [] } for fields_to_parse in [["repeated_int64"], ["repeated_string"], [ "repeated_string", "repeated_uint32", "repeated_int32" ]]: parsed_fields = struct2tensor_ops.parse_message_level( [proto], test_pb2.AllSimple.DESCRIPTOR, fields_to_parse) for f in parsed_fields: self.assertAllEqual(expected_field_value[f.field_name], f.value)
def test_parse_full_message_level_for_all_simple(self): """Test a single message with every possible primitive field.""" all_simple = test_pb2.AllSimple() all_simple.optional_string = "foo" all_simple.optional_int32 = -5 all_simple.optional_uint32 = 2**31 all_simple.optional_int64 = 100123 all_simple.optional_uint64 = 2**63 all_simple.optional_float = 6.5 all_simple.optional_double = -7.0 all_simple.repeated_string.append("foo") all_simple.repeated_int32.append(32) all_simple.repeated_uint32.append(123) all_simple.repeated_int64.append(123456) all_simple.repeated_uint64.append(123) all_simple.repeated_float.append(1.0) all_simple.repeated_double.append(1.5) runnable = _get_full_message_level_runnable([all_simple]) self.assertLen(runnable["optional_string"][INDEX].shape.dims, 1) self.assertLen(runnable["optional_string"][VALUE].shape.dims, 1) self.assertLen(runnable["repeated_string"][INDEX].shape.dims, 1) self.assertLen(runnable["repeated_string"][VALUE].shape.dims, 1) result = runnable self.assertAllEqual(result["optional_string"][INDEX], [0]) self.assertAllEqual(result["optional_string"][VALUE], [b"foo"]) self.assertAllEqual(result["optional_int32"][INDEX], [0]) self.assertAllEqual(result["optional_int32"][VALUE], [-5]) self.assertAllEqual(result["optional_uint32"][INDEX], [0]) self.assertAllEqual(result["optional_uint32"][VALUE], [2**31]) self.assertAllEqual(result["optional_int64"][INDEX], [0]) self.assertAllEqual(result["optional_int64"][VALUE], [100123]) self.assertAllEqual(result["optional_uint64"][INDEX], [0]) self.assertAllEqual(result["optional_uint64"][VALUE], [2**63]) self.assertAllEqual(result["optional_float"][INDEX], [0]) self.assertAllEqual(result["optional_float"][VALUE], [6.5]) self.assertAllEqual(result["optional_double"][INDEX], [0]) self.assertAllEqual(result["optional_double"][VALUE], [-7.0])
def _get_optional_int32(serialized_all_simple): """Take a serialized test_pb2.AllSimple object and extract optional_int32.""" holder = test_pb2.AllSimple() holder.ParseFromString(serialized_all_simple) return holder.optional_int32
def _get_empty_all_simple(): """Take a serialized test_pb2.AllSimple object and extract optional_int32.""" return test_pb2.AllSimple().SerializeToString()