Esempio n. 1
0
    def CheckDescriptorMapping(self, mapping):
        # Verifies that a property like 'messageDescriptor.fields' has all the
        # properties of an immutable abc.Mapping.
        self.assertNotEqual(
            mapping, unittest_pb2.TestAllExtensions.DESCRIPTOR.fields_by_name)
        self.assertNotEqual(mapping, {})
        self.assertNotEqual(mapping, 1)
        self.assertFalse(mapping == 1)  # Only for cpp test coverage
        excepted_dict = dict(list(mapping.items()))
        self.assertEqual(mapping, excepted_dict)
        self.assertEqual(mapping, mapping)
        self.assertGreater(len(mapping), 0)  # Sized
        self.assertEqual(len(mapping), len(excepted_dict))  # Iterable
        if sys.version_info >= (3, ):
            key, item = next(iter(list(mapping.items())))
        else:
            key, item = list(mapping.items())[0]
        self.assertIn(key, mapping)  # Container
        self.assertEqual(mapping.get(key), item)
        with self.assertRaises(TypeError):
            mapping.get()
        # TODO(jieluo): Fix python and cpp extension diff.
        if api_implementation.Type() == 'python':
            self.assertRaises(TypeError, mapping.get, [])
        else:
            self.assertEqual(None, mapping.get([]))
        # keys(), iterkeys() &co
        item = (next(iter(list(mapping.keys()))),
                next(iter(list(mapping.values()))))
        self.assertEqual(item, next(iter(list(mapping.items()))))
        if sys.version_info < (3, ):

            def CheckItems(seq, iterator):
                self.assertEqual(next(iterator), seq[0])
                self.assertEqual(list(iterator), seq[1:])

            CheckItems(list(mapping.keys()), iter(mapping.keys()))
            CheckItems(list(mapping.values()), iter(mapping.values()))
            CheckItems(list(mapping.items()), iter(mapping.items()))
        excepted_dict[key] = 'change value'
        self.assertNotEqual(mapping, excepted_dict)
        del excepted_dict[key]
        excepted_dict['new_key'] = 'new'
        self.assertNotEqual(mapping, excepted_dict)
        self.assertRaises(KeyError, mapping.__getitem__, 'key_error')
        self.assertRaises(KeyError, mapping.__getitem__, len(mapping) + 1)
        # TODO(jieluo): Add __repr__ support for DescriptorMapping.
        if api_implementation.Type() == 'python':
            self.assertEqual(len(str(dict(list(mapping.items())))),
                             len(str(mapping)))
        else:
            self.assertEqual(str(mapping)[0], '<')
Esempio n. 2
0
 def CheckDescriptorSequence(self, sequence):
     # Verifies that a property like 'messageDescriptor.fields' has all the
     # properties of an immutable abc.Sequence.
     self.assertNotEqual(sequence,
                         unittest_pb2.TestAllExtensions.DESCRIPTOR.fields)
     self.assertNotEqual(sequence, [])
     self.assertNotEqual(sequence, 1)
     self.assertFalse(sequence == 1)  # Only for cpp test coverage
     self.assertEqual(sequence, sequence)
     expected_list = list(sequence)
     self.assertEqual(expected_list, sequence)
     self.assertGreater(len(sequence), 0)  # Sized
     self.assertEqual(len(sequence), len(expected_list))  # Iterable
     self.assertEqual(sequence[len(sequence) - 1], sequence[-1])
     item = sequence[0]
     self.assertEqual(item, sequence[0])
     self.assertIn(item, sequence)  # Container
     self.assertEqual(sequence.index(item), 0)
     self.assertEqual(sequence.count(item), 1)
     other_item = unittest_pb2.NestedTestAllTypes.DESCRIPTOR.fields[0]
     self.assertNotIn(other_item, sequence)
     self.assertEqual(sequence.count(other_item), 0)
     self.assertRaises(ValueError, sequence.index, other_item)
     self.assertRaises(ValueError, sequence.index, [])
     reversed_iterator = reversed(sequence)
     self.assertEqual(list(reversed_iterator), list(sequence)[::-1])
     self.assertRaises(StopIteration, next, reversed_iterator)
     expected_list[0] = 'change value'
     self.assertNotEqual(expected_list, sequence)
     # TODO(jieluo): Change __repr__ support for DescriptorSequence.
     if api_implementation.Type() == 'python':
         self.assertEqual(str(list(sequence)), str(sequence))
     else:
         self.assertEqual(str(sequence)[0], '<')
  def testEnumDefaultValue(self):
    """Test the default value of enums which don't start at zero."""
    def _CheckDefaultValue(file_descriptor):
      default_value = (file_descriptor
                       .message_types_by_name['DescriptorPoolTest1']
                       .fields_by_name['nested_enum']
                       .default_value)
      self.assertEqual(default_value,
                       descriptor_pool_test1_pb2.DescriptorPoolTest1.BETA)
    # First check what the generated descriptor contains.
    _CheckDefaultValue(descriptor_pool_test1_pb2.DESCRIPTOR)
    # Then check the generated pool. Normally this is the same descriptor.
    file_descriptor = symbol_database.Default().pool.FindFileByName(
        'google/protobuf/internal/descriptor_pool_test1.proto')
    self.assertIs(file_descriptor, descriptor_pool_test1_pb2.DESCRIPTOR)
    _CheckDefaultValue(file_descriptor)

    if isinstance(self, SecondaryDescriptorFromDescriptorDB):
      if api_implementation.Type() == 'cpp':
        # Cpp extension cannot call Add on a DescriptorPool
        # that uses a DescriptorDatabase.
        # TODO(jieluo): Fix python and cpp extension diff.
        return
    # Then check the dynamic pool and its internal DescriptorDatabase.
    descriptor_proto = descriptor_pb2.FileDescriptorProto.FromString(
        descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb)
    self.pool.Add(descriptor_proto)
    # And do the same check as above
    file_descriptor = self.pool.FindFileByName(
        'google/protobuf/internal/descriptor_pool_test1.proto')
    _CheckDefaultValue(file_descriptor)
  def __init__(self, name, package, options=None,
               serialized_options=None, serialized_pb=None,
               dependencies=None, public_dependencies=None,
               syntax=None, pool=None):
    """Constructor."""
    super(FileDescriptor, self).__init__(
        options, serialized_options, 'FileOptions')

    if pool is None:
      from dis_sdk_python.dependency.google.protobuf import descriptor_pool
      pool = descriptor_pool.Default()
    self.pool = pool
    self.message_types_by_name = {}
    self.name = name
    self.package = package
    self.syntax = syntax or "proto2"
    self.serialized_pb = serialized_pb

    self.enum_types_by_name = {}
    self.extensions_by_name = {}
    self.services_by_name = {}
    self.dependencies = (dependencies or [])
    self.public_dependencies = (public_dependencies or [])

    if (api_implementation.Type() == 'cpp' and
        self.serialized_pb is not None):
      _message.default_pool.AddSerializedFile(self.serialized_pb)
    def testGetMessages(self):
        # performed twice because multiple calls with the same input must be allowed
        for _ in range(2):
            # GetMessage should work regardless of the order the FileDescriptorProto
            # are provided. In particular, the function should succeed when the files
            # are not in the topological order of dependencies.

            # Assuming factory_test2_fd depends on factory_test1_fd.
            self.assertIn(self.factory_test1_fd.name,
                          self.factory_test2_fd.dependency)
            # Get messages should work when a file comes before its dependencies:
            # factory_test2_fd comes before factory_test1_fd.
            messages = message_factory.GetMessages(
                [self.factory_test2_fd, self.factory_test1_fd])
            self.assertTrue(
                set([
                    'dis_sdk_python.dependency.google.protobuf.python.internal.Factory2Message',
                    'dis_sdk_python.dependency.google.protobuf.python.internal.Factory1Message'
                ], ).issubset(set(messages.keys())))
            self._ExerciseDynamicClass(messages[
                'dis_sdk_python.dependency.google.protobuf.python.internal.Factory2Message']
                                       )
            factory_msg1 = messages[
                'dis_sdk_python.dependency.google.protobuf.python.internal.Factory1Message']
            self.assertTrue(
                set([
                    'dis_sdk_python.dependency.google.protobuf.python.internal.Factory2Message.one_more_field',
                    'dis_sdk_python.dependency.google.protobuf.python.internal.another_field'
                ], ).issubset(
                    set(ext.full_name for ext in factory_msg1.DESCRIPTOR.file.
                        pool.FindAllExtensions(factory_msg1.DESCRIPTOR))))
            msg1 = messages[
                'dis_sdk_python.dependency.google.protobuf.python.internal.Factory1Message'](
                )
            ext1 = msg1.Extensions._FindExtensionByName(
                'dis_sdk_python.dependency.google.protobuf.python.internal.Factory2Message.one_more_field'
            )
            ext2 = msg1.Extensions._FindExtensionByName(
                'dis_sdk_python.dependency.google.protobuf.python.internal.another_field'
            )
            msg1.Extensions[ext1] = 'test1'
            msg1.Extensions[ext2] = 'test2'
            self.assertEqual('test1', msg1.Extensions[ext1])
            self.assertEqual('test2', msg1.Extensions[ext2])
            self.assertEqual(None,
                             msg1.Extensions._FindExtensionByNumber(12321))
            if api_implementation.Type() == 'cpp':
                # TODO(jieluo): Fix len to return the correct value.
                # self.assertEqual(2, len(msg1.Extensions))
                self.assertEqual(len(msg1.Extensions), len(msg1.Extensions))
                self.assertRaises(TypeError,
                                  msg1.Extensions._FindExtensionByName, 0)
                self.assertRaises(TypeError,
                                  msg1.Extensions._FindExtensionByNumber, '')
            else:
                self.assertEqual(None, msg1.Extensions._FindExtensionByName(0))
                self.assertEqual(None,
                                 msg1.Extensions._FindExtensionByNumber(''))
 def testConflictRegister(self):
   if isinstance(self, SecondaryDescriptorFromDescriptorDB):
     if api_implementation.Type() == 'cpp':
       # Cpp extension cannot call Add on a DescriptorPool
       # that uses a DescriptorDatabase.
       # TODO(jieluo): Fix python and cpp extension diff.
       return
   unittest_fd = descriptor_pb2.FileDescriptorProto.FromString(
       unittest_pb2.DESCRIPTOR.serialized_pb)
   conflict_fd = copy.deepcopy(unittest_fd)
   conflict_fd.name = 'other_file'
   if api_implementation.Type() == 'cpp':
     try:
       self.pool.Add(unittest_fd)
       self.pool.Add(conflict_fd)
     except TypeError:
       pass
   else:
     with warnings.catch_warnings(record=True) as w:
       # Cause all warnings to always be triggered.
       warnings.simplefilter('always')
       pool = copy.deepcopy(self.pool)
       # No warnings to add the same descriptors.
       file_descriptor = unittest_pb2.DESCRIPTOR
       pool.AddDescriptor(
           file_descriptor.message_types_by_name['TestAllTypes'])
       pool.AddEnumDescriptor(
           file_descriptor.enum_types_by_name['ForeignEnum'])
       pool.AddServiceDescriptor(
           file_descriptor.services_by_name['TestService'])
       pool.AddExtensionDescriptor(
           file_descriptor.extensions_by_name['optional_int32_extension'])
       self.assertEqual(len(w), 0)
       # Check warnings for conflict descriptors with the same name.
       pool.Add(unittest_fd)
       pool.Add(conflict_fd)
       pool.FindFileByName(unittest_fd.name)
       pool.FindFileByName(conflict_fd.name)
       self.assertTrue(len(w))
       self.assertIs(w[0].category, RuntimeWarning)
       self.assertIn('Conflict register for file "other_file": ',
                     str(w[0].message))
       self.assertIn('already defined in file '
                     '"google/protobuf/unittest.proto"',
                     str(w[0].message))
 def testAddFileDescriptor(self):
   if isinstance(self, SecondaryDescriptorFromDescriptorDB):
     if api_implementation.Type() == 'cpp':
       # Cpp extension cannot call Add on a DescriptorPool
       # that uses a DescriptorDatabase.
       # TODO(jieluo): Fix python and cpp extension diff.
       return
   file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto')
   self.pool.Add(file_desc)
   self.pool.AddSerializedFile(file_desc.SerializeToString())
 def testAddSerializedFile(self):
   if isinstance(self, SecondaryDescriptorFromDescriptorDB):
     if api_implementation.Type() == 'cpp':
       # Cpp extension cannot call Add on a DescriptorPool
       # that uses a DescriptorDatabase.
       # TODO(jieluo): Fix python and cpp extension diff.
       return
   self.pool = descriptor_pool.DescriptorPool()
   self.pool.AddSerializedFile(self.factory_test1_fd.SerializeToString())
   self.pool.AddSerializedFile(self.factory_test2_fd.SerializeToString())
   self.testFindMessageTypeByName()
 def testFindOneofByName(self):
   if isinstance(self, SecondaryDescriptorFromDescriptorDB):
     if api_implementation.Type() == 'cpp':
       # TODO(jieluo): Fix cpp extension to find oneof correctly
       # when descriptor pool is using an underlying database.
       return
   oneof = self.pool.FindOneofByName(
       'dis_sdk_python.dependency.google.protobuf.python.internal.Factory2Message.oneof_field')
   self.assertEqual(oneof.name, 'oneof_field')
   with self.assertRaises(KeyError):
     self.pool.FindOneofByName('Does not exist')
  def testFindTypeErrors(self):
    self.assertRaises(TypeError, self.pool.FindExtensionByNumber, '')

    # TODO(jieluo): Fix python to raise correct errors.
    if api_implementation.Type() == 'cpp':
      self.assertRaises(TypeError, self.pool.FindMethodByName, 0)
      self.assertRaises(KeyError, self.pool.FindMethodByName, '')
      error_type = TypeError
    else:
      error_type = AttributeError
    self.assertRaises(error_type, self.pool.FindMessageTypeByName, 0)
    self.assertRaises(error_type, self.pool.FindFieldByName, 0)
    self.assertRaises(error_type, self.pool.FindExtensionByName, 0)
    self.assertRaises(error_type, self.pool.FindEnumTypeByName, 0)
    self.assertRaises(error_type, self.pool.FindOneofByName, 0)
    self.assertRaises(error_type, self.pool.FindServiceByName, 0)
    self.assertRaises(error_type, self.pool.FindFileContainingSymbol, 0)
    if api_implementation.Type() == 'python':
      error_type = KeyError
    self.assertRaises(error_type, self.pool.FindFileByName, 0)
  def testFindFieldByName(self):
    if isinstance(self, SecondaryDescriptorFromDescriptorDB):
      if api_implementation.Type() == 'cpp':
        # TODO(jieluo): Fix cpp extension to find field correctly
        # when descriptor pool is using an underlying database.
        return
    field = self.pool.FindFieldByName(
        'dis_sdk_python.dependency.google.protobuf.python.internal.Factory1Message.list_value')
    self.assertEqual(field.name, 'list_value')
    self.assertEqual(field.label, field.LABEL_REPEATED)
    self.assertFalse(field.has_options)

    with self.assertRaises(KeyError):
      self.pool.FindFieldByName('Does not exist')
 def testFindExtensionByName(self):
   if isinstance(self, SecondaryDescriptorFromDescriptorDB):
     if api_implementation.Type() == 'cpp':
       # TODO(jieluo): Fix cpp extension to find extension correctly
       # when descriptor pool is using an underlying database.
       return
   # An extension defined in a message.
   extension = self.pool.FindExtensionByName(
       'dis_sdk_python.dependency.google.protobuf.python.internal.Factory2Message.one_more_field')
   self.assertEqual(extension.name, 'one_more_field')
   # An extension defined at file scope.
   extension = self.pool.FindExtensionByName(
       'dis_sdk_python.dependency.google.protobuf.python.internal.another_field')
   self.assertEqual(extension.name, 'another_field')
   self.assertEqual(extension.number, 1002)
   with self.assertRaises(KeyError):
     self.pool.FindFieldByName('Does not exist')
Esempio n. 13
0
    def testMakeDescriptorWithNestedFields(self):
        file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
        file_descriptor_proto.name = 'Foo2'
        message_type = file_descriptor_proto.message_type.add()
        message_type.name = file_descriptor_proto.name
        nested_type = message_type.nested_type.add()
        nested_type.name = 'Sub'
        enum_type = nested_type.enum_type.add()
        enum_type.name = 'FOO'
        enum_type_val = enum_type.value.add()
        enum_type_val.name = 'BAR'
        enum_type_val.number = 3
        field = message_type.field.add()
        field.number = 1
        field.name = 'uint64_field'
        field.label = descriptor.FieldDescriptor.LABEL_REQUIRED
        field.type = descriptor.FieldDescriptor.TYPE_UINT64
        field = message_type.field.add()
        field.number = 2
        field.name = 'nested_message_field'
        field.label = descriptor.FieldDescriptor.LABEL_REQUIRED
        field.type = descriptor.FieldDescriptor.TYPE_MESSAGE
        field.type_name = 'Sub'
        enum_field = nested_type.field.add()
        enum_field.number = 2
        enum_field.name = 'bar_field'
        enum_field.label = descriptor.FieldDescriptor.LABEL_REQUIRED
        enum_field.type = descriptor.FieldDescriptor.TYPE_ENUM
        enum_field.type_name = 'Foo2.Sub.FOO'

        result = descriptor.MakeDescriptor(message_type)
        self.assertEqual(result.fields[0].cpp_type,
                         descriptor.FieldDescriptor.CPPTYPE_UINT64)
        self.assertEqual(result.fields[1].cpp_type,
                         descriptor.FieldDescriptor.CPPTYPE_MESSAGE)
        self.assertEqual(result.fields[1].message_type.containing_type, result)
        self.assertEqual(result.nested_types[0].fields[0].full_name,
                         'Foo2.Sub.bar_field')
        self.assertEqual(result.nested_types[0].fields[0].enum_type,
                         result.nested_types[0].enum_types[0])
        self.assertFalse(result.has_options)
        self.assertFalse(result.fields[0].has_options)
        if api_implementation.Type() == 'cpp':
            with self.assertRaises(AttributeError):
                result.fields[0].has_options = False
 def testComplexNesting(self):
   if isinstance(self, SecondaryDescriptorFromDescriptorDB):
     if api_implementation.Type() == 'cpp':
       # Cpp extension cannot call Add on a DescriptorPool
       # that uses a DescriptorDatabase.
       # TODO(jieluo): Fix python and cpp extension diff.
       return
   more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString(
       more_messages_pb2.DESCRIPTOR.serialized_pb)
   test1_desc = descriptor_pb2.FileDescriptorProto.FromString(
       descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb)
   test2_desc = descriptor_pb2.FileDescriptorProto.FromString(
       descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb)
   self.pool.Add(more_messages_desc)
   self.pool.Add(test1_desc)
   self.pool.Add(test2_desc)
   TEST1_FILE.CheckFile(self, self.pool)
   TEST2_FILE.CheckFile(self, self.pool)
  def __init__(self, name, full_name, index, number, type, cpp_type, label,
               default_value, message_type, enum_type, containing_type,
               is_extension, extension_scope, options=None,
               serialized_options=None,
               has_default_value=True, containing_oneof=None, json_name=None,
               file=None):  # pylint: disable=redefined-builtin
    """The arguments are as described in the description of FieldDescriptor
    attributes above.

    Note that containing_type may be None, and may be set later if necessary
    (to deal with circular references between message types, for example).
    Likewise for extension_scope.
    """
    super(FieldDescriptor, self).__init__(
        options, serialized_options, 'FieldOptions')
    self.name = name
    self.full_name = full_name
    self.file = file
    self._camelcase_name = None
    if json_name is None:
      self.json_name = _ToJsonName(name)
    else:
      self.json_name = json_name
    self.index = index
    self.number = number
    self.type = type
    self.cpp_type = cpp_type
    self.label = label
    self.has_default_value = has_default_value
    self.default_value = default_value
    self.containing_type = containing_type
    self.message_type = message_type
    self.enum_type = enum_type
    self.is_extension = is_extension
    self.extension_scope = extension_scope
    self.containing_oneof = containing_oneof
    if api_implementation.Type() == 'cpp':
      if is_extension:
        self._cdescriptor = _message.default_pool.FindExtensionByName(full_name)
      else:
        self._cdescriptor = _message.default_pool.FindFieldByName(full_name)
    else:
      self._cdescriptor = None
 def CheckField(self, test, msg_desc, name, index, file_desc):
   field_desc = msg_desc.fields_by_name[name]
   field_type_desc = msg_desc.nested_types_by_name[self.type_name]
   test.assertEqual(name, field_desc.name)
   expected_field_full_name = '.'.join([msg_desc.full_name, name])
   test.assertEqual(expected_field_full_name, field_desc.full_name)
   test.assertEqual(index, field_desc.index)
   test.assertEqual(self.number, field_desc.number)
   test.assertEqual(descriptor.FieldDescriptor.TYPE_MESSAGE, field_desc.type)
   test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_MESSAGE,
                    field_desc.cpp_type)
   test.assertFalse(field_desc.has_default_value)
   test.assertEqual(msg_desc, field_desc.containing_type)
   test.assertEqual(field_type_desc, field_desc.message_type)
   test.assertEqual(file_desc, field_desc.file)
   # TODO(jieluo): Fix python and cpp extension diff for message field
   # default value.
   if api_implementation.Type() == 'cpp':
     test.assertRaises(
         NotImplementedError, getattr, field_desc, 'default_value')
Esempio n. 17
0
 def CheckFieldDescriptor(self, field_descriptor):
     # Basic properties
     self.assertEqual(field_descriptor.name, 'optional_int32')
     self.assertEqual(field_descriptor.camelcase_name, 'optionalInt32')
     self.assertEqual(field_descriptor.full_name,
                      'protobuf_unittest.TestAllTypes.optional_int32')
     self.assertEqual(field_descriptor.containing_type.name, 'TestAllTypes')
     self.assertEqual(field_descriptor.file, unittest_pb2.DESCRIPTOR)
     # Test equality and hashability
     self.assertEqual(field_descriptor, field_descriptor)
     self.assertEqual(
         field_descriptor.containing_type.fields_by_name['optional_int32'],
         field_descriptor)
     self.assertEqual(
         field_descriptor.containing_type.
         fields_by_camelcase_name['optionalInt32'], field_descriptor)
     self.assertIn(field_descriptor, [field_descriptor])
     self.assertIn(field_descriptor, {field_descriptor: None})
     self.assertEqual(None, field_descriptor.extension_scope)
     self.assertEqual(None, field_descriptor.enum_type)
     if api_implementation.Type() == 'cpp':
         # For test coverage only
         self.assertEqual(field_descriptor.id, field_descriptor.id)
Esempio n. 18
0
class DescriptorCopyToProtoTest(unittest.TestCase):
    """Tests for CopyTo functions of Descriptor."""
    def _AssertProtoEqual(self, actual_proto, expected_class, expected_ascii):
        expected_proto = expected_class()
        text_format.Merge(expected_ascii, expected_proto)

        self.assertEqual(
            actual_proto, expected_proto,
            'Not equal,\nActual:\n%s\nExpected:\n%s\n' %
            (str(actual_proto), str(expected_proto)))

    def _InternalTestCopyToProto(self, desc, expected_proto_class,
                                 expected_proto_ascii):
        actual = expected_proto_class()
        desc.CopyToProto(actual)
        self._AssertProtoEqual(actual, expected_proto_class,
                               expected_proto_ascii)

    def testCopyToProto_EmptyMessage(self):
        self._InternalTestCopyToProto(unittest_pb2.TestEmptyMessage.DESCRIPTOR,
                                      descriptor_pb2.DescriptorProto,
                                      TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII)

    def testCopyToProto_NestedMessage(self):
        TEST_NESTED_MESSAGE_ASCII = """
      name: 'NestedMessage'
      field: <
        name: 'bb'
        number: 1
        label: 1  # Optional
        type: 5  # TYPE_INT32
      >
      """

        self._InternalTestCopyToProto(
            unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR,
            descriptor_pb2.DescriptorProto, TEST_NESTED_MESSAGE_ASCII)

    def testCopyToProto_ForeignNestedMessage(self):
        TEST_FOREIGN_NESTED_ASCII = """
      name: 'TestForeignNested'
      field: <
        name: 'foreign_nested'
        number: 1
        label: 1  # Optional
        type: 11  # TYPE_MESSAGE
        type_name: '.protobuf_unittest.TestAllTypes.NestedMessage'
      >
      """

        self._InternalTestCopyToProto(
            unittest_pb2.TestForeignNested.DESCRIPTOR,
            descriptor_pb2.DescriptorProto, TEST_FOREIGN_NESTED_ASCII)

    def testCopyToProto_ForeignEnum(self):
        TEST_FOREIGN_ENUM_ASCII = """
      name: 'ForeignEnum'
      value: <
        name: 'FOREIGN_FOO'
        number: 4
      >
      value: <
        name: 'FOREIGN_BAR'
        number: 5
      >
      value: <
        name: 'FOREIGN_BAZ'
        number: 6
      >
      """

        self._InternalTestCopyToProto(unittest_pb2.ForeignEnum.DESCRIPTOR,
                                      descriptor_pb2.EnumDescriptorProto,
                                      TEST_FOREIGN_ENUM_ASCII)

    def testCopyToProto_Options(self):
        TEST_DEPRECATED_FIELDS_ASCII = """
      name: 'TestDeprecatedFields'
      field: <
        name: 'deprecated_int32'
        number: 1
        label: 1  # Optional
        type: 5  # TYPE_INT32
        options: <
          deprecated: true
        >
      >
      field {
        name: "deprecated_int32_in_oneof"
        number: 2
        label: LABEL_OPTIONAL
        type: TYPE_INT32
        options {
          deprecated: true
        }
        oneof_index: 0
      }
      oneof_decl {
        name: "oneof_fields"
      }
      """

        self._InternalTestCopyToProto(
            unittest_pb2.TestDeprecatedFields.DESCRIPTOR,
            descriptor_pb2.DescriptorProto, TEST_DEPRECATED_FIELDS_ASCII)

    def testCopyToProto_AllExtensions(self):
        TEST_EMPTY_MESSAGE_WITH_EXTENSIONS_ASCII = """
      name: 'TestEmptyMessageWithExtensions'
      extension_range: <
        start: 1
        end: 536870912
      >
      """

        self._InternalTestCopyToProto(
            unittest_pb2.TestEmptyMessageWithExtensions.DESCRIPTOR,
            descriptor_pb2.DescriptorProto,
            TEST_EMPTY_MESSAGE_WITH_EXTENSIONS_ASCII)

    def testCopyToProto_SeveralExtensions(self):
        TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII = """
      name: 'TestMultipleExtensionRanges'
      extension_range: <
        start: 42
        end: 43
      >
      extension_range: <
        start: 4143
        end: 4244
      >
      extension_range: <
        start: 65536
        end: 536870912
      >
      """

        self._InternalTestCopyToProto(
            unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR,
            descriptor_pb2.DescriptorProto,
            TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII)

    def testCopyToProto_FileDescriptor(self):
        UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = ("""
      name: 'google/protobuf/unittest_import.proto'
      package: 'protobuf_unittest_import'
      dependency: 'google/protobuf/unittest_import_public.proto'
      message_type: <
        name: 'ImportMessage'
        field: <
          name: 'd'
          number: 1
          label: 1  # Optional
          type: 5  # TYPE_INT32
        >
      >
      """ + """enum_type: <
        name: 'ImportEnum'
        value: <
          name: 'IMPORT_FOO'
          number: 7
        >
        value: <
          name: 'IMPORT_BAR'
          number: 8
        >
        value: <
          name: 'IMPORT_BAZ'
          number: 9
        >
      >
      enum_type: <
        name: 'ImportEnumForMap'
        value: <
          name: 'UNKNOWN'
          number: 0
        >
        value: <
          name: 'FOO'
          number: 1
        >
        value: <
          name: 'BAR'
          number: 2
        >
      >
      options: <
        java_package: 'com.dis_sdk_python.dependency.google.protobuf.test'
        optimize_for: 1  # SPEED
      """ + """
        cc_enable_arenas: true
      >
      public_dependency: 0
    """)
        self._InternalTestCopyToProto(unittest_import_pb2.DESCRIPTOR,
                                      descriptor_pb2.FileDescriptorProto,
                                      UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII)

    def testCopyToProto_ServiceDescriptor(self):
        TEST_SERVICE_ASCII = """
      name: 'TestService'
      method: <
        name: 'Foo'
        input_type: '.protobuf_unittest.FooRequest'
        output_type: '.protobuf_unittest.FooResponse'
      >
      method: <
        name: 'Bar'
        input_type: '.protobuf_unittest.BarRequest'
        output_type: '.protobuf_unittest.BarResponse'
      >
      """
        self._InternalTestCopyToProto(unittest_pb2.TestService.DESCRIPTOR,
                                      descriptor_pb2.ServiceDescriptorProto,
                                      TEST_SERVICE_ASCII)

    @unittest.skipIf(api_implementation.Type() == 'python',
                     'It is not implemented in python.')
    # TODO(jieluo): Add support for pure python or remove in c extension.
    def testCopyToProto_MethodDescriptor(self):
        expected_ascii = """
      name: 'Foo'
      input_type: '.protobuf_unittest.FooRequest'
      output_type: '.protobuf_unittest.FooResponse'
    """
        method_descriptor = unittest_pb2.TestService.DESCRIPTOR.FindMethodByName(
            'Foo')
        self._InternalTestCopyToProto(method_descriptor,
                                      descriptor_pb2.MethodDescriptorProto,
                                      expected_ascii)

    @unittest.skipIf(api_implementation.Type() == 'python',
                     'Pure python does not raise error.')
    # TODO(jieluo): Fix pure python to check with the proto type.
    def testCopyToProto_TypeError(self):
        file_proto = descriptor_pb2.FileDescriptorProto()
        self.assertRaises(TypeError,
                          unittest_pb2.TestEmptyMessage.DESCRIPTOR.CopyToProto,
                          file_proto)
        self.assertRaises(TypeError,
                          unittest_pb2.ForeignEnum.DESCRIPTOR.CopyToProto,
                          file_proto)
        self.assertRaises(TypeError,
                          unittest_pb2.TestService.DESCRIPTOR.CopyToProto,
                          file_proto)
        proto = descriptor_pb2.DescriptorProto()
        self.assertRaises(TypeError,
                          unittest_import_pb2.DESCRIPTOR.CopyToProto, proto)
Esempio n. 19
0
class DescriptorTest(unittest.TestCase):
    def setUp(self):
        file_proto = descriptor_pb2.FileDescriptorProto(
            name='some/filename/some.proto', package='protobuf_unittest')
        message_proto = file_proto.message_type.add(name='NestedMessage')
        message_proto.field.add(
            name='bb',
            number=1,
            type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32,
            label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL)
        enum_proto = message_proto.enum_type.add(name='ForeignEnum')
        enum_proto.value.add(name='FOREIGN_FOO', number=4)
        enum_proto.value.add(name='FOREIGN_BAR', number=5)
        enum_proto.value.add(name='FOREIGN_BAZ', number=6)

        file_proto.message_type.add(name='ResponseMessage')
        service_proto = file_proto.service.add(name='Service')
        method_proto = service_proto.method.add(
            name='CallMethod',
            input_type='.protobuf_unittest.NestedMessage',
            output_type='.protobuf_unittest.ResponseMessage')

        # Note: Calling DescriptorPool.Add() multiple times with the same file only
        # works if the input is canonical; in particular, all type names must be
        # fully qualified.
        self.pool = self.GetDescriptorPool()
        self.pool.Add(file_proto)
        self.my_file = self.pool.FindFileByName(file_proto.name)
        self.my_message = self.my_file.message_types_by_name[
            message_proto.name]
        self.my_enum = self.my_message.enum_types_by_name[enum_proto.name]
        self.my_service = self.my_file.services_by_name[service_proto.name]
        self.my_method = self.my_service.methods_by_name[method_proto.name]

    def GetDescriptorPool(self):
        return symbol_database.Default().pool

    def testEnumValueName(self):
        self.assertEqual(self.my_message.EnumValueName('ForeignEnum', 4),
                         'FOREIGN_FOO')

        self.assertEqual(
            self.my_message.enum_types_by_name['ForeignEnum'].
            values_by_number[4].name,
            self.my_message.EnumValueName('ForeignEnum', 4))
        with self.assertRaises(KeyError):
            self.my_message.EnumValueName('ForeignEnum', 999)
        with self.assertRaises(KeyError):
            self.my_message.EnumValueName('NoneEnum', 999)
        with self.assertRaises(TypeError):
            self.my_message.EnumValueName()

    def testEnumFixups(self):
        self.assertEqual(self.my_enum, self.my_enum.values[0].type)

    def testContainingTypeFixups(self):
        self.assertEqual(self.my_message,
                         self.my_message.fields[0].containing_type)
        self.assertEqual(self.my_message, self.my_enum.containing_type)

    def testContainingServiceFixups(self):
        self.assertEqual(self.my_service, self.my_method.containing_service)

    def testGetOptions(self):
        self.assertEqual(self.my_enum.GetOptions(),
                         descriptor_pb2.EnumOptions())
        self.assertEqual(self.my_enum.values[0].GetOptions(),
                         descriptor_pb2.EnumValueOptions())
        self.assertEqual(self.my_message.GetOptions(),
                         descriptor_pb2.MessageOptions())
        self.assertEqual(self.my_message.fields[0].GetOptions(),
                         descriptor_pb2.FieldOptions())
        self.assertEqual(self.my_method.GetOptions(),
                         descriptor_pb2.MethodOptions())
        self.assertEqual(self.my_service.GetOptions(),
                         descriptor_pb2.ServiceOptions())

    def testSimpleCustomOptions(self):
        file_descriptor = unittest_custom_options_pb2.DESCRIPTOR
        message_descriptor = (unittest_custom_options_pb2.
                              TestMessageWithCustomOptions.DESCRIPTOR)
        field_descriptor = message_descriptor.fields_by_name['field1']
        oneof_descriptor = message_descriptor.oneofs_by_name['AnOneof']
        enum_descriptor = message_descriptor.enum_types_by_name['AnEnum']
        enum_value_descriptor = (
            message_descriptor.enum_values_by_name['ANENUM_VAL2'])
        other_enum_value_descriptor = (
            message_descriptor.enum_values_by_name['ANENUM_VAL1'])
        service_descriptor = (unittest_custom_options_pb2.
                              TestServiceWithCustomOptions.DESCRIPTOR)
        method_descriptor = service_descriptor.FindMethodByName('Foo')

        file_options = file_descriptor.GetOptions()
        file_opt1 = unittest_custom_options_pb2.file_opt1
        self.assertEqual(9876543210, file_options.Extensions[file_opt1])
        message_options = message_descriptor.GetOptions()
        message_opt1 = unittest_custom_options_pb2.message_opt1
        self.assertEqual(-56, message_options.Extensions[message_opt1])
        field_options = field_descriptor.GetOptions()
        field_opt1 = unittest_custom_options_pb2.field_opt1
        self.assertEqual(8765432109, field_options.Extensions[field_opt1])
        field_opt2 = unittest_custom_options_pb2.field_opt2
        self.assertEqual(42, field_options.Extensions[field_opt2])
        oneof_options = oneof_descriptor.GetOptions()
        oneof_opt1 = unittest_custom_options_pb2.oneof_opt1
        self.assertEqual(-99, oneof_options.Extensions[oneof_opt1])
        enum_options = enum_descriptor.GetOptions()
        enum_opt1 = unittest_custom_options_pb2.enum_opt1
        self.assertEqual(-789, enum_options.Extensions[enum_opt1])
        enum_value_options = enum_value_descriptor.GetOptions()
        enum_value_opt1 = unittest_custom_options_pb2.enum_value_opt1
        self.assertEqual(123, enum_value_options.Extensions[enum_value_opt1])

        service_options = service_descriptor.GetOptions()
        service_opt1 = unittest_custom_options_pb2.service_opt1
        self.assertEqual(-9876543210, service_options.Extensions[service_opt1])
        method_options = method_descriptor.GetOptions()
        method_opt1 = unittest_custom_options_pb2.method_opt1
        self.assertEqual(unittest_custom_options_pb2.METHODOPT1_VAL2,
                         method_options.Extensions[method_opt1])

        message_descriptor = (
            unittest_custom_options_pb2.DummyMessageContainingEnum.DESCRIPTOR)
        self.assertTrue(file_descriptor.has_options)
        self.assertFalse(message_descriptor.has_options)
        self.assertTrue(field_descriptor.has_options)
        self.assertTrue(oneof_descriptor.has_options)
        self.assertTrue(enum_descriptor.has_options)
        self.assertTrue(enum_value_descriptor.has_options)
        self.assertFalse(other_enum_value_descriptor.has_options)

    def testDifferentCustomOptionTypes(self):
        kint32min = -2**31
        kint64min = -2**63
        kint32max = 2**31 - 1
        kint64max = 2**63 - 1
        kuint32max = 2**32 - 1
        kuint64max = 2**64 - 1

        message_descriptor =\
            unittest_custom_options_pb2.CustomOptionMinIntegerValues.DESCRIPTOR
        message_options = message_descriptor.GetOptions()
        self.assertEqual(
            False,
            message_options.Extensions[unittest_custom_options_pb2.bool_opt])
        self.assertEqual(
            kint32min,
            message_options.Extensions[unittest_custom_options_pb2.int32_opt])
        self.assertEqual(
            kint64min,
            message_options.Extensions[unittest_custom_options_pb2.int64_opt])
        self.assertEqual(
            0,
            message_options.Extensions[unittest_custom_options_pb2.uint32_opt])
        self.assertEqual(
            0,
            message_options.Extensions[unittest_custom_options_pb2.uint64_opt])
        self.assertEqual(
            kint32min,
            message_options.Extensions[unittest_custom_options_pb2.sint32_opt])
        self.assertEqual(
            kint64min,
            message_options.Extensions[unittest_custom_options_pb2.sint64_opt])
        self.assertEqual(
            0, message_options.Extensions[
                unittest_custom_options_pb2.fixed32_opt])
        self.assertEqual(
            0, message_options.Extensions[
                unittest_custom_options_pb2.fixed64_opt])
        self.assertEqual(
            kint32min, message_options.Extensions[
                unittest_custom_options_pb2.sfixed32_opt])
        self.assertEqual(
            kint64min, message_options.Extensions[
                unittest_custom_options_pb2.sfixed64_opt])

        message_descriptor =\
            unittest_custom_options_pb2.CustomOptionMaxIntegerValues.DESCRIPTOR
        message_options = message_descriptor.GetOptions()
        self.assertEqual(
            True,
            message_options.Extensions[unittest_custom_options_pb2.bool_opt])
        self.assertEqual(
            kint32max,
            message_options.Extensions[unittest_custom_options_pb2.int32_opt])
        self.assertEqual(
            kint64max,
            message_options.Extensions[unittest_custom_options_pb2.int64_opt])
        self.assertEqual(
            kuint32max,
            message_options.Extensions[unittest_custom_options_pb2.uint32_opt])
        self.assertEqual(
            kuint64max,
            message_options.Extensions[unittest_custom_options_pb2.uint64_opt])
        self.assertEqual(
            kint32max,
            message_options.Extensions[unittest_custom_options_pb2.sint32_opt])
        self.assertEqual(
            kint64max,
            message_options.Extensions[unittest_custom_options_pb2.sint64_opt])
        self.assertEqual(
            kuint32max, message_options.Extensions[
                unittest_custom_options_pb2.fixed32_opt])
        self.assertEqual(
            kuint64max, message_options.Extensions[
                unittest_custom_options_pb2.fixed64_opt])
        self.assertEqual(
            kint32max, message_options.Extensions[
                unittest_custom_options_pb2.sfixed32_opt])
        self.assertEqual(
            kint64max, message_options.Extensions[
                unittest_custom_options_pb2.sfixed64_opt])

        message_descriptor =\
            unittest_custom_options_pb2.CustomOptionOtherValues.DESCRIPTOR
        message_options = message_descriptor.GetOptions()
        self.assertEqual(
            -100,
            message_options.Extensions[unittest_custom_options_pb2.int32_opt])
        self.assertAlmostEqual(
            12.3456789,
            message_options.Extensions[unittest_custom_options_pb2.float_opt],
            6)
        self.assertAlmostEqual(
            1.234567890123456789,
            message_options.Extensions[unittest_custom_options_pb2.double_opt])
        self.assertEqual(
            "Hello, \"World\"",
            message_options.Extensions[unittest_custom_options_pb2.string_opt])
        self.assertEqual(
            b"Hello\0World",
            message_options.Extensions[unittest_custom_options_pb2.bytes_opt])
        dummy_enum = unittest_custom_options_pb2.DummyMessageContainingEnum
        self.assertEqual(
            dummy_enum.TEST_OPTION_ENUM_TYPE2,
            message_options.Extensions[unittest_custom_options_pb2.enum_opt])

        message_descriptor =\
            unittest_custom_options_pb2.SettingRealsFromPositiveInts.DESCRIPTOR
        message_options = message_descriptor.GetOptions()
        self.assertAlmostEqual(
            12,
            message_options.Extensions[unittest_custom_options_pb2.float_opt],
            6)
        self.assertAlmostEqual(
            154,
            message_options.Extensions[unittest_custom_options_pb2.double_opt])

        message_descriptor =\
            unittest_custom_options_pb2.SettingRealsFromNegativeInts.DESCRIPTOR
        message_options = message_descriptor.GetOptions()
        self.assertAlmostEqual(
            -12,
            message_options.Extensions[unittest_custom_options_pb2.float_opt],
            6)
        self.assertAlmostEqual(
            -154,
            message_options.Extensions[unittest_custom_options_pb2.double_opt])

    def testComplexExtensionOptions(self):
        descriptor =\
            unittest_custom_options_pb2.VariousComplexOptions.DESCRIPTOR
        options = descriptor.GetOptions()
        self.assertEqual(
            42,
            options.Extensions[unittest_custom_options_pb2.complex_opt1].foo)
        self.assertEqual(
            324, options.Extensions[unittest_custom_options_pb2.complex_opt1].
            Extensions[unittest_custom_options_pb2.quux])
        self.assertEqual(
            876, options.Extensions[unittest_custom_options_pb2.complex_opt1].
            Extensions[unittest_custom_options_pb2.corge].qux)
        self.assertEqual(
            987,
            options.Extensions[unittest_custom_options_pb2.complex_opt2].baz)
        self.assertEqual(
            654, options.Extensions[unittest_custom_options_pb2.complex_opt2].
            Extensions[unittest_custom_options_pb2.grault])
        self.assertEqual(
            743, options.Extensions[
                unittest_custom_options_pb2.complex_opt2].bar.foo)
        self.assertEqual(
            1999, options.Extensions[unittest_custom_options_pb2.complex_opt2].
            bar.Extensions[unittest_custom_options_pb2.quux])
        self.assertEqual(
            2008, options.Extensions[unittest_custom_options_pb2.complex_opt2].
            bar.Extensions[unittest_custom_options_pb2.corge].qux)
        self.assertEqual(
            741, options.Extensions[unittest_custom_options_pb2.complex_opt2].
            Extensions[unittest_custom_options_pb2.garply].foo)
        self.assertEqual(
            1998, options.Extensions[unittest_custom_options_pb2.complex_opt2].
            Extensions[unittest_custom_options_pb2.garply].Extensions[
                unittest_custom_options_pb2.quux])
        self.assertEqual(
            2121, options.Extensions[unittest_custom_options_pb2.complex_opt2].
            Extensions[unittest_custom_options_pb2.garply].Extensions[
                unittest_custom_options_pb2.corge].qux)
        self.assertEqual(
            1971,
            options.Extensions[unittest_custom_options_pb2.ComplexOptionType2.
                               ComplexOptionType4.complex_opt4].waldo)
        self.assertEqual(
            321, options.Extensions[
                unittest_custom_options_pb2.complex_opt2].fred.waldo)
        self.assertEqual(
            9,
            options.Extensions[unittest_custom_options_pb2.complex_opt3].qux)
        self.assertEqual(
            22, options.Extensions[unittest_custom_options_pb2.complex_opt3].
            complexoptiontype5.plugh)
        self.assertEqual(
            24,
            options.Extensions[unittest_custom_options_pb2.complexopt6].xyzzy)

    # Check that aggregate options were parsed and saved correctly in
    # the appropriate descriptors.
    def testAggregateOptions(self):
        file_descriptor = unittest_custom_options_pb2.DESCRIPTOR
        message_descriptor =\
            unittest_custom_options_pb2.AggregateMessage.DESCRIPTOR
        field_descriptor = message_descriptor.fields_by_name["fieldname"]
        enum_descriptor = unittest_custom_options_pb2.AggregateEnum.DESCRIPTOR
        enum_value_descriptor = enum_descriptor.values_by_name["VALUE"]
        service_descriptor =\
            unittest_custom_options_pb2.AggregateService.DESCRIPTOR
        method_descriptor = service_descriptor.FindMethodByName("Method")

        # Tests for the different types of data embedded in fileopt
        file_options = file_descriptor.GetOptions().Extensions[
            unittest_custom_options_pb2.fileopt]
        self.assertEqual(100, file_options.i)
        self.assertEqual("FileAnnotation", file_options.s)
        self.assertEqual("NestedFileAnnotation", file_options.sub.s)
        self.assertEqual(
            "FileExtensionAnnotation", file_options.file.Extensions[
                unittest_custom_options_pb2.fileopt].s)
        self.assertEqual(
            "EmbeddedMessageSetElement", file_options.mset.Extensions[
                unittest_custom_options_pb2.AggregateMessageSetElement.
                message_set_extension].s)

        # Simple tests for all the other types of annotations
        self.assertEqual(
            "MessageAnnotation",
            message_descriptor.GetOptions().Extensions[
                unittest_custom_options_pb2.msgopt].s)
        self.assertEqual(
            "FieldAnnotation",
            field_descriptor.GetOptions().Extensions[
                unittest_custom_options_pb2.fieldopt].s)
        self.assertEqual(
            "EnumAnnotation",
            enum_descriptor.GetOptions().Extensions[
                unittest_custom_options_pb2.enumopt].s)
        self.assertEqual(
            "EnumValueAnnotation",
            enum_value_descriptor.GetOptions().Extensions[
                unittest_custom_options_pb2.enumvalopt].s)
        self.assertEqual(
            "ServiceAnnotation",
            service_descriptor.GetOptions().Extensions[
                unittest_custom_options_pb2.serviceopt].s)
        self.assertEqual(
            "MethodAnnotation",
            method_descriptor.GetOptions().Extensions[
                unittest_custom_options_pb2.methodopt].s)

    def testNestedOptions(self):
        nested_message =\
            unittest_custom_options_pb2.NestedOptionType.NestedMessage.DESCRIPTOR
        self.assertEqual(
            1001,
            nested_message.GetOptions().Extensions[
                unittest_custom_options_pb2.message_opt1])
        nested_field = nested_message.fields_by_name["nested_field"]
        self.assertEqual(
            1002,
            nested_field.GetOptions().Extensions[
                unittest_custom_options_pb2.field_opt1])
        outer_message =\
            unittest_custom_options_pb2.NestedOptionType.DESCRIPTOR
        nested_enum = outer_message.enum_types_by_name["NestedEnum"]
        self.assertEqual(
            1003,
            nested_enum.GetOptions().Extensions[
                unittest_custom_options_pb2.enum_opt1])
        nested_enum_value = outer_message.enum_values_by_name[
            "NESTED_ENUM_VALUE"]
        self.assertEqual(
            1004,
            nested_enum_value.GetOptions().Extensions[
                unittest_custom_options_pb2.enum_value_opt1])
        nested_extension = outer_message.extensions_by_name["nested_extension"]
        self.assertEqual(
            1005,
            nested_extension.GetOptions().Extensions[
                unittest_custom_options_pb2.field_opt2])

    def testFileDescriptorReferences(self):
        self.assertEqual(self.my_enum.file, self.my_file)
        self.assertEqual(self.my_message.file, self.my_file)

    def testFileDescriptor(self):
        self.assertEqual(self.my_file.name, 'some/filename/some.proto')
        self.assertEqual(self.my_file.package, 'protobuf_unittest')
        self.assertEqual(self.my_file.pool, self.pool)
        self.assertFalse(self.my_file.has_options)
        self.assertEqual('proto2', self.my_file.syntax)
        file_proto = descriptor_pb2.FileDescriptorProto()
        self.my_file.CopyToProto(file_proto)
        self.assertEqual(self.my_file.serialized_pb,
                         file_proto.SerializeToString())
        # Generated modules also belong to the default pool.
        self.assertEqual(unittest_pb2.DESCRIPTOR.pool,
                         descriptor_pool.Default())

    @unittest.skipIf(
        api_implementation.Type() != 'cpp'
        or api_implementation.Version() != 2,
        'Immutability of descriptors is only enforced in v2 implementation')
    def testImmutableCppDescriptor(self):
        file_descriptor = unittest_pb2.DESCRIPTOR
        message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
        field_descriptor = message_descriptor.fields_by_name['optional_int32']
        enum_descriptor = message_descriptor.enum_types_by_name['NestedEnum']
        oneof_descriptor = message_descriptor.oneofs_by_name['oneof_field']
        with self.assertRaises(AttributeError):
            message_descriptor.fields_by_name = None
        with self.assertRaises(TypeError):
            message_descriptor.fields_by_name['Another'] = None
        with self.assertRaises(TypeError):
            message_descriptor.fields.append(None)
        with self.assertRaises(AttributeError):
            field_descriptor.containing_type = message_descriptor
        with self.assertRaises(AttributeError):
            file_descriptor.has_options = False
        with self.assertRaises(AttributeError):
            field_descriptor.has_options = False
        with self.assertRaises(AttributeError):
            oneof_descriptor.has_options = False
        with self.assertRaises(AttributeError):
            enum_descriptor.has_options = False
        with self.assertRaises(AttributeError) as e:
            message_descriptor.has_options = True
        self.assertEqual('attribute is not writable: has_options',
                         str(e.exception))
to inject all the useful functionality into the classes
output by the protocol compiler at compile-time.

The upshot of all this is that the real implementation
details for ALL pure-Python protocol buffers are *here in
this file*.
"""

__author__ = '[email protected] (Will Robinson)'


from dis_sdk_python.dependency.google.protobuf.internal import api_implementation
from dis_sdk_python.dependency.google.protobuf import message


if api_implementation.Type() == 'cpp':
  from dis_sdk_python.dependency.google.protobuf.pyext import cpp_message as message_impl
else:
  from dis_sdk_python.dependency.google.protobuf.internal import python_message as message_impl

# The type of all Message classes.
# Part of the public interface, but normally only used by message factories.
GeneratedProtocolMessageType = message_impl.GeneratedProtocolMessageType

MESSAGE_CLASS_CACHE = {}


def ParseMessage(descriptor, byte_str):
  """Generate a new Message instance from this Descriptor and a byte string.

  Args:
def SkipCheckUnknownFieldIfCppImplementation(func):
    return unittest.skipIf(
        api_implementation.Type() == 'cpp'
        and api_implementation.Version() == 2,
        'Addtional test for pure python involved protect members')(func)
class AddDescriptorTest(unittest.TestCase):

  def _TestMessage(self, prefix):
    pool = descriptor_pool.DescriptorPool()
    pool.AddDescriptor(unittest_pb2.TestAllTypes.DESCRIPTOR)
    self.assertEqual(
        'protobuf_unittest.TestAllTypes',
        pool.FindMessageTypeByName(
            prefix + 'protobuf_unittest.TestAllTypes').full_name)

    # AddDescriptor is not recursive.
    with self.assertRaises(KeyError):
      pool.FindMessageTypeByName(
          prefix + 'protobuf_unittest.TestAllTypes.NestedMessage')

    pool.AddDescriptor(unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR)
    self.assertEqual(
        'protobuf_unittest.TestAllTypes.NestedMessage',
        pool.FindMessageTypeByName(
            prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').full_name)

    # Files are implicitly also indexed when messages are added.
    self.assertEqual(
        'google/protobuf/unittest.proto',
        pool.FindFileByName(
            'google/protobuf/unittest.proto').name)

    self.assertEqual(
        'google/protobuf/unittest.proto',
        pool.FindFileContainingSymbol(
            prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').name)

  @unittest.skipIf(api_implementation.Type() == 'cpp',
                   'With the cpp implementation, Add() must be called first')
  def testMessage(self):
    self._TestMessage('')
    self._TestMessage('.')

  def _TestEnum(self, prefix):
    pool = descriptor_pool.DescriptorPool()
    pool.AddEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR)
    self.assertEqual(
        'protobuf_unittest.ForeignEnum',
        pool.FindEnumTypeByName(
            prefix + 'protobuf_unittest.ForeignEnum').full_name)

    # AddEnumDescriptor is not recursive.
    with self.assertRaises(KeyError):
      pool.FindEnumTypeByName(
          prefix + 'protobuf_unittest.ForeignEnum.NestedEnum')

    pool.AddEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR)
    self.assertEqual(
        'protobuf_unittest.TestAllTypes.NestedEnum',
        pool.FindEnumTypeByName(
            prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name)

    # Files are implicitly also indexed when enums are added.
    self.assertEqual(
        'google/protobuf/unittest.proto',
        pool.FindFileByName(
            'google/protobuf/unittest.proto').name)

    self.assertEqual(
        'google/protobuf/unittest.proto',
        pool.FindFileContainingSymbol(
            prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').name)

  @unittest.skipIf(api_implementation.Type() == 'cpp',
                   'With the cpp implementation, Add() must be called first')
  def testEnum(self):
    self._TestEnum('')
    self._TestEnum('.')

  @unittest.skipIf(api_implementation.Type() == 'cpp',
                   'With the cpp implementation, Add() must be called first')
  def testService(self):
    pool = descriptor_pool.DescriptorPool()
    with self.assertRaises(KeyError):
      pool.FindServiceByName('protobuf_unittest.TestService')
    pool.AddServiceDescriptor(unittest_pb2._TESTSERVICE)
    self.assertEqual(
        'protobuf_unittest.TestService',
        pool.FindServiceByName('protobuf_unittest.TestService').full_name)

  @unittest.skipIf(api_implementation.Type() == 'cpp',
                   'With the cpp implementation, Add() must be called first')
  def testFile(self):
    pool = descriptor_pool.DescriptorPool()
    pool.AddFileDescriptor(unittest_pb2.DESCRIPTOR)
    self.assertEqual(
        'google/protobuf/unittest.proto',
        pool.FindFileByName(
            'google/protobuf/unittest.proto').name)

    # AddFileDescriptor is not recursive; messages and enums within files must
    # be explicitly registered.
    with self.assertRaises(KeyError):
      pool.FindFileContainingSymbol(
          'protobuf_unittest.TestAllTypes')

  def testEmptyDescriptorPool(self):
    # Check that an empty DescriptorPool() contains no messages.
    pool = descriptor_pool.DescriptorPool()
    proto_file_name = descriptor_pb2.DESCRIPTOR.name
    self.assertRaises(KeyError, pool.FindFileByName, proto_file_name)
    # Add the above file to the pool
    file_descriptor = descriptor_pb2.FileDescriptorProto()
    descriptor_pb2.DESCRIPTOR.CopyToProto(file_descriptor)
    pool.Add(file_descriptor)
    # Now it exists.
    self.assertTrue(pool.FindFileByName(proto_file_name))

  def testCustomDescriptorPool(self):
    # Create a new pool, and add a file descriptor.
    pool = descriptor_pool.DescriptorPool()
    file_desc = descriptor_pb2.FileDescriptorProto(
        name='some/file.proto', package='package')
    file_desc.message_type.add(name='Message')
    pool.Add(file_desc)
    self.assertEqual(pool.FindFileByName('some/file.proto').name,
                     'some/file.proto')
    self.assertEqual(pool.FindMessageTypeByName('package.Message').name,
                     'Message')
    # Test no package
    file_proto = descriptor_pb2.FileDescriptorProto(
        name='some/filename/container.proto')
    message_proto = file_proto.message_type.add(
        name='TopMessage')
    message_proto.field.add(
        name='bb',
        number=1,
        type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32,
        label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL)
    enum_proto = file_proto.enum_type.add(name='TopEnum')
    enum_proto.value.add(name='FOREIGN_FOO', number=4)
    file_proto.service.add(name='TopService')
    pool = descriptor_pool.DescriptorPool()
    pool.Add(file_proto)
    self.assertEqual('TopMessage',
                     pool.FindMessageTypeByName('TopMessage').name)
    self.assertEqual('TopEnum', pool.FindEnumTypeByName('TopEnum').name)
    self.assertEqual('TopService', pool.FindServiceByName('TopService').name)

  def testFileDescriptorOptionsWithCustomDescriptorPool(self):
    # Create a descriptor pool, and add a new FileDescriptorProto to it.
    pool = descriptor_pool.DescriptorPool()
    file_name = 'file_descriptor_options_with_custom_descriptor_pool.proto'
    file_descriptor_proto = descriptor_pb2.FileDescriptorProto(name=file_name)
    extension_id = file_options_test_pb2.foo_options
    file_descriptor_proto.options.Extensions[extension_id].foo_name = 'foo'
    pool.Add(file_descriptor_proto)
    # The options set on the FileDescriptorProto should be available in the
    # descriptor even if they contain extensions that cannot be deserialized
    # using the pool.
    file_descriptor = pool.FindFileByName(file_name)
    options = file_descriptor.GetOptions()
    self.assertEqual('foo', options.Extensions[extension_id].foo_name)
    # The object returned by GetOptions() is cached.
    self.assertIs(options, file_descriptor.GetOptions())

  def testAddTypeError(self):
    pool = descriptor_pool.DescriptorPool()
    with self.assertRaises(TypeError):
      pool.AddDescriptor(0)
    with self.assertRaises(TypeError):
      pool.AddEnumDescriptor(0)
    with self.assertRaises(TypeError):
      pool.AddServiceDescriptor(0)
    with self.assertRaises(TypeError):
      pool.AddExtensionDescriptor(0)
    with self.assertRaises(TypeError):
      pool.AddFileDescriptor(0)
def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True,
                   syntax=None):
  """Make a protobuf Descriptor given a DescriptorProto protobuf.

  Handles nested descriptors. Note that this is limited to the scope of defining
  a message inside of another message. Composite fields can currently only be
  resolved if the message is defined in the same scope as the field.

  Args:
    desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
    package: Optional package name for the new message Descriptor (string).
    build_file_if_cpp: Update the C++ descriptor pool if api matches.
                       Set to False on recursion, so no duplicates are created.
    syntax: The syntax/semantics that should be used.  Set to "proto3" to get
            proto3 field presence semantics.
  Returns:
    A Descriptor for protobuf messages.
  """
  if api_implementation.Type() == 'cpp' and build_file_if_cpp:
    # The C++ implementation requires all descriptors to be backed by the same
    # definition in the C++ descriptor pool. To do this, we build a
    # FileDescriptorProto with the same definition as this descriptor and build
    # it into the pool.
    from dis_sdk_python.dependency.google.protobuf import descriptor_pb2
    file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
    file_descriptor_proto.message_type.add().MergeFrom(desc_proto)

    # Generate a random name for this proto file to prevent conflicts with any
    # imported ones. We need to specify a file name so the descriptor pool
    # accepts our FileDescriptorProto, but it is not important what that file
    # name is actually set to.
    proto_name = binascii.hexlify(os.urandom(16)).decode('ascii')

    if package:
      file_descriptor_proto.name = os.path.join(package.replace('.', '/'),
                                                proto_name + '.proto')
      file_descriptor_proto.package = package
    else:
      file_descriptor_proto.name = proto_name + '.proto'

    _message.default_pool.Add(file_descriptor_proto)
    result = _message.default_pool.FindFileByName(file_descriptor_proto.name)

    if _USE_C_DESCRIPTORS:
      return result.message_types_by_name[desc_proto.name]

  full_message_name = [desc_proto.name]
  if package: full_message_name.insert(0, package)

  # Create Descriptors for enum types
  enum_types = {}
  for enum_proto in desc_proto.enum_type:
    full_name = '.'.join(full_message_name + [enum_proto.name])
    enum_desc = EnumDescriptor(
      enum_proto.name, full_name, None, [
          EnumValueDescriptor(enum_val.name, ii, enum_val.number)
          for ii, enum_val in enumerate(enum_proto.value)])
    enum_types[full_name] = enum_desc

  # Create Descriptors for nested types
  nested_types = {}
  for nested_proto in desc_proto.nested_type:
    full_name = '.'.join(full_message_name + [nested_proto.name])
    # Nested types are just those defined inside of the message, not all types
    # used by fields in the message, so no loops are possible here.
    nested_desc = MakeDescriptor(nested_proto,
                                 package='.'.join(full_message_name),
                                 build_file_if_cpp=False,
                                 syntax=syntax)
    nested_types[full_name] = nested_desc

  fields = []
  for field_proto in desc_proto.field:
    full_name = '.'.join(full_message_name + [field_proto.name])
    enum_desc = None
    nested_desc = None
    if field_proto.json_name:
      json_name = field_proto.json_name
    else:
      json_name = None
    if field_proto.HasField('type_name'):
      type_name = field_proto.type_name
      full_type_name = '.'.join(full_message_name +
                                [type_name[type_name.rfind('.')+1:]])
      if full_type_name in nested_types:
        nested_desc = nested_types[full_type_name]
      elif full_type_name in enum_types:
        enum_desc = enum_types[full_type_name]
      # Else type_name references a non-local type, which isn't implemented
    field = FieldDescriptor(
        field_proto.name, full_name, field_proto.number - 1,
        field_proto.number, field_proto.type,
        FieldDescriptor.ProtoTypeToCppProtoType(field_proto.type),
        field_proto.label, None, nested_desc, enum_desc, None, False, None,
        options=_OptionsOrNone(field_proto), has_default_value=False,
        json_name=json_name)
    fields.append(field)

  desc_name = '.'.join(full_message_name)
  return Descriptor(desc_proto.name, desc_name, None, None, fields,
                    list(nested_types.values()), list(enum_types.values()), [],
                    options=_OptionsOrNone(desc_proto))