Example #1
0
    def _parse_header_format(self, buffer):
        # make BufferIO look like TTransport for varint helpers
        buffer_transport = TMemoryBuffer()
        buffer_transport._buffer = buffer

        buffer.read(2)  # discard the magic bytes
        self.flags, = U16.unpack(buffer.read(U16.size))
        self.sequence_id, = I32.unpack(buffer.read(I32.size))

        header_length = U16.unpack(buffer.read(U16.size))[0] * 4
        end_of_headers = buffer.tell() + header_length
        if end_of_headers > len(buffer.getvalue()):
            raise TTransportException(
                TTransportException.SIZE_LIMIT,
                "Header size is larger than whole frame.",
            )

        self._protocol_id = readVarint(buffer_transport)

        transforms = []
        transform_count = readVarint(buffer_transport)
        for _ in range(transform_count):
            transform_id = readVarint(buffer_transport)
            if transform_id not in READ_TRANSFORMS_BY_ID:
                raise TApplicationException(
                    TApplicationException.INVALID_TRANSFORM,
                    "Unknown transform: %d" % transform_id,
                )
            transforms.append(transform_id)
        transforms.reverse()

        headers = {}
        while buffer.tell() < end_of_headers:
            header_type = readVarint(buffer_transport)
            if header_type == TInfoHeaderType.KEY_VALUE:
                count = readVarint(buffer_transport)
                for _ in range(count):
                    key = _readString(buffer_transport)
                    value = _readString(buffer_transport)
                    headers[key] = value
            else:
                break  # ignore unknown headers
        self._read_headers = headers

        # skip padding / anything we didn't understand
        buffer.seek(end_of_headers)

        payload = buffer.read()
        for transform_id in transforms:
            transform_fn = READ_TRANSFORMS_BY_ID[transform_id]
            payload = transform_fn(payload)
        return BufferIO(payload)
Example #2
0
    def DeserializeThriftCall(self, buf):
        """Deserialize a stream and context to a MethodReturnMessage.

    Args:
      buf - The buffer.
      ctx - The context from serialization.

    Returns:
      A MethodCallMessage.
    """

        thrift_buffer = TMemoryBuffer()
        thrift_buffer._buffer = buf
        protocol = self._protocol_factory.getProtocol(thrift_buffer)

        (fn_name, msg_type, seq_id) = protocol.readMessageBegin()
        if msg_type == TMessageType.EXCEPTION:
            x = TApplicationException()
            x.read(protocol)
            protocol.readMessageEnd()
            return MethodReturnMessage(error=x)

        result_cls = self._FindClass('%s_result' % fn_name)
        if result_cls:
            result = result_cls()
            result.read(protocol)
        else:
            result = None
        protocol.readMessageEnd()

        if not result:
            return MethodReturnMessage()
        if getattr(result, 'success', None) is not None:
            return MethodReturnMessage(return_value=result.success)

        result_spec = getattr(result_cls, 'thrift_spec', None)
        if result_spec:
            exceptions = result_spec[1:]
            for e in exceptions:
                attr_val = getattr(result, e[2], None)
                if attr_val is not None:
                    return MethodReturnMessage(error=attr_val)

        return MethodReturnMessage(
            TApplicationException(TApplicationException.MISSING_RESULT,
                                  "%s failed: unknown result" % fn_name))
Example #3
0
  def DeserializeThriftCall(self, buf):
    """Deserialize a stream and context to a MethodReturnMessage.

    Args:
      buf - The buffer.
      ctx - The context from serialization.

    Returns:
      A MethodCallMessage.
    """

    thrift_buffer = TMemoryBuffer()
    thrift_buffer._buffer = buf
    protocol = self._protocol_factory.getProtocol(thrift_buffer)

    (fn_name, msg_type, seq_id) = protocol.readMessageBegin()
    if msg_type == TMessageType.EXCEPTION:
      x = TApplicationException()
      x.read(protocol)
      protocol.readMessageEnd()
      return MethodReturnMessage(error=x)

    result_cls = self._FindClass('%s_result' % fn_name)
    if result_cls:
      result = result_cls()
      result.read(protocol)
    else:
      result = None
    protocol.readMessageEnd()

    if not result:
      return MethodReturnMessage()
    if getattr(result, 'success', None) is not None:
      return MethodReturnMessage(return_value=result.success)

    result_spec = getattr(result_cls, 'thrift_spec', None)
    if result_spec:
      exceptions = result_spec[1:]
      for e in exceptions:
        attr_val = getattr(result, e[2], None)
        if attr_val is not None:
          return MethodReturnMessage(error=attr_val)

    return MethodReturnMessage(TApplicationException(
      TApplicationException.MISSING_RESULT, "%s failed: unknown result" % fn_name))
Example #4
0
  def SerializeThriftCall(self, msg, buf):
    """Serialize a MethodCallMessage to a stream

    Args:
      msg - The MethodCallMessage to serialize.
      buf - The buffer to serialize into.
    """
    thrift_buffer = TMemoryBuffer()
    thrift_buffer._buffer = buf
    protocol = self._protocol_factory.getProtocol(thrift_buffer)
    method, args, kwargs = msg.method, msg.args, msg.kwargs
    is_one_way = self._FindClass('%s_result' % method) is None
    args_cls = self._FindClass('%s_args' % method)
    if not args_cls:
      raise AttributeError('Unable to find args class for method %s' % method)

    protocol.writeMessageBegin(
        msg.method,
        TMessageType.ONEWAY if is_one_way else TMessageType.CALL,
        self._seq_id)
    thrift_args = args_cls(*args, **kwargs)
    thrift_args.write(protocol)
    protocol.writeMessageEnd()
Example #5
0
    def SerializeThriftCall(self, msg, buf):
        """Serialize a MethodCallMessage to a stream

    Args:
      msg - The MethodCallMessage to serialize.
      buf - The buffer to serialize into.
    """
        thrift_buffer = TMemoryBuffer()
        thrift_buffer._buffer = buf
        protocol = self._protocol_factory.getProtocol(thrift_buffer)
        method, args, kwargs = msg.method, msg.args, msg.kwargs
        is_one_way = self._FindClass('%s_result' % method) is None
        args_cls = self._FindClass('%s_args' % method)
        if not args_cls:
            raise AttributeError('Unable to find args class for method %s' %
                                 method)

        protocol.writeMessageBegin(
            msg.method,
            TMessageType.ONEWAY if is_one_way else TMessageType.CALL,
            self._seq_id)
        thrift_args = args_cls(*args, **kwargs)
        thrift_args.write(protocol)
        protocol.writeMessageEnd()
Example #6
0
    def serializeDummyProtocolToThrift(self,
                                       data: DummyProtocol,
                                       baseException: dict = None,
                                       readWith: str = None):
        if baseException is None:
            baseException = {}
        if readWith is not None:
            new1 = self.generateDummyProtocol2(data, 4, fixSuccessHeaders=True)
            try:
                a = eval(f'{readWith}_result')
                a = a()
            except AttributeError:
                a = None
            except NameError:
                a = None
            if a is not None:
                e = TMemoryBuffer()
                f = testProtocol(e)
                e._buffer = io.BytesIO(new1)
                a.read(f)
                if getattr(a, 'success', None) is not None:
                    return a.success
                if getattr(a, 'e', None) is not None:
                    code = getattr(a.e, 'code', None)
                    reason = getattr(a.e, 'reason', None)
                    parameterMap = getattr(a.e, 'parameterMap', None)
                    raise LineServiceException({}, code, reason, parameterMap,
                                               a.e)
                return None

        def _gen():
            return DummyThrift()

        def _get(a):
            return a.data if isinstance(a, DummyProtocolData) else a

        def _genFunc(a: DummyProtocolData, b, f):
            def __gen(a: DummyProtocolData, b):
                c = _gen()
                for d in a.data:
                    b(d, c)
                return c

            def __cek(a: DummyProtocolData, f):
                if a.type == 12:
                    c = __gen(a, f)
                elif a.type == 13:
                    c = {}
                    d = a.data
                    for e in d:
                        g = d[e]
                        h = e
                        if isinstance(h, DummyProtocolData):
                            h = __cek(h, f)
                        if isinstance(g, DummyProtocolData):
                            g = __cek(g, f)
                        c[h] = g
                elif a.type in (14, 15):
                    c = []
                    for d in a.data:
                        e = d
                        if isinstance(d, DummyProtocolData):
                            e = __cek(d, f)
                        c.append(e)
                else:
                    c = a.data
                return c

            c = __cek(a, f)
            setattr(b, f"val_{a.id}", c)

        a = _gen()

        def b(c, refs):
            return _genFunc(c, refs,
                            b) if type(c.data) in [list, dict] else setattr(
                                refs, f"val_{c.id}", c.data)

        if data.data is not None:
            b(data.data, a)
        if self.checkAndGetValue(a, 'val_0') is not None:
            return a.val_0
        _ecode = baseException.get('code', 1)
        _emsg = baseException.get('message', 2)
        _emeta = baseException.get('metadata', 3)
        if self.checkAndGetValue(a, 'val_1') is not None:
            raise LineServiceException(
                {}, self.checkAndGetValue(a.val_1, f'val_{_ecode}'),
                self.checkAndGetValue(a.val_1, f'val_{_emsg}'),
                self.checkAndGetValue(a.val_1, f'val_{_emeta}'), a.val_1)
        print(a)
        return None