def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): """Returns a decoder for a group field.""" end_tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_END_GROUP) end_tag_len = len(end_tag_bytes) if is_packed: raise AssertionError if is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_START_GROUP) tag_len = len(tag_bytes) def DecodeRepeatedField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) # Read sub-message. pos = value.add()._InternalParse(buffer, pos, end) # Read end tag. new_pos = pos + end_tag_len if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: raise _DecodeError('Missing group end tag.') # Predict that the next tag is another copy of the same repeated field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end: # Prediction failed. Return. return new_pos return DecodeRepeatedField else: def DecodeField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) # Read sub-message. pos = value._InternalParse(buffer, pos, end) # Read end tag. new_pos = pos + end_tag_len if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: raise _DecodeError('Missing group end tag.') return new_pos return DecodeField
def MessageSetItemDecoder(extensions_by_number): type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) local_ReadTag = ReadTag local_DecodeVarint = _DecodeVarint local_SkipField = SkipField def DecodeItem(buffer, pos, end, message, field_dict): message_set_item_start = pos type_id = -1 message_start = -1 message_end = -1 while True: (tag_bytes, pos) = local_ReadTag(buffer, pos) if tag_bytes == type_id_tag_bytes: (type_id, pos) = local_DecodeVarint(buffer, pos) elif tag_bytes == message_tag_bytes: (size, message_start) = local_DecodeVarint(buffer, pos) pos = message_end = message_start + size elif tag_bytes == item_end_tag_bytes: break else: pos = SkipField(buffer, pos, end, tag_bytes) if pos == -1: raise _DecodeError('Missing group end tag.') if pos > end: raise _DecodeError('Truncated message.') if type_id == -1: raise _DecodeError('MessageSet item missing type_id.') if message_start == -1: raise _DecodeError('MessageSet item missing message.') extension = extensions_by_number.get(type_id) if extension is not None: value = field_dict.get(extension) if value is None: value = field_dict.setdefault( extension, extension.message_type._concrete_class()) raise _DecodeError('Unexpected end-group tag.') else: if not message._unknown_fields: message._unknown_fields = [] message._unknown_fields.append( (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos])) return pos return DecodeItem
def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default, clear_if_default=False): if is_packed: local_DecodeVarint = _DecodeVarint def DecodePackedField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) (endpoint, pos) = local_DecodeVarint(buffer, pos) endpoint += pos if endpoint > end: raise _DecodeError('Truncated message.') while pos < endpoint: (element, pos) = decode_value(buffer, pos) value.append(element) if pos > endpoint: del value[-1] # Discard corrupt value. raise _DecodeError('Packed element was truncated.') return pos return DecodePackedField elif is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_type) tag_len = len(tag_bytes) def DecodeRepeatedField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: (element, new_pos) = decode_value(buffer, pos) value.append(element) # Predict that the next tag is another copy of the same repeated # field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos >= end: # Prediction failed. Return. if new_pos > end: raise _DecodeError('Truncated message.') return new_pos return DecodeRepeatedField else: def DecodeField(buffer, pos, end, message, field_dict): (new_value, pos) = decode_value(buffer, pos) if pos > end: raise _DecodeError('Truncated message.') if clear_if_default and not new_value: field_dict.pop(key, None) else: field_dict[key] = new_value return pos return DecodeField
def DecodeField(buffer, pos, end, message, field_dict): """Decode serialized repeated enum to its value and a new position. Args: buffer: memoryview of the serialized bytes. pos: int, position in the memory view to start at. end: int, end position of serialized data message: Message object to store unknown fields in field_dict: Map[Descriptor, Any] to store decoded values in. Returns: int, new position in serialized data. """ value_start_pos = pos (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) if pos > end: raise _DecodeError('Truncated message.') # pylint: disable=protected-access if enum_value in enum_type.values_by_number: field_dict[key] = enum_value else: if not message._unknown_fields: message._unknown_fields = [] tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) message._unknown_fields.append( (tag_bytes, buffer[value_start_pos:pos].tobytes())) # pylint: enable=protected-access return pos
def DecodePackedField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) (endpoint, pos) = local_DecodeVarint(buffer, pos) endpoint += pos if endpoint > end: raise _DecodeError('Truncated message.') while pos < endpoint: value_start_pos = pos (element, pos) = _DecodeSignedVarint32(buffer, pos) if element in enum_type.values_by_number: value.append(element) else: if not message._unknown_fields: message._unknown_fields = [] tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) message._unknown_fields.append( (tag_bytes, buffer[value_start_pos:pos])) if pos > endpoint: if element in enum_type.values_by_number: del value[-1] # Discard corrupt value. else: del message._unknown_fields[-1] raise _DecodeError('Packed element was truncated.') return pos
def BytesDecoder(field_number, is_repeated, is_packed, key, new_default): """Returns a decoder for a bytes field.""" local_DecodeVarint = _DecodeVarint assert not is_packed if is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) tag_len = len(tag_bytes) def DecodeRepeatedField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') value.append(buffer[pos:new_pos].tobytes()) # Predict that the next tag is another copy of the same repeated field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end: # Prediction failed. Return. return new_pos return DecodeRepeatedField else: def DecodeField(buffer, pos, end, message, field_dict): (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') field_dict[key] = buffer[pos:new_pos].tobytes() return new_pos return DecodeField
def BytesDecoder(field_number, is_repeated, is_packed, key, new_default): local_DecodeVarint = _DecodeVarint if is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) tag_len = len(tag_bytes) def DecodeRepeatedField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while True: (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') value.append(buffer[pos:new_pos]) pos = new_pos + tag_len if not buffer[new_pos:pos] != tag_bytes: if new_pos == end: return new_pos return new_pos return DecodeRepeatedField else: def DecodeField(buffer, pos, end, message, field_dict): (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') field_dict[key] = buffer[pos:new_pos] return new_pos return DecodeField
def StringDecoder(field_number, is_repeated, is_packed, key, new_default, clear_if_default=False): """Returns a decoder for a string field.""" local_DecodeVarint = _DecodeVarint def _ConvertToUnicode(memview): """Convert byte to unicode.""" byte_str = memview.tobytes() try: value = str(byte_str, 'utf-8') except UnicodeDecodeError as e: # add more information to the error message and re-raise it. e.reason = '%s in field: %s' % (e, key.full_name) raise return value assert not is_packed if is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) tag_len = len(tag_bytes) def DecodeRepeatedField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') value.append(_ConvertToUnicode(buffer[pos:new_pos])) # Predict that the next tag is another copy of the same repeated field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end: # Prediction failed. Return. return new_pos return DecodeRepeatedField else: def DecodeField(buffer, pos, end, message, field_dict): (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') if clear_if_default and not size: field_dict.pop(key, None) else: field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) return new_pos return DecodeField
def AddDecoder(wiretype, is_packed): tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) cls._decoders_by_tag[tag_bytes] = ( type_checkers.TYPE_TO_DECODER[field_descriptor.type]( field_descriptor.number, is_repeated, is_packed, field_descriptor, field_descriptor._default_constructor), field_descriptor if field_descriptor.containing_oneof is not None else None)
def encode_group(value, typedef, field_number): """Encode a protobuf group type""" # Message will take care of the start tag # Need to add the end_tag output = encode_message(value, typedef, group=True) end_tag = encoder.TagBytes(int(field_number), wire_format.WIRETYPE_END_GROUP) output.append(end_tag) return output
def StringDecoder(field_number, is_repeated, is_packed, key, new_default, is_strict_utf8=False): """Returns a decoder for a string field.""" local_DecodeVarint = _DecodeVarint local_unicode = six.text_type def _ConvertToUnicode(memview): """Convert byte to unicode.""" byte_str = memview.tobytes() try: value = local_unicode(byte_str, 'utf-8') except UnicodeDecodeError as e: # add more information to the error message and re-raise it. e.reason = '%s in field: %s' % (e, key.full_name) raise if is_strict_utf8 and six.PY2 and sys.maxunicode > _UCS2_MAXUNICODE: # Only do the check for python2 ucs4 when is_strict_utf8 enabled if _SURROGATE_PATTERN.search(value): reason = ('String field %s contains invalid UTF-8 data when parsing' 'a protocol buffer: surrogates not allowed. Use' 'the bytes type if you intend to send raw bytes.') % ( key.full_name) raise message.DecodeError(reason) return value assert not is_packed if is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) tag_len = len(tag_bytes) def DecodeRepeatedField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') value.append(_ConvertToUnicode(buffer[pos:new_pos])) # Predict that the next tag is another copy of the same repeated field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end: # Prediction failed. Return. return new_pos return DecodeRepeatedField else: def DecodeField(buffer, pos, end, message, field_dict): (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) return new_pos return DecodeField
def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): """Returns a decoder for a message field.""" local_DecodeVarint = _DecodeVarint assert not is_packed if is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) tag_len = len(tag_bytes) def DecodeRepeatedField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) # Read length. (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated message.') # Read sub-message. if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: # The only reason _InternalParse would return early is if it # encountered an end-group tag. raise _DecodeError('Unexpected end-group tag.') # Predict that the next tag is another copy of the same repeated field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end: # Prediction failed. Return. return new_pos return DecodeRepeatedField else: def DecodeField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) # Read length. (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated message.(newpos:%d, end:%d)' % (new_pos, end)) # Read sub-message. if value._InternalParse(buffer, pos, new_pos) != new_pos: # The only reason _InternalParse would return early is if it encountered # an end-group tag. raise _DecodeError('Unexpected end-group tag.') return new_pos return DecodeField
def CheckUnknownField(self, name, expected_value): field_descriptor = self.descriptor.fields_by_name[name] wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] field_tag = encoder.TagBytes(field_descriptor.number, wire_type) result_dict = {} for tag_bytes, value in self.empty_message._unknown_fields: if tag_bytes == field_tag: decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] decoder(value, 0, len(value), self.all_fields, result_dict) self.assertEqual(expected_value, result_dict[field_descriptor])
def GetField(self, name): field_descriptor = self.descriptor.fields_by_name[name] wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] field_tag = encoder.TagBytes(field_descriptor.number, wire_type) result_dict = {} for tag_bytes, value in self.unknown_fields: if tag_bytes == field_tag: decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] decoder(value, 0, len(value), self.all_fields, result_dict) return result_dict[field_descriptor]
def GetUnknownField(self, name): field_descriptor = self.descriptor.fields_by_name[name] wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[ field_descriptor.type] field_tag = encoder.TagBytes(field_descriptor.number, wire_type) result_dict = {} for tag_bytes, value in self.missing_message._unknown_fields: if tag_bytes == field_tag: decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[ tag_bytes][0] decoder(value, 0, len(value), self.message, result_dict) return result_dict[field_descriptor]
def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default): if is_packed: local_DecodeVarint = _DecodeVarint def DecodePackedField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) (endpoint, pos) = local_DecodeVarint(buffer, pos) endpoint += pos if endpoint > end: raise _DecodeError('Truncated message.') while pos < endpoint: (element, pos) = decode_value(buffer, pos) value.append(element) if pos > endpoint: del value[-1] raise _DecodeError('Packed element was truncated.') return pos return DecodePackedField if is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_type) tag_len = len(tag_bytes) def DecodeRepeatedField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while True: (element, new_pos) = decode_value(buffer, pos) value.append(element) pos = new_pos + tag_len if not buffer[new_pos:pos] != tag_bytes: if new_pos >= end: if new_pos > end: raise _DecodeError('Truncated message.') return new_pos if new_pos > end: raise _DecodeError('Truncated message.') return new_pos return DecodeRepeatedField else: def DecodeField(buffer, pos, end, message, field_dict): (field_dict[key], pos) = decode_value(buffer, pos) if pos > end: del field_dict[key] raise _DecodeError('Truncated message.') return pos return DecodeField
def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): end_tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_END_GROUP) end_tag_len = len(end_tag_bytes) if is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_START_GROUP) tag_len = len(tag_bytes) def DecodeRepeatedField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while True: value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) pos = value.add()._InternalParse(buffer, pos, end) new_pos = pos + end_tag_len if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: raise _DecodeError('Missing group end tag.') pos = new_pos + tag_len if not buffer[new_pos:pos] != tag_bytes: if new_pos == end: return new_pos return new_pos return DecodeRepeatedField else: def DecodeField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) pos = value._InternalParse(buffer, pos, end) new_pos = pos + end_tag_len if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: raise _DecodeError('Missing group end tag.') return new_pos return DecodeField
def UnknownMessageSetItemDecoder(): """Returns a decoder for a Unknown MessageSet item.""" type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) def DecodeUnknownItem(buffer): pos = 0 end = len(buffer) message_start = -1 message_end = -1 while 1: (tag_bytes, pos) = ReadTag(buffer, pos) if tag_bytes == type_id_tag_bytes: (type_id, pos) = _DecodeVarint(buffer, pos) elif tag_bytes == message_tag_bytes: (size, message_start) = _DecodeVarint(buffer, pos) pos = message_end = message_start + size elif tag_bytes == item_end_tag_bytes: break else: pos = SkipField(buffer, pos, end, tag_bytes) if pos == -1: raise _DecodeError('Missing group end tag.') if pos > end: raise _DecodeError('Truncated message.') if type_id == -1: raise _DecodeError('MessageSet item missing type_id.') if message_start == -1: raise _DecodeError('MessageSet item missing message.') return (type_id, buffer[message_start:message_end].tobytes()) return DecodeUnknownItem
def DecodePackedField(buffer, pos, end, message, field_dict): """Decode serialized packed enum to its value and a new position. Args: buffer: memoryview of the serialized bytes. pos: int, position in the memory view to start at. end: int, end position of serialized data message: Message object to store unknown fields in field_dict: Map[Descriptor, Any] to store decoded values in. Returns: int, new position in serialized data. """ value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) (endpoint, pos) = local_DecodeVarint(buffer, pos) endpoint += pos if endpoint > end: raise _DecodeError('Truncated message.') while pos < endpoint: value_start_pos = pos (element, pos) = _DecodeSignedVarint32(buffer, pos) # pylint: disable=protected-access if element in enum_type.values_by_number: value.append(element) else: if not message._unknown_fields: message._unknown_fields = [] tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) message._unknown_fields.append( (tag_bytes, buffer[value_start_pos:pos].tobytes())) if message._unknown_field_set is None: message._unknown_field_set = containers.UnknownFieldSet( ) message._unknown_field_set._add( field_number, wire_format.WIRETYPE_VARINT, element) # pylint: enable=protected-access if pos > endpoint: if element in enum_type.values_by_number: del value[-1] # Discard corrupt value. else: del message._unknown_fields[-1] # pylint: disable=protected-access del message._unknown_field_set._values[-1] # pylint: enable=protected-access raise _DecodeError('Packed element was truncated.') return pos
def DecodeField(buffer, pos, end, message, field_dict): value_start_pos = pos (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) if pos > end: raise _DecodeError('Truncated message.') if enum_value in enum_type.values_by_number: field_dict[key] = enum_value else: if not message._unknown_fields: message._unknown_fields = [] tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) message._unknown_fields.append( (tag_bytes, buffer[value_start_pos:pos])) return pos
def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): local_DecodeVarint = _DecodeVarint if is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) tag_len = len(tag_bytes) def DecodeRepeatedField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while True: value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated message.') if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: raise _DecodeError('Unexpected end-group tag.') pos = new_pos + tag_len if not buffer[new_pos:pos] != tag_bytes: if new_pos == end: return new_pos return new_pos return DecodeRepeatedField else: def DecodeField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated message.') if value._InternalParse(buffer, pos, new_pos) != new_pos: raise _DecodeError('Unexpected end-group tag.') return new_pos return DecodeField
def MapDecoder(field_descriptor, new_default, is_message_map): """Returns a decoder for a map field.""" key = field_descriptor tag_bytes = encoder.TagBytes(field_descriptor.number, wire_format.WIRETYPE_LENGTH_DELIMITED) tag_len = len(tag_bytes) local_DecodeVarint = _DecodeVarint # Can't read _concrete_class yet; might not be initialized. message_type = field_descriptor.message_type def DecodeMap(buffer, pos, end, message, field_dict): submsg = message_type._concrete_class() value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: # Read length. (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated message.') # Read sub-message. submsg.Clear() if submsg._InternalParse(buffer, pos, new_pos) != new_pos: # The only reason _InternalParse would return early is if it # encountered an end-group tag. raise _DecodeError('Unexpected end-group tag.') if is_message_map: value[submsg.key].MergeFrom(submsg.value) else: value[submsg.key] = submsg.value # Predict that the next tag is another copy of the same repeated field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end: # Prediction failed. Return. return new_pos return DecodeMap
def AddDecoder(wiretype, is_packed): tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) cls._decoders_by_tag[tag_bytes] = type_checkers.TYPE_TO_DECODER[ field_descriptor.type](field_descriptor.number, is_repeated, is_packed, field_descriptor, field_descriptor._default_constructor)
def DecodeField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated message.') if value._InternalParse(buffer, pos, new_pos) != new_pos: raise _DecodeError('Unexpected end-group tag.') return new_pos return DecodeField MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) def MessageSetItemDecoder(extensions_by_number): type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) local_ReadTag = ReadTag local_DecodeVarint = _DecodeVarint local_SkipField = SkipField def DecodeItem(buffer, pos, end, message, field_dict): message_set_item_start = pos type_id = -1 message_start = -1 message_end = -1 while True:
def MessageSetItemDecoder(extensions_by_number): """Returns a decoder for a MessageSet item. The parameter is the _extensions_by_number map for the message class. The message set message looks like this: message MessageSet { repeated group Item = 1 { required int32 type_id = 2; required string message = 3; } } """ type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) local_ReadTag = ReadTag local_DecodeVarint = _DecodeVarint local_SkipField = SkipField def DecodeItem(buffer, pos, end, message, field_dict): message_set_item_start = pos type_id = -1 message_start = -1 message_end = -1 # Technically, type_id and message can appear in any order, so we need # a little loop here. while 1: (tag_bytes, pos) = local_ReadTag(buffer, pos) if tag_bytes == type_id_tag_bytes: (type_id, pos) = local_DecodeVarint(buffer, pos) elif tag_bytes == message_tag_bytes: (size, message_start) = local_DecodeVarint(buffer, pos) pos = message_end = message_start + size elif tag_bytes == item_end_tag_bytes: break else: pos = SkipField(buffer, pos, end, tag_bytes) if pos == -1: raise _DecodeError('Missing group end tag.') if pos > end: raise _DecodeError('Truncated message.') if type_id == -1: raise _DecodeError('MessageSet item missing type_id.') if message_start == -1: raise _DecodeError('MessageSet item missing message.') extension = extensions_by_number.get(type_id) if extension is not None: value = field_dict.get(extension) if value is None: value = field_dict.setdefault( extension, extension.message_type._concrete_class()) if value._InternalParse(buffer, message_start, message_end) != message_end: # The only reason _InternalParse would return early is if it encountered # an end-group tag. raise _DecodeError('Unexpected end-group tag.') else: if not message._unknown_fields: message._unknown_fields = [] message._unknown_fields.append( (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos])) return pos return DecodeItem
def encode_message(data, typedef, group=False): """Encode a Python dictionary representing a protobuf message data - Python dictionary mapping field numbers to values typedef - Type information including field number, field name and field type This will throw an exception if an unkown value is used as a key """ output = bytearray() for field_number, value in data.items(): # Get the field number convert it as necessary alt_field_number = None # if isinstance(field_number, (unicode, str)): if isinstance(field_number, str): if '-' in field_number: field_number, alt_field_number = field_number.split('-') for number, info in typedef.items(): if info['name'] == field_number and field_number != '': field_number = number break else: field_number = str(field_number) if field_number not in typedef: raise ValueError('Provided field name/number %s is not valid' % (field_number)) field_typedef = typedef[field_number] # Get encoder if 'type' not in field_typedef: raise ValueError('Field %s does not have a defined type' % field_number) field_type = field_typedef['type'] field_encoder = None if field_type == 'message': innertypedef = None # Check for a defined message type if alt_field_number is not None: if alt_field_number not in field_typedef['alt_typedefs']: raise ValueError( 'Provided alt field name/number %s is not valid for field_number %s' % (alt_field_number, field_number)) innertypedef = field_typedef['alt_typedefs'][alt_field_number] elif 'message_typedef' in field_typedef: # "Anonymous" inner message # Required to have a 'message_typedef' if 'message_typedef' not in field_typedef: raise ValueError( 'Could not find type definition for message field: %s' % field_number) innertypedef = field_typedef['message_typedef'] else: # if field_typedef['message_type_name'] not in known_messages: # raise ValueError('Message type (%s) has not been defined' # % field_typedef['message_type_name']) # innertypedef = field_typedef['message_type_name'] raise ValueError('Message type (%s) has not been defined' % field_typedef['message_type_name']) field_encoder = lambda data: encode_lendelim_message( data, innertypedef) elif field_type == 'group': innertypedef = None # Check for a defined group type if 'group_typedef' not in field_typedef: raise ValueError( 'Could not find type definition for group field: %s' % field_number) innertypedef = field_typedef['group_typedef'] field_encoder = lambda data: encode_group(data, innertypedef, field_number) else: if field_type not in types.encoders: raise ValueError('Unknown type: %s' % field_type) field_encoder = types.encoders[field_type] if field_encoder is None: raise ValueError('Encoder not implemented: %s' % field_type) # Encode the tag tag = encoder.TagBytes(int(field_number), types.wiretypes[field_type]) try: # Handle repeated values if isinstance(value, list) and not field_type.startswith('packed_'): for repeated in value: output += tag output += field_encoder(repeated) else: output += tag output += field_encoder(value) except Exception as exc: raise (ValueError, 'Error attempting to encode "%s" as %s: %s' % (value, field_type, exc), sys.exc_info()[2]) return output
def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): enum_type = key.enum_type if is_packed: local_DecodeVarint = _DecodeVarint def DecodePackedField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) (endpoint, pos) = local_DecodeVarint(buffer, pos) endpoint += pos if endpoint > end: raise _DecodeError('Truncated message.') while pos < endpoint: value_start_pos = pos (element, pos) = _DecodeSignedVarint32(buffer, pos) if element in enum_type.values_by_number: value.append(element) else: if not message._unknown_fields: message._unknown_fields = [] tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) message._unknown_fields.append( (tag_bytes, buffer[value_start_pos:pos])) if pos > endpoint: if element in enum_type.values_by_number: del value[-1] # Discard corrupt value. else: del message._unknown_fields[-1] raise _DecodeError('Packed element was truncated.') return pos return DecodePackedField elif is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) tag_len = len(tag_bytes) def DecodeRepeatedField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: (element, new_pos) = _DecodeSignedVarint32(buffer, pos) if element in enum_type.values_by_number: value.append(element) else: if not message._unknown_fields: message._unknown_fields = [] message._unknown_fields.append( (tag_bytes, buffer[pos:new_pos])) # Predict that the next tag is another copy of the same repeated # field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos >= end: # Prediction failed. Return. if new_pos > end: raise _DecodeError('Truncated message.') return new_pos return DecodeRepeatedField else: def DecodeField(buffer, pos, end, message, field_dict): value_start_pos = pos (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) if pos > end: raise _DecodeError('Truncated message.') if enum_value in enum_type.values_by_number: field_dict[key] = enum_value else: if not message._unknown_fields: message._unknown_fields = [] tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) message._unknown_fields.append( (tag_bytes, buffer[value_start_pos:pos])) return pos return DecodeField
"""Returns a decoder for a string field.""" local_DecodeVarint = _DecodeVarint local_unicode = unicode def _ConvertToUnicode(byte_str): try: return local_unicode(byte_str, 'utf-8') except UnicodeDecodeError, e: # add more information to the error message and re-raise it. e.reason = '%s in field: %s' % (e, key.full_name) raise assert not is_packed if is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) tag_len = len(tag_bytes) def DecodeRepeatedField(buffer, pos, end, message, field_dict): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') value.append(_ConvertToUnicode(buffer[pos:new_pos])) # Predict that the next tag is another copy of the same repeated field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end:
def MessageSetItemDecoder(descriptor): """Returns a decoder for a MessageSet item. The parameter is the message Descriptor. The message set message looks like this: message MessageSet { repeated group Item = 1 { required int32 type_id = 2; required string message = 3; } } """ type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) local_ReadTag = ReadTag local_DecodeVarint = _DecodeVarint local_SkipField = SkipField def DecodeItem(buffer, pos, end, message, field_dict): """Decode serialized message set to its value and new position. Args: buffer: memoryview of the serialized bytes. pos: int, position in the memory view to start at. end: int, end position of serialized data message: Message object to store unknown fields in field_dict: Map[Descriptor, Any] to store decoded values in. Returns: int, new position in serialized data. """ message_set_item_start = pos type_id = -1 message_start = -1 message_end = -1 # Technically, type_id and message can appear in any order, so we need # a little loop here. while 1: (tag_bytes, pos) = local_ReadTag(buffer, pos) if tag_bytes == type_id_tag_bytes: (type_id, pos) = local_DecodeVarint(buffer, pos) elif tag_bytes == message_tag_bytes: (size, message_start) = local_DecodeVarint(buffer, pos) pos = message_end = message_start + size elif tag_bytes == item_end_tag_bytes: break else: pos = SkipField(buffer, pos, end, tag_bytes) if pos == -1: raise _DecodeError('Missing group end tag.') if pos > end: raise _DecodeError('Truncated message.') if type_id == -1: raise _DecodeError('MessageSet item missing type_id.') if message_start == -1: raise _DecodeError('MessageSet item missing message.') extension = message.Extensions._FindExtensionByNumber(type_id) # pylint: disable=protected-access if extension is not None: value = field_dict.get(extension) if value is None: message_type = extension.message_type if not hasattr(message_type, '_concrete_class'): # pylint: disable=protected-access message._FACTORY.GetPrototype(message_type) value = field_dict.setdefault(extension, message_type._concrete_class()) if value._InternalParse(buffer, message_start, message_end) != message_end: # The only reason _InternalParse would return early is if it encountered # an end-group tag. raise _DecodeError('Unexpected end-group tag.') else: if not message._unknown_fields: message._unknown_fields = [] message._unknown_fields.append( (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) if message._unknown_field_set is None: message._unknown_field_set = containers.UnknownFieldSet() message._unknown_field_set._add( type_id, wire_format.WIRETYPE_LENGTH_DELIMITED, buffer[message_start:message_end].tobytes()) # pylint: enable=protected-access return pos return DecodeItem
def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, clear_if_default=False): """Returns a decoder for enum field.""" enum_type = key.enum_type if is_packed: local_DecodeVarint = _DecodeVarint def DecodePackedField(buffer, pos, end, message, field_dict): """Decode serialized packed enum to its value and a new position. Args: buffer: memoryview of the serialized bytes. pos: int, position in the memory view to start at. end: int, end position of serialized data message: Message object to store unknown fields in field_dict: Map[Descriptor, Any] to store decoded values in. Returns: int, new position in serialized data. """ value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) (endpoint, pos) = local_DecodeVarint(buffer, pos) endpoint += pos if endpoint > end: raise _DecodeError('Truncated message.') while pos < endpoint: value_start_pos = pos (element, pos) = _DecodeSignedVarint32(buffer, pos) # pylint: disable=protected-access if element in enum_type.values_by_number: value.append(element) else: if not message._unknown_fields: message._unknown_fields = [] tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) message._unknown_fields.append( (tag_bytes, buffer[value_start_pos:pos].tobytes())) if message._unknown_field_set is None: message._unknown_field_set = containers.UnknownFieldSet( ) message._unknown_field_set._add( field_number, wire_format.WIRETYPE_VARINT, element) # pylint: enable=protected-access if pos > endpoint: if element in enum_type.values_by_number: del value[-1] # Discard corrupt value. else: del message._unknown_fields[-1] # pylint: disable=protected-access del message._unknown_field_set._values[-1] # pylint: enable=protected-access raise _DecodeError('Packed element was truncated.') return pos return DecodePackedField elif is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) tag_len = len(tag_bytes) def DecodeRepeatedField(buffer, pos, end, message, field_dict): """Decode serialized repeated enum to its value and a new position. Args: buffer: memoryview of the serialized bytes. pos: int, position in the memory view to start at. end: int, end position of serialized data message: Message object to store unknown fields in field_dict: Map[Descriptor, Any] to store decoded values in. Returns: int, new position in serialized data. """ value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: (element, new_pos) = _DecodeSignedVarint32(buffer, pos) # pylint: disable=protected-access if element in enum_type.values_by_number: value.append(element) else: if not message._unknown_fields: message._unknown_fields = [] message._unknown_fields.append( (tag_bytes, buffer[pos:new_pos].tobytes())) if message._unknown_field_set is None: message._unknown_field_set = containers.UnknownFieldSet( ) message._unknown_field_set._add( field_number, wire_format.WIRETYPE_VARINT, element) # pylint: enable=protected-access # Predict that the next tag is another copy of the same repeated # field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos >= end: # Prediction failed. Return. if new_pos > end: raise _DecodeError('Truncated message.') return new_pos return DecodeRepeatedField else: def DecodeField(buffer, pos, end, message, field_dict): """Decode serialized repeated enum to its value and a new position. Args: buffer: memoryview of the serialized bytes. pos: int, position in the memory view to start at. end: int, end position of serialized data message: Message object to store unknown fields in field_dict: Map[Descriptor, Any] to store decoded values in. Returns: int, new position in serialized data. """ value_start_pos = pos (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) if pos > end: raise _DecodeError('Truncated message.') if clear_if_default and not enum_value: field_dict.pop(key, None) return pos # pylint: disable=protected-access if enum_value in enum_type.values_by_number: field_dict[key] = enum_value else: if not message._unknown_fields: message._unknown_fields = [] tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) message._unknown_fields.append( (tag_bytes, buffer[value_start_pos:pos].tobytes())) if message._unknown_field_set is None: message._unknown_field_set = containers.UnknownFieldSet() message._unknown_field_set._add(field_number, wire_format.WIRETYPE_VARINT, enum_value) # pylint: enable=protected-access return pos return DecodeField