Ejemplo n.º 1
0
def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
    """Returns a decoder for a group field."""

    end_tag_bytes = encoder.TagBytes(field_number,
                                     wire_format.WIRETYPE_END_GROUP)
    end_tag_len = len(end_tag_bytes)

    if is_packed:
        raise AssertionError
    if is_repeated:
        tag_bytes = encoder.TagBytes(field_number,
                                     wire_format.WIRETYPE_START_GROUP)
        tag_len = len(tag_bytes)

        def DecodeRepeatedField(buffer, pos, end, message, field_dict):
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            while 1:
                value = field_dict.get(key)
                if value is None:
                    value = field_dict.setdefault(key, new_default(message))
                # Read sub-message.
                pos = value.add()._InternalParse(buffer, pos, end)
                # Read end tag.
                new_pos = pos + end_tag_len
                if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
                    raise _DecodeError('Missing group end tag.')
                # Predict that the next tag is another copy of the same repeated field.
                pos = new_pos + tag_len
                if buffer[new_pos:pos] != tag_bytes or new_pos == end:
                    # Prediction failed.  Return.
                    return new_pos

        return DecodeRepeatedField
    else:

        def DecodeField(buffer, pos, end, message, field_dict):
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            # Read sub-message.
            pos = value._InternalParse(buffer, pos, end)
            # Read end tag.
            new_pos = pos + end_tag_len
            if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
                raise _DecodeError('Missing group end tag.')
            return new_pos

        return DecodeField
Ejemplo n.º 2
0
def MessageSetItemDecoder(extensions_by_number):
    type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
    message_tag_bytes = encoder.TagBytes(3,
                                         wire_format.WIRETYPE_LENGTH_DELIMITED)
    item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
    local_ReadTag = ReadTag
    local_DecodeVarint = _DecodeVarint
    local_SkipField = SkipField

    def DecodeItem(buffer, pos, end, message, field_dict):
        message_set_item_start = pos
        type_id = -1
        message_start = -1
        message_end = -1
        while True:
            (tag_bytes, pos) = local_ReadTag(buffer, pos)
            if tag_bytes == type_id_tag_bytes:
                (type_id, pos) = local_DecodeVarint(buffer, pos)
            elif tag_bytes == message_tag_bytes:
                (size, message_start) = local_DecodeVarint(buffer, pos)
                pos = message_end = message_start + size
            elif tag_bytes == item_end_tag_bytes:
                break
            else:
                pos = SkipField(buffer, pos, end, tag_bytes)
                if pos == -1:
                    raise _DecodeError('Missing group end tag.')
        if pos > end:
            raise _DecodeError('Truncated message.')
        if type_id == -1:
            raise _DecodeError('MessageSet item missing type_id.')
        if message_start == -1:
            raise _DecodeError('MessageSet item missing message.')
        extension = extensions_by_number.get(type_id)
        if extension is not None:
            value = field_dict.get(extension)
            if value is None:
                value = field_dict.setdefault(
                    extension, extension.message_type._concrete_class())
            raise _DecodeError('Unexpected end-group tag.')
        else:
            if not message._unknown_fields:
                message._unknown_fields = []
            message._unknown_fields.append(
                (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos]))
        return pos

    return DecodeItem
Ejemplo n.º 3
0
    def SpecificDecoder(field_number,
                        is_repeated,
                        is_packed,
                        key,
                        new_default,
                        clear_if_default=False):
        if is_packed:
            local_DecodeVarint = _DecodeVarint

            def DecodePackedField(buffer, pos, end, message, field_dict):
                value = field_dict.get(key)
                if value is None:
                    value = field_dict.setdefault(key, new_default(message))
                (endpoint, pos) = local_DecodeVarint(buffer, pos)
                endpoint += pos
                if endpoint > end:
                    raise _DecodeError('Truncated message.')
                while pos < endpoint:
                    (element, pos) = decode_value(buffer, pos)
                    value.append(element)
                if pos > endpoint:
                    del value[-1]  # Discard corrupt value.
                    raise _DecodeError('Packed element was truncated.')
                return pos

            return DecodePackedField
        elif is_repeated:
            tag_bytes = encoder.TagBytes(field_number, wire_type)
            tag_len = len(tag_bytes)

            def DecodeRepeatedField(buffer, pos, end, message, field_dict):
                value = field_dict.get(key)
                if value is None:
                    value = field_dict.setdefault(key, new_default(message))
                while 1:
                    (element, new_pos) = decode_value(buffer, pos)
                    value.append(element)
                    # Predict that the next tag is another copy of the same repeated
                    # field.
                    pos = new_pos + tag_len
                    if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
                        # Prediction failed.  Return.
                        if new_pos > end:
                            raise _DecodeError('Truncated message.')
                        return new_pos

            return DecodeRepeatedField
        else:

            def DecodeField(buffer, pos, end, message, field_dict):
                (new_value, pos) = decode_value(buffer, pos)
                if pos > end:
                    raise _DecodeError('Truncated message.')
                if clear_if_default and not new_value:
                    field_dict.pop(key, None)
                else:
                    field_dict[key] = new_value
                return pos

            return DecodeField
Ejemplo n.º 4
0
        def DecodeField(buffer, pos, end, message, field_dict):
            """Decode serialized repeated enum to its value and a new position.

      Args:
        buffer: memoryview of the serialized bytes.
        pos: int, position in the memory view to start at.
        end: int, end position of serialized data
        message: Message object to store unknown fields in
        field_dict: Map[Descriptor, Any] to store decoded values in.

      Returns:
        int, new position in serialized data.
      """
            value_start_pos = pos
            (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
            if pos > end:
                raise _DecodeError('Truncated message.')
            # pylint: disable=protected-access
            if enum_value in enum_type.values_by_number:
                field_dict[key] = enum_value
            else:
                if not message._unknown_fields:
                    message._unknown_fields = []
                tag_bytes = encoder.TagBytes(field_number,
                                             wire_format.WIRETYPE_VARINT)
                message._unknown_fields.append(
                    (tag_bytes, buffer[value_start_pos:pos].tobytes()))
                # pylint: enable=protected-access
            return pos
Ejemplo n.º 5
0
 def DecodePackedField(buffer, pos, end, message, field_dict):
     value = field_dict.get(key)
     if value is None:
         value = field_dict.setdefault(key, new_default(message))
     (endpoint, pos) = local_DecodeVarint(buffer, pos)
     endpoint += pos
     if endpoint > end:
         raise _DecodeError('Truncated message.')
     while pos < endpoint:
         value_start_pos = pos
         (element, pos) = _DecodeSignedVarint32(buffer, pos)
         if element in enum_type.values_by_number:
             value.append(element)
         else:
             if not message._unknown_fields:
                 message._unknown_fields = []
             tag_bytes = encoder.TagBytes(field_number,
                                          wire_format.WIRETYPE_VARINT)
             message._unknown_fields.append(
                 (tag_bytes, buffer[value_start_pos:pos]))
     if pos > endpoint:
         if element in enum_type.values_by_number:
             del value[-1]  # Discard corrupt value.
         else:
             del message._unknown_fields[-1]
         raise _DecodeError('Packed element was truncated.')
     return pos
Ejemplo n.º 6
0
def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
  """Returns a decoder for a bytes field."""

  local_DecodeVarint = _DecodeVarint

  assert not is_packed
  if is_repeated:
    tag_bytes = encoder.TagBytes(field_number,
                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
    tag_len = len(tag_bytes)
    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
      value = field_dict.get(key)
      if value is None:
        value = field_dict.setdefault(key, new_default(message))
      while 1:
        (size, pos) = local_DecodeVarint(buffer, pos)
        new_pos = pos + size
        if new_pos > end:
          raise _DecodeError('Truncated string.')
        value.append(buffer[pos:new_pos].tobytes())
        # Predict that the next tag is another copy of the same repeated field.
        pos = new_pos + tag_len
        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
          # Prediction failed.  Return.
          return new_pos
    return DecodeRepeatedField
  else:
    def DecodeField(buffer, pos, end, message, field_dict):
      (size, pos) = local_DecodeVarint(buffer, pos)
      new_pos = pos + size
      if new_pos > end:
        raise _DecodeError('Truncated string.')
      field_dict[key] = buffer[pos:new_pos].tobytes()
      return new_pos
    return DecodeField
Ejemplo n.º 7
0
def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
    local_DecodeVarint = _DecodeVarint
    if is_repeated:
        tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
        tag_len = len(tag_bytes)

        def DecodeRepeatedField(buffer, pos, end, message, field_dict):
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            while True:
                (size, pos) = local_DecodeVarint(buffer, pos)
                new_pos = pos + size
                if new_pos > end:
                    raise _DecodeError('Truncated string.')
                value.append(buffer[pos:new_pos])
                pos = new_pos + tag_len
                if not buffer[new_pos:pos] != tag_bytes:
                    if new_pos == end:
                        return new_pos
                return new_pos

        return DecodeRepeatedField
    else:

        def DecodeField(buffer, pos, end, message, field_dict):
            (size, pos) = local_DecodeVarint(buffer, pos)
            new_pos = pos + size
            if new_pos > end:
                raise _DecodeError('Truncated string.')
            field_dict[key] = buffer[pos:new_pos]
            return new_pos

        return DecodeField
Ejemplo n.º 8
0
def StringDecoder(field_number,
                  is_repeated,
                  is_packed,
                  key,
                  new_default,
                  clear_if_default=False):
    """Returns a decoder for a string field."""

    local_DecodeVarint = _DecodeVarint

    def _ConvertToUnicode(memview):
        """Convert byte to unicode."""
        byte_str = memview.tobytes()
        try:
            value = str(byte_str, 'utf-8')
        except UnicodeDecodeError as e:
            # add more information to the error message and re-raise it.
            e.reason = '%s in field: %s' % (e, key.full_name)
            raise

        return value

    assert not is_packed
    if is_repeated:
        tag_bytes = encoder.TagBytes(field_number,
                                     wire_format.WIRETYPE_LENGTH_DELIMITED)
        tag_len = len(tag_bytes)

        def DecodeRepeatedField(buffer, pos, end, message, field_dict):
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            while 1:
                (size, pos) = local_DecodeVarint(buffer, pos)
                new_pos = pos + size
                if new_pos > end:
                    raise _DecodeError('Truncated string.')
                value.append(_ConvertToUnicode(buffer[pos:new_pos]))
                # Predict that the next tag is another copy of the same repeated field.
                pos = new_pos + tag_len
                if buffer[new_pos:pos] != tag_bytes or new_pos == end:
                    # Prediction failed.  Return.
                    return new_pos

        return DecodeRepeatedField
    else:

        def DecodeField(buffer, pos, end, message, field_dict):
            (size, pos) = local_DecodeVarint(buffer, pos)
            new_pos = pos + size
            if new_pos > end:
                raise _DecodeError('Truncated string.')
            if clear_if_default and not size:
                field_dict.pop(key, None)
            else:
                field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
            return new_pos

        return DecodeField
Ejemplo n.º 9
0
 def AddDecoder(wiretype, is_packed):
   tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
   cls._decoders_by_tag[tag_bytes] = (
       type_checkers.TYPE_TO_DECODER[field_descriptor.type](
           field_descriptor.number, is_repeated, is_packed,
           field_descriptor, field_descriptor._default_constructor),
       field_descriptor if field_descriptor.containing_oneof is not None
       else None)
Ejemplo n.º 10
0
def encode_group(value, typedef, field_number):
    """Encode a protobuf group type"""
    # Message will take care of the start tag
    # Need to add the end_tag
    output = encode_message(value, typedef, group=True)
    end_tag = encoder.TagBytes(int(field_number), wire_format.WIRETYPE_END_GROUP)
    output.append(end_tag)
    return output
Ejemplo n.º 11
0
def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
                  is_strict_utf8=False):
  """Returns a decoder for a string field."""

  local_DecodeVarint = _DecodeVarint
  local_unicode = six.text_type

  def _ConvertToUnicode(memview):
    """Convert byte to unicode."""
    byte_str = memview.tobytes()
    try:
      value = local_unicode(byte_str, 'utf-8')
    except UnicodeDecodeError as e:
      # add more information to the error message and re-raise it.
      e.reason = '%s in field: %s' % (e, key.full_name)
      raise

    if is_strict_utf8 and six.PY2 and sys.maxunicode > _UCS2_MAXUNICODE:
      # Only do the check for python2 ucs4 when is_strict_utf8 enabled
      if _SURROGATE_PATTERN.search(value):
        reason = ('String field %s contains invalid UTF-8 data when parsing'
                  'a protocol buffer: surrogates not allowed. Use'
                  'the bytes type if you intend to send raw bytes.') % (
                      key.full_name)
        raise message.DecodeError(reason)

    return value

  assert not is_packed
  if is_repeated:
    tag_bytes = encoder.TagBytes(field_number,
                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
    tag_len = len(tag_bytes)
    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
      value = field_dict.get(key)
      if value is None:
        value = field_dict.setdefault(key, new_default(message))
      while 1:
        (size, pos) = local_DecodeVarint(buffer, pos)
        new_pos = pos + size
        if new_pos > end:
          raise _DecodeError('Truncated string.')
        value.append(_ConvertToUnicode(buffer[pos:new_pos]))
        # Predict that the next tag is another copy of the same repeated field.
        pos = new_pos + tag_len
        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
          # Prediction failed.  Return.
          return new_pos
    return DecodeRepeatedField
  else:
    def DecodeField(buffer, pos, end, message, field_dict):
      (size, pos) = local_DecodeVarint(buffer, pos)
      new_pos = pos + size
      if new_pos > end:
        raise _DecodeError('Truncated string.')
      field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
      return new_pos
    return DecodeField
Ejemplo n.º 12
0
def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
    """Returns a decoder for a message field."""

    local_DecodeVarint = _DecodeVarint

    assert not is_packed
    if is_repeated:
        tag_bytes = encoder.TagBytes(field_number,
                                     wire_format.WIRETYPE_LENGTH_DELIMITED)
        tag_len = len(tag_bytes)

        def DecodeRepeatedField(buffer, pos, end, message, field_dict):
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            while 1:
                value = field_dict.get(key)
                if value is None:
                    value = field_dict.setdefault(key, new_default(message))
                # Read length.
                (size, pos) = local_DecodeVarint(buffer, pos)
                new_pos = pos + size
                if new_pos > end:
                    raise _DecodeError('Truncated message.')
                # Read sub-message.
                if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
                    # The only reason _InternalParse would return early is if it
                    # encountered an end-group tag.
                    raise _DecodeError('Unexpected end-group tag.')
                # Predict that the next tag is another copy of the same repeated field.
                pos = new_pos + tag_len
                if buffer[new_pos:pos] != tag_bytes or new_pos == end:
                    # Prediction failed.  Return.
                    return new_pos

        return DecodeRepeatedField
    else:

        def DecodeField(buffer, pos, end, message, field_dict):
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            # Read length.
            (size, pos) = local_DecodeVarint(buffer, pos)
            new_pos = pos + size
            if new_pos > end:
                raise _DecodeError('Truncated message.(newpos:%d, end:%d)' %
                                   (new_pos, end))
            # Read sub-message.
            if value._InternalParse(buffer, pos, new_pos) != new_pos:
                # The only reason _InternalParse would return early is if it encountered
                # an end-group tag.
                raise _DecodeError('Unexpected end-group tag.')
            return new_pos

        return DecodeField
Ejemplo n.º 13
0
 def CheckUnknownField(self, name, expected_value):
   field_descriptor = self.descriptor.fields_by_name[name]
   wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
   field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
   result_dict = {}
   for tag_bytes, value in self.empty_message._unknown_fields:
     if tag_bytes == field_tag:
       decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
       decoder(value, 0, len(value), self.all_fields, result_dict)
   self.assertEqual(expected_value, result_dict[field_descriptor])
Ejemplo n.º 14
0
 def GetField(self, name):
   field_descriptor = self.descriptor.fields_by_name[name]
   wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
   field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
   result_dict = {}
   for tag_bytes, value in self.unknown_fields:
     if tag_bytes == field_tag:
       decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
       decoder(value, 0, len(value), self.all_fields, result_dict)
   return result_dict[field_descriptor]
Ejemplo n.º 15
0
 def GetUnknownField(self, name):
     field_descriptor = self.descriptor.fields_by_name[name]
     wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[
         field_descriptor.type]
     field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
     result_dict = {}
     for tag_bytes, value in self.missing_message._unknown_fields:
         if tag_bytes == field_tag:
             decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[
                 tag_bytes][0]
             decoder(value, 0, len(value), self.message, result_dict)
     return result_dict[field_descriptor]
Ejemplo n.º 16
0
    def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default):
        if is_packed:
            local_DecodeVarint = _DecodeVarint

            def DecodePackedField(buffer, pos, end, message, field_dict):
                value = field_dict.get(key)
                if value is None:
                    value = field_dict.setdefault(key, new_default(message))
                (endpoint, pos) = local_DecodeVarint(buffer, pos)
                endpoint += pos
                if endpoint > end:
                    raise _DecodeError('Truncated message.')
                while pos < endpoint:
                    (element, pos) = decode_value(buffer, pos)
                    value.append(element)
                if pos > endpoint:
                    del value[-1]
                    raise _DecodeError('Packed element was truncated.')
                return pos

            return DecodePackedField
        if is_repeated:
            tag_bytes = encoder.TagBytes(field_number, wire_type)
            tag_len = len(tag_bytes)

            def DecodeRepeatedField(buffer, pos, end, message, field_dict):
                value = field_dict.get(key)
                if value is None:
                    value = field_dict.setdefault(key, new_default(message))
                while True:
                    (element, new_pos) = decode_value(buffer, pos)
                    value.append(element)
                    pos = new_pos + tag_len
                    if not buffer[new_pos:pos] != tag_bytes:
                        if new_pos >= end:
                            if new_pos > end:
                                raise _DecodeError('Truncated message.')
                            return new_pos
                    if new_pos > end:
                        raise _DecodeError('Truncated message.')
                    return new_pos

            return DecodeRepeatedField
        else:

            def DecodeField(buffer, pos, end, message, field_dict):
                (field_dict[key], pos) = decode_value(buffer, pos)
                if pos > end:
                    del field_dict[key]
                    raise _DecodeError('Truncated message.')
                return pos

            return DecodeField
Ejemplo n.º 17
0
def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
    end_tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_END_GROUP)
    end_tag_len = len(end_tag_bytes)
    if is_repeated:
        tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_START_GROUP)
        tag_len = len(tag_bytes)

        def DecodeRepeatedField(buffer, pos, end, message, field_dict):
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            while True:
                value = field_dict.get(key)
                if value is None:
                    value = field_dict.setdefault(key, new_default(message))
                pos = value.add()._InternalParse(buffer, pos, end)
                new_pos = pos + end_tag_len
                if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
                    raise _DecodeError('Missing group end tag.')
                pos = new_pos + tag_len
                if not buffer[new_pos:pos] != tag_bytes:
                    if new_pos == end:
                        return new_pos
                return new_pos

        return DecodeRepeatedField
    else:

        def DecodeField(buffer, pos, end, message, field_dict):
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            pos = value._InternalParse(buffer, pos, end)
            new_pos = pos + end_tag_len
            if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
                raise _DecodeError('Missing group end tag.')
            return new_pos

        return DecodeField
Ejemplo n.º 18
0
def UnknownMessageSetItemDecoder():
    """Returns a decoder for a Unknown MessageSet item."""

    type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
    message_tag_bytes = encoder.TagBytes(3,
                                         wire_format.WIRETYPE_LENGTH_DELIMITED)
    item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)

    def DecodeUnknownItem(buffer):
        pos = 0
        end = len(buffer)
        message_start = -1
        message_end = -1
        while 1:
            (tag_bytes, pos) = ReadTag(buffer, pos)
            if tag_bytes == type_id_tag_bytes:
                (type_id, pos) = _DecodeVarint(buffer, pos)
            elif tag_bytes == message_tag_bytes:
                (size, message_start) = _DecodeVarint(buffer, pos)
                pos = message_end = message_start + size
            elif tag_bytes == item_end_tag_bytes:
                break
            else:
                pos = SkipField(buffer, pos, end, tag_bytes)
                if pos == -1:
                    raise _DecodeError('Missing group end tag.')

        if pos > end:
            raise _DecodeError('Truncated message.')

        if type_id == -1:
            raise _DecodeError('MessageSet item missing type_id.')
        if message_start == -1:
            raise _DecodeError('MessageSet item missing message.')

        return (type_id, buffer[message_start:message_end].tobytes())

    return DecodeUnknownItem
Ejemplo n.º 19
0
        def DecodePackedField(buffer, pos, end, message, field_dict):
            """Decode serialized packed enum to its value and a new position.

      Args:
        buffer: memoryview of the serialized bytes.
        pos: int, position in the memory view to start at.
        end: int, end position of serialized data
        message: Message object to store unknown fields in
        field_dict: Map[Descriptor, Any] to store decoded values in.

      Returns:
        int, new position in serialized data.
      """
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            (endpoint, pos) = local_DecodeVarint(buffer, pos)
            endpoint += pos
            if endpoint > end:
                raise _DecodeError('Truncated message.')
            while pos < endpoint:
                value_start_pos = pos
                (element, pos) = _DecodeSignedVarint32(buffer, pos)
                # pylint: disable=protected-access
                if element in enum_type.values_by_number:
                    value.append(element)
                else:
                    if not message._unknown_fields:
                        message._unknown_fields = []
                    tag_bytes = encoder.TagBytes(field_number,
                                                 wire_format.WIRETYPE_VARINT)

                    message._unknown_fields.append(
                        (tag_bytes, buffer[value_start_pos:pos].tobytes()))
                    if message._unknown_field_set is None:
                        message._unknown_field_set = containers.UnknownFieldSet(
                        )
                    message._unknown_field_set._add(
                        field_number, wire_format.WIRETYPE_VARINT, element)
                    # pylint: enable=protected-access
            if pos > endpoint:
                if element in enum_type.values_by_number:
                    del value[-1]  # Discard corrupt value.
                else:
                    del message._unknown_fields[-1]
                    # pylint: disable=protected-access
                    del message._unknown_field_set._values[-1]
                    # pylint: enable=protected-access
                raise _DecodeError('Packed element was truncated.')
            return pos
Ejemplo n.º 20
0
 def DecodeField(buffer, pos, end, message, field_dict):
     value_start_pos = pos
     (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
     if pos > end:
         raise _DecodeError('Truncated message.')
     if enum_value in enum_type.values_by_number:
         field_dict[key] = enum_value
     else:
         if not message._unknown_fields:
             message._unknown_fields = []
         tag_bytes = encoder.TagBytes(field_number,
                                      wire_format.WIRETYPE_VARINT)
         message._unknown_fields.append(
             (tag_bytes, buffer[value_start_pos:pos]))
     return pos
Ejemplo n.º 21
0
def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
    local_DecodeVarint = _DecodeVarint
    if is_repeated:
        tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
        tag_len = len(tag_bytes)

        def DecodeRepeatedField(buffer, pos, end, message, field_dict):
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            while True:
                value = field_dict.get(key)
                if value is None:
                    value = field_dict.setdefault(key, new_default(message))
                (size, pos) = local_DecodeVarint(buffer, pos)
                new_pos = pos + size
                if new_pos > end:
                    raise _DecodeError('Truncated message.')
                if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
                    raise _DecodeError('Unexpected end-group tag.')
                pos = new_pos + tag_len
                if not buffer[new_pos:pos] != tag_bytes:
                    if new_pos == end:
                        return new_pos
                return new_pos

        return DecodeRepeatedField
    else:

        def DecodeField(buffer, pos, end, message, field_dict):
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            (size, pos) = local_DecodeVarint(buffer, pos)
            new_pos = pos + size
            if new_pos > end:
                raise _DecodeError('Truncated message.')
            if value._InternalParse(buffer, pos, new_pos) != new_pos:
                raise _DecodeError('Unexpected end-group tag.')
            return new_pos

        return DecodeField
Ejemplo n.º 22
0
def MapDecoder(field_descriptor, new_default, is_message_map):
    """Returns a decoder for a map field."""

    key = field_descriptor
    tag_bytes = encoder.TagBytes(field_descriptor.number,
                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
    tag_len = len(tag_bytes)
    local_DecodeVarint = _DecodeVarint
    # Can't read _concrete_class yet; might not be initialized.
    message_type = field_descriptor.message_type

    def DecodeMap(buffer, pos, end, message, field_dict):
        submsg = message_type._concrete_class()
        value = field_dict.get(key)
        if value is None:
            value = field_dict.setdefault(key, new_default(message))
        while 1:
            # Read length.
            (size, pos) = local_DecodeVarint(buffer, pos)
            new_pos = pos + size
            if new_pos > end:
                raise _DecodeError('Truncated message.')
            # Read sub-message.
            submsg.Clear()
            if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
                # The only reason _InternalParse would return early is if it
                # encountered an end-group tag.
                raise _DecodeError('Unexpected end-group tag.')

            if is_message_map:
                value[submsg.key].MergeFrom(submsg.value)
            else:
                value[submsg.key] = submsg.value

            # Predict that the next tag is another copy of the same repeated field.
            pos = new_pos + tag_len
            if buffer[new_pos:pos] != tag_bytes or new_pos == end:
                # Prediction failed.  Return.
                return new_pos

    return DecodeMap
 def AddDecoder(wiretype, is_packed):
     tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
     cls._decoders_by_tag[tag_bytes] = type_checkers.TYPE_TO_DECODER[
         field_descriptor.type](field_descriptor.number, is_repeated,
                                is_packed, field_descriptor,
                                field_descriptor._default_constructor)
Ejemplo n.º 24
0
        def DecodeField(buffer, pos, end, message, field_dict):
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            (size, pos) = local_DecodeVarint(buffer, pos)
            new_pos = pos + size
            if new_pos > end:
                raise _DecodeError('Truncated message.')
            if value._InternalParse(buffer, pos, new_pos) != new_pos:
                raise _DecodeError('Unexpected end-group tag.')
            return new_pos

        return DecodeField

MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)

def MessageSetItemDecoder(extensions_by_number):
    type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
    message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
    item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
    local_ReadTag = ReadTag
    local_DecodeVarint = _DecodeVarint
    local_SkipField = SkipField

    def DecodeItem(buffer, pos, end, message, field_dict):
        message_set_item_start = pos
        type_id = -1
        message_start = -1
        message_end = -1
        while True:
Ejemplo n.º 25
0
def MessageSetItemDecoder(extensions_by_number):
    """Returns a decoder for a MessageSet item.

  The parameter is the _extensions_by_number map for the message class.

  The message set message looks like this:
    message MessageSet {
      repeated group Item = 1 {
        required int32 type_id = 2;
        required string message = 3;
      }
    }
  """

    type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
    message_tag_bytes = encoder.TagBytes(3,
                                         wire_format.WIRETYPE_LENGTH_DELIMITED)
    item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)

    local_ReadTag = ReadTag
    local_DecodeVarint = _DecodeVarint
    local_SkipField = SkipField

    def DecodeItem(buffer, pos, end, message, field_dict):
        message_set_item_start = pos
        type_id = -1
        message_start = -1
        message_end = -1

        # Technically, type_id and message can appear in any order, so we need
        # a little loop here.
        while 1:
            (tag_bytes, pos) = local_ReadTag(buffer, pos)
            if tag_bytes == type_id_tag_bytes:
                (type_id, pos) = local_DecodeVarint(buffer, pos)
            elif tag_bytes == message_tag_bytes:
                (size, message_start) = local_DecodeVarint(buffer, pos)
                pos = message_end = message_start + size
            elif tag_bytes == item_end_tag_bytes:
                break
            else:
                pos = SkipField(buffer, pos, end, tag_bytes)
                if pos == -1:
                    raise _DecodeError('Missing group end tag.')

        if pos > end:
            raise _DecodeError('Truncated message.')

        if type_id == -1:
            raise _DecodeError('MessageSet item missing type_id.')
        if message_start == -1:
            raise _DecodeError('MessageSet item missing message.')

        extension = extensions_by_number.get(type_id)
        if extension is not None:
            value = field_dict.get(extension)
            if value is None:
                value = field_dict.setdefault(
                    extension, extension.message_type._concrete_class())
            if value._InternalParse(buffer, message_start,
                                    message_end) != message_end:
                # The only reason _InternalParse would return early is if it encountered
                # an end-group tag.
                raise _DecodeError('Unexpected end-group tag.')
        else:
            if not message._unknown_fields:
                message._unknown_fields = []
            message._unknown_fields.append(
                (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos]))

        return pos

    return DecodeItem
Ejemplo n.º 26
0
def encode_message(data, typedef, group=False):
    """Encode a Python dictionary representing a protobuf message
       data - Python dictionary mapping field numbers to values
       typedef - Type information including field number, field name and field type
       This will throw an exception if an unkown value is used as a key
    """
    output = bytearray()

    for field_number, value in data.items():
        # Get the field number convert it as necessary
        alt_field_number = None
        # if isinstance(field_number, (unicode, str)):
        if isinstance(field_number, str):
            if '-' in field_number:
                field_number, alt_field_number = field_number.split('-')
            for number, info in typedef.items():
                if info['name'] == field_number and field_number != '':
                    field_number = number
                    break
        else:
            field_number = str(field_number)

        if field_number not in typedef:
            raise ValueError('Provided field name/number %s is not valid' %
                             (field_number))

        field_typedef = typedef[field_number]

        # Get encoder
        if 'type' not in field_typedef:
            raise ValueError('Field %s does not have a defined type' %
                             field_number)

        field_type = field_typedef['type']

        field_encoder = None
        if field_type == 'message':
            innertypedef = None
            # Check for a defined message type
            if alt_field_number is not None:
                if alt_field_number not in field_typedef['alt_typedefs']:
                    raise ValueError(
                        'Provided alt field name/number %s is not valid for field_number %s'
                        % (alt_field_number, field_number))
                innertypedef = field_typedef['alt_typedefs'][alt_field_number]
            elif 'message_typedef' in field_typedef:
                # "Anonymous" inner message
                # Required to have a 'message_typedef'
                if 'message_typedef' not in field_typedef:
                    raise ValueError(
                        'Could not find type definition for message field: %s'
                        % field_number)
                innertypedef = field_typedef['message_typedef']
            else:
                # if field_typedef['message_type_name'] not in known_messages:
                #     raise ValueError('Message type (%s) has not been defined'
                #                      % field_typedef['message_type_name'])
                # innertypedef = field_typedef['message_type_name']

                raise ValueError('Message type (%s) has not been defined' %
                                 field_typedef['message_type_name'])

            field_encoder = lambda data: encode_lendelim_message(
                data, innertypedef)
        elif field_type == 'group':
            innertypedef = None
            # Check for a defined group type
            if 'group_typedef' not in field_typedef:
                raise ValueError(
                    'Could not find type definition for group field: %s' %
                    field_number)
            innertypedef = field_typedef['group_typedef']

            field_encoder = lambda data: encode_group(data, innertypedef,
                                                      field_number)
        else:
            if field_type not in types.encoders:
                raise ValueError('Unknown type: %s' % field_type)
            field_encoder = types.encoders[field_type]
            if field_encoder is None:
                raise ValueError('Encoder not implemented: %s' % field_type)

        # Encode the tag
        tag = encoder.TagBytes(int(field_number), types.wiretypes[field_type])

        try:
            # Handle repeated values
            if isinstance(value,
                          list) and not field_type.startswith('packed_'):
                for repeated in value:
                    output += tag
                    output += field_encoder(repeated)
            else:
                output += tag
                output += field_encoder(value)
        except Exception as exc:
            raise (ValueError, 'Error attempting to encode "%s" as %s: %s' %
                   (value, field_type, exc), sys.exc_info()[2])

    return output
Ejemplo n.º 27
0
def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
    enum_type = key.enum_type
    if is_packed:
        local_DecodeVarint = _DecodeVarint

        def DecodePackedField(buffer, pos, end, message, field_dict):
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            (endpoint, pos) = local_DecodeVarint(buffer, pos)
            endpoint += pos
            if endpoint > end:
                raise _DecodeError('Truncated message.')
            while pos < endpoint:
                value_start_pos = pos
                (element, pos) = _DecodeSignedVarint32(buffer, pos)
                if element in enum_type.values_by_number:
                    value.append(element)
                else:
                    if not message._unknown_fields:
                        message._unknown_fields = []
                    tag_bytes = encoder.TagBytes(field_number,
                                                 wire_format.WIRETYPE_VARINT)
                    message._unknown_fields.append(
                        (tag_bytes, buffer[value_start_pos:pos]))
            if pos > endpoint:
                if element in enum_type.values_by_number:
                    del value[-1]  # Discard corrupt value.
                else:
                    del message._unknown_fields[-1]
                raise _DecodeError('Packed element was truncated.')
            return pos

        return DecodePackedField
    elif is_repeated:
        tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
        tag_len = len(tag_bytes)

        def DecodeRepeatedField(buffer, pos, end, message, field_dict):
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            while 1:
                (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
                if element in enum_type.values_by_number:
                    value.append(element)
                else:
                    if not message._unknown_fields:
                        message._unknown_fields = []
                    message._unknown_fields.append(
                        (tag_bytes, buffer[pos:new_pos]))
                # Predict that the next tag is another copy of the same repeated
                # field.
                pos = new_pos + tag_len
                if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
                    # Prediction failed.  Return.
                    if new_pos > end:
                        raise _DecodeError('Truncated message.')
                    return new_pos

        return DecodeRepeatedField
    else:

        def DecodeField(buffer, pos, end, message, field_dict):
            value_start_pos = pos
            (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
            if pos > end:
                raise _DecodeError('Truncated message.')
            if enum_value in enum_type.values_by_number:
                field_dict[key] = enum_value
            else:
                if not message._unknown_fields:
                    message._unknown_fields = []
                tag_bytes = encoder.TagBytes(field_number,
                                             wire_format.WIRETYPE_VARINT)
                message._unknown_fields.append(
                    (tag_bytes, buffer[value_start_pos:pos]))
            return pos

        return DecodeField
Ejemplo n.º 28
0
    """Returns a decoder for a string field."""

    local_DecodeVarint = _DecodeVarint
    local_unicode = unicode

    def _ConvertToUnicode(byte_str):
        try:
            return local_unicode(byte_str, 'utf-8')
        except UnicodeDecodeError, e:
            # add more information to the error message and re-raise it.
            e.reason = '%s in field: %s' % (e, key.full_name)
            raise

    assert not is_packed
    if is_repeated:
        tag_bytes = encoder.TagBytes(field_number,
                                     wire_format.WIRETYPE_LENGTH_DELIMITED)
        tag_len = len(tag_bytes)

        def DecodeRepeatedField(buffer, pos, end, message, field_dict):
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            while 1:
                (size, pos) = local_DecodeVarint(buffer, pos)
                new_pos = pos + size
                if new_pos > end:
                    raise _DecodeError('Truncated string.')
                value.append(_ConvertToUnicode(buffer[pos:new_pos]))
                # Predict that the next tag is another copy of the same repeated field.
                pos = new_pos + tag_len
                if buffer[new_pos:pos] != tag_bytes or new_pos == end:
Ejemplo n.º 29
0
def MessageSetItemDecoder(descriptor):
    """Returns a decoder for a MessageSet item.

  The parameter is the message Descriptor.

  The message set message looks like this:
    message MessageSet {
      repeated group Item = 1 {
        required int32 type_id = 2;
        required string message = 3;
      }
    }
  """

    type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
    message_tag_bytes = encoder.TagBytes(3,
                                         wire_format.WIRETYPE_LENGTH_DELIMITED)
    item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)

    local_ReadTag = ReadTag
    local_DecodeVarint = _DecodeVarint
    local_SkipField = SkipField

    def DecodeItem(buffer, pos, end, message, field_dict):
        """Decode serialized message set to its value and new position.

    Args:
      buffer: memoryview of the serialized bytes.
      pos: int, position in the memory view to start at.
      end: int, end position of serialized data
      message: Message object to store unknown fields in
      field_dict: Map[Descriptor, Any] to store decoded values in.

    Returns:
      int, new position in serialized data.
    """
        message_set_item_start = pos
        type_id = -1
        message_start = -1
        message_end = -1

        # Technically, type_id and message can appear in any order, so we need
        # a little loop here.
        while 1:
            (tag_bytes, pos) = local_ReadTag(buffer, pos)
            if tag_bytes == type_id_tag_bytes:
                (type_id, pos) = local_DecodeVarint(buffer, pos)
            elif tag_bytes == message_tag_bytes:
                (size, message_start) = local_DecodeVarint(buffer, pos)
                pos = message_end = message_start + size
            elif tag_bytes == item_end_tag_bytes:
                break
            else:
                pos = SkipField(buffer, pos, end, tag_bytes)
                if pos == -1:
                    raise _DecodeError('Missing group end tag.')

        if pos > end:
            raise _DecodeError('Truncated message.')

        if type_id == -1:
            raise _DecodeError('MessageSet item missing type_id.')
        if message_start == -1:
            raise _DecodeError('MessageSet item missing message.')

        extension = message.Extensions._FindExtensionByNumber(type_id)
        # pylint: disable=protected-access
        if extension is not None:
            value = field_dict.get(extension)
            if value is None:
                message_type = extension.message_type
                if not hasattr(message_type, '_concrete_class'):
                    # pylint: disable=protected-access
                    message._FACTORY.GetPrototype(message_type)
                value = field_dict.setdefault(extension,
                                              message_type._concrete_class())
            if value._InternalParse(buffer, message_start,
                                    message_end) != message_end:
                # The only reason _InternalParse would return early is if it encountered
                # an end-group tag.
                raise _DecodeError('Unexpected end-group tag.')
        else:
            if not message._unknown_fields:
                message._unknown_fields = []
            message._unknown_fields.append(
                (MESSAGE_SET_ITEM_TAG,
                 buffer[message_set_item_start:pos].tobytes()))
            if message._unknown_field_set is None:
                message._unknown_field_set = containers.UnknownFieldSet()
            message._unknown_field_set._add(
                type_id, wire_format.WIRETYPE_LENGTH_DELIMITED,
                buffer[message_start:message_end].tobytes())
            # pylint: enable=protected-access

        return pos

    return DecodeItem
Ejemplo n.º 30
0
def EnumDecoder(field_number,
                is_repeated,
                is_packed,
                key,
                new_default,
                clear_if_default=False):
    """Returns a decoder for enum field."""
    enum_type = key.enum_type
    if is_packed:
        local_DecodeVarint = _DecodeVarint

        def DecodePackedField(buffer, pos, end, message, field_dict):
            """Decode serialized packed enum to its value and a new position.

      Args:
        buffer: memoryview of the serialized bytes.
        pos: int, position in the memory view to start at.
        end: int, end position of serialized data
        message: Message object to store unknown fields in
        field_dict: Map[Descriptor, Any] to store decoded values in.

      Returns:
        int, new position in serialized data.
      """
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            (endpoint, pos) = local_DecodeVarint(buffer, pos)
            endpoint += pos
            if endpoint > end:
                raise _DecodeError('Truncated message.')
            while pos < endpoint:
                value_start_pos = pos
                (element, pos) = _DecodeSignedVarint32(buffer, pos)
                # pylint: disable=protected-access
                if element in enum_type.values_by_number:
                    value.append(element)
                else:
                    if not message._unknown_fields:
                        message._unknown_fields = []
                    tag_bytes = encoder.TagBytes(field_number,
                                                 wire_format.WIRETYPE_VARINT)

                    message._unknown_fields.append(
                        (tag_bytes, buffer[value_start_pos:pos].tobytes()))
                    if message._unknown_field_set is None:
                        message._unknown_field_set = containers.UnknownFieldSet(
                        )
                    message._unknown_field_set._add(
                        field_number, wire_format.WIRETYPE_VARINT, element)
                    # pylint: enable=protected-access
            if pos > endpoint:
                if element in enum_type.values_by_number:
                    del value[-1]  # Discard corrupt value.
                else:
                    del message._unknown_fields[-1]
                    # pylint: disable=protected-access
                    del message._unknown_field_set._values[-1]
                    # pylint: enable=protected-access
                raise _DecodeError('Packed element was truncated.')
            return pos

        return DecodePackedField
    elif is_repeated:
        tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
        tag_len = len(tag_bytes)

        def DecodeRepeatedField(buffer, pos, end, message, field_dict):
            """Decode serialized repeated enum to its value and a new position.

      Args:
        buffer: memoryview of the serialized bytes.
        pos: int, position in the memory view to start at.
        end: int, end position of serialized data
        message: Message object to store unknown fields in
        field_dict: Map[Descriptor, Any] to store decoded values in.

      Returns:
        int, new position in serialized data.
      """
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            while 1:
                (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
                # pylint: disable=protected-access
                if element in enum_type.values_by_number:
                    value.append(element)
                else:
                    if not message._unknown_fields:
                        message._unknown_fields = []
                    message._unknown_fields.append(
                        (tag_bytes, buffer[pos:new_pos].tobytes()))
                    if message._unknown_field_set is None:
                        message._unknown_field_set = containers.UnknownFieldSet(
                        )
                    message._unknown_field_set._add(
                        field_number, wire_format.WIRETYPE_VARINT, element)
                # pylint: enable=protected-access
                # Predict that the next tag is another copy of the same repeated
                # field.
                pos = new_pos + tag_len
                if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
                    # Prediction failed.  Return.
                    if new_pos > end:
                        raise _DecodeError('Truncated message.')
                    return new_pos

        return DecodeRepeatedField
    else:

        def DecodeField(buffer, pos, end, message, field_dict):
            """Decode serialized repeated enum to its value and a new position.

      Args:
        buffer: memoryview of the serialized bytes.
        pos: int, position in the memory view to start at.
        end: int, end position of serialized data
        message: Message object to store unknown fields in
        field_dict: Map[Descriptor, Any] to store decoded values in.

      Returns:
        int, new position in serialized data.
      """
            value_start_pos = pos
            (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
            if pos > end:
                raise _DecodeError('Truncated message.')
            if clear_if_default and not enum_value:
                field_dict.pop(key, None)
                return pos
            # pylint: disable=protected-access
            if enum_value in enum_type.values_by_number:
                field_dict[key] = enum_value
            else:
                if not message._unknown_fields:
                    message._unknown_fields = []
                tag_bytes = encoder.TagBytes(field_number,
                                             wire_format.WIRETYPE_VARINT)
                message._unknown_fields.append(
                    (tag_bytes, buffer[value_start_pos:pos].tobytes()))
                if message._unknown_field_set is None:
                    message._unknown_field_set = containers.UnknownFieldSet()
                message._unknown_field_set._add(field_number,
                                                wire_format.WIRETYPE_VARINT,
                                                enum_value)
                # pylint: enable=protected-access
            return pos

        return DecodeField