def testSubUnknownFields(self): message = unittest_pb2.TestAllTypes() message.optionalgroup.a = 123 destination = unittest_pb2.TestEmptyMessage() destination.ParseFromString(message.SerializeToString()) sub_unknown_fields = unknown_fields.UnknownFieldSet( destination)[0].data self.assertEqual(1, len(sub_unknown_fields)) self.assertEqual(sub_unknown_fields[0].data, 123) destination.Clear() self.assertEqual(1, len(sub_unknown_fields)) self.assertEqual(sub_unknown_fields[0].data, 123) message.Clear() message.optional_uint32 = 456 nested_message = unittest_pb2.NestedTestAllTypes() nested_message.payload.optional_nested_message.ParseFromString( message.SerializeToString()) unknown_field_set = unknown_fields.UnknownFieldSet( nested_message.payload.optional_nested_message) self.assertEqual(unknown_field_set[0].data, 456) nested_message.ClearField('payload') self.assertEqual(unknown_field_set[0].data, 456) unknown_field_set = unknown_fields.UnknownFieldSet( nested_message.payload.optional_nested_message) self.assertEqual(0, len(unknown_field_set))
def testMergeFrom(self): message = unittest_pb2.TestAllTypes() message.optional_int32 = 1 message.optional_uint32 = 2 source = unittest_pb2.TestEmptyMessage() source.ParseFromString(message.SerializeToString()) message.ClearField('optional_int32') message.optional_int64 = 3 message.optional_uint32 = 4 destination = unittest_pb2.TestEmptyMessage() unknown_field_set = unknown_fields.UnknownFieldSet(destination) self.assertEqual(0, len(unknown_field_set)) destination.ParseFromString(message.SerializeToString()) self.assertEqual(0, len(unknown_field_set)) unknown_field_set = unknown_fields.UnknownFieldSet(destination) self.assertEqual(2, len(unknown_field_set)) destination.MergeFrom(source) self.assertEqual(2, len(unknown_field_set)) # Check that the fields where correctly merged, even stored in the unknown # fields set. message.ParseFromString(destination.SerializeToString()) self.assertEqual(message.optional_int32, 1) self.assertEqual(message.optional_uint32, 2) self.assertEqual(message.optional_int64, 3)
def testSerializeMessageSetWireFormatUnknownExtension(self): # Create a message using the message set wire format with an unknown # message. raw = unittest_mset_pb2.RawMessageSet() # Add an unknown extension. item = raw.item.add() item.type_id = 98218603 message1 = message_set_extensions_pb2.TestMessageSetExtension1() message1.i = 12345 item.message = message1.SerializeToString() serialized = raw.SerializeToString() # Parse message using the message set wire format. proto = message_set_extensions_pb2.TestMessageSet() proto.MergeFromString(serialized) unknown_field_set = unknown_fields.UnknownFieldSet(proto) self.assertEqual(len(unknown_field_set), 1) # Unknown field should have wire format data which can be parsed back to # original message. self.assertEqual(unknown_field_set[0].field_number, item.type_id) self.assertEqual(unknown_field_set[0].wire_type, wire_format.WIRETYPE_LENGTH_DELIMITED) d = unknown_field_set[0].data message_new = message_set_extensions_pb2.TestMessageSetExtension1() message_new.ParseFromString(d) self.assertEqual(message1, message_new) # Verify that the unknown extension is serialized unchanged reserialized = proto.SerializeToString() new_raw = unittest_mset_pb2.RawMessageSet() new_raw.MergeFromString(reserialized) self.assertEqual(raw, new_raw)
def testUnknownField(self): message = unittest_pb2.TestAllTypes() message.optional_int32 = 123 destination = unittest_pb2.TestEmptyMessage() destination.ParseFromString(message.SerializeToString()) unknown_field = unknown_fields.UnknownFieldSet(destination)[0] destination.Clear() self.assertEqual(unknown_field.data, 123)
def testCheckUnknownFieldValueForEnum(self): unknown_field_set = unknown_fields.UnknownFieldSet( self.missing_message) self.assertEqual(len(unknown_field_set), 5) self.CheckUnknownField('optional_nested_enum', self.message.optional_nested_enum) self.CheckUnknownField('repeated_nested_enum', self.message.repeated_nested_enum) self.CheckUnknownField('packed_nested_enum', self.message.packed_nested_enum)
def testCheckUnknownFieldValue(self): unknown_field_set = unknown_fields.UnknownFieldSet(self.empty_message) # Test enum. self.CheckUnknownField('optional_nested_enum', unknown_field_set, self.all_fields.optional_nested_enum) self.InternalCheckUnknownField('optional_nested_enum', self.all_fields.optional_nested_enum) # Test repeated enum. self.CheckUnknownField('repeated_nested_enum', unknown_field_set, self.all_fields.repeated_nested_enum) self.InternalCheckUnknownField('repeated_nested_enum', self.all_fields.repeated_nested_enum) # Test varint. self.CheckUnknownField('optional_int32', unknown_field_set, self.all_fields.optional_int32) self.InternalCheckUnknownField('optional_int32', self.all_fields.optional_int32) # Test fixed32. self.CheckUnknownField('optional_fixed32', unknown_field_set, self.all_fields.optional_fixed32) self.InternalCheckUnknownField('optional_fixed32', self.all_fields.optional_fixed32) # Test fixed64. self.CheckUnknownField('optional_fixed64', unknown_field_set, self.all_fields.optional_fixed64) self.InternalCheckUnknownField('optional_fixed64', self.all_fields.optional_fixed64) # Test length delimited. self.CheckUnknownField('optional_string', unknown_field_set, self.all_fields.optional_string.encode('utf-8')) self.InternalCheckUnknownField('optional_string', self.all_fields.optional_string) # Test group. self.CheckUnknownField('optionalgroup', unknown_field_set, (17, 0, 117)) self.InternalCheckUnknownField('optionalgroup', self.all_fields.optionalgroup) self.assertEqual(98, len(unknown_field_set))
def CheckUnknownField(self, name, expected_value): field_descriptor = self.descriptor.fields_by_name[name] unknown_field_set = unknown_fields.UnknownFieldSet( self.missing_message) self.assertIsInstance(unknown_field_set, unknown_fields.UnknownFieldSet) count = 0 for field in unknown_field_set: if field.field_number == field_descriptor.number: count += 1 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: self.assertIn(field.data, expected_value) else: self.assertEqual(expected_value, field.data) if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: self.assertEqual(count, len(expected_value)) else: self.assertEqual(count, 1)
def testUnknownExtensions(self): message = unittest_pb2.TestEmptyMessageWithExtensions() message.ParseFromString(self.all_fields_data) self.assertEqual(len(unknown_fields.UnknownFieldSet(message)), 98) self.assertEqual(message.SerializeToString(), self.all_fields_data)
def leaking_function(): for _ in range(nb_leaks): unknown_fields.UnknownFieldSet(self.empty_message)
def testClear(self): unknown_field_set = unknown_fields.UnknownFieldSet(self.empty_message) self.empty_message.Clear() # All cleared, even unknown fields. self.assertEqual(self.empty_message.SerializeToString(), b'') self.assertEqual(len(unknown_field_set), 98)