def setUp(self): mock_filewatcher = mock.Mock(spec=FileWatcher) mock_filewatcher.get_data.return_value = { "secrets": { "secret/authentication/public-key": { "type": "versioned", "current": AUTH_TOKEN_PUBLIC_KEY, } }, "vault": {"token": "test", "url": "http://vault.example.com:8200/"}, } self.store = store.SecretsStore("/secrets") self.store._filewatcher = mock_filewatcher self.factory = EdgeRequestContextFactory(self.store)
def setUp(self): self.itrans = TMemoryBuffer() self.iprot = THeaderProtocol(self.itrans) self.otrans = TMemoryBuffer() self.oprot = THeaderProtocol(self.otrans) self.observer = mock.Mock(spec=BaseplateObserver) self.server_observer = mock.Mock(spec=ServerSpanObserver) def _register_mock(context, server_span): server_span.register(self.server_observer) self.observer.on_server_span_created.side_effect = _register_mock self.logger = mock.Mock(spec=logging.Logger) self.server_context = TRpcConnectionContext(self.itrans, self.iprot, self.oprot) mock_filewatcher = mock.Mock(spec=FileWatcher) mock_filewatcher.get_data.return_value = { "secrets": { "secret/authentication/public-key": { "type": "versioned", "current": AUTH_TOKEN_PUBLIC_KEY, }, }, "vault": { "token": "test", "url": "http://vault.example.com:8200/", } } self.secrets = store.SecretsStore("/secrets") self.secrets._filewatcher = mock_filewatcher baseplate = Baseplate() baseplate.register(self.observer) self.edge_context_factory = EdgeRequestContextFactory(self.secrets) event_handler = BaseplateProcessorEventHandler( self.logger, baseplate, edge_context_factory=self.edge_context_factory, ) handler = TestHandler() self.processor = TestService.ContextProcessor(handler) self.processor.setEventHandler(event_handler)
def setUp(self): configurator = Configurator() configurator.add_route("example", "/example", request_method="GET") configurator.add_route("trace_context", "/trace_context", request_method="GET") configurator.add_view( example_application, route_name="example", renderer="json") configurator.add_view( local_tracing_within_context, route_name="trace_context", renderer="json") configurator.add_view( render_exception_view, context=ControlFlowException, renderer="json", ) configurator.add_view( render_bad_exception_view, context=ControlFlowException2, renderer="json", ) mock_filewatcher = mock.Mock(spec=FileWatcher) mock_filewatcher.get_data.return_value = { "secrets": { "secret/authentication/public-key": { "type": "versioned", "current": AUTH_TOKEN_PUBLIC_KEY, }, }, "vault": { "token": "test", "url": "http://vault.example.com:8200/", } } secrets = store.SecretsStore("/secrets") secrets._filewatcher = mock_filewatcher self.observer = mock.Mock(spec=BaseplateObserver) self.server_observer = mock.Mock(spec=ServerSpanObserver) def _register_mock(context, server_span): server_span.register(self.server_observer) self.observer.on_server_span_created.side_effect = _register_mock self.baseplate = Baseplate() self.baseplate.register(self.observer) self.baseplate_configurator = BaseplateConfigurator( self.baseplate, trust_trace_headers=True, edge_context_factory=EdgeRequestContextFactory(secrets), ) configurator.include(self.baseplate_configurator.includeme) self.context_init_event_subscriber = mock.Mock() configurator.add_subscriber(self.context_init_event_subscriber, ServerSpanInitialized) app = configurator.make_wsgi_app() self.test_app = webtest.TestApp(app)
class ThriftTests(unittest.TestCase): def setUp(self): self.itrans = TMemoryBuffer() self.iprot = THeaderProtocol(self.itrans) self.otrans = TMemoryBuffer() self.oprot = THeaderProtocol(self.otrans) self.observer = mock.Mock(spec=BaseplateObserver) self.server_observer = mock.Mock(spec=ServerSpanObserver) def _register_mock(context, server_span): server_span.register(self.server_observer) self.observer.on_server_span_created.side_effect = _register_mock self.logger = mock.Mock(spec=logging.Logger) self.server_context = TRpcConnectionContext(self.itrans, self.iprot, self.oprot) mock_filewatcher = mock.Mock(spec=FileWatcher) mock_filewatcher.get_data.return_value = { "secrets": { "secret/authentication/public-key": { "type": "versioned", "current": AUTH_TOKEN_PUBLIC_KEY, }, }, "vault": { "token": "test", "url": "http://vault.example.com:8200/", } } self.secrets = store.SecretsStore("/secrets") self.secrets._filewatcher = mock_filewatcher baseplate = Baseplate() baseplate.register(self.observer) self.edge_context_factory = EdgeRequestContextFactory(self.secrets) event_handler = BaseplateProcessorEventHandler( self.logger, baseplate, edge_context_factory=self.edge_context_factory, ) handler = TestHandler() self.processor = TestService.ContextProcessor(handler) self.processor.setEventHandler(event_handler) @mock.patch("random.getrandbits") def test_no_trace_headers(self, getrandbits): getrandbits.return_value = 1234 client_memory_trans = TMemoryBuffer() client_prot = THeaderProtocol(client_memory_trans) client = TestService.Client(client_prot) try: client.example_simple() except TTransportException: pass # we don't have a test response for the client self.itrans._readBuffer = StringIO(client_memory_trans.getvalue()) self.processor.process(self.iprot, self.oprot, self.server_context) self.assertEqual(self.observer.on_server_span_created.call_count, 1) context, server_span = self.observer.on_server_span_created.call_args[ 0] self.assertEqual(server_span.trace_id, 1234) self.assertEqual(server_span.parent_id, None) self.assertEqual(server_span.id, 1234) self.assertEqual(self.server_observer.on_start.call_count, 1) self.assertEqual(self.server_observer.on_finish.call_count, 1) self.assertEqual(self.server_observer.on_finish.call_args[0], (None, )) def test_with_headers(self): client_memory_trans = TMemoryBuffer() client_prot = THeaderProtocol(client_memory_trans) client_header_trans = client_prot.trans client_header_trans.set_header("Trace", "1234") client_header_trans.set_header("Parent", "2345") client_header_trans.set_header("Span", "3456") client_header_trans.set_header("Sampled", "1") client_header_trans.set_header("Flags", "1") client = TestService.Client(client_prot) try: client.example_simple() except TTransportException: pass # we don't have a test response for the client self.itrans._readBuffer = StringIO(client_memory_trans.getvalue()) self.processor.process(self.iprot, self.oprot, self.server_context) self.assertEqual(self.observer.on_server_span_created.call_count, 1) context, server_span = self.observer.on_server_span_created.call_args[ 0] self.assertEqual(server_span.trace_id, 1234) self.assertEqual(server_span.parent_id, 2345) self.assertEqual(server_span.id, 3456) self.assertTrue(server_span.sampled) self.assertEqual(server_span.flags, 1) with self.assertRaises(NoAuthenticationError): context.request_context.user.id self.assertEqual(self.server_observer.on_start.call_count, 1) self.assertEqual(self.server_observer.on_finish.call_count, 1) self.assertEqual(self.server_observer.on_finish.call_args[0], (None, )) def test_edge_request_headers(self): client_memory_trans = TMemoryBuffer() client_prot = THeaderProtocol(client_memory_trans) client_header_trans = client_prot.trans client_header_trans.set_header("Edge-Request", SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH) client_header_trans.set_header("Trace", "1234") client_header_trans.set_header("Parent", "2345") client_header_trans.set_header("Span", "3456") client_header_trans.set_header("Sampled", "1") client_header_trans.set_header("Flags", "1") client = TestService.Client(client_prot) try: client.example_simple() except TTransportException: pass # we don't have a test response for the client self.itrans._readBuffer = StringIO(client_memory_trans.getvalue()) self.processor.process(self.iprot, self.oprot, self.server_context) context, _ = self.observer.on_server_span_created.call_args[0] try: self.assertEqual(context.request_context.user.id, "t2_example") self.assertEqual(context.request_context.user.roles, set()) self.assertEqual(context.request_context.user.is_logged_in, True) self.assertEqual(context.request_context.user.loid, "t2_deadbeef") self.assertEqual(context.request_context.user.cookie_created_ms, 100000) self.assertEqual(context.request_context.oauth_client.id, None) self.assertFalse( context.request_context.oauth_client.is_type("third_party")) self.assertEqual(context.request_context.session.id, "beefdead") except jwt.exceptions.InvalidAlgorithmError: raise unittest.SkipTest("cryptography is not installed") def test_expected_exception_not_passed_to_server_span_finish(self): client_memory_trans = TMemoryBuffer() client_prot = THeaderProtocol(client_memory_trans) client = TestService.Client(client_prot) try: client.example_throws(crash=False) except TTransportException: pass # we don't have a test response for the client self.itrans._readBuffer = StringIO(client_memory_trans.getvalue()) self.processor.process(self.iprot, self.oprot, self.server_context) self.assertEqual(self.server_observer.on_start.call_count, 1) self.assertEqual(self.server_observer.on_finish.call_count, 1) self.assertEqual(self.server_observer.on_finish.call_args[0], (None, )) def test_unexpected_exception_passed_to_server_span_finish(self): client_memory_trans = TMemoryBuffer() client_prot = THeaderProtocol(client_memory_trans) client = TestService.Client(client_prot) try: client.example_throws(crash=True) except TTransportException: pass # we don't have a test response for the client self.itrans._readBuffer = StringIO(client_memory_trans.getvalue()) self.processor.process(self.iprot, self.oprot, self.server_context) self.assertEqual(self.server_observer.on_start.call_count, 1) self.assertEqual(self.server_observer.on_finish.call_count, 1) _, captured_exc, _ = self.server_observer.on_finish.call_args[0][0] self.assertIsInstance(captured_exc, UnexpectedException) def test_client_proxy_flow(self): client_memory_trans = TMemoryBuffer() client_prot = THeaderProtocol(client_memory_trans) class Pool(object): @contextlib.contextmanager def connection(self): yield client_prot client_factory = ThriftContextFactory(Pool(), TestService.Client) span = mock.MagicMock() child_span = span.make_child().__enter__() child_span.trace_id = 1 child_span.parent_id = 1 child_span.id = 1 child_span.sampled = True child_span.flags = None edge_context = self.edge_context_factory.from_upstream( SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH) edge_context.attach_context(child_span.context) client = client_factory.make_object_for_context("test", span) try: client.example_simple() except TTransportException: pass # we don't have a test response for the client self.itrans._readBuffer = StringIO(client_memory_trans.getvalue()) self.processor.process(self.iprot, self.oprot, self.server_context) context, _ = self.observer.on_server_span_created.call_args[0] try: self.assertEqual(context.request_context.user.id, "t2_example") self.assertEqual(context.request_context.user.roles, set()) self.assertEqual(context.request_context.user.is_logged_in, True) self.assertEqual(context.request_context.user.loid, "t2_deadbeef") self.assertEqual(context.request_context.user.cookie_created_ms, 100000) self.assertEqual(context.request_context.oauth_client.id, None) self.assertFalse( context.request_context.oauth_client.is_type("third_party")) self.assertEqual(context.request_context.session.id, "beefdead") except jwt.exceptions.InvalidAlgorithmError: raise unittest.SkipTest("cryptography is not installed")
class EdgeRequestContextTests(unittest.TestCase): LOID_ID = "t2_deadbeef" LOID_CREATED_MS = 100000 SESSION_ID = "beefdead" def setUp(self): mock_filewatcher = mock.Mock(spec=FileWatcher) mock_filewatcher.get_data.return_value = { "secrets": { "secret/authentication/public-key": { "type": "versioned", "current": AUTH_TOKEN_PUBLIC_KEY, }, }, "vault": { "token": "test", "url": "http://vault.example.com:8200/", } } self.store = store.SecretsStore("/secrets") self.store._filewatcher = mock_filewatcher self.factory = EdgeRequestContextFactory(self.store) def test_create(self): request_context = self.factory.new( authentication_token=AUTH_TOKEN_VALID, loid_id=self.LOID_ID, loid_created_ms=self.LOID_CREATED_MS, session_id=self.SESSION_ID, ) self.assertIsNot(request_context._t_request, None) self.assertEqual(request_context._header, SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH) def test_create_validation(self): with self.assertRaises(ValueError): self.factory.new( authentication_token=None, loid_id="abc123", loid_created_ms=self.LOID_CREATED_MS, session_id=self.SESSION_ID, ) def test_create_empty_context(self): request_context = self.factory.new() self.assertEqual(request_context._header, b'\x0c\x00\x01\x00\x0c\x00\x02\x00\x00') def test_logged_out_user(self): request_context = self.factory.from_upstream( SERIALIZED_EDGECONTEXT_WITH_NO_AUTH) with self.assertRaises(NoAuthenticationError): request_context.user.id with self.assertRaises(NoAuthenticationError): request_context.user.roles self.assertFalse(request_context.user.is_logged_in) self.assertEqual(request_context.user.loid, self.LOID_ID) self.assertEqual(request_context.user.cookie_created_ms, self.LOID_CREATED_MS) with self.assertRaises(NoAuthenticationError): request_context.oauth_client.id with self.assertRaises(NoAuthenticationError): request_context.oauth_client.is_type("third_party") self.assertEqual(request_context.session.id, self.SESSION_ID) self.assertEqual( request_context.event_fields(), { "user_id": self.LOID_ID, "logged_in": False, "cookie_created_timestamp": self.LOID_CREATED_MS, "session_id": self.SESSION_ID, "oauth_client_id": None, }, ) @unittest.skipIf(not cryptography_installed, "cryptography not installed") def test_logged_in_user(self): request_context = self.factory.from_upstream( SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH) self.assertEqual(request_context.user.id, "t2_example") self.assertTrue(request_context.user.is_logged_in) self.assertEqual(request_context.user.loid, self.LOID_ID) self.assertEqual(request_context.user.cookie_created_ms, self.LOID_CREATED_MS) self.assertEqual(request_context.user.roles, set()) self.assertFalse(request_context.user.has_role("test")) self.assertIs(request_context.oauth_client.id, None) self.assertFalse(request_context.oauth_client.is_type("third_party")) self.assertEqual(request_context.session.id, self.SESSION_ID) self.assertEqual( request_context.event_fields(), { "user_id": "t2_example", "logged_in": True, "cookie_created_timestamp": self.LOID_CREATED_MS, "session_id": self.SESSION_ID, "oauth_client_id": None, }, ) @unittest.skipIf(not cryptography_installed, "cryptography not installed") def test_expired_token(self): request_context = self.factory.from_upstream( SERIALIZED_EDGECONTEXT_WITH_EXPIRED_AUTH) with self.assertRaises(NoAuthenticationError): request_context.user.id with self.assertRaises(NoAuthenticationError): request_context.user.roles with self.assertRaises(NoAuthenticationError): request_context.oauth_client.id with self.assertRaises(NoAuthenticationError): request_context.oauth_client.is_type("third_party") self.assertFalse(request_context.user.is_logged_in) self.assertEqual(request_context.user.loid, self.LOID_ID) self.assertEqual(request_context.user.cookie_created_ms, self.LOID_CREATED_MS) self.assertEqual(request_context.session.id, self.SESSION_ID) self.assertEqual( request_context.event_fields(), { "user_id": self.LOID_ID, "logged_in": False, "cookie_created_timestamp": self.LOID_CREATED_MS, "session_id": self.SESSION_ID, "oauth_client_id": None, }, ) @unittest.skipIf(not cryptography_installed, "cryptography not installed") def test_anonymous_token(self): request_context = self.factory.from_upstream( SERIALIZED_EDGECONTEXT_WITH_ANON_AUTH) with self.assertRaises(NoAuthenticationError): request_context.user.id self.assertFalse(request_context.user.is_logged_in) self.assertEqual(request_context.user.loid, self.LOID_ID) self.assertEqual(request_context.user.cookie_created_ms, self.LOID_CREATED_MS) self.assertEqual(request_context.session.id, self.SESSION_ID) self.assertTrue(request_context.user.has_role("anonymous"))