예제 #1
0
 def setUp(self):
     mock_filewatcher = mock.Mock(spec=FileWatcher)
     mock_filewatcher.get_data.return_value = {
         "secrets": {
             "secret/authentication/public-key": {
                 "type": "versioned",
                 "current": AUTH_TOKEN_PUBLIC_KEY,
             }
         },
         "vault": {"token": "test", "url": "http://vault.example.com:8200/"},
     }
     self.store = store.SecretsStore("/secrets")
     self.store._filewatcher = mock_filewatcher
     self.factory = EdgeRequestContextFactory(self.store)
예제 #2
0
    def setUp(self):
        self.itrans = TMemoryBuffer()
        self.iprot = THeaderProtocol(self.itrans)

        self.otrans = TMemoryBuffer()
        self.oprot = THeaderProtocol(self.otrans)

        self.observer = mock.Mock(spec=BaseplateObserver)
        self.server_observer = mock.Mock(spec=ServerSpanObserver)

        def _register_mock(context, server_span):
            server_span.register(self.server_observer)

        self.observer.on_server_span_created.side_effect = _register_mock

        self.logger = mock.Mock(spec=logging.Logger)
        self.server_context = TRpcConnectionContext(self.itrans, self.iprot,
                                                    self.oprot)

        mock_filewatcher = mock.Mock(spec=FileWatcher)
        mock_filewatcher.get_data.return_value = {
            "secrets": {
                "secret/authentication/public-key": {
                    "type": "versioned",
                    "current": AUTH_TOKEN_PUBLIC_KEY,
                },
            },
            "vault": {
                "token": "test",
                "url": "http://vault.example.com:8200/",
            }
        }
        self.secrets = store.SecretsStore("/secrets")
        self.secrets._filewatcher = mock_filewatcher

        baseplate = Baseplate()
        baseplate.register(self.observer)

        self.edge_context_factory = EdgeRequestContextFactory(self.secrets)

        event_handler = BaseplateProcessorEventHandler(
            self.logger,
            baseplate,
            edge_context_factory=self.edge_context_factory,
        )

        handler = TestHandler()
        self.processor = TestService.ContextProcessor(handler)
        self.processor.setEventHandler(event_handler)
예제 #3
0
    def setUp(self):
        configurator = Configurator()
        configurator.add_route("example", "/example", request_method="GET")
        configurator.add_route("trace_context", "/trace_context", request_method="GET")

        configurator.add_view(
            example_application, route_name="example", renderer="json")

        configurator.add_view(
            local_tracing_within_context, route_name="trace_context", renderer="json")

        configurator.add_view(
            render_exception_view,
            context=ControlFlowException,
            renderer="json",
        )

        configurator.add_view(
            render_bad_exception_view,
            context=ControlFlowException2,
            renderer="json",
        )

        mock_filewatcher = mock.Mock(spec=FileWatcher)
        mock_filewatcher.get_data.return_value = {
            "secrets": {
                "secret/authentication/public-key": {
                    "type": "versioned",
                    "current": AUTH_TOKEN_PUBLIC_KEY,
                },
            },
            "vault": {
                "token": "test",
                "url": "http://vault.example.com:8200/",
            }
        }
        secrets = store.SecretsStore("/secrets")
        secrets._filewatcher = mock_filewatcher

        self.observer = mock.Mock(spec=BaseplateObserver)
        self.server_observer = mock.Mock(spec=ServerSpanObserver)
        def _register_mock(context, server_span):
            server_span.register(self.server_observer)
        self.observer.on_server_span_created.side_effect = _register_mock

        self.baseplate = Baseplate()
        self.baseplate.register(self.observer)
        self.baseplate_configurator = BaseplateConfigurator(
            self.baseplate,
            trust_trace_headers=True,
            edge_context_factory=EdgeRequestContextFactory(secrets),
        )
        configurator.include(self.baseplate_configurator.includeme)
        self.context_init_event_subscriber = mock.Mock()
        configurator.add_subscriber(self.context_init_event_subscriber, ServerSpanInitialized)
        app = configurator.make_wsgi_app()
        self.test_app = webtest.TestApp(app)
예제 #4
0
class ThriftTests(unittest.TestCase):
    def setUp(self):
        self.itrans = TMemoryBuffer()
        self.iprot = THeaderProtocol(self.itrans)

        self.otrans = TMemoryBuffer()
        self.oprot = THeaderProtocol(self.otrans)

        self.observer = mock.Mock(spec=BaseplateObserver)
        self.server_observer = mock.Mock(spec=ServerSpanObserver)

        def _register_mock(context, server_span):
            server_span.register(self.server_observer)

        self.observer.on_server_span_created.side_effect = _register_mock

        self.logger = mock.Mock(spec=logging.Logger)
        self.server_context = TRpcConnectionContext(self.itrans, self.iprot,
                                                    self.oprot)

        mock_filewatcher = mock.Mock(spec=FileWatcher)
        mock_filewatcher.get_data.return_value = {
            "secrets": {
                "secret/authentication/public-key": {
                    "type": "versioned",
                    "current": AUTH_TOKEN_PUBLIC_KEY,
                },
            },
            "vault": {
                "token": "test",
                "url": "http://vault.example.com:8200/",
            }
        }
        self.secrets = store.SecretsStore("/secrets")
        self.secrets._filewatcher = mock_filewatcher

        baseplate = Baseplate()
        baseplate.register(self.observer)

        self.edge_context_factory = EdgeRequestContextFactory(self.secrets)

        event_handler = BaseplateProcessorEventHandler(
            self.logger,
            baseplate,
            edge_context_factory=self.edge_context_factory,
        )

        handler = TestHandler()
        self.processor = TestService.ContextProcessor(handler)
        self.processor.setEventHandler(event_handler)

    @mock.patch("random.getrandbits")
    def test_no_trace_headers(self, getrandbits):
        getrandbits.return_value = 1234

        client_memory_trans = TMemoryBuffer()
        client_prot = THeaderProtocol(client_memory_trans)
        client = TestService.Client(client_prot)
        try:
            client.example_simple()
        except TTransportException:
            pass  # we don't have a test response for the client
        self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

        self.processor.process(self.iprot, self.oprot, self.server_context)

        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, None)
        self.assertEqual(server_span.id, 1234)

        self.assertEqual(self.server_observer.on_start.call_count, 1)
        self.assertEqual(self.server_observer.on_finish.call_count, 1)
        self.assertEqual(self.server_observer.on_finish.call_args[0], (None, ))

    def test_with_headers(self):
        client_memory_trans = TMemoryBuffer()
        client_prot = THeaderProtocol(client_memory_trans)
        client_header_trans = client_prot.trans
        client_header_trans.set_header("Trace", "1234")
        client_header_trans.set_header("Parent", "2345")
        client_header_trans.set_header("Span", "3456")
        client_header_trans.set_header("Sampled", "1")
        client_header_trans.set_header("Flags", "1")
        client = TestService.Client(client_prot)
        try:
            client.example_simple()
        except TTransportException:
            pass  # we don't have a test response for the client
        self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

        self.processor.process(self.iprot, self.oprot, self.server_context)
        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, 2345)
        self.assertEqual(server_span.id, 3456)
        self.assertTrue(server_span.sampled)
        self.assertEqual(server_span.flags, 1)

        with self.assertRaises(NoAuthenticationError):
            context.request_context.user.id

        self.assertEqual(self.server_observer.on_start.call_count, 1)
        self.assertEqual(self.server_observer.on_finish.call_count, 1)
        self.assertEqual(self.server_observer.on_finish.call_args[0], (None, ))

    def test_edge_request_headers(self):
        client_memory_trans = TMemoryBuffer()
        client_prot = THeaderProtocol(client_memory_trans)
        client_header_trans = client_prot.trans
        client_header_trans.set_header("Edge-Request",
                                       SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH)
        client_header_trans.set_header("Trace", "1234")
        client_header_trans.set_header("Parent", "2345")
        client_header_trans.set_header("Span", "3456")
        client_header_trans.set_header("Sampled", "1")
        client_header_trans.set_header("Flags", "1")
        client = TestService.Client(client_prot)
        try:
            client.example_simple()
        except TTransportException:
            pass  # we don't have a test response for the client
        self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

        self.processor.process(self.iprot, self.oprot, self.server_context)

        context, _ = self.observer.on_server_span_created.call_args[0]

        try:
            self.assertEqual(context.request_context.user.id, "t2_example")
            self.assertEqual(context.request_context.user.roles, set())
            self.assertEqual(context.request_context.user.is_logged_in, True)
            self.assertEqual(context.request_context.user.loid, "t2_deadbeef")
            self.assertEqual(context.request_context.user.cookie_created_ms,
                             100000)
            self.assertEqual(context.request_context.oauth_client.id, None)
            self.assertFalse(
                context.request_context.oauth_client.is_type("third_party"))
            self.assertEqual(context.request_context.session.id, "beefdead")
        except jwt.exceptions.InvalidAlgorithmError:
            raise unittest.SkipTest("cryptography is not installed")

    def test_expected_exception_not_passed_to_server_span_finish(self):
        client_memory_trans = TMemoryBuffer()
        client_prot = THeaderProtocol(client_memory_trans)
        client = TestService.Client(client_prot)
        try:
            client.example_throws(crash=False)
        except TTransportException:
            pass  # we don't have a test response for the client
        self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

        self.processor.process(self.iprot, self.oprot, self.server_context)

        self.assertEqual(self.server_observer.on_start.call_count, 1)
        self.assertEqual(self.server_observer.on_finish.call_count, 1)
        self.assertEqual(self.server_observer.on_finish.call_args[0], (None, ))

    def test_unexpected_exception_passed_to_server_span_finish(self):
        client_memory_trans = TMemoryBuffer()
        client_prot = THeaderProtocol(client_memory_trans)
        client = TestService.Client(client_prot)
        try:
            client.example_throws(crash=True)
        except TTransportException:
            pass  # we don't have a test response for the client
        self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

        self.processor.process(self.iprot, self.oprot, self.server_context)

        self.assertEqual(self.server_observer.on_start.call_count, 1)
        self.assertEqual(self.server_observer.on_finish.call_count, 1)
        _, captured_exc, _ = self.server_observer.on_finish.call_args[0][0]
        self.assertIsInstance(captured_exc, UnexpectedException)

    def test_client_proxy_flow(self):
        client_memory_trans = TMemoryBuffer()
        client_prot = THeaderProtocol(client_memory_trans)

        class Pool(object):
            @contextlib.contextmanager
            def connection(self):
                yield client_prot

        client_factory = ThriftContextFactory(Pool(), TestService.Client)
        span = mock.MagicMock()
        child_span = span.make_child().__enter__()
        child_span.trace_id = 1
        child_span.parent_id = 1
        child_span.id = 1
        child_span.sampled = True
        child_span.flags = None

        edge_context = self.edge_context_factory.from_upstream(
            SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH)
        edge_context.attach_context(child_span.context)
        client = client_factory.make_object_for_context("test", span)
        try:
            client.example_simple()
        except TTransportException:
            pass  # we don't have a test response for the client
        self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

        self.processor.process(self.iprot, self.oprot, self.server_context)

        context, _ = self.observer.on_server_span_created.call_args[0]

        try:
            self.assertEqual(context.request_context.user.id, "t2_example")
            self.assertEqual(context.request_context.user.roles, set())
            self.assertEqual(context.request_context.user.is_logged_in, True)
            self.assertEqual(context.request_context.user.loid, "t2_deadbeef")
            self.assertEqual(context.request_context.user.cookie_created_ms,
                             100000)
            self.assertEqual(context.request_context.oauth_client.id, None)
            self.assertFalse(
                context.request_context.oauth_client.is_type("third_party"))
            self.assertEqual(context.request_context.session.id, "beefdead")
        except jwt.exceptions.InvalidAlgorithmError:
            raise unittest.SkipTest("cryptography is not installed")
예제 #5
0
class EdgeRequestContextTests(unittest.TestCase):
    LOID_ID = "t2_deadbeef"
    LOID_CREATED_MS = 100000
    SESSION_ID = "beefdead"

    def setUp(self):
        mock_filewatcher = mock.Mock(spec=FileWatcher)
        mock_filewatcher.get_data.return_value = {
            "secrets": {
                "secret/authentication/public-key": {
                    "type": "versioned",
                    "current": AUTH_TOKEN_PUBLIC_KEY,
                },
            },
            "vault": {
                "token": "test",
                "url": "http://vault.example.com:8200/",
            }
        }
        self.store = store.SecretsStore("/secrets")
        self.store._filewatcher = mock_filewatcher
        self.factory = EdgeRequestContextFactory(self.store)

    def test_create(self):
        request_context = self.factory.new(
            authentication_token=AUTH_TOKEN_VALID,
            loid_id=self.LOID_ID,
            loid_created_ms=self.LOID_CREATED_MS,
            session_id=self.SESSION_ID,
        )
        self.assertIsNot(request_context._t_request, None)
        self.assertEqual(request_context._header,
                         SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH)

    def test_create_validation(self):
        with self.assertRaises(ValueError):
            self.factory.new(
                authentication_token=None,
                loid_id="abc123",
                loid_created_ms=self.LOID_CREATED_MS,
                session_id=self.SESSION_ID,
            )

    def test_create_empty_context(self):
        request_context = self.factory.new()
        self.assertEqual(request_context._header,
                         b'\x0c\x00\x01\x00\x0c\x00\x02\x00\x00')

    def test_logged_out_user(self):
        request_context = self.factory.from_upstream(
            SERIALIZED_EDGECONTEXT_WITH_NO_AUTH)

        with self.assertRaises(NoAuthenticationError):
            request_context.user.id
        with self.assertRaises(NoAuthenticationError):
            request_context.user.roles

        self.assertFalse(request_context.user.is_logged_in)
        self.assertEqual(request_context.user.loid, self.LOID_ID)
        self.assertEqual(request_context.user.cookie_created_ms,
                         self.LOID_CREATED_MS)

        with self.assertRaises(NoAuthenticationError):
            request_context.oauth_client.id
        with self.assertRaises(NoAuthenticationError):
            request_context.oauth_client.is_type("third_party")

        self.assertEqual(request_context.session.id, self.SESSION_ID)
        self.assertEqual(
            request_context.event_fields(),
            {
                "user_id": self.LOID_ID,
                "logged_in": False,
                "cookie_created_timestamp": self.LOID_CREATED_MS,
                "session_id": self.SESSION_ID,
                "oauth_client_id": None,
            },
        )

    @unittest.skipIf(not cryptography_installed, "cryptography not installed")
    def test_logged_in_user(self):
        request_context = self.factory.from_upstream(
            SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH)

        self.assertEqual(request_context.user.id, "t2_example")
        self.assertTrue(request_context.user.is_logged_in)
        self.assertEqual(request_context.user.loid, self.LOID_ID)
        self.assertEqual(request_context.user.cookie_created_ms,
                         self.LOID_CREATED_MS)
        self.assertEqual(request_context.user.roles, set())
        self.assertFalse(request_context.user.has_role("test"))
        self.assertIs(request_context.oauth_client.id, None)
        self.assertFalse(request_context.oauth_client.is_type("third_party"))
        self.assertEqual(request_context.session.id, self.SESSION_ID)
        self.assertEqual(
            request_context.event_fields(),
            {
                "user_id": "t2_example",
                "logged_in": True,
                "cookie_created_timestamp": self.LOID_CREATED_MS,
                "session_id": self.SESSION_ID,
                "oauth_client_id": None,
            },
        )

    @unittest.skipIf(not cryptography_installed, "cryptography not installed")
    def test_expired_token(self):
        request_context = self.factory.from_upstream(
            SERIALIZED_EDGECONTEXT_WITH_EXPIRED_AUTH)

        with self.assertRaises(NoAuthenticationError):
            request_context.user.id
        with self.assertRaises(NoAuthenticationError):
            request_context.user.roles
        with self.assertRaises(NoAuthenticationError):
            request_context.oauth_client.id
        with self.assertRaises(NoAuthenticationError):
            request_context.oauth_client.is_type("third_party")

        self.assertFalse(request_context.user.is_logged_in)
        self.assertEqual(request_context.user.loid, self.LOID_ID)
        self.assertEqual(request_context.user.cookie_created_ms,
                         self.LOID_CREATED_MS)
        self.assertEqual(request_context.session.id, self.SESSION_ID)
        self.assertEqual(
            request_context.event_fields(),
            {
                "user_id": self.LOID_ID,
                "logged_in": False,
                "cookie_created_timestamp": self.LOID_CREATED_MS,
                "session_id": self.SESSION_ID,
                "oauth_client_id": None,
            },
        )

    @unittest.skipIf(not cryptography_installed, "cryptography not installed")
    def test_anonymous_token(self):
        request_context = self.factory.from_upstream(
            SERIALIZED_EDGECONTEXT_WITH_ANON_AUTH)

        with self.assertRaises(NoAuthenticationError):
            request_context.user.id
        self.assertFalse(request_context.user.is_logged_in)
        self.assertEqual(request_context.user.loid, self.LOID_ID)
        self.assertEqual(request_context.user.cookie_created_ms,
                         self.LOID_CREATED_MS)
        self.assertEqual(request_context.session.id, self.SESSION_ID)
        self.assertTrue(request_context.user.has_role("anonymous"))