예제 #1
0
    def __init__(self, s: socket):
        self.s = s
        self.file = self.s.makefile("rwb")
        self.wrapper = IOWrapper(self.file, socket_=s)
        self.current_id = -1

        self.handshake()
예제 #2
0
class TChannelConnection:

    file: IO
    wrapper: IOWrapper
    s: socket.socket

    @classmethod
    def open(cls,
             host: object,
             port: object,
             timeout: int = None) -> TChannelConnection:
        s: socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.settimeout(timeout)
        s.connect((host, port))
        return cls(s)

    def __init__(self, s: socket):
        self.s = s
        self.file = self.s.makefile("rwb")
        self.wrapper = IOWrapper(self.file, socket_=s)
        self.current_id = -1

        self.handshake()

    def set_next_timeout_cb(self, cb: Callable):
        self.wrapper.set_next_timeout_cb(cb)

    def new_id(self):
        self.current_id += 1
        return self.current_id

    def handshake(self):
        req: InitReqFrame = InitReqFrame()
        req.id = self.new_id()
        req.headers.d["host_port"] = "0.0.0.0:0"
        req.headers.d["process_name"] = "python-process"
        self.write_frame(req)

        res = self.read_frame()
        if res.TYPE != 0x02:
            raise Exception("Unexpected response from server")

    def write_frame(self, frame: Frame):
        frame.write(self.wrapper)
        self.wrapper.flush()

    def read_frame(self):
        frame = Frame.read_frame(self.wrapper)
        if isinstance(frame, ErrorFrame):
            raise TChannelException(error_frame=frame)
        return frame

    def close(self):
        self.s.close()
        self.wrapper.close()

    def call_function(self,
                      call: ThriftFunctionCall) -> ThriftFunctionResponse:
        frames = call.build_frames(self.new_id())
        for frame in frames:
            self.write_frame(frame)
        response = ThriftFunctionResponse()
        while not response.is_complete():
            frame = self.read_frame()
            if frame.TYPE not in (CallResFrame.TYPE,
                                  CallResContinueFrame.TYPE):
                raise Exception("Unexpected type: " + Frame.TYPE)
            assert isinstance(frame, FrameWithArgs)
            response.process_frame(frame)
        return response
예제 #3
0
 def build_arg1(self) -> bytes:
     f = BytesIO()
     wrapper: IOWrapper = IOWrapper(f)
     wrapper.write_string(self.method_name)
     wrapper.flush()
     return f.getvalue()
예제 #4
0
 def build_arg2(self) -> bytes:
     f = BytesIO()
     wrapper: IOWrapper = IOWrapper(f)
     h = KVHeaders(self.application_headers, 2)
     h.write_headers(wrapper)
     return f.getvalue()
예제 #5
0
 def process_arg2(self, b):
     f: BytesIO = BytesIO(b)
     wrapper: IOWrapper = IOWrapper(f)
     h: KVHeaders = KVHeaders.read_kv_headers(wrapper, 2,
                                              "ThriftFunctionResponse")
     self.application_headers = h.d