Ejemplo n.º 1
0
    def setUp(self):
        self.itrans = TMemoryBuffer()
        self.iprot = THeaderProtocol(self.itrans)

        self.otrans = TMemoryBuffer()
        self.oprot = THeaderProtocol(self.otrans)

        self.observer = mock.Mock(spec=BaseplateObserver)
        self.server_observer = mock.Mock(spec=ServerSpanObserver)

        def _register_mock(context, server_span):
            server_span.register(self.server_observer)

        self.observer.on_server_span_created.side_effect = _register_mock

        self.logger = mock.Mock(spec=logging.Logger)
        self.server_context = TRpcConnectionContext(self.itrans, self.iprot,
                                                    self.oprot)

        baseplate = Baseplate()
        baseplate.register(self.observer)

        event_handler = BaseplateProcessorEventHandler(self.logger, baseplate)

        handler = TestHandler()
        self.processor = BaseplateService.ContextProcessor(handler)
        self.processor.setEventHandler(event_handler)
Ejemplo n.º 2
0
    def message_received(self, frame):
        # We support the deprecated FRAMED transport for old fb303
        # clients that were otherwise failing miserably.
        client_types = {
            THeaderTransport.HEADERS_CLIENT_TYPE,
            THeaderTransport.FRAMED_DEPRECATED,
        }

        ibuf = TReadOnlyBuffer(frame)
        iprot = THeaderProtocol(ibuf, client_types=client_types)
        obuf = TWriteOnlyBuffer()
        oprot = THeaderProtocol(obuf, client_types=client_types)

        try:
            yield from self.processor.process(
                iprot,
                oprot,
                self.server_context,
            )
            msg = obuf.getvalue()
            if len(msg) > 0:
                self.transport.write(msg)
        except Exception:
            logger.exception("Exception while processing request")
            self.transport.close()
Ejemplo n.º 3
0
 def flush(self):
     msg = self._writeBuffer.getvalue()
     tmi = TMemoryBuffer(msg)
     iprot = THeaderProtocol(tmi)
     fname, mtype, seqid = iprot.readMessageBegin()
     fname = fname.decode()
     self._proto.schedule_timeout(fname, seqid)
     self._trans.send_message(msg)
     self._writeBuffer = BytesIO()
Ejemplo n.º 4
0
 def flush(self):
     msg = self._writeBuffer.getvalue()
     tmi = TMemoryBuffer(msg)
     iprot = THeaderProtocol(tmi)
     fname, mtype, seqid = iprot.readMessageBegin()
     fname = fname.decode()
     self._proto.schedule_timeout(fname, seqid)
     self._trans.send_message(msg)
     self._writeBuffer = BytesIO()
Ejemplo n.º 5
0
 def flush(self):
     msg = self.getvalue()
     tmi = TReadOnlyBuffer(msg)
     iprot = THeaderProtocol(tmi)
     fname, mtype, seqid = iprot.readMessageBegin()
     fname = fname.decode()
     self._proto.schedule_timeout(fname, seqid)
     self._trans.send_message(msg)
     self.reset()
Ejemplo n.º 6
0
 def flush(self):
     msg = self.getvalue()
     tmi = TReadOnlyBuffer(msg)
     iprot = THeaderProtocol(tmi)
     fname, mtype, seqid = iprot.readMessageBegin()
     fname = fname.decode()
     self._proto.schedule_timeout(fname, seqid)
     self._trans.send_message(msg)
     self.reset()
Ejemplo n.º 7
0
    def message_received(self, frame):
        tmi = TMemoryBuffer(frame)
        iprot = THeaderProtocol(tmi)
        (fname, mtype, rseqid) = iprot.readMessageBegin()

        method = getattr(self.client, "recv_" + fname.decode(), None)
        if method is None:
            logging.error("Method " + fname + " isn't supported, bug?")
            self.transport.close()
        else:
            method(iprot, mtype, rseqid)
Ejemplo n.º 8
0
    def message_received(self, frame):
        tmi = TMemoryBuffer(frame)
        iprot = THeaderProtocol(tmi)
        (fname, mtype, rseqid) = iprot.readMessageBegin()

        method = getattr(self.client, "recv_" + fname.decode(), None)
        if method is None:
            logger.error("Method %r is not supported", method)
            self.transport.close()
        else:
            method(iprot, mtype, rseqid)
Ejemplo n.º 9
0
    def setUp(self):
        self.itrans = TMemoryBuffer()
        self.iprot = THeaderProtocol(self.itrans)

        self.otrans = TMemoryBuffer()
        self.oprot = THeaderProtocol(self.otrans)

        self.observer = mock.Mock(spec=BaseplateObserver)
        self.server_observer = mock.Mock(spec=ServerSpanObserver)

        def _register_mock(context, server_span):
            server_span.register(self.server_observer)

        self.observer.on_server_span_created.side_effect = _register_mock

        self.logger = mock.Mock(spec=logging.Logger)
        self.server_context = TRpcConnectionContext(self.itrans, self.iprot,
                                                    self.oprot)

        mock_filewatcher = mock.Mock(spec=FileWatcher)
        mock_filewatcher.get_data.return_value = {
            "secrets": {
                "secret/authentication/public-key": {
                    "type": "versioned",
                    "current": AUTH_TOKEN_PUBLIC_KEY,
                },
            },
            "vault": {
                "token": "test",
                "url": "http://vault.example.com:8200/",
            }
        }
        self.secrets = store.SecretsStore("/secrets")
        self.secrets._filewatcher = mock_filewatcher

        baseplate = Baseplate()
        baseplate.register(self.observer)

        self.edge_context_factory = EdgeRequestContextFactory(self.secrets)

        event_handler = BaseplateProcessorEventHandler(
            self.logger,
            baseplate,
            edge_context_factory=self.edge_context_factory,
        )

        handler = TestHandler()
        self.processor = TestService.ContextProcessor(handler)
        self.processor.setEventHandler(event_handler)
Ejemplo n.º 10
0
    def test_auth_headers(self):
        client_memory_trans = TMemoryBuffer()
        client_prot = THeaderProtocol(client_memory_trans)
        client_header_trans = client_prot.trans
        client_header_trans.set_header("Authentication", self.VALID_TOKEN)
        client_header_trans.set_header("Trace", "1234")
        client_header_trans.set_header("Parent", "2345")
        client_header_trans.set_header("Span", "3456")
        client_header_trans.set_header("Sampled", "1")
        client_header_trans.set_header("Flags", "1")
        client = TestService.Client(client_prot)
        try:
            client.example_simple()
        except TTransportException:
            pass  # we don't have a test response for the client
        self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

        self.processor.process(self.iprot, self.oprot, self.server_context)

        context, _ = self.observer.on_server_span_created.call_args[0]

        try:
            self.assertTrue(context.authentication.valid)
            self.assertEqual(context.authentication.account_id, "test_user_id")
        except jwt.exceptions.InvalidAlgorithmError:
            raise unittest.SkipTest("cryptography is not installed")
Ejemplo n.º 11
0
def create_thrift_client(
    eden_dir: "Optional[str]" = None,
    socket_path: "Optional[str]" = None,
    timeout: "Optional[float]" = None,
) -> "EdenClient":
    """
    Construct a thrift client to speak to the running eden server
    instance associated with the specified mount point.

    @return Returns a context manager for EdenService.Client.
    """

    if socket_path is not None:
        pass
    elif eden_dir is not None:
        socket_path = os.path.join(eden_dir, SOCKET_PATH)
    else:
        raise TypeError("one of eden_dir or socket_path is required")
    if sys.platform == "win32":
        socket = WinTSocket(unix_socket=socket_path)
    else:
        socket = TSocket(unix_socket=socket_path)

    # We used to set a default timeout here, but picking the right duration is hard,
    # and safely retrying an arbitrary thrift call may not be safe.  So we
    # just leave the client with no timeout, unless one is given.
    if timeout is None:
        timeout_ms = None
    else:
        timeout_ms = timeout * 1000
    socket.setTimeout(timeout_ms)

    transport = THeaderTransport(socket)
    protocol = THeaderProtocol(transport)
    return EdenClient(socket_path, transport, protocol)
Ejemplo n.º 12
0
    def test_no_trace_headers(self, getrandbits):
        getrandbits.return_value = 1234

        client_memory_trans = TMemoryBuffer()
        client_prot = THeaderProtocol(client_memory_trans)
        client = TestService.Client(client_prot)
        try:
            client.example_simple()
        except TTransportException:
            pass  # we don't have a test response for the client
        self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

        self.processor.process(self.iprot, self.oprot, self.server_context)

        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, None)
        self.assertEqual(server_span.id, 1234)

        self.assertEqual(self.server_observer.on_start.call_count, 1)
        self.assertEqual(self.server_observer.on_finish.call_count, 1)
        self.assertEqual(self.server_observer.on_finish.call_args[0], (None, ))
Ejemplo n.º 13
0
def create_client(
    client_klass,
    host=None,
    port=None,
    client_type=None,
    path=None,
    timeout=None,
):
    """
    Given a thrift client class, and a host/port
    return a client using HeaderTransport
    """
    from thrift.transport.TSocket import TSocket
    from thrift.protocol.THeaderProtocol import THeaderProtocol

    sock = TSocket(host=host, port=port, unix_socket=path)
    sock.setTimeout(timeout)
    protocol = THeaderProtocol(
        sock,
        client_types=[client_type]
        if client_type
        else None,  # We accept the same as our inital send_
        client_type=client_type,  # Used for the inital send_
    )
    sock.open()
    return client_klass(protocol)
Ejemplo n.º 14
0
    def test_edge_request_headers(self):
        client_memory_trans = TMemoryBuffer()
        client_prot = THeaderProtocol(client_memory_trans)
        client_header_trans = client_prot.trans
        client_header_trans.set_header("Edge-Request",
                                       SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH)
        client_header_trans.set_header("Trace", "1234")
        client_header_trans.set_header("Parent", "2345")
        client_header_trans.set_header("Span", "3456")
        client_header_trans.set_header("Sampled", "1")
        client_header_trans.set_header("Flags", "1")
        client = TestService.Client(client_prot)
        try:
            client.example_simple()
        except TTransportException:
            pass  # we don't have a test response for the client
        self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

        self.processor.process(self.iprot, self.oprot, self.server_context)

        context, _ = self.observer.on_server_span_created.call_args[0]

        try:
            self.assertEqual(context.request_context.user.id, "t2_example")
            self.assertEqual(context.request_context.user.roles, set())
            self.assertEqual(context.request_context.user.is_logged_in, True)
            self.assertEqual(context.request_context.user.loid, "t2_deadbeef")
            self.assertEqual(context.request_context.user.cookie_created_ms,
                             100000)
            self.assertEqual(context.request_context.oauth_client.id, None)
            self.assertFalse(
                context.request_context.oauth_client.is_type("third_party"))
            self.assertEqual(context.request_context.session.id, "beefdead")
        except jwt.exceptions.InvalidAlgorithmError:
            raise unittest.SkipTest("cryptography is not installed")
Ejemplo n.º 15
0
    def call_processor(self, input, client_type, protocol_type,
                       client_principal):
        try:
            # The input string has already had the header removed, but
            # the python processor will expect it to be there.  In
            # order to reconstitute the message with headers, we use
            # the THeaderProtocol object to write into a memory
            # buffer, then pass that buffer to the python processor.

            write_buf = TMemoryBuffer()
            trans = THeaderTransport(write_buf, client_types=[client_type])
            trans.set_protocol_id(protocol_type)
            trans.write(input)
            trans.flush()

            prot_buf = TMemoryBuffer(write_buf.getvalue())
            prot = THeaderProtocol(prot_buf)

            ctx = TCppConnectionContext(client_principal)

            self.processor.process(prot, prot, ctx)

            # And on the way out, we need to strip off the header,
            # because the C++ code will expect to add it.

            read_buf = TMemoryBuffer(prot_buf.getvalue())
            trans = THeaderTransport(read_buf, client_types=[client_type])
            trans.readFrame(0)

            return trans.cstringio_buf.read()
        except:
            # Don't let exceptions escape back into C++
            traceback.print_exc()
Ejemplo n.º 16
0
    def message_received(self, frame):
        tmi = TMemoryBuffer(frame)
        tmo = TMemoryBuffer()

        iprot = THeaderProtocol(tmi)
        oprot = THeaderProtocol(tmo)

        try:
            yield from self.processor.process(iprot, oprot,
                                              self.server_context)
            msg = tmo.getvalue()
            if len(msg) > 0:
                self.transport.write(msg)
        except Exception:
            logging.exception("Exception while processing request")
            self.transport.close()
Ejemplo n.º 17
0
    def test_with_headers(self):
        client_memory_trans = TMemoryBuffer()
        client_prot = THeaderProtocol(client_memory_trans)
        client_header_trans = client_prot.trans
        client_header_trans.set_header("Trace", "1234")
        client_header_trans.set_header("Parent", "2345")
        client_header_trans.set_header("Span", "3456")
        client_header_trans.set_header("Sampled", "1")
        client_header_trans.set_header("Flags", "1")
        client = TestService.Client(client_prot)
        try:
            client.example_simple()
        except TTransportException:
            pass  # we don't have a test response for the client
        self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

        self.processor.process(self.iprot, self.oprot, self.server_context)
        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, 2345)
        self.assertEqual(server_span.id, 3456)
        self.assertTrue(server_span.sampled)
        self.assertEqual(server_span.flags, 1)

        with self.assertRaises(NoAuthenticationError):
            context.request_context.user.id

        self.assertEqual(self.server_observer.on_start.call_count, 1)
        self.assertEqual(self.server_observer.on_finish.call_count, 1)
        self.assertEqual(self.server_observer.on_finish.call_args[0], (None, ))
Ejemplo n.º 18
0
    def test_with_headers(self):
        client_memory_trans = TMemoryBuffer()
        client_prot = THeaderProtocol(client_memory_trans)
        client_header_trans = client_prot.trans
        client_header_trans.set_header("Trace", "1234")
        client_header_trans.set_header("Parent", "2345")
        client_header_trans.set_header("Span", "3456")
        client_header_trans.set_header("Sampled", "1")
        client_header_trans.set_header("Flags", "1")
        client = BaseplateService.Client(client_prot)
        try:
            client.is_healthy()
        except:
            pass  # we don't have a test response for the client
        self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

        self.processor.process(self.iprot, self.oprot, self.server_context)
        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, 2345)
        self.assertEqual(server_span.id, 3456)
        self.assertTrue(server_span.sampled)
        self.assertEqual(server_span.flags, 1)

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
Ejemplo n.º 19
0
    def call_processor(self, input, headers, client_type, protocol_type,
                       context_data, callback):
        try:
            # The input string has already had the header removed, but
            # the python processor will expect it to be there.  In
            # order to reconstitute the message with headers, we use
            # the THeaderProtocol object to write into a memory
            # buffer, then pass that buffer to the python processor.

            should_sample = self._shouldSample()

            timestamps = CallTimestamps()
            timestamps.processBegin = 0
            timestamps.processEnd = 0
            if self.observer and should_sample:
                timestamps.processBegin = int(time.time() * 10**6)

            write_buf = TMemoryBuffer()
            trans = THeaderTransport(write_buf)
            trans._THeaderTransport__client_type = client_type
            trans._THeaderTransport__write_headers = headers
            trans.set_protocol_id(protocol_type)
            trans.write(input)
            trans.flush()

            prot_buf = TMemoryBuffer(write_buf.getvalue())
            prot = THeaderProtocol(prot_buf, client_types=[client_type])

            ctx = TCppConnectionContext(context_data)

            ret = self.processor.process(prot, prot, ctx)

            done_callback = partial(_ProcessorAdapter.done,
                                    prot_buf=prot_buf,
                                    client_type=client_type,
                                    callback=callback)

            if self.observer:
                if should_sample:
                    timestamps.processEnd = int(time.time() * 10**6)

                # This only bumps counters if `processBegin != 0` and
                # `processEnd != 0` and these will only be non-zero if
                # we are sampling this request.
                self.observer.callCompleted(timestamps)

            # This future is created by and returned from the processor's
            # ThreadPoolExecutor, which keeps a reference to it. So it is
            # fine for this future to end its lifecycle here.
            if isinstance(ret, Future):
                ret.add_done_callback(lambda x, d=done_callback: d())
            else:
                done_callback()
        except:
            # Don't let exceptions escape back into C++
            traceback.print_exc()
    def setUp(self):
        """Create two buffers, transports, and protocols.

        self._h_trans uses THeaderTransport
        self._f_trans uses TFuzzyHeaderTransport
        """
        cls = self.__class__

        # THeaderTransport attributes
        self._h_buf = TMemoryBuffer()
        self._h_trans = THeaderTransport(self._h_buf)
        self._h_prot = THeaderProtocol(self._h_trans)

        # TFuzzyHeaderTransport attributes
        self._f_buf = TMemoryBuffer()
        self._f_trans = TFuzzyHeaderTransport(
            self._f_buf, fuzz_fields=cls.fuzz_fields,
            fuzz_all_if_empty=False, verbose=False)
        self._f_prot = THeaderProtocol(self._f_trans)
Ejemplo n.º 21
0
    def schedule_timeout(self, fname, seqid):
        timeout = self.timeouts[fname]
        if not timeout:
            return

        tmo = TMemoryBuffer()
        thp = THeaderTransport(tmo)
        oprot = THeaderProtocol(thp)
        exc = TApplicationException(
            TApplicationException.TIMEOUT, "Call to {} timed out".format(fname)
        )
        oprot.writeMessageBegin(fname, TMessageType.EXCEPTION, seqid)
        exc.write(oprot)
        oprot.writeMessageEnd()
        thp.flush()
        timeout_task = self.loop.create_task(
            self.message_received(tmo.getvalue(), delay=timeout),
        )
        self.update_pending_tasks(seqid, timeout_task)
Ejemplo n.º 22
0
    def test_client_proxy_flow(self):
        client_memory_trans = TMemoryBuffer()
        client_prot = THeaderProtocol(client_memory_trans)

        class Pool(object):
            @contextlib.contextmanager
            def connection(self):
                yield client_prot

        client_factory = ThriftContextFactory(Pool(), TestService.Client)
        span = mock.MagicMock()
        child_span = span.make_child().__enter__()
        child_span.trace_id = 1
        child_span.parent_id = 1
        child_span.id = 1
        child_span.sampled = True
        child_span.flags = None
        # We decode the token to unicode to make sure that it is converted to
        # bytes correctly by the AuthenticationContext.  We do this because a
        # unicode token in Python 2 ends up causing a UnicodeDecodeError when
        # Thrift tries to write the header.
        unicode_token = self.VALID_TOKEN.decode()
        auth_context = AuthenticationContext(
            token=unicode_token,
            secrets=self.secrets,
        )
        edge_context = EdgeRequestContext(
            authentication_context=auth_context,
            header=self.SERIALIZED_REQUEST_HEADER,
        )
        edge_context.attach_context(child_span.context)
        client = client_factory.make_object_for_context("test", span)
        try:
            client.example_simple()
        except TTransportException:
            pass  # we don't have a test response for the client
        self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

        self.processor.process(self.iprot, self.oprot, self.server_context)

        context, _ = self.observer.on_server_span_created.call_args[0]

        try:
            self.assertEqual(context.request_context.user.id, "test_user_id")
            self.assertEqual(context.request_context.user.roles, set())
            self.assertEqual(context.request_context.user.is_logged_in, True)
            self.assertEqual(context.request_context.user.loid, "t2_deadbeef")
            self.assertEqual(context.request_context.user.cookie_created_ms,
                             100000)
            self.assertEqual(context.request_context.oauth_client.id, None)
            self.assertFalse(
                context.request_context.oauth_client.is_type("third_party"))
            self.assertEqual(context.request_context.session.id, "beefdead")
        except jwt.exceptions.InvalidAlgorithmError:
            raise unittest.SkipTest("cryptography is not installed")
Ejemplo n.º 23
0
 def __init__(self, eden_dir=None, mounted_path=None):
     self._eden_dir = eden_dir
     if mounted_path:
         sock_path = os.path.join(mounted_path, '.eden', 'socket')
     else:
         sock_path = os.path.join(self._eden_dir, SOCKET_PATH)
     self._socket = TSocket(unix_socket=sock_path)
     self._socket.setTimeout(60000)  # in milliseconds
     self._transport = THeaderTransport(self._socket)
     self._protocol = THeaderProtocol(self._transport)
     super(EdenClient, self).__init__(self._protocol)
Ejemplo n.º 24
0
    def __init__(self, host, port=None, timeout=2.0):
        self.host = host

        self._socket = TSocket(host, self.DEFAULT_PORT)
        # TSocket.setTimeout() takes a value in milliseconds
        self._socket.setTimeout(timeout * 1000)
        self._transport = THeaderTransport(self._socket)
        self._protocol = THeaderProtocol(self._transport)

        self._transport.open()
        QsfpService.Client.__init__(self, self._protocol)
Ejemplo n.º 25
0
    def __init__(self, host, port=None, timeout=5.0):
        self.host = host
        if port is None:
            port = self.DEFAULT_PORT

        self._socket = TSocket(host, port)
        # TSocket.setTimeout() takes a value in milliseconds
        self._socket.setTimeout(timeout * 1000)
        self._transport = THeaderTransport(self._socket)
        self._protocol = THeaderProtocol(self._transport)
        self._transport.open()
        PcapPushSubscriber.Client.__init__(self, self._protocol)
Ejemplo n.º 26
0
    def __init__(self, host, port=None, timeout=5.0):
        self.host = host
        if port is None:
            port = self.DEFAULT_PORT

        self._socket = TSocket(host, port)
        self._socket.setTimeout(timeout * 1000)
        self._transport = THeaderTransport(self._socket)
        self._protocol = THeaderProtocol(self._transport)

        self._transport.open()
        NetlinkManagerService.Client.__init__(self, self._protocol)
Ejemplo n.º 27
0
    def __init__(self, host, port=None, timeout=10.0):
        # In a box with all 32 QSFP ports populated, it takes about 7.5s right
        # now to read all 32 QSFP ports. So, put the defaut timeout to 10s.
        self.host = host

        self._socket = TSocket(host, self.DEFAULT_PORT)
        # TSocket.setTimeout() takes a value in milliseconds
        self._socket.setTimeout(timeout * 1000)
        self._transport = THeaderTransport(self._socket)
        self._protocol = THeaderProtocol(self._transport)

        self._transport.open()
        QsfpService.Client.__init__(self, self._protocol)
Ejemplo n.º 28
0
    def message_received(self, frame, delay=0):
        tmi = TReadOnlyBuffer(frame)
        iprot = THeaderProtocol(tmi)
        (fname, mtype, rseqid) = iprot.readMessageBegin()

        if delay:
            yield from asyncio.sleep(delay)
        else:
            try:
                timeout_task = self.pending_tasks.pop(rseqid)
            except KeyError:
                # Task doesn't have a timeout or has already been cancelled
                # and pruned from `pending_tasks`.
                pass
            else:
                timeout_task.cancel()

        method = getattr(self.client, "recv_" + fname.decode(), None)
        if method is None:
            logger.error("Method %r is not supported", method)
            self.transport.abort()
        else:
            method(iprot, mtype, rseqid)
Ejemplo n.º 29
0
    def message_received(self, frame, delay=0):
        tmi = TReadOnlyBuffer(frame)
        iprot = THeaderProtocol(tmi)
        (fname, mtype, rseqid) = iprot.readMessageBegin()

        if delay:
            yield from asyncio.sleep(delay)
        else:
            try:
                timeout_task = self.pending_tasks.pop(rseqid)
            except KeyError:
                # Task doesn't have a timeout or has already been cancelled
                # and pruned from `pending_tasks`.
                pass
            else:
                timeout_task.cancel()

        method = getattr(self.client, "recv_" + fname.decode(), None)
        if method is None:
            logger.error("Method %r is not supported", method)
            self.transport.abort()
        else:
            method(iprot, mtype, rseqid)
Ejemplo n.º 30
0
 def __init__(self, eden_dir=None, socket_path=None):
     if socket_path is not None:
         self._socket_path = socket_path
     elif eden_dir is not None:
         self._socket_path = os.path.join(eden_dir, SOCKET_PATH)
     else:
         raise TypeError("one of eden_dir or socket_path is required")
     self._socket = TSocket(unix_socket=self._socket_path)
     # We used to set a timeout here, but picking the right duration is hard,
     # and safely retrying an arbitrary thrift call may not be safe.  So we
     # just leave the client with no timeout.
     # self._socket.setTimeout(60000)  # in milliseconds
     self._transport = THeaderTransport(self._socket)
     self._protocol = THeaderProtocol(self._transport)
     super(EdenClient, self).__init__(self._protocol)
Ejemplo n.º 31
0
    def test_expected_exception_not_passed_to_server_span_finish(self):
        client_memory_trans = TMemoryBuffer()
        client_prot = THeaderProtocol(client_memory_trans)
        client = TestService.Client(client_prot)
        try:
            client.example_throws(crash=False)
        except TTransportException:
            pass  # we don't have a test response for the client
        self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

        self.processor.process(self.iprot, self.oprot, self.server_context)

        self.assertEqual(self.server_observer.on_start.call_count, 1)
        self.assertEqual(self.server_observer.on_finish.call_count, 1)
        self.assertEqual(self.server_observer.on_finish.call_args[0], (None, ))
Ejemplo n.º 32
0
 def __init__(self, eden_dir=None, mounted_path=None):
     self._eden_dir = eden_dir
     if mounted_path:
         sock_path = os.readlink(
             os.path.join(mounted_path, '.eden', 'socket'))
     else:
         sock_path = os.path.join(self._eden_dir, SOCKET_PATH)
     self._socket = TSocket(unix_socket=sock_path)
     # We used to set a timeout here, but picking the right duration is hard,
     # and safely retrying an arbitrary thrift call may not be safe.  So we
     # just leave the client with no timeout.
     #self._socket.setTimeout(60000)  # in milliseconds
     self._transport = THeaderTransport(self._socket)
     self._protocol = THeaderProtocol(self._transport)
     super(EdenClient, self).__init__(self._protocol)
Ejemplo n.º 33
0
    def test_client_proxy_flow(self):
        client_memory_trans = TMemoryBuffer()
        client_prot = THeaderProtocol(client_memory_trans)

        class Pool(object):
            @contextlib.contextmanager
            def connection(self):
                yield client_prot

        client_factory = ThriftContextFactory(Pool(), TestService.Client)
        span = mock.MagicMock()
        child_span = span.make_child().__enter__()
        child_span.trace_id = 1
        child_span.parent_id = 1
        child_span.id = 1
        child_span.sampled = True
        child_span.flags = None

        edge_context = self.edge_context_factory.from_upstream(
            SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH)
        edge_context.attach_context(child_span.context)
        client = client_factory.make_object_for_context("test", span)
        try:
            client.example_simple()
        except TTransportException:
            pass  # we don't have a test response for the client
        self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

        self.processor.process(self.iprot, self.oprot, self.server_context)

        context, _ = self.observer.on_server_span_created.call_args[0]

        try:
            self.assertEqual(context.request_context.user.id, "t2_example")
            self.assertEqual(context.request_context.user.roles, set())
            self.assertEqual(context.request_context.user.is_logged_in, True)
            self.assertEqual(context.request_context.user.loid, "t2_deadbeef")
            self.assertEqual(context.request_context.user.cookie_created_ms,
                             100000)
            self.assertEqual(context.request_context.oauth_client.id, None)
            self.assertFalse(
                context.request_context.oauth_client.is_type("third_party"))
            self.assertEqual(context.request_context.session.id, "beefdead")
        except jwt.exceptions.InvalidAlgorithmError:
            raise unittest.SkipTest("cryptography is not installed")
Ejemplo n.º 34
0
    def call_processor(self, input, headers, client_type, protocol_type,
                       context_data):
        try:
            # The input string has already had the header removed, but
            # the python processor will expect it to be there.  In
            # order to reconstitute the message with headers, we use
            # the THeaderProtocol object to write into a memory
            # buffer, then pass that buffer to the python processor.

            write_buf = TMemoryBuffer()
            trans = THeaderTransport(write_buf)
            trans._THeaderTransport__client_type = client_type
            trans._THeaderTransport__write_headers = headers
            trans.set_protocol_id(protocol_type)
            trans.write(input)
            trans.flush()

            prot_buf = TMemoryBuffer(write_buf.getvalue())
            prot = THeaderProtocol(prot_buf, client_types=[client_type])

            ctx = TCppConnectionContext(context_data)

            self.processor.process(prot, prot, ctx)

            # Check for empty result. If so, return an empty string
            # here.  This is probably a oneway request, but we can't
            # reliably tell.  The C++ code does basically the same
            # thing.

            response = prot_buf.getvalue()
            if len(response) == 0:
                return response

            # And on the way out, we need to strip off the header,
            # because the C++ code will expect to add it.

            read_buf = TMemoryBuffer(response)
            trans = THeaderTransport(read_buf, client_types=[client_type])
            trans.readFrame(len(response))

            return trans.cstringio_buf.read()
        except:
            # Don't let exceptions escape back into C++
            traceback.print_exc()
Ejemplo n.º 35
0
    def test_no_headers(self):
        client_memory_trans = TMemoryBuffer()
        client_prot = THeaderProtocol(client_memory_trans)
        client = BaseplateService.Client(client_prot)
        try:
            client.is_healthy()
        except:
            pass  # we don't have a test response for the client
        self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

        self.processor.process(self.iprot, self.oprot, self.server_context)

        self.assertEqual(self.observer.on_root_span_created.call_count, 1)

        context, root_span = self.observer.on_root_span_created.call_args[0]
        self.assertEqual(root_span.trace_id, "no-trace")
        self.assertEqual(root_span.parent_id, "no-parent")
        self.assertEqual(root_span.id, "no-span")

        self.assertTrue(self.root_observer.on_start.called)
        self.assertTrue(self.root_observer.on_stop.called)
Ejemplo n.º 36
0
 def __init__(self, probe, *args, **kwargs):
     THeaderProtocol.__init__(self, *args, **kwargs)
     self.probe = probe
Ejemplo n.º 37
0
 def readMessageBegin(self):
     self.probe.touch()
     return THeaderProtocol.readMessageBegin(self)