def __init__(self,
                 name,
                 full_name,
                 index,
                 number,
                 type,
                 cpp_type,
                 label,
                 default_value,
                 message_type,
                 enum_type,
                 containing_type,
                 is_extension,
                 extension_scope,
                 options=None,
                 serialized_options=None,
                 has_default_value=True,
                 containing_oneof=None,
                 json_name=None,
                 file=None):  # pylint: disable=redefined-builtin
        """The arguments are as described in the description of FieldDescriptor
    attributes above.

    Note that containing_type may be None, and may be set later if necessary
    (to deal with circular references between message types, for example).
    Likewise for extension_scope.
    """
        super(FieldDescriptor, self).__init__(options, serialized_options,
                                              'FieldOptions')
        self.name = name
        self.full_name = full_name
        self.file = file
        self._camelcase_name = None
        if json_name is None:
            self.json_name = _ToJsonName(name)
        else:
            self.json_name = json_name
        self.index = index
        self.number = number
        self.type = type
        self.cpp_type = cpp_type
        self.label = label
        self.has_default_value = has_default_value
        self.default_value = default_value
        self.containing_type = containing_type
        self.message_type = message_type
        self.enum_type = enum_type
        self.is_extension = is_extension
        self.extension_scope = extension_scope
        self.containing_oneof = containing_oneof
        if api_implementation.Type() == 'cpp':
            if is_extension:
                self._cdescriptor = _message.default_pool.FindExtensionByName(
                    full_name)
            else:
                self._cdescriptor = _message.default_pool.FindFieldByName(
                    full_name)
        else:
            self._cdescriptor = None
 def setUp(self):
     super().setUp()
     # Verify that the implementation we requested via the Flag is honoured.
     assert api_implementation.Type(
     ) == flags.FLAGS.proto_implementation_type, (
         'Expected proto implementation type '
         f'"{flags.FLAGS.proto_implementation_type}", got: '
         f'"{api_implementation.Type()}"')
Ejemplo n.º 3
0
 def _GetDescriptorPoolClass(self):
     # Test with both implementations of descriptor pools.
     if api_implementation.Type() == 'cpp':
         # pylint: disable=g-import-not-at-top
         from google.protobuf.pyext import _message
         return _message.DescriptorPool
     else:
         return descriptor_pool.DescriptorPool
Ejemplo n.º 4
0
 def testUnknownPackedEnumValue(self):
     if api_implementation.Type() == 'cpp':
         # For repeated enums, both implementations agree.
         self.assertEqual([], self.missing_message.packed_nested_enum)
     else:
         self.assertEqual([], self.missing_message.packed_nested_enum)
         value = self.GetUnknownField('packed_nested_enum')
         self.assertEqual(self.message.packed_nested_enum, value)
 def testConflictRegister(self):
   if isinstance(self, SecondaryDescriptorFromDescriptorDB):
     if api_implementation.Type() == 'cpp':
       # Cpp extension cannot call Add on a DescriptorPool
       # that uses a DescriptorDatabase.
       # TODO(jieluo): Fix python and cpp extension diff.
       return
   unittest_fd = descriptor_pb2.FileDescriptorProto.FromString(
       unittest_pb2.DESCRIPTOR.serialized_pb)
   conflict_fd = copy.deepcopy(unittest_fd)
   conflict_fd.name = 'other_file'
   if api_implementation.Type() == 'cpp':
     try:
       self.pool.Add(unittest_fd)
       self.pool.Add(conflict_fd)
     except TypeError:
       pass
   else:
     with warnings.catch_warnings(record=True) as w:
       # Cause all warnings to always be triggered.
       warnings.simplefilter('always')
       pool = copy.deepcopy(self.pool)
       # No warnings to add the same descriptors.
       file_descriptor = unittest_pb2.DESCRIPTOR
       pool.AddDescriptor(
           file_descriptor.message_types_by_name['TestAllTypes'])
       pool.AddEnumDescriptor(
           file_descriptor.enum_types_by_name['ForeignEnum'])
       pool.AddServiceDescriptor(
           file_descriptor.services_by_name['TestService'])
       pool.AddExtensionDescriptor(
           file_descriptor.extensions_by_name['optional_int32_extension'])
       self.assertEqual(len(w), 0)
       # Check warnings for conflict descriptors with the same name.
       pool.Add(unittest_fd)
       pool.Add(conflict_fd)
       pool.FindFileByName(unittest_fd.name)
       pool.FindFileByName(conflict_fd.name)
       self.assertTrue(len(w))
       self.assertIs(w[0].category, RuntimeWarning)
       self.assertIn('Conflict register for file "other_file": ',
                     str(w[0].message))
       self.assertIn('already defined in file '
                     '"google/protobuf/unittest.proto"',
                     str(w[0].message))
 def testBadUtf8String(self, message_module):
   if api_implementation.Type() != 'python':
     self.skipTest("Skipping testBadUtf8String, currently only the python "
                   "api implementation raises UnicodeDecodeError when a "
                   "string field contains bad utf-8.")
   bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
   with self.assertRaises(UnicodeDecodeError) as context:
     message_module.TestAllTypes.FromString(bad_utf8_data)
   self.assertIn('TestAllTypes.optional_string', str(context.exception))
Ejemplo n.º 7
0
    def RegisterMessageDescriptor(self, message_descriptor):
        """Registers the given message descriptor in the local database.

    Args:
      message_descriptor (Descriptor): the message descriptor to add.
    """
        if api_implementation.Type() == 'python':
            # pylint: disable=protected-access
            self.pool._AddDescriptor(message_descriptor)
Ejemplo n.º 8
0
    def RegisterFileDescriptor(self, file_descriptor):
        """Registers the given file descriptor in the local database.

    Args:
      file_descriptor (FileDescriptor): The file descriptor to register.
    """
        if api_implementation.Type() == 'python':
            # pylint: disable=protected-access
            self.pool._InternalAddFileDescriptor(file_descriptor)
Ejemplo n.º 9
0
 def setUp(self):
   self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
   self.all_fields = unittest_pb2.TestAllTypes()
   test_util.SetAllFields(self.all_fields)
   self.all_fields_data = self.all_fields.SerializeToString()
   self.empty_message = unittest_pb2.TestEmptyMessage()
   self.empty_message.ParseFromString(self.all_fields_data)
   if api_implementation.Type() != 'cpp':
     # _unknown_fields is an implementation detail.
     self.unknown_fields = self.empty_message._unknown_fields
 def testAddFileDescriptor(self):
   if isinstance(self, SecondaryDescriptorFromDescriptorDB):
     if api_implementation.Type() == 'cpp':
       # Cpp extension cannot call Add on a DescriptorPool
       # that uses a DescriptorDatabase.
       # TODO(jieluo): Fix python and cpp extension diff.
       return
   file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto')
   self.pool.Add(file_desc)
   self.pool.AddSerializedFile(file_desc.SerializeToString())
Ejemplo n.º 11
0
    def RegisterServiceDescriptor(self, service_descriptor):
        """Registers the given service descriptor in the local database.

    Args:
      service_descriptor (ServiceDescriptor): the service descriptor to
        register.
    """
        if api_implementation.Type() == 'python':
            # pylint: disable=protected-access
            self.pool._AddServiceDescriptor(service_descriptor)
Ejemplo n.º 12
0
    def testGetMessages(self):
        # performed twice because multiple calls with the same input must be allowed
        for _ in range(2):
            # GetMessage should work regardless of the order the FileDescriptorProto
            # are provided. In particular, the function should succeed when the files
            # are not in the topological order of dependencies.

            # Assuming factory_test2_fd depends on factory_test1_fd.
            self.assertIn(self.factory_test1_fd.name,
                          self.factory_test2_fd.dependency)
            # Get messages should work when a file comes before its dependencies:
            # factory_test2_fd comes before factory_test1_fd.
            messages = message_factory.GetMessages(
                [self.factory_test2_fd, self.factory_test1_fd])
            self.assertTrue(
                set([
                    'google.protobuf.python.internal.Factory2Message',
                    'google.protobuf.python.internal.Factory1Message'
                ], ).issubset(set(messages.keys())))
            self._ExerciseDynamicClass(
                messages['google.protobuf.python.internal.Factory2Message'])
            factory_msg1 = messages[
                'google.protobuf.python.internal.Factory1Message']
            self.assertTrue(
                set([
                    'google.protobuf.python.internal.Factory2Message.one_more_field',
                    'google.protobuf.python.internal.another_field'
                ], ).issubset(
                    set(ext.full_name for ext in factory_msg1.DESCRIPTOR.file.
                        pool.FindAllExtensions(factory_msg1.DESCRIPTOR))))
            msg1 = messages['google.protobuf.python.internal.Factory1Message'](
            )
            ext1 = msg1.Extensions._FindExtensionByName(
                'google.protobuf.python.internal.Factory2Message.one_more_field'
            )
            ext2 = msg1.Extensions._FindExtensionByName(
                'google.protobuf.python.internal.another_field')
            msg1.Extensions[ext1] = 'test1'
            msg1.Extensions[ext2] = 'test2'
            self.assertEqual('test1', msg1.Extensions[ext1])
            self.assertEqual('test2', msg1.Extensions[ext2])
            self.assertEqual(None,
                             msg1.Extensions._FindExtensionByNumber(12321))
            if api_implementation.Type() == 'cpp':
                # TODO(jieluo): Fix len to return the correct value.
                # self.assertEqual(2, len(msg1.Extensions))
                self.assertEqual(len(msg1.Extensions), len(msg1.Extensions))
                self.assertRaises(TypeError,
                                  msg1.Extensions._FindExtensionByName, 0)
                self.assertRaises(TypeError,
                                  msg1.Extensions._FindExtensionByNumber, '')
            else:
                self.assertEqual(None, msg1.Extensions._FindExtensionByName(0))
                self.assertEqual(None,
                                 msg1.Extensions._FindExtensionByNumber(''))
 def testAddSerializedFile(self):
   if isinstance(self, SecondaryDescriptorFromDescriptorDB):
     if api_implementation.Type() == 'cpp':
       # Cpp extension cannot call Add on a DescriptorPool
       # that uses a DescriptorDatabase.
       # TODO(jieluo): Fix python and cpp extension diff.
       return
   self.pool = descriptor_pool.DescriptorPool()
   self.pool.AddSerializedFile(self.factory_test1_fd.SerializeToString())
   self.pool.AddSerializedFile(self.factory_test2_fd.SerializeToString())
   self.testFindMessageTypeByName()
Ejemplo n.º 14
0
    def __init__(self,
                 name,
                 full_name,
                 index,
                 number,
                 type,
                 cpp_type,
                 label,
                 default_value,
                 message_type,
                 enum_type,
                 containing_type,
                 is_extension,
                 extension_scope,
                 options=None,
                 has_default_value=True):
        """The arguments are as described in the description of FieldDescriptor
    attributes above.

    Note that containing_type may be None, and may be set later if necessary
    (to deal with circular references between message types, for example).
    Likewise for extension_scope.
    """
        super(FieldDescriptor, self).__init__(options, 'FieldOptions')
        self.name = name
        self.full_name = full_name
        self.index = index
        self.number = number
        self.type = type
        self.cpp_type = cpp_type
        self.label = label
        self.has_default_value = has_default_value
        self.default_value = default_value
        self.containing_type = containing_type
        self.message_type = message_type
        self.enum_type = enum_type
        self.is_extension = is_extension
        self.extension_scope = extension_scope
        if api_implementation.Type() == 'cpp':
            if is_extension:
                if api_implementation.Version() == 2:
                    self._cdescriptor = _message.GetExtensionDescriptor(
                        full_name)
                else:
                    self._cdescriptor = cpp_message.GetExtensionDescriptor(
                        full_name)
            else:
                if api_implementation.Version() == 2:
                    self._cdescriptor = _message.GetFieldDescriptor(full_name)
                else:
                    self._cdescriptor = cpp_message.GetFieldDescriptor(
                        full_name)
        else:
            self._cdescriptor = None
 def testFindOneofByName(self):
   if isinstance(self, SecondaryDescriptorFromDescriptorDB):
     if api_implementation.Type() == 'cpp':
       # TODO(jieluo): Fix cpp extension to find oneof correctly
       # when descriptor pool is using an underlying database.
       return
   oneof = self.pool.FindOneofByName(
       'google.protobuf.python.internal.Factory2Message.oneof_field')
   self.assertEqual(oneof.name, 'oneof_field')
   with self.assertRaises(KeyError):
     self.pool.FindOneofByName('Does not exist')
Ejemplo n.º 16
0
  def __init__(self, name, package, options=None, serialized_pb=None):
    """Constructor."""
    super(FileDescriptor, self).__init__(options, 'FileOptions')

    self.message_types_by_name = {}
    self.name = name
    self.package = package
    self.serialized_pb = serialized_pb
    if (api_implementation.Type() == 'cpp' and
        self.serialized_pb is not None):
      cpp_message.BuildFile(self.serialized_pb)
  def testFindTypeErrors(self):
    self.assertRaises(TypeError, self.pool.FindExtensionByNumber, '')

    # TODO(jieluo): Fix python to raise correct errors.
    if api_implementation.Type() == 'cpp':
      self.assertRaises(TypeError, self.pool.FindMethodByName, 0)
      self.assertRaises(KeyError, self.pool.FindMethodByName, '')
      error_type = TypeError
    else:
      error_type = AttributeError
    self.assertRaises(error_type, self.pool.FindMessageTypeByName, 0)
    self.assertRaises(error_type, self.pool.FindFieldByName, 0)
    self.assertRaises(error_type, self.pool.FindExtensionByName, 0)
    self.assertRaises(error_type, self.pool.FindEnumTypeByName, 0)
    self.assertRaises(error_type, self.pool.FindOneofByName, 0)
    self.assertRaises(error_type, self.pool.FindServiceByName, 0)
    self.assertRaises(error_type, self.pool.FindFileContainingSymbol, 0)
    if api_implementation.Type() == 'python':
      error_type = KeyError
    self.assertRaises(error_type, self.pool.FindFileByName, 0)
Ejemplo n.º 18
0
 def InternalCheckUnknownField(self, name, expected_value):
   if api_implementation.Type() == 'cpp':
     return
   field_descriptor = self.descriptor.fields_by_name[name]
   wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
   field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
   result_dict = {}
   for tag_bytes, value in self.empty_message._unknown_fields:
     if tag_bytes == field_tag:
       decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
       decoder(memoryview(value), 0, len(value), self.all_fields, result_dict)
   self.assertEqual(expected_value, result_dict[field_descriptor])
Ejemplo n.º 19
0
 def CheckDescriptorMapping(self, mapping):
     # Verifies that a property like 'messageDescriptor.fields' has all the
     # properties of an immutable abc.Mapping.
     self.assertNotEqual(
         mapping, unittest_pb2.TestAllExtensions.DESCRIPTOR.fields_by_name)
     self.assertNotEqual(mapping, {})
     self.assertNotEqual(mapping, 1)
     self.assertFalse(mapping == 1)  # Only for cpp test coverage
     excepted_dict = dict(mapping.items())
     self.assertEqual(mapping, excepted_dict)
     self.assertEqual(mapping, mapping)
     self.assertGreater(len(mapping), 0)  # Sized
     self.assertEqual(len(mapping), len(excepted_dict))  # Iterable
     key, item = next(iter(mapping.items()))
     self.assertIn(key, mapping)  # Container
     self.assertEqual(mapping.get(key), item)
     with self.assertRaises(TypeError):
         mapping.get()
     # TODO(jieluo): Fix python and cpp extension diff.
     if api_implementation.Type() == 'python':
         self.assertRaises(TypeError, mapping.get, [])
     else:
         self.assertEqual(None, mapping.get([]))
     # keys(), iterkeys() &co
     item = (next(iter(mapping.keys())), next(iter(mapping.values())))
     self.assertEqual(item, next(iter(mapping.items())))
     excepted_dict[key] = 'change value'
     self.assertNotEqual(mapping, excepted_dict)
     del excepted_dict[key]
     excepted_dict['new_key'] = 'new'
     self.assertNotEqual(mapping, excepted_dict)
     self.assertRaises(KeyError, mapping.__getitem__, 'key_error')
     self.assertRaises(KeyError, mapping.__getitem__, len(mapping) + 1)
     # TODO(jieluo): Add __repr__ support for DescriptorMapping.
     if api_implementation.Type() == 'python':
         self.assertEqual(len(str(dict(mapping.items()))),
                          len(str(mapping)))
     else:
         self.assertEqual(str(mapping)[0], '<')
Ejemplo n.º 20
0
    def RegisterEnumDescriptor(self, enum_descriptor):
        """Registers the given enum descriptor in the local database.

    Args:
      enum_descriptor (EnumDescriptor): The enum descriptor to register.

    Returns:
      EnumDescriptor: The provided descriptor.
    """
        if api_implementation.Type() == 'python':
            # pylint: disable=protected-access
            self.pool._AddEnumDescriptor(enum_descriptor)
        return enum_descriptor
 def testPickleRepeatedScalarContainer(self, message_module):
   # TODO(tibell): The pure-Python implementation support pickling of
   #   scalar containers in *some* cases. For now the cpp2 version
   #   throws an exception to avoid a segfault. Investigate if we
   #   want to support pickling of these fields.
   #
   # For more information see: https://b2.corp.google.com/u/0/issues/18677897
   if (api_implementation.Type() != 'cpp' or
       api_implementation.Version() == 2):
     return
   m = message_module.TestAllTypes()
   with self.assertRaises(pickle.PickleError) as _:
     pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL)
Ejemplo n.º 22
0
 def __init__(self,
              name,
              full_name,
              index,
              number,
              type,
              cpp_type,
              label,
              default_value,
              message_type,
              enum_type,
              containing_type,
              is_extension,
              extension_scope,
              options=None,
              has_default_value=True):
     super(FieldDescriptor, self).__init__(options, 'FieldOptions')
     self.name = name
     self.full_name = full_name
     self.index = index
     self.number = number
     self.type = type
     self.cpp_type = cpp_type
     self.label = label
     self.has_default_value = has_default_value
     self.default_value = default_value
     self.containing_type = containing_type
     self.message_type = message_type
     self.enum_type = enum_type
     self.is_extension = is_extension
     self.extension_scope = extension_scope
     if api_implementation.Type() == 'cpp':
         if is_extension:
             if api_implementation.Version() == 2:
                 self._cdescriptor = _message.GetExtensionDescriptor(
                     full_name)
             else:
                 self._cdescriptor = cpp_message.GetExtensionDescriptor(
                     full_name)
                 if api_implementation.Version() == 2:
                     self._cdescriptor = _message.GetFieldDescriptor(
                         full_name)
                 else:
                     self._cdescriptor = cpp_message.GetFieldDescriptor(
                         full_name)
         elif api_implementation.Version() == 2:
             self._cdescriptor = _message.GetFieldDescriptor(full_name)
         else:
             self._cdescriptor = cpp_message.GetFieldDescriptor(full_name)
     else:
         self._cdescriptor = None
  def testFindFieldByName(self):
    if isinstance(self, SecondaryDescriptorFromDescriptorDB):
      if api_implementation.Type() == 'cpp':
        # TODO(jieluo): Fix cpp extension to find field correctly
        # when descriptor pool is using an underlying database.
        return
    field = self.pool.FindFieldByName(
        'google.protobuf.python.internal.Factory1Message.list_value')
    self.assertEqual(field.name, 'list_value')
    self.assertEqual(field.label, field.LABEL_REPEATED)
    self.assertFalse(field.has_options)

    with self.assertRaises(KeyError):
      self.pool.FindFieldByName('Does not exist')
def parse_archive_payload(
        raw_bytes: bytes,
        package_id: 'Optional[PackageRef]' = None) -> 'G.ArchivePayload':
    """
    Convert ``bytes`` into a :class:`G.ArchivePayload`.

    Note that this function will temporarily increase Python's recursion limit to handle cases where
    parsing a DAML-LF archive requires deeper recursion limits.
    """
    # noinspection PyPackageRequirements
    from google.protobuf.message import DecodeError
    from . import model as G

    current_time = time.time()

    prev_recursion_limit = sys.getrecursionlimit()
    sys.setrecursionlimit(5000)
    archive_payload = G.ArchivePayload()
    try:
        archive_payload.ParseFromString(raw_bytes)
    except DecodeError:
        # noinspection PyPackageRequirements
        from google.protobuf.internal import api_implementation
        if api_implementation.Type() == 'cpp':
            LOG.error(
                'Failed to decode metadata. This may be due to bugs in the native Protobuf'
            )
            LOG.error(
                'implementation as exposed through Python, so setting an environment'
            )
            LOG.error(
                'variable to force a non-native implementation may help work around this'
            )
            LOG.error('problem:')
            LOG.error(
                '    export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python')
        raise
    finally:
        sys.setrecursionlimit(prev_recursion_limit)

    final_time = time.time()
    total_millis = (final_time - current_time) * 1000
    if package_id is None:
        LOG.info('Parsed %s bytes of metadata in %2.f ms.', len(raw_bytes),
                 total_millis)
    else:
        LOG.info('Parsed %s bytes of metadata (package ID %r) in %2.f ms.',
                 len(raw_bytes), package_id, total_millis)

    return archive_payload
Ejemplo n.º 25
0
def main_graph_base_converter(file_config):
    """
    The entrance for converter, script files will be converted.

    Args:
        file_config (dict): The config of file which to convert.
    """

    if api_implementation.Type() != 'cpp' or os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION') != 'cpp':
        log_console.warning("Protobuf is currently implemented in \"Python\". "
                            "The conversion process may take a long time. "
                            "Please use `export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp` to enable cpp backend.")

    graph_path = file_config['model_file']
    frame_type = get_framework_type(graph_path)
    if not file_config.get("shape"):
        raise ParamMissingError("Param missing, `--shape` is required when using graph mode.", only_console=True)

    check_params = ['input_nodes', 'output_nodes']
    check_params_exist(check_params, file_config)

    if len(file_config['shape']) != len(file_config.get("input_nodes", [])):
        raise BadParamError("`--shape` and `--input_nodes` must have the same length, "
                            "and no redundant node in `--input_nodes`.", only_console=True)

    input_nodes = dict()
    for shape, node in zip(file_config['shape'], file_config['input_nodes']):
        input_nodes[node] = shape

    if frame_type == FrameworkType.ONNX.value:
        graph_based_converter_onnx_to_ms(graph_path=graph_path,
                                         input_nodes=input_nodes,
                                         output_nodes=file_config['output_nodes'],
                                         output_folder=file_config['outfile_dir'],
                                         report_folder=file_config['report_dir'],
                                         query_result_folder=file_config.get("query_result_folder"))

    elif frame_type == FrameworkType.TENSORFLOW.value:
        graph_based_converter_tf_to_ms(graph_path=graph_path,
                                       input_nodes=input_nodes,
                                       output_nodes=file_config['output_nodes'],
                                       output_folder=file_config['outfile_dir'],
                                       report_folder=file_config['report_dir'],
                                       query_result_folder=file_config.get("query_result_folder"))

    else:
        error_msg = "Get UNSUPPORTED model."
        error = UnknownModelError(error_msg)
        raise error
Ejemplo n.º 26
0
    def __init__(self,
                 name,
                 package,
                 options=None,
                 serialized_options=None,
                 serialized_pb=None,
                 dependencies=None,
                 public_dependencies=None,
                 syntax=None,
                 pool=None):
        """Constructor."""
        super(FileDescriptor, self).__init__(options, serialized_options,
                                             'FileOptions')

        if pool is None:
            if api_implementation.Type() == 'cpp':
                pool = _message.default_pool
            else:
                from google.protobuf import descriptor_pool
                pool = descriptor_pool.Default()
        self.pool = pool
        self.message_types_by_name = {}
        self.name = name
        self.package = package
        self.syntax = syntax or "proto2"
        self.serialized_pb = serialized_pb

        self.enum_types_by_name = {}
        self.extensions_by_name = {}
        self.services_by_name = {}
        self.dependencies = (dependencies or [])
        self.public_dependencies = (public_dependencies or [])

        if self.serialized_pb is not None and api_implementation.Type(
        ) == 'cpp':
            self.pool.AddSerializedFile(self.serialized_pb)
Ejemplo n.º 27
0
    def encode(self, instance):
        """Encode a tf.transform encoded dict as serialized tf.Example."""
        if (self._encode_example_cache is None
                or api_implementation.Type() == 'python'):
            # Initialize the encode Example cache (used by this and all subsequent
            # calls).
            example = tf.train.Example()
            for feature_handler in self._feature_handlers:
                feature_handler.initialize_encode_cache(example)
            self._encode_example_cache = example

        # Encode and serialize using the Example cache.
        for feature_handler in self._feature_handlers:
            value = instance[feature_handler.name]
            feature_handler.encode_value(value)
        return self._encode_example_cache.SerializeToString()
 def testFindExtensionByName(self):
   if isinstance(self, SecondaryDescriptorFromDescriptorDB):
     if api_implementation.Type() == 'cpp':
       # TODO(jieluo): Fix cpp extension to find extension correctly
       # when descriptor pool is using an underlying database.
       return
   # An extension defined in a message.
   extension = self.pool.FindExtensionByName(
       'google.protobuf.python.internal.Factory2Message.one_more_field')
   self.assertEqual(extension.name, 'one_more_field')
   # An extension defined at file scope.
   extension = self.pool.FindExtensionByName(
       'google.protobuf.python.internal.another_field')
   self.assertEqual(extension.name, 'another_field')
   self.assertEqual(extension.number, 1002)
   with self.assertRaises(KeyError):
     self.pool.FindFieldByName('Does not exist')
Ejemplo n.º 29
0
  def testMakeDescriptorWithNestedFields(self):
    file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
    file_descriptor_proto.name = 'Foo2'
    message_type = file_descriptor_proto.message_type.add()
    message_type.name = file_descriptor_proto.name
    nested_type = message_type.nested_type.add()
    nested_type.name = 'Sub'
    enum_type = nested_type.enum_type.add()
    enum_type.name = 'FOO'
    enum_type_val = enum_type.value.add()
    enum_type_val.name = 'BAR'
    enum_type_val.number = 3
    field = message_type.field.add()
    field.number = 1
    field.name = 'uint64_field'
    field.label = descriptor.FieldDescriptor.LABEL_REQUIRED
    field.type = descriptor.FieldDescriptor.TYPE_UINT64
    field = message_type.field.add()
    field.number = 2
    field.name = 'nested_message_field'
    field.label = descriptor.FieldDescriptor.LABEL_REQUIRED
    field.type = descriptor.FieldDescriptor.TYPE_MESSAGE
    field.type_name = 'Sub'
    enum_field = nested_type.field.add()
    enum_field.number = 2
    enum_field.name = 'bar_field'
    enum_field.label = descriptor.FieldDescriptor.LABEL_REQUIRED
    enum_field.type = descriptor.FieldDescriptor.TYPE_ENUM
    enum_field.type_name = 'Foo2.Sub.FOO'

    result = descriptor.MakeDescriptor(message_type)
    self.assertEqual(result.fields[0].cpp_type,
                     descriptor.FieldDescriptor.CPPTYPE_UINT64)
    self.assertEqual(result.fields[1].cpp_type,
                     descriptor.FieldDescriptor.CPPTYPE_MESSAGE)
    self.assertEqual(result.fields[1].message_type.containing_type,
                     result)
    self.assertEqual(result.nested_types[0].fields[0].full_name,
                     'Foo2.Sub.bar_field')
    self.assertEqual(result.nested_types[0].fields[0].enum_type,
                     result.nested_types[0].enum_types[0])
    self.assertFalse(result.has_options)
    self.assertFalse(result.fields[0].has_options)
    if api_implementation.Type() == 'cpp':
      with self.assertRaises(AttributeError):
        result.fields[0].has_options = False
Ejemplo n.º 30
0
 def testGetMessages(self):
     # performed twice because multiple calls with the same input must be allowed
     for _ in range(2):
         messages = message_factory.GetMessages(
             [self.factory_test1_fd, self.factory_test2_fd])
         self.assertTrue(
             set([
                 'google.protobuf.python.internal.Factory2Message',
                 'google.protobuf.python.internal.Factory1Message'
             ], ).issubset(set(messages.keys())))
         self._ExerciseDynamicClass(
             messages['google.protobuf.python.internal.Factory2Message'])
         factory_msg1 = messages[
             'google.protobuf.python.internal.Factory1Message']
         self.assertTrue(
             set([
                 'google.protobuf.python.internal.Factory2Message.one_more_field',
                 'google.protobuf.python.internal.another_field'
             ], ).issubset(
                 set(ext.full_name for ext in factory_msg1.DESCRIPTOR.file.
                     pool.FindAllExtensions(factory_msg1.DESCRIPTOR))))
         msg1 = messages['google.protobuf.python.internal.Factory1Message'](
         )
         ext1 = msg1.Extensions._FindExtensionByName(
             'google.protobuf.python.internal.Factory2Message.one_more_field'
         )
         ext2 = msg1.Extensions._FindExtensionByName(
             'google.protobuf.python.internal.another_field')
         msg1.Extensions[ext1] = 'test1'
         msg1.Extensions[ext2] = 'test2'
         self.assertEqual('test1', msg1.Extensions[ext1])
         self.assertEqual('test2', msg1.Extensions[ext2])
         self.assertEqual(None,
                          msg1.Extensions._FindExtensionByNumber(12321))
         if api_implementation.Type() == 'cpp':
             # TODO(jieluo): Fix len to return the correct value.
             # self.assertEqual(2, len(msg1.Extensions))
             self.assertEqual(len(msg1.Extensions), len(msg1.Extensions))
             self.assertRaises(TypeError,
                               msg1.Extensions._FindExtensionByName, 0)
             self.assertRaises(TypeError,
                               msg1.Extensions._FindExtensionByNumber, '')
         else:
             self.assertEqual(None, msg1.Extensions._FindExtensionByName(0))
             self.assertEqual(None,
                              msg1.Extensions._FindExtensionByNumber(''))