def testMergeBadExtension(self): message = unittest_pb2.TestAllExtensions() text = '[unknown_extension]: 8\n' self.assertRaisesWithMessage( text_format.ParseError, '1:2 : Extension "unknown_extension" not registered.', text_format.Merge, text, message) message = unittest_pb2.TestAllTypes() self.assertRaisesWithMessage(text_format.ParseError, ( '1:2 : Message type "protobuf_unittest.TestAllTypes" does not have ' 'extensions.'), text_format.Merge, text, message)
def testParseGroupNotClosed(self): message = unittest_pb2.TestAllTypes() text = 'RepeatedGroup: <' self.assertRaisesWithLiteralMatch( text_format.ParseError, '1:16 : Expected ">".', text_format.Parse, text, message) text = 'RepeatedGroup: {' self.assertRaisesWithLiteralMatch( text_format.ParseError, '1:16 : Expected "}".', text_format.Parse, text, message)
def testPrintRepeatedFieldsAsOneLine(self): message = unittest_pb2.TestAllTypes() message.repeated_int32.append(1) message.repeated_int32.append(1) message.repeated_int32.append(3) message.repeated_string.append("Google") message.repeated_string.append("Zurich") self.CompareToGoldenText( text_format.MessageToString(message, as_one_line=True), 'repeated_int32: 1 repeated_int32: 1 repeated_int32: 3 ' 'repeated_string: "Google" repeated_string: "Zurich"')
def testPublicImports(self): # Test public imports as embedded message. all_type_proto = unittest_pb2.TestAllTypes() self.assertEqual(0, all_type_proto.optional_public_import_message.e) # PublicImportMessage is actually defined in unittest_import_public_pb2 # module, and is public imported by unittest_import_pb2 module. public_import_proto = unittest_import_pb2.PublicImportMessage() self.assertEqual(0, public_import_proto.e) self.assertTrue(unittest_import_public_pb2.PublicImportMessage is unittest_import_pb2.PublicImportMessage)
def testUnknownField(self): message = unittest_pb2.TestAllTypes() message.optional_int32 = 123 destination = unittest_pb2.TestEmptyMessage() destination.ParseFromString(message.SerializeToString()) unknown_field = destination.UnknownFields()[0] destination.Clear() with self.assertRaises(ValueError) as context: unknown_field.data # pylint: disable=pointless-statement self.assertIn('The parent message might be cleared.', str(context.exception))
def testParseMessageByFieldNumber(self): message = unittest_pb2.TestAllTypes() text = ('34: 1\n' 'repeated_uint64: 2\n') text_format.Parse(text, message, allow_field_number=True) self.assertEqual(1, message.repeated_uint64[0]) self.assertEqual(2, message.repeated_uint64[1]) message = unittest_mset_pb2.TestMessageSetContainer() text = ('1 {\n' ' 1545008 {\n' ' 15: 23\n' ' }\n' ' 1547769 {\n' ' 25: \"foo\"\n' ' }\n' '}\n') text_format.Parse(text, message, allow_field_number=True) ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension self.assertEqual(23, message.message_set.Extensions[ext1].i) self.assertEqual('foo', message.message_set.Extensions[ext2].str) # Can't parse field number without set allow_field_number=True. message = unittest_pb2.TestAllTypes() text = '34:1\n' six.assertRaisesRegex( self, text_format.ParseError, (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' r'"34".'), text_format.Parse, text, message) # Can't parse if field number is not found. text = '1234:1\n' six.assertRaisesRegex( self, text_format.ParseError, (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' r'"1234".'), text_format.Parse, text, message, allow_field_number=True)
def testNegativeInfinity(self): golden_data = ('\x5D\x00\x00\x80\xFF' '\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' '\xCD\x02\x00\x00\x80\xFF' '\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsNegInf(golden_message.optional_float)) self.assertTrue(IsNegInf(golden_message.optional_double)) self.assertTrue(IsNegInf(golden_message.repeated_float[0])) self.assertTrue(IsNegInf(golden_message.repeated_double[0])) self.assertEqual(golden_data, golden_message.SerializeToString())
def testNotANumber(self): golden_data = ('\x5D\x00\x00\xC0\x7F' '\x61\x00\x00\x00\x00\x00\x00\xF8\x7F' '\xCD\x02\x00\x00\xC0\x7F' '\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F') golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(isnan(golden_message.optional_float)) self.assertTrue(isnan(golden_message.optional_double)) self.assertTrue(isnan(golden_message.repeated_float[0])) self.assertTrue(isnan(golden_message.repeated_double[0])) self.assertTrue(golden_message.SerializeToString() == golden_data)
def testRepeatedNestedFieldIteration(self): msg = unittest_pb2.TestAllTypes() msg.repeated_nested_message.add(bb=1) msg.repeated_nested_message.add(bb=2) msg.repeated_nested_message.add(bb=3) msg.repeated_nested_message.add(bb=4) self.assertEquals([1, 2, 3, 4], [m.bb for m in msg.repeated_nested_message]) self.assertEquals( [4, 3, 2, 1], [m.bb for m in reversed(msg.repeated_nested_message)]) self.assertEquals([4, 3, 2, 1], [m.bb for m in msg.repeated_nested_message[::-1]])
def testRoundTripExoticAsOneLine(self): message = unittest_pb2.TestAllTypes() message.repeated_int64.append(-9223372036854775808) message.repeated_uint64.append(18446744073709551615) message.repeated_double.append(123.456) message.repeated_double.append(1.23e22) message.repeated_double.append(1.23e-18) message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"') message.repeated_string.append(u'\u00fc\ua71f') # Test as_utf8 = False. wire_text = text_format.MessageToString( message, as_one_line=True, as_utf8=False) parsed_message = unittest_pb2.TestAllTypes() text_format.Merge(wire_text, parsed_message) self.assertEquals(message, parsed_message) # Test as_utf8 = True. wire_text = text_format.MessageToString( message, as_one_line=True, as_utf8=True) parsed_message = unittest_pb2.TestAllTypes() text_format.Merge(wire_text, parsed_message) self.assertEquals(message, parsed_message)
def testNotANumber(self): golden_data = ('\x5D\x00\x00\xC0\x7F' '\x61\x00\x00\x00\x00\x00\x00\xF8\x7F' '\xCD\x02\x00\x00\xC0\x7F' '\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F') golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(isnan(golden_message.optional_float)) self.assertTrue(isnan(golden_message.optional_double)) self.assertTrue(isnan(golden_message.repeated_float[0])) self.assertTrue(isnan(golden_message.repeated_double[0])) # The protocol buffer may serialize to any one of multiple different # representations of a NaN. Rather than verify a specific representation, # verify the serialized string can be converted into a correctly # behaving protocol buffer. serialized = golden_message.SerializeToString() message = unittest_pb2.TestAllTypes() message.ParseFromString(serialized) self.assertTrue(isnan(message.optional_float)) self.assertTrue(isnan(message.optional_double)) self.assertTrue(isnan(message.repeated_float[0])) self.assertTrue(isnan(message.repeated_double[0]))
def testParseTrailingCommas(self): message = unittest_pb2.TestAllTypes() text = ('repeated_int64: 100;\n' 'repeated_int64: 200;\n' 'repeated_int64: 300,\n' 'repeated_string: "one",\n' 'repeated_string: "two";\n') text_format.Parse(text, message) self.assertEqual(100, message.repeated_int64[0]) self.assertEqual(200, message.repeated_int64[1]) self.assertEqual(300, message.repeated_int64[2]) self.assertEqual(u'one', message.repeated_string[0]) self.assertEqual(u'two', message.repeated_string[1])
def testEnums(self): # We test only module-level enums here. # TODO(robinson): Examine descriptors directly to check # enum descriptor output. self.assertEqual(4, unittest_pb2.FOREIGN_FOO) self.assertEqual(5, unittest_pb2.FOREIGN_BAR) self.assertEqual(6, unittest_pb2.FOREIGN_BAZ) proto = unittest_pb2.TestAllTypes() self.assertEqual(1, proto.FOO) self.assertEqual(1, unittest_pb2.TestAllTypes.FOO) self.assertEqual(2, proto.BAR) self.assertEqual(2, unittest_pb2.TestAllTypes.BAR) self.assertEqual(3, proto.BAZ) self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
def testFieldPresence(self): message = unittest_pb2.TestAllTypes() self.assertFalse(message.HasField("optional_int32")) self.assertFalse(message.HasField("optional_bool")) self.assertFalse(message.HasField("optional_nested_message")) with self.assertRaises(ValueError): message.HasField("field_doesnt_exist") with self.assertRaises(ValueError): message.HasField("repeated_int32") with self.assertRaises(ValueError): message.HasField("repeated_nested_message") self.assertEqual(0, message.optional_int32) self.assertEqual(False, message.optional_bool) self.assertEqual(0, message.optional_nested_message.bb) # Fields are set even when setting the values to default values. message.optional_int32 = 0 message.optional_bool = False message.optional_nested_message.bb = 0 self.assertTrue(message.HasField("optional_int32")) self.assertTrue(message.HasField("optional_bool")) self.assertTrue(message.HasField("optional_nested_message")) # Set the fields to non-default values. message.optional_int32 = 5 message.optional_bool = True message.optional_nested_message.bb = 15 self.assertTrue(message.HasField("optional_int32")) self.assertTrue(message.HasField("optional_bool")) self.assertTrue(message.HasField("optional_nested_message")) # Clearing the fields unsets them and resets their value to default. message.ClearField("optional_int32") message.ClearField("optional_bool") message.ClearField("optional_nested_message") self.assertFalse(message.HasField("optional_int32")) self.assertFalse(message.HasField("optional_bool")) self.assertFalse(message.HasField("optional_nested_message")) self.assertEqual(0, message.optional_int32) self.assertEqual(False, message.optional_bool) self.assertEqual(0, message.optional_nested_message.bb)
def testMergeFrom(self): message = unittest_pb2.TestAllTypes() message.optional_int32 = 1 message.optional_uint32 = 2 source = unittest_pb2.TestEmptyMessage() source.ParseFromString(message.SerializeToString()) message.ClearField('optional_int32') message.optional_int64 = 3 message.optional_uint32 = 4 destination = unittest_pb2.TestEmptyMessage() destination.ParseFromString(message.SerializeToString()) unknown_fields = destination._unknown_fields[:] destination.MergeFrom(source) self.assertEqual(unknown_fields + source._unknown_fields, destination._unknown_fields)
def testPrintExotic(self): message = unittest_pb2.TestAllTypes() message.repeated_int64.append(-9223372036854775808) message.repeated_uint64.append(18446744073709551615) message.repeated_double.append(123.456) message.repeated_double.append(1.23e22) message.repeated_double.append(1.23e-18) message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'\"') self.CompareToGoldenText( text_format.MessageToString(message), 'repeated_int64: -9223372036854775808\n' 'repeated_uint64: 18446744073709551615\n' 'repeated_double: 123.456\n' 'repeated_double: 1.23e+22\n' 'repeated_double: 1.23e-18\n' 'repeated_string: ' '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n')
def testSortingRepeatedCompositeFieldsCustomComparator(self): """Check passing a custom comparator to sort a repeated composite field.""" message = unittest_pb2.TestAllTypes() message.repeated_nested_message.add().bb = 1 message.repeated_nested_message.add().bb = 3 message.repeated_nested_message.add().bb = 2 message.repeated_nested_message.add().bb = 6 message.repeated_nested_message.add().bb = 5 message.repeated_nested_message.add().bb = 4 message.repeated_nested_message.sort(lambda x, y: cmp(x.bb, y.bb)) self.assertEqual(message.repeated_nested_message[0].bb, 1) self.assertEqual(message.repeated_nested_message[1].bb, 2) self.assertEqual(message.repeated_nested_message[2].bb, 3) self.assertEqual(message.repeated_nested_message[3].bb, 4) self.assertEqual(message.repeated_nested_message[4].bb, 5) self.assertEqual(message.repeated_nested_message[5].bb, 6)
def testPrintExoticAsOneLine(self): message = unittest_pb2.TestAllTypes() message.repeated_int64.append(-9223372036854775808) message.repeated_uint64.append(18446744073709551615) message.repeated_double.append(123.456) message.repeated_double.append(1.23e22) message.repeated_double.append(1.23e-18) message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"') message.repeated_string.append(u('\u00fc\ua71f')) self.CompareToGoldenText( self.RemoveRedundantZeros( text_format.MessageToString(message, as_one_line=True)), b('repeated_int64: -9223372036854775808') + b(' repeated_uint64: 18446744073709551615') + b(' repeated_double: 123.456') + b(' repeated_double: 1.23e+22') + b(' repeated_double: 1.23e-18') + b(' repeated_string: ') + b('"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""') + b(' repeated_string: "\\303\\274\\352\\234\\237"'))
def testMergeExotic(self): message = unittest_pb2.TestAllTypes() text = ('repeated_int64: -9223372036854775808\n' 'repeated_uint64: 18446744073709551615\n' 'repeated_double: 123.456\n' 'repeated_double: 1.23e+22\n' 'repeated_double: 1.23e-18\n' 'repeated_string: \n' '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n') text_format.Merge(text, message) self.assertEqual(-9223372036854775808, message.repeated_int64[0]) self.assertEqual(18446744073709551615, message.repeated_uint64[0]) self.assertEqual(123.456, message.repeated_double[0]) self.assertEqual(1.23e22, message.repeated_double[1]) self.assertEqual(1.23e-18, message.repeated_double[2]) self.assertEqual( '\000\001\a\b\f\n\r\t\v\\\'\"', message.repeated_string[0])
def testMergeMessageSet(self): message = unittest_pb2.TestAllTypes() text = (b('repeated_uint64: 1\n') + b('repeated_uint64: 2\n')) text_format.Merge(text, message) self.assertEqual(1, message.repeated_uint64[0]) self.assertEqual(2, message.repeated_uint64[1]) message = unittest_mset_pb2.TestMessageSetContainer() text = (b('message_set {\n') + b(' [protobuf_unittest.TestMessageSetExtension1] {\n') + b(' i: 23\n') + b(' }\n') + b(' [protobuf_unittest.TestMessageSetExtension2] {\n') + b(' str: \"foo\"\n') + b(' }\n') + b('}\n')) text_format.Merge(text, message) ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension self.assertEqual(23, message.message_set.Extensions[ext1].i) self.assertEqual('foo', message.message_set.Extensions[ext2].str)
def testExtremeDoubleValues(self): message = unittest_pb2.TestAllTypes() # Most positive exponent, no significand bits set. kMostPosExponentNoSigBits = math.pow(2, 1023) message.optional_double = kMostPosExponentNoSigBits message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_double == kMostPosExponentNoSigBits) # Most positive exponent, one significand bit set. kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023) message.optional_double = kMostPosExponentOneSigBit message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_double == kMostPosExponentOneSigBit) # Repeat last two cases with values of same magnitude, but negative. message.optional_double = -kMostPosExponentNoSigBits message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits) message.optional_double = -kMostPosExponentOneSigBit message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit) # Most negative exponent, no significand bits set. kMostNegExponentNoSigBits = math.pow(2, -1023) message.optional_double = kMostNegExponentNoSigBits message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_double == kMostNegExponentNoSigBits) # Most negative exponent, one significand bit set. kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023) message.optional_double = kMostNegExponentOneSigBit message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_double == kMostNegExponentOneSigBit) # Repeat last two cases with values of the same magnitude, but negative. message.optional_double = -kMostNegExponentNoSigBits message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits) message.optional_double = -kMostNegExponentOneSigBit message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
def testParseStringFieldUnescape(self): message = unittest_pb2.TestAllTypes() text = r'''repeated_string: "\xf\x62" repeated_string: "\\xf\\x62" repeated_string: "\\\xf\\\x62" repeated_string: "\\\\xf\\\\x62" repeated_string: "\\\\\xf\\\\\x62" repeated_string: "\x5cx20"''' text_format.Parse(text, message) SLASH = '\\' self.assertEqual('\x0fb', message.repeated_string[0]) self.assertEqual(SLASH + 'xf' + SLASH + 'x62', message.repeated_string[1]) self.assertEqual(SLASH + '\x0f' + SLASH + 'b', message.repeated_string[2]) self.assertEqual(SLASH + SLASH + 'xf' + SLASH + SLASH + 'x62', message.repeated_string[3]) self.assertEqual(SLASH + SLASH + '\x0f' + SLASH + SLASH + 'b', message.repeated_string[4]) self.assertEqual(SLASH + 'x20', message.repeated_string[5])
def testSortingRepeatedScalarFieldsCustomComparator(self): """Check some different types with custom comparator.""" message = unittest_pb2.TestAllTypes() message.repeated_int32.append(-3) message.repeated_int32.append(-2) message.repeated_int32.append(-1) message.repeated_int32.sort(lambda x, y: cmp(abs(x), abs(y))) self.assertEqual(message.repeated_int32[0], -1) self.assertEqual(message.repeated_int32[1], -2) self.assertEqual(message.repeated_int32[2], -3) message.repeated_string.append('aaa') message.repeated_string.append('bb') message.repeated_string.append('c') message.repeated_string.sort(lambda x, y: cmp(len(x), len(y))) self.assertEqual(message.repeated_string[0], 'c') self.assertEqual(message.repeated_string[1], 'bb') self.assertEqual(message.repeated_string[2], 'aaa')
def testParseEnumValue(self): message = json_format_proto3_pb2.TestMessage() text = '{"enumValue": 0}' json_format.Parse(text, message) text = '{"enumValue": 1}' json_format.Parse(text, message) self.CheckError( '{"enumValue": "baz"}', 'Failed to parse enumValue field: Invalid enum value baz ' 'for enum type proto3.EnumType.') # Proto3 accepts numeric unknown enums. text = '{"enumValue": 12345}' json_format.Parse(text, message) # Proto2 does not accept unknown enums. message = unittest_pb2.TestAllTypes() self.assertRaisesRegexp( json_format.ParseError, 'Failed to parse optionalNestedEnum field: Invalid enum value 12345 ' 'for enum type protobuf_unittest.TestAllTypes.NestedEnum.', json_format.Parse, '{"optionalNestedEnum": 12345}', message)
def testMergeFrom(self): message = unittest_pb2.TestAllTypes() message.optional_int32 = 1 message.optional_uint32 = 2 source = unittest_pb2.TestEmptyMessage() source.ParseFromString(message.SerializeToString()) message.ClearField('optional_int32') message.optional_int64 = 3 message.optional_uint32 = 4 destination = unittest_pb2.TestEmptyMessage() destination.ParseFromString(message.SerializeToString()) destination.MergeFrom(source) # Check that the fields where correctly merged, even stored in the unknown # fields set. message.ParseFromString(destination.SerializeToString()) self.assertEqual(message.optional_int32, 1) self.assertEqual(message.optional_uint32, 2) self.assertEqual(message.optional_int64, 3)
def testDefaultValueForCustomMessages(self): """Check the value returned by non-existent fields.""" def _CheckValueAndType(value, expected_value, expected_type): self.assertEqual(value, expected_value) self.assertIsInstance(value, expected_type) def _CheckDefaultValues(msg): try: int64 = int except NameError: # Python3 int64 = int try: unicode_type = str except NameError: # Python3 unicode_type = str _CheckValueAndType(msg.optional_int32, 0, int) _CheckValueAndType(msg.optional_uint64, 0, (int64, int)) _CheckValueAndType(msg.optional_float, 0, (float, int)) _CheckValueAndType(msg.optional_double, 0, (float, int)) _CheckValueAndType(msg.optional_bool, False, bool) _CheckValueAndType(msg.optional_string, '', unicode_type) _CheckValueAndType(msg.optional_bytes, b'', bytes) _CheckValueAndType(msg.optional_nested_enum, msg.FOO, int) # First for the generated message _CheckDefaultValues(unittest_pb2.TestAllTypes()) # Then for a message built with from the DescriptorPool. pool = descriptor_pool.DescriptorPool() pool.Add( descriptor_pb2.FileDescriptorProto.FromString( unittest_import_public_pb2.DESCRIPTOR.serialized_pb)) pool.Add( descriptor_pb2.FileDescriptorProto.FromString( unittest_import_pb2.DESCRIPTOR.serialized_pb)) pool.Add( descriptor_pb2.FileDescriptorProto.FromString( unittest_pb2.DESCRIPTOR.serialized_pb)) message_class = message_factory.MessageFactory(pool).GetPrototype( pool.FindMessageTypeByName( unittest_pb2.TestAllTypes.DESCRIPTOR.full_name)) _CheckDefaultValues(message_class())
def testUnion(self): mask1 = field_mask_pb2.FieldMask() mask2 = field_mask_pb2.FieldMask() out_mask = field_mask_pb2.FieldMask() mask1.FromJsonString('foo,baz') mask2.FromJsonString('bar,quz') out_mask.Union(mask1, mask2) self.assertEqual('bar,baz,foo,quz', out_mask.ToJsonString()) # Overlap with duplicated paths. mask1.FromJsonString('foo,baz.bb') mask2.FromJsonString('baz.bb,quz') out_mask.Union(mask1, mask2) self.assertEqual('baz.bb,foo,quz', out_mask.ToJsonString()) # Overlap with paths covering some other paths. mask1.FromJsonString('foo.bar.baz,quz') mask2.FromJsonString('foo.bar,bar') out_mask.Union(mask1, mask2) self.assertEqual('bar,foo.bar,quz', out_mask.ToJsonString()) src = unittest_pb2.TestAllTypes() with self.assertRaises(ValueError): out_mask.Union(src, mask2)
def testOneofSemantics(self): m = unittest_pb2.TestAllTypes() self.assertIs(None, m.WhichOneof('oneof_field')) m.oneof_uint32 = 11 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) self.assertTrue(m.HasField('oneof_uint32')) m.oneof_string = u'foo' self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) self.assertFalse(m.HasField('oneof_uint32')) self.assertTrue(m.HasField('oneof_string')) m.oneof_nested_message.bb = 11 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) self.assertFalse(m.HasField('oneof_string')) self.assertTrue(m.HasField('oneof_nested_message')) m.oneof_bytes = b'bb' self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) self.assertFalse(m.HasField('oneof_nested_message')) self.assertTrue(m.HasField('oneof_bytes'))
def testRepeatedCompositeFieldSortArguments(self): """Check sorting a repeated composite field using list.sort() arguments.""" message = unittest_pb2.TestAllTypes() get_bb = operator.attrgetter('bb') cmp_bb = lambda a, b: cmp(a.bb, b.bb) message.repeated_nested_message.add().bb = 1 message.repeated_nested_message.add().bb = 3 message.repeated_nested_message.add().bb = 2 message.repeated_nested_message.add().bb = 6 message.repeated_nested_message.add().bb = 5 message.repeated_nested_message.add().bb = 4 message.repeated_nested_message.sort(key=get_bb) self.assertEqual([k.bb for k in message.repeated_nested_message], [1, 2, 3, 4, 5, 6]) message.repeated_nested_message.sort(key=get_bb, reverse=True) self.assertEqual([k.bb for k in message.repeated_nested_message], [6, 5, 4, 3, 2, 1]) message.repeated_nested_message.sort(sort_function=cmp_bb) self.assertEqual([k.bb for k in message.repeated_nested_message], [1, 2, 3, 4, 5, 6]) message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True) self.assertEqual([k.bb for k in message.repeated_nested_message], [6, 5, 4, 3, 2, 1])
def testSortingRepeatedScalarFieldsDefaultComparator(self): """Check some different types with the default comparator.""" message = unittest_pb2.TestAllTypes() # TODO(mattp): would testing more scalar types strengthen test? message.repeated_int32.append(1) message.repeated_int32.append(3) message.repeated_int32.append(2) message.repeated_int32.sort() self.assertEqual(message.repeated_int32[0], 1) self.assertEqual(message.repeated_int32[1], 2) self.assertEqual(message.repeated_int32[2], 3) message.repeated_float.append(1.1) message.repeated_float.append(1.3) message.repeated_float.append(1.2) message.repeated_float.sort() self.assertAlmostEqual(message.repeated_float[0], 1.1) self.assertAlmostEqual(message.repeated_float[1], 1.2) self.assertAlmostEqual(message.repeated_float[2], 1.3) message.repeated_string.append('a') message.repeated_string.append('c') message.repeated_string.append('b') message.repeated_string.sort() self.assertEqual(message.repeated_string[0], 'a') self.assertEqual(message.repeated_string[1], 'b') self.assertEqual(message.repeated_string[2], 'c') message.repeated_bytes.append('a') message.repeated_bytes.append('c') message.repeated_bytes.append('b') message.repeated_bytes.sort() self.assertEqual(message.repeated_bytes[0], 'a') self.assertEqual(message.repeated_bytes[1], 'b') self.assertEqual(message.repeated_bytes[2], 'c')