예제 #1
0
def serve_thrift(handler, server_span_observer=None):
    # create baseplate root
    baseplate = Baseplate()
    if server_span_observer:

        class TestBaseplateObserver(BaseplateObserver):
            def on_server_span_created(self, context, server_span):
                server_span.register(server_span_observer)

        baseplate.register(TestBaseplateObserver())

    # set up the server's processor
    logger = mock.Mock(spec=logging.Logger)
    edge_context_factory = make_edge_context_factory()
    processor = TestService.Processor(handler)
    processor = baseplateify_processor(processor, logger, baseplate, edge_context_factory)

    # bind a server socket on an available port
    server_bind_endpoint = config.Endpoint("127.0.0.1:0")
    listener = make_listener(server_bind_endpoint)
    server = make_server(
        {"max_concurrency": "100", "stop_timeout": "1 millisecond"}, listener, processor
    )

    # figure out what port the server ended up on
    server_address = listener.getsockname()
    server.endpoint = config.Endpoint(f"{server_address[0]}:{server_address[1]}")

    # run the server until our caller is done with it
    server_greenlet = gevent.spawn(server.serve_forever)
    try:
        yield server
    finally:
        server_greenlet.kill()
예제 #2
0
def baseplate_thrift_client(endpoint, client_spec, client_span_observer=None):
    baseplate = Baseplate(
        app_config={
            "baseplate.service_name": "fancy test client",
            "example_service.endpoint": str(endpoint),
        }
    )

    if client_span_observer:

        class TestServerSpanObserver(ServerSpanObserver):
            def on_child_span_created(self, span):
                span.register(client_span_observer)

        observer = TestServerSpanObserver()

        class TestBaseplateObserver(BaseplateObserver):
            def on_server_span_created(self, context, span):
                span.register(observer)

        baseplate.register(TestBaseplateObserver())

    context = baseplate.make_context_object()
    trace_info = TraceInfo.from_upstream(
        trace_id="1234", parent_id="2345", span_id="3456", flags=4567, sampled=True
    )

    baseplate.configure_context({"example_service": ThriftClient(client_spec.Client)})

    baseplate.make_server_span(context, "example_service.example", trace_info)

    context.raw_edge_context = FakeEdgeContextFactory.RAW_BYTES

    yield context
예제 #3
0
def test_connection_error(client_cls):
    baseplate = Baseplate({
        "myclient.filter.ip_allowlist": "127.0.0.0/8",
        "myclient.filter.port_denylist": "0"
    })
    baseplate.configure_context({"myclient": client_cls()})

    observer = TestBaseplateObserver()
    baseplate.register(observer)

    bogus_url = "http://localhost:1/"
    with pytest.raises(requests.exceptions.ConnectionError):
        with baseplate.server_context("test") as context:
            context.myclient.get(bogus_url)

    server_span_observer = observer.children[0]
    assert len(server_span_observer.children) == 1

    client_span_observer = server_span_observer.children[0]
    assert client_span_observer.span.name == "myclient.request"
    assert client_span_observer.on_start_called
    assert client_span_observer.on_finish_called
    assert client_span_observer.on_finish_exc_info is not None
    assert client_span_observer.tags["http.url"] == bogus_url
    assert client_span_observer.tags["http.method"] == "GET"
    assert "http.status_code" not in client_span_observer.tags
예제 #4
0
def test_client_makes_client_span(client_cls, method, http_server):
    baseplate = Baseplate({
        "myclient.filter.ip_allowlist": "127.0.0.0/8",
        "myclient.filter.port_denylist": "0"
    })
    baseplate.configure_context({"myclient": client_cls()})

    observer = TestBaseplateObserver()
    baseplate.register(observer)

    with baseplate.server_context("test") as context:
        fn = getattr(context.myclient, method.lower())
        response = fn(http_server.url)

    assert response.status_code == 204

    server_span_observer = observer.children[0]
    assert len(server_span_observer.children) == 1

    client_span_observer = server_span_observer.children[0]
    assert client_span_observer.span.name == "myclient.request"
    assert client_span_observer.on_start_called
    assert client_span_observer.on_finish_called
    assert client_span_observer.on_finish_exc_info is None
    assert client_span_observer.tags["http.url"] == http_server.url
    assert client_span_observer.tags["http.method"] == method
    assert client_span_observer.tags["http.status_code"] == 204
예제 #5
0
    def setUp(self):
        self.baseplate_observer = TestBaseplateObserver()

        profiles = {
            "foo": ExecutionProfile(consistency_level=ConsistencyLevel.QUORUM)
        }

        baseplate = Baseplate()
        baseplate.register(self.baseplate_observer)
        baseplate.configure_context(
            {
                "cassandra.contact_points":
                cassandra_endpoint.address.host,
                "cassandra_no_prof.contact_points":
                cassandra_endpoint.address.host,
            },
            {
                "cassandra_no_prof":
                CassandraClient(keyspace="system"),
                "cassandra":
                CassandraClient(keyspace="system",
                                execution_profiles=profiles),
            },
        )

        self.context = baseplate.make_context_object()
        self.server_span = baseplate.make_server_span(self.context, "test")
예제 #6
0
    def setUp(self):
        self.baseplate_observer = TestBaseplateObserver()

        baseplate = Baseplate({"redis.url": f"redis://{redis_endpoint}/0"})
        baseplate.register(self.baseplate_observer)
        baseplate.configure_context({"redis": RedisClient()})

        self.context = baseplate.make_context_object()
        self.server_span = baseplate.make_server_span(self.context, "test")
예제 #7
0
    def setUp(self):
        self.baseplate_observer = TestBaseplateObserver()

        baseplate = Baseplate({"memcache.endpoint": str(memcached_endpoint)})
        baseplate.register(self.baseplate_observer)
        baseplate.configure_context({"memcache": MemcacheClient()})

        self.context = baseplate.make_context_object()
        self.server_span = baseplate.make_server_span(self.context, "test")
예제 #8
0
    def test_null_server_observer(self):
        baseplate = Baseplate()
        mock_context = baseplate.make_context_object()
        mock_observer = mock.Mock(spec=BaseplateObserver)
        mock_observer.on_server_span_created.return_value = None
        baseplate.register(mock_observer)
        server_span = baseplate.make_server_span(mock_context, "name", TraceInfo(1, 2, 3, None, 0))

        self.assertEqual(server_span.observers, [])
예제 #9
0
    def setUp(self):
        self.baseplate_observer = TestBaseplateObserver()

        baseplate = Baseplate({"cassandra.contact_points": cassandra_endpoint.address.host})
        baseplate.register(self.baseplate_observer)
        baseplate.configure_context({"cassandra": CassandraClient(keyspace="system")})

        self.context = baseplate.make_context_object()
        self.server_span = baseplate.make_server_span(self.context, "test")
예제 #10
0
    def test_with_server_context(self):
        baseplate = Baseplate()
        observer = mock.Mock(spec=BaseplateObserver)
        baseplate.register(observer)

        observer.on_server_span_created.assert_not_called()
        with baseplate.server_context("example") as context:
            observer.on_server_span_created.assert_called_once()
            self.assertIsInstance(context, RequestContext)
예제 #11
0
    def test_server_observer_made(self):
        baseplate = Baseplate()
        mock_context = baseplate.make_context_object()
        mock_observer = mock.Mock(spec=BaseplateObserver)
        baseplate.register(mock_observer)
        server_span = baseplate.make_server_span(mock_context, "name",
                                                 TraceInfo(1, 2, 3, None, 0))

        self.assertEqual(baseplate.observers, [mock_observer])
        self.assertEqual(mock_observer.on_server_span_created.call_count, 1)
        self.assertEqual(mock_observer.on_server_span_created.call_args,
                         mock.call(mock_context, server_span))
예제 #12
0
    def setUp(self):
        self.baseplate_observer = TestBaseplateObserver()

        baseplate = Baseplate({
            "rediscluster.url": f"redis://{redis_endpoint}/0",
            "rediscluster.timeout": "1 second",
            "rediscluster.max_connections": "4",
        })
        baseplate.register(self.baseplate_observer)
        baseplate.configure_context({"rediscluster": ClusterRedisClient()})

        self.context = baseplate.make_context_object()
        self.server_span = baseplate.make_server_span(self.context, "test")
예제 #13
0
    def setUp(self):
        engine = engine_from_config({"database.url":
                                     "sqlite://"})  # in-memory db
        Base.metadata.create_all(bind=engine)
        factory = SQLAlchemySessionContextFactory(engine)

        self.baseplate_observer = TestBaseplateObserver()

        baseplate = Baseplate()
        baseplate.register(self.baseplate_observer)
        baseplate.add_to_context("db", factory)

        self.context = baseplate.make_context_object()
        self.server_span = baseplate.make_server_span(self.context, "test")
예제 #14
0
def test_default_timeout():
    baseplate = Baseplate()

    observer = TimeoutBaseplateObserver.from_config({"server_timeout.default": "50 milliseconds"})
    baseplate.register(observer)

    context = baseplate.make_context_object()
    with baseplate.make_server_span(context, "test"):
        with pytest.raises(ServerTimeout):
            gevent.sleep(1)

    context = baseplate.make_context_object()
    with baseplate.make_server_span(context, "test"):
        gevent.sleep(0)  # shouldn't time out since it's so fast!
예제 #15
0
    def setUp(self):
        self.allowance = 10
        self.interval = 1
        ratelimiter_factory = RateLimiterContextFactory(
            self.backend_factory, self.allowance, self.interval
        )

        self.baseplate_observer = TestBaseplateObserver()

        baseplate = Baseplate()
        baseplate.register(self.baseplate_observer)
        baseplate.add_to_context("ratelimiter", ratelimiter_factory)

        self.context = baseplate.make_context_object()
        self.server_span = baseplate.make_server_span(self.context, "test")
예제 #16
0
def baseplate_thrift_client(endpoint, client_spec, client_span_observer=None):
    baseplate = Baseplate(
        app_config={
            "baseplate.service_name": "fancy test client",
            "example_service.endpoint": str(endpoint),
        })

    if client_span_observer:

        class TestServerSpanObserver(ServerSpanObserver):
            def on_child_span_created(self, span):
                span.register(client_span_observer)

        observer = TestServerSpanObserver()

        class TestBaseplateObserver(BaseplateObserver):
            def on_server_span_created(self, context, span):
                span.register(observer)

        baseplate.register(TestBaseplateObserver())

    context = baseplate.make_context_object()
    trace_info = TraceInfo.from_upstream(trace_id=1234,
                                         parent_id=2345,
                                         span_id=3456,
                                         flags=4567,
                                         sampled=True)

    baseplate.configure_context(
        {"example_service": ThriftClient(client_spec.Client)})

    baseplate.make_server_span(context, "example_service.example", trace_info)

    edge_context_factory = make_edge_context_factory()
    edge_context = edge_context_factory.from_upstream(
        SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH)
    edge_context.attach_context(context)

    yield context
예제 #17
0
class TracingTests(unittest.TestCase):
    def _register_mock(self, context, server_span):
        server_span_observer = TraceServerSpanObserver('test-service',
                                                       'test-hostname',
                                                       server_span,
                                                       NullRecorder())
        server_span.register(server_span_observer)
        self.server_span_observer = server_span_observer

    def setUp(self):
        configurator = Configurator()
        configurator.add_route("example", "/example", request_method="GET")
        configurator.add_view(example_application,
                              route_name="example",
                              renderer="json")

        self.observer = TraceBaseplateObserver('test-service')

        self.baseplate = Baseplate()
        self.baseplate.register(self.observer)

        self.baseplate_configurator = BaseplateConfigurator(
            self.baseplate,
            trust_trace_headers=True,
        )
        configurator.include(self.baseplate_configurator.includeme)
        app = configurator.make_wsgi_app()
        self.test_app = webtest.TestApp(app)

    def test_trace_on_inbound_request(self):
        with mock.patch.object(TraceBaseplateObserver,
                               'on_server_span_created',
                               side_effect=self._register_mock) as mocked:

            self.test_app.get('/example')
            span = self.server_span_observer._serialize()
            self.assertEqual(span['name'], 'example')
            self.assertEqual(len(span['annotations']), 2)
            self.assertEqual(span['parentId'], 0)
예제 #18
0
class TracingTests(unittest.TestCase):

    def _register_mock(self, context, server_span):
        server_span_observer = TraceServerSpanObserver('test-service',
                                                       'test-hostname',
                                                       server_span,
                                                       NullRecorder())
        server_span.register(server_span_observer)
        self.server_span_observer = server_span_observer

    def setUp(self):
        configurator = Configurator()
        configurator.add_route("example", "/example", request_method="GET")
        configurator.add_view(
            example_application, route_name="example", renderer="json")

        self.observer = TraceBaseplateObserver('test-service')

        self.baseplate = Baseplate()
        self.baseplate.register(self.observer)

        self.baseplate_configurator = BaseplateConfigurator(
            self.baseplate,
            trust_trace_headers=True,
        )
        configurator.include(self.baseplate_configurator.includeme)
        app = configurator.make_wsgi_app()
        self.test_app = webtest.TestApp(app)

    def test_trace_on_inbound_request(self):
        with mock.patch.object(TraceBaseplateObserver, 'on_server_span_created',
                               side_effect=self._register_mock) as mocked:

            self.test_app.get('/example')
            span = self.server_span_observer._serialize()
            self.assertEqual(span['name'], 'example')
            self.assertEqual(len(span['annotations']), 2)
            self.assertEqual(span['parentId'], 0)
예제 #19
0
class ConfiguratorTests(unittest.TestCase):
    def setUp(self):
        configurator = Configurator()
        configurator.add_route("example", "/example", request_method="GET")
        configurator.add_route("trace_context",
                               "/trace_context",
                               request_method="GET")

        configurator.add_view(example_application,
                              route_name="example",
                              renderer="json")

        configurator.add_view(local_tracing_within_context,
                              route_name="trace_context",
                              renderer="json")

        configurator.add_view(render_exception_view,
                              context=ControlFlowException,
                              renderer="json")

        configurator.add_view(render_bad_exception_view,
                              context=ControlFlowException2,
                              renderer="json")

        mock_filewatcher = mock.Mock(spec=FileWatcher)
        mock_filewatcher.get_data.return_value = {
            "secrets": {
                "secret/authentication/public-key": {
                    "type": "versioned",
                    "current": AUTH_TOKEN_PUBLIC_KEY,
                }
            },
            "vault": {
                "token": "test",
                "url": "http://vault.example.com:8200/"
            },
        }
        secrets = SecretsStore("/secrets")
        secrets._filewatcher = mock_filewatcher

        self.observer = mock.Mock(spec=BaseplateObserver)
        self.server_observer = mock.Mock(spec=ServerSpanObserver)

        def _register_mock(context, server_span):
            server_span.register(self.server_observer)

        self.observer.on_server_span_created.side_effect = _register_mock

        self.baseplate = Baseplate()
        self.baseplate.register(self.observer)
        self.baseplate_configurator = BaseplateConfigurator(
            self.baseplate,
            trust_trace_headers=True,
            edge_context_factory=EdgeRequestContextFactory(secrets),
        )
        configurator.include(self.baseplate_configurator.includeme)
        self.context_init_event_subscriber = mock.Mock()
        configurator.add_subscriber(self.context_init_event_subscriber,
                                    ServerSpanInitialized)
        app = configurator.make_wsgi_app()
        self.test_app = webtest.TestApp(app)

    @mock.patch("random.getrandbits")
    def test_no_trace_headers(self, getrandbits):
        getrandbits.return_value = 1234
        self.test_app.get("/example")

        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, None)
        self.assertEqual(server_span.id, 1234)

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)

    def test_trace_headers(self):
        self.test_app.get(
            "/example",
            headers={
                "X-Trace": "1234",
                "X-Parent": "2345",
                "X-Span": "3456",
                "X-Sampled": "1",
                "X-Flags": "1",
            },
        )

        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, 2345)
        self.assertEqual(server_span.id, 3456)
        self.assertEqual(server_span.sampled, True)
        self.assertEqual(server_span.flags, 1)

        with self.assertRaises(NoAuthenticationError):
            context.request_context.user.id

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)

    def test_edge_request_headers(self):
        self.test_app.get(
            "/example",
            headers={
                "X-Trace": "1234",
                "X-Edge-Request": SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH,
                "X-Parent": "2345",
                "X-Span": "3456",
                "X-Sampled": "1",
                "X-Flags": "1",
            },
        )
        context, _ = self.observer.on_server_span_created.call_args[0]
        try:
            self.assertEqual(context.request_context.user.id, "t2_example")
            self.assertEqual(context.request_context.user.roles, set())
            self.assertEqual(context.request_context.user.is_logged_in, True)
            self.assertEqual(context.request_context.user.loid, "t2_deadbeef")
            self.assertEqual(context.request_context.user.cookie_created_ms,
                             100000)
            self.assertEqual(context.request_context.oauth_client.id, None)
            self.assertFalse(
                context.request_context.oauth_client.is_type("third_party"))
            self.assertEqual(context.request_context.session.id, "beefdead")
        except jwt.exceptions.InvalidAlgorithmError:
            raise unittest.SkipTest("cryptography is not installed")

    def test_not_found(self):
        self.test_app.get("/nope", status=404)

        self.assertFalse(self.observer.on_server_span_created.called)
        self.assertFalse(self.context_init_event_subscriber.called)

    def test_exception_caught(self):
        with self.assertRaises(TestException):
            self.test_app.get("/example?error")

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)
        _, captured_exc, _ = self.server_observer.on_finish.call_args[0][0]
        self.assertIsInstance(captured_exc, TestException)

    def test_control_flow_exception_not_caught(self):
        self.test_app.get("/example?control_flow_exception")

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)
        args, _ = self.server_observer.on_finish.call_args
        self.assertEqual(args[0], None)

    def test_exception_in_exception_view_caught(self):
        with self.assertRaises(ExceptionViewException):
            self.test_app.get("/example?exception_view_exception")

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)
        _, captured_exc, _ = self.server_observer.on_finish.call_args[0][0]
        self.assertIsInstance(captured_exc, ExceptionViewException)

    @mock.patch("random.getrandbits")
    def test_distrust_headers(self, getrandbits):
        getrandbits.return_value = 1234
        self.baseplate_configurator.header_trust_handler.trust_headers = False

        self.test_app.get("/example",
                          headers={
                              "X-Trace": "1234",
                              "X-Parent": "2345",
                              "X-Span": "3456"
                          })

        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertEqual(server_span.trace_id, getrandbits.return_value)
        self.assertEqual(server_span.parent_id, None)
        self.assertEqual(server_span.id, getrandbits.return_value)

    def test_local_trace_in_context(self):
        self.test_app.get("/trace_context")
        self.assertEqual(self.server_observer.on_child_span_created.call_count,
                         1)
        child_span = self.server_observer.on_child_span_created.call_args[0][0]
        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertNotEqual(child_span.context, context)
예제 #20
0
class Globals(object):
    spec = {

        ConfigValue.int: [
            'db_pool_size',
            'db_pool_overflow_size',
            'commentpane_cache_time',
            'num_mc_clients',
            'MAX_CAMPAIGNS_PER_LINK',
            'MIN_DOWN_LINK',
            'MIN_UP_KARMA',
            'MIN_DOWN_KARMA',
            'MIN_RATE_LIMIT_KARMA',
            'MIN_RATE_LIMIT_COMMENT_KARMA',
            'HOT_PAGE_AGE',
            'ADMIN_COOKIE_TTL',
            'ADMIN_COOKIE_MAX_IDLE',
            'OTP_COOKIE_TTL',
            'hsts_max_age',
            'num_comments',
            'max_comments',
            'max_comments_gold',
            'max_comment_parent_walk',
            'max_sr_images',
            'num_serendipity',
            'comment_visits_period',
            'butler_max_mentions',
            'min_membership_create_community',
            'bcrypt_work_factor',
            'cassandra_pool_size',
            'sr_banned_quota',
            'sr_muted_quota',
            'sr_wikibanned_quota',
            'sr_wikicontributor_quota',
            'sr_moderator_invite_quota',
            'sr_contributor_quota',
            'sr_quota_time',
            'sr_invite_limit',
            'thumbnail_hidpi_scaling',
            'wiki_keep_recent_days',
            'wiki_max_page_length_bytes',
            'wiki_max_page_name_length',
            'wiki_max_page_separators',
            'RL_RESET_MINUTES',
            'RL_OAUTH_RESET_MINUTES',
            'comment_karma_display_floor',
            'link_karma_display_floor',
            'mobile_auth_gild_time',
            'default_total_budget_pennies',
            'min_total_budget_pennies',
            'max_total_budget_pennies',
            'default_bid_pennies',
            'min_bid_pennies',
            'max_bid_pennies',
            'frequency_cap_min',
            'frequency_cap_default',
            'eu_cookie_max_attempts',
        ],

        ConfigValue.float: [
            'statsd_sample_rate',
            'querycache_prune_chance',
            'RL_AVG_REQ_PER_SEC',
            'RL_OAUTH_AVG_REQ_PER_SEC',
            'RL_LOGIN_AVG_PER_SEC',
            'RL_LOGIN_IP_AVG_PER_SEC',
            'RL_SHARE_AVG_PER_SEC',
            'tracing_sample_rate',
        ],

        ConfigValue.bool: [
            'debug',
            'log_start',
            'sqlprinting',
            'template_debug',
            'reload_templates',
            'uncompressedJS',
            'css_killswitch',
            'db_create_tables',
            'disallow_db_writes',
            'disable_ratelimit',
            'amqp_logging',
            'read_only_mode',
            'disable_wiki',
            'heavy_load_mode',
            'disable_captcha',
            'disable_ads',
            'disable_require_admin_otp',
            'trust_local_proxies',
            'shard_commentstree_queues',
            'shard_author_query_queues',
            'shard_subreddit_query_queues',
            'shard_domain_query_queues',
            'authnet_validate',
            'ENFORCE_RATELIMIT',
            'RL_SITEWIDE_ENABLED',
            'RL_OAUTH_SITEWIDE_ENABLED',
            'enable_loggedout_experiments',
        ],

        ConfigValue.tuple: [
            'plugins',
            'stalecaches',
            'lockcaches',
            'permacache_memcaches',
            'cassandra_seeds',
            'automatic_reddits',
            'hardcache_categories',
            'case_sensitive_domains',
            'known_image_domains',
            'reserved_subdomains',
            'offsite_subdomains',
            'TRAFFIC_LOG_HOSTS',
            'exempt_login_user_agents',
            'autoexpand_media_types',
            'media_preview_domain_whitelist',
            'multi_icons',
            'hide_subscribers_srs',
            'mcrouter_addr',
        ],

        ConfigValue.tuple_of(ConfigValue.int): [
            'thumbnail_size',
            'preview_image_max_size',
            'preview_image_min_size',
            'mobile_ad_image_size',
        ],

        ConfigValue.tuple_of(ConfigValue.float): [
            'ios_versions',
            'android_versions',
        ],

        ConfigValue.dict(ConfigValue.str, ConfigValue.int): [
            'user_agent_ratelimit_regexes',
        ],

        ConfigValue.str: [
            'wiki_page_registration_info',
            'wiki_page_privacy_policy',
            'wiki_page_user_agreement',
            'wiki_page_gold_bottlecaps',
            'fraud_email',
            'feedback_email',
            'share_reply',
            'community_email',
            'smtp_server',
            'events_collector_url',
            'events_collector_test_url',
            'search_provider',
        ],

        ConfigValue.choice(ONE=CL_ONE, QUORUM=CL_QUORUM): [
             'cassandra_rcl',
             'cassandra_wcl',
        ],

        ConfigValue.choice(zookeeper="zookeeper", config="config"): [
            "liveconfig_source",
            "secrets_source",
        ],

        ConfigValue.timeinterval: [
            'ARCHIVE_AGE',
            "vote_queue_grace_period",
        ],

        config_gold_price: [
            'gold_month_price',
            'gold_year_price',
            'cpm_selfserve',
            'cpm_selfserve_geotarget_metro',
            'cpm_selfserve_geotarget_country',
            'cpm_selfserve_collection',
        ],

        ConfigValue.baseplate(baseplate_config.Optional(baseplate_config.Endpoint)): [
            "activity_endpoint",
            "tracing_endpoint",
        ],

        ConfigValue.dict(ConfigValue.str, ConfigValue.str): [
            'emr_traffic_tags',
        ],
    }

    live_config_spec = {
        ConfigValue.bool: [
            'frontend_logging',
            'mobile_gild_first_login',
            'precomputed_comment_suggested_sort',
        ],
        ConfigValue.int: [
            'captcha_exempt_comment_karma',
            'captcha_exempt_link_karma',
            'create_sr_account_age_days',
            'create_sr_comment_karma',
            'create_sr_link_karma',
            'cflag_min_votes',
            'ads_popularity_threshold',
            'precomputed_comment_sort_min_comments',
            'comment_vote_update_threshold',
            'comment_vote_update_period',
        ],
        ConfigValue.float: [
            'cflag_lower_bound',
            'cflag_upper_bound',
            'spotlight_interest_sub_p',
            'spotlight_interest_nosub_p',
            'gold_revenue_goal',
            'invalid_key_sample_rate',
            'events_collector_vote_sample_rate',
            'events_collector_poison_sample_rate',
            'events_collector_mod_sample_rate',
            'events_collector_quarantine_sample_rate',
            'events_collector_modmail_sample_rate',
            'events_collector_report_sample_rate',
            'events_collector_submit_sample_rate',
            'events_collector_comment_sample_rate',
            'events_collector_use_gzip_chance',
            'https_cert_testing_probability',
        ],
        ConfigValue.tuple: [
            'fastlane_links',
            'listing_chooser_sample_multis',
            'discovery_srs',
            'proxy_gilding_accounts',
            'mweb_blacklist_expressions',
            'global_loid_experiments',
            'precomputed_comment_sorts',
            'mailgun_domains',
        ],
        ConfigValue.str: [
            'listing_chooser_gold_multi',
            'listing_chooser_explore_sr',
        ],
        ConfigValue.messages: [
            'welcomebar_messages',
            'sidebar_message',
            'gold_sidebar_message',
        ],
        ConfigValue.dict(ConfigValue.str, ConfigValue.int): [
            'ticket_groups',
            'ticket_user_fields', 
        ],
        ConfigValue.dict(ConfigValue.str, ConfigValue.float): [
            'pennies_per_server_second',
        ],
        ConfigValue.dict(ConfigValue.str, ConfigValue.str): [
            'employee_approved_clients',
            'modmail_forwarding_email',
            'modmail_account_map',
        ],
        ConfigValue.dict(ConfigValue.str, ConfigValue.choice(**PERMISSIONS)): [
            'employees',
        ],
    }

    def __init__(self, config, global_conf, app_conf, paths, **extra):
        """
        Globals acts as a container for objects available throughout
        the life of the application.

        One instance of Globals is created by Pylons during
        application initialization and is available during requests
        via the 'g' variable.

        ``config``
            The PylonsConfig object passed in from ``config/environment.py``

        ``global_conf``
            The same variable used throughout ``config/middleware.py``
            namely, the variables from the ``[DEFAULT]`` section of the
            configuration file.

        ``app_conf``
            The same ``kw`` dictionary used throughout
            ``config/middleware.py`` namely, the variables from the
            section in the config file for your application.

        ``extra``
            The configuration returned from ``load_config`` in 
            ``config/middleware.py`` which may be of use in the setup of
            your global variables.

        """

        global_conf.setdefault("debug", False)

        # reloading site ensures that we have a fresh sys.path to build our
        # working set off of. this means that forked worker processes won't get
        # the sys.path that was current when the master process was spawned
        # meaning that new plugins will be picked up on regular app reload
        # rather than having to restart the master process as well.
        reload(site)
        self.pkg_resources_working_set = pkg_resources.WorkingSet()

        self.config = ConfigValueParser(global_conf)
        self.config.add_spec(self.spec)
        self.plugins = PluginLoader(self.pkg_resources_working_set,
                                    self.config.get("plugins", []))

        self.stats = Stats(self.config.get('statsd_addr'),
                           self.config.get('statsd_sample_rate'))
        self.startup_timer = self.stats.get_timer("app_startup")
        self.startup_timer.start()

        self.baseplate = Baseplate()
        self.baseplate.configure_logging()
        self.baseplate.register(R2BaseplateObserver())
#        self.baseplate.configure_tracing(
#            service_name="r2",
#            tracing_endpoint=self.config.get("tracing_endpoint"),
#            sample_rate=self.config.get("tracing_sample_rate"),
#        )

        self.paths = paths

        self.running_as_script = global_conf.get('running_as_script', False)
        
        # turn on for language support
        self.lang = getattr(self, 'site_lang', 'en')
        self.languages, self.lang_name = get_active_langs(
            config, default_lang=self.lang)

        all_languages = self.lang_name.keys()
        all_languages.sort()
        self.all_languages = all_languages
        
        # set default time zone if one is not set
        tz = global_conf.get('timezone', 'UTC')
        self.tz = pytz.timezone(tz)
        
        dtz = global_conf.get('display_timezone', tz)
        self.display_tz = pytz.timezone(dtz)

        self.startup_timer.intermediate("init")

    def __getattr__(self, name):
        if not name.startswith('_') and name in self.config:
            return self.config[name]
        else:
            raise AttributeError("g has no attr %r" % name)

    def setup(self):
        self.env = ''
        if (
            # handle direct invocation of "nosetests"
            "test" in sys.argv[0] or
            # handle "setup.py test" and all permutations thereof.
            "setup.py" in sys.argv[0] and "test" in sys.argv[1:]
        ):
            self.env = "unit_test"

        self.queues = queues.declare_queues(self)

        self.extension_subdomains = dict(
            simple="mobile",
            i="compact",
            api="api",
            rss="rss",
            xml="xml",
            json="json",
        )

        ################# PROVIDERS
        self.auth_provider = select_provider(
            self.config,
            self.pkg_resources_working_set,
            "r2.provider.auth",
            self.authentication_provider,
        )
        self.media_provider = select_provider(
            self.config,
            self.pkg_resources_working_set,
            "r2.provider.media",
            self.media_provider,
        )
        self.cdn_provider = select_provider(
            self.config,
            self.pkg_resources_working_set,
            "r2.provider.cdn",
            self.cdn_provider,
        )
        self.ticket_provider = select_provider(
            self.config,
            self.pkg_resources_working_set,
            "r2.provider.support",
            # TODO: fix this later, it refuses to pick up 
            # g.config['ticket_provider'] value, so hardcoding for now.
            # really, the next uncommented line should be:
            #self.ticket_provider,
            # instead of:
            "zendesk",
        )
        self.image_resizing_provider = select_provider(
            self.config,
            self.pkg_resources_working_set,
            "r2.provider.image_resizing",
            self.image_resizing_provider,
        )
        self.email_provider = select_provider(
            self.config,
            self.pkg_resources_working_set,
            "r2.provider.email",
            self.email_provider,
        )
        self.startup_timer.intermediate("providers")

        ################# CONFIGURATION
        # AMQP is required
        if not self.amqp_host:
            raise ValueError("amqp_host not set in the .ini")

        if not self.cassandra_seeds:
            raise ValueError("cassandra_seeds not set in the .ini")

        # heavy load mode is read only mode with a different infobar
        if self.heavy_load_mode:
            self.read_only_mode = True

        origin_prefix = self.domain_prefix + "." if self.domain_prefix else ""
        self.origin = self.default_scheme + "://" + origin_prefix + self.domain

        self.trusted_domains = set([self.domain])
        if self.https_endpoint:
            https_url = urlparse(self.https_endpoint)
            self.trusted_domains.add(https_url.hostname)

        # load the unique hashed names of files under static
        static_files = os.path.join(self.paths.get('static_files'), 'static')
        names_file_path = os.path.join(static_files, 'names.json')
        if os.path.exists(names_file_path):
            with open(names_file_path) as handle:
                self.static_names = json.load(handle)
        else:
            self.static_names = {}

        # make python warnings go through the logging system
        logging.captureWarnings(capture=True)

        log = logging.getLogger('reddit')

        # when we're a script (paster run) just set up super simple logging
        if self.running_as_script:
            log.setLevel(logging.INFO)
            log.addHandler(logging.StreamHandler())

        # if in debug mode, override the logging level to DEBUG
        if self.debug:
            log.setLevel(logging.DEBUG)

        # attempt to figure out which pool we're in and add that to the
        # LogRecords.
        try:
            with open("/etc/ec2_asg", "r") as f:
                pool = f.read().strip()
            # clean up the pool name since we're putting stuff after "-"
            pool = pool.partition("-")[0]
        except IOError:
            pool = "reddit-app"
        self.log = logging.LoggerAdapter(log, {"pool": pool})

        # set locations
        locations = pkg_resources.resource_stream(__name__,
                                                  "../data/locations.json")
        self.locations = json.loads(locations.read())

        if not self.media_domain:
            self.media_domain = self.domain
        if self.media_domain == self.domain:
            print >> sys.stderr, ("Warning: g.media_domain == g.domain. " +
                   "This may give untrusted content access to user cookies")
        if self.oauth_domain == self.domain:
            print >> sys.stderr, ("Warning: g.oauth_domain == g.domain. "
                    "CORS requests to g.domain will be allowed")

        for arg in sys.argv:
            tokens = arg.split("=")
            if len(tokens) == 2:
                k, v = tokens
                self.log.debug("Overriding g.%s to %s" % (k, v))
                setattr(self, k, v)

        self.reddit_host = socket.gethostname()
        self.reddit_pid  = os.getpid()

        if hasattr(signal, 'SIGUSR1'):
            # not all platforms have user signals
            signal.signal(signal.SIGUSR1, thread_dump)

        locale.setlocale(locale.LC_ALL, self.locale)

        # Pre-calculate ratelimit values
        self.RL_RESET_SECONDS = self.config["RL_RESET_MINUTES"] * 60
        self.RL_MAX_REQS = int(self.config["RL_AVG_REQ_PER_SEC"] *
                                      self.RL_RESET_SECONDS)

        self.RL_OAUTH_RESET_SECONDS = self.config["RL_OAUTH_RESET_MINUTES"] * 60
        self.RL_OAUTH_MAX_REQS = int(self.config["RL_OAUTH_AVG_REQ_PER_SEC"] *
                                     self.RL_OAUTH_RESET_SECONDS)

        self.RL_LOGIN_MAX_REQS = int(self.config["RL_LOGIN_AVG_PER_SEC"] *
                                     self.RL_RESET_SECONDS)
        self.RL_LOGIN_IP_MAX_REQS = int(self.config["RL_LOGIN_IP_AVG_PER_SEC"] *
                                        self.RL_RESET_SECONDS)
        self.RL_SHARE_MAX_REQS = int(self.config["RL_SHARE_AVG_PER_SEC"] *
                                     self.RL_RESET_SECONDS)

        # Compile ratelimit regexs
        user_agent_ratelimit_regexes = {}
        for agent_re, limit in self.user_agent_ratelimit_regexes.iteritems():
            user_agent_ratelimit_regexes[re.compile(agent_re)] = limit
        self.user_agent_ratelimit_regexes = user_agent_ratelimit_regexes

        self.startup_timer.intermediate("configuration")

        ################# ZOOKEEPER
        zk_hosts = self.config["zookeeper_connection_string"]
        zk_username = self.config["zookeeper_username"]
        zk_password = self.config["zookeeper_password"]
        self.zookeeper = connect_to_zookeeper(zk_hosts, (zk_username,
                                                         zk_password))

        self.throttles = IPNetworkLiveList(
            self.zookeeper,
            root="/throttles",
            reduced_data_node="/throttles_reduced",
        )

        parser = ConfigParser.RawConfigParser()
        parser.optionxform = str
        parser.read([self.config["__file__"]])

        if self.config["liveconfig_source"] == "zookeeper":
            self.live_config = LiveConfig(self.zookeeper, LIVE_CONFIG_NODE)
        else:
            self.live_config = extract_live_config(parser, self.plugins)

        if self.config["secrets_source"] == "zookeeper":
            self.secrets = fetch_secrets(self.zookeeper)
        else:
            self.secrets = extract_secrets(parser)

        ################# PRIVILEGED USERS
        self.admins = PermissionFilteredEmployeeList(
            self.live_config, type="admin")
        self.sponsors = PermissionFilteredEmployeeList(
            self.live_config, type="sponsor")
        self.employees = PermissionFilteredEmployeeList(
            self.live_config, type="employee")

        # Store which OAuth clients employees may use, the keys are just for
        # readability.
        self.employee_approved_clients = \
            self.live_config["employee_approved_clients"].values()

        self.startup_timer.intermediate("zookeeper")

        ################# MEMCACHE
        num_mc_clients = self.num_mc_clients

        # a smaller pool of caches used only for distributed locks.
        self.lock_cache = CMemcache(
            "lock",
            self.lockcaches,
            num_clients=num_mc_clients,
        )
        self.make_lock = make_lock_factory(self.lock_cache, self.stats)

        # memcaches used in front of the permacache CF in cassandra.
        # XXX: this is a legacy thing; permacache was made when C* didn't have
        # a row cache.
        permacache_memcaches = CMemcache(
            "perma",
            self.permacache_memcaches,
            min_compress_len=1400,
            num_clients=num_mc_clients,
        )

        # the stalecache is a memcached local to the current app server used
        # for data that's frequently fetched but doesn't need to be fresh.
        if self.stalecaches:
            stalecaches = CMemcache(
                "stale",
                self.stalecaches,
                num_clients=num_mc_clients,
            )
        else:
            stalecaches = None

        self.startup_timer.intermediate("memcache")

        ################# MCROUTER
        self.mcrouter = Mcrouter(
            "mcrouter",
            self.mcrouter_addr,
            min_compress_len=1400,
            num_clients=num_mc_clients,
        )

        ################# THRIFT-BASED SERVICES
        activity_endpoint = self.config.get("activity_endpoint")
        if activity_endpoint:
            # make ActivityInfo objects rendercache-key friendly
            # TODO: figure out a more general solution for this if
            # we need to do this for other thrift-generated objects
            ActivityInfo.cache_key = lambda self, style: repr(self)

            activity_pool = ThriftConnectionPool(activity_endpoint, timeout=0.1)
            self.baseplate.add_to_context("activity_service",
                ThriftContextFactory(activity_pool, ActivityService.Client))

        self.startup_timer.intermediate("thrift")

        ################# CASSANDRA
        keyspace = "reddit"
        self.cassandra_pools = {
            "main":
                StatsCollectingConnectionPool(
                    keyspace,
                    stats=self.stats,
                    logging_name="main",
                    server_list=self.cassandra_seeds,
                    pool_size=self.cassandra_pool_size,
                    timeout=4,
                    max_retries=3,
                    prefill=False
                ),
        }

        permacache_cf = Permacache._setup_column_family(
            'permacache',
            self.cassandra_pools[self.cassandra_default_pool],
        )

        self.startup_timer.intermediate("cassandra")

        ################# POSTGRES
        self.dbm = self.load_db_params()
        self.startup_timer.intermediate("postgres")

        ################# CHAINS
        # initialize caches. Any cache-chains built here must be added
        # to cache_chains (closed around by reset_caches) so that they
        # can properly reset their local components
        cache_chains = {}
        localcache_cls = (SelfEmptyingCache if self.running_as_script
                          else LocalCache)

        if stalecaches:
            self.gencache = StaleCacheChain(
                localcache_cls(),
                stalecaches,
                self.mcrouter,
            )
        else:
            self.gencache = CacheChain((localcache_cls(), self.mcrouter))
        cache_chains.update(gencache=self.gencache)

        if stalecaches:
            self.thingcache = StaleCacheChain(
                localcache_cls(),
                stalecaches,
                self.mcrouter,
            )
        else:
            self.thingcache = CacheChain((localcache_cls(), self.mcrouter))
        cache_chains.update(thingcache=self.thingcache)

        if stalecaches:
            self.memoizecache = StaleCacheChain(
                localcache_cls(),
                stalecaches,
                self.mcrouter,
            )
        else:
            self.memoizecache = MemcacheChain(
                (localcache_cls(), self.mcrouter))
        cache_chains.update(memoizecache=self.memoizecache)

        if stalecaches:
            self.srmembercache = StaleCacheChain(
                localcache_cls(),
                stalecaches,
                self.mcrouter,
            )
        else:
            self.srmembercache = MemcacheChain(
                (localcache_cls(), self.mcrouter))
        cache_chains.update(srmembercache=self.srmembercache)

        if stalecaches:
            self.relcache = StaleCacheChain(
                localcache_cls(),
                stalecaches,
                self.mcrouter,
            )
        else:
            self.relcache = MemcacheChain(
                (localcache_cls(), self.mcrouter))
        cache_chains.update(relcache=self.relcache)

        self.ratelimitcache = MemcacheChain(
                (localcache_cls(), self.mcrouter))
        cache_chains.update(ratelimitcache=self.ratelimitcache)

        # rendercache holds rendered partial templates.
        self.rendercache = MemcacheChain((
            localcache_cls(),
            self.mcrouter,
        ))
        cache_chains.update(rendercache=self.rendercache)

        # commentpanecaches hold fully rendered comment panes
        self.commentpanecache = MemcacheChain((
            localcache_cls(),
            self.mcrouter,
        ))
        cache_chains.update(commentpanecache=self.commentpanecache)

        # cassandra_local_cache is used for request-local caching in tdb_cassandra
        self.cassandra_local_cache = localcache_cls()
        cache_chains.update(cassandra_local_cache=self.cassandra_local_cache)

        if stalecaches:
            permacache_cache = StaleCacheChain(
                localcache_cls(),
                stalecaches,
                permacache_memcaches,
            )
        else:
            permacache_cache = CacheChain(
                (localcache_cls(), permacache_memcaches),
            )
        cache_chains.update(permacache=permacache_cache)

        self.permacache = Permacache(
            permacache_cache,
            permacache_cf,
            lock_factory=self.make_lock,
        )

        # hardcache is used for various things that tend to expire
        # TODO: replace hardcache w/ cassandra stuff
        self.hardcache = HardcacheChain(
            (localcache_cls(), HardCache(self)),
            cache_negative_results=True,
        )
        cache_chains.update(hardcache=self.hardcache)

        # I know this sucks, but we need non-request-threads to be
        # able to reset the caches, so we need them be able to close
        # around 'cache_chains' without being able to call getattr on
        # 'g'
        def reset_caches():
            for name, chain in cache_chains.iteritems():
                if isinstance(chain, TransitionalCache):
                    chain = chain.read_chain

                chain.reset()
                if isinstance(chain, LocalCache):
                    continue
                elif isinstance(chain, StaleCacheChain):
                    chain.stats = StaleCacheStats(self.stats, name)
                else:
                    chain.stats = CacheStats(self.stats, name)
        self.cache_chains = cache_chains

        self.reset_caches = reset_caches
        self.reset_caches()

        self.startup_timer.intermediate("cache_chains")

        # try to set the source control revision numbers
        self.versions = {}
        r2_root = os.path.dirname(os.path.dirname(self.paths["root"]))
        r2_gitdir = os.path.join(r2_root, ".git")
        self.short_version = self.record_repo_version("r2", r2_gitdir)

        if I18N_PATH:
            i18n_git_path = os.path.join(os.path.dirname(I18N_PATH), ".git")
            self.record_repo_version("i18n", i18n_git_path)

        # Initialize the amqp module globals, start the worker, etc.
        r2.lib.amqp.initialize(self)

        self.events = EventQueue()

        self.startup_timer.intermediate("revisions")

    def setup_complete(self):
        self.startup_timer.stop()
        self.stats.flush()

        if self.log_start:
            self.log.error(
                "%s:%s started %s at %s (took %.02fs)",
                self.reddit_host,
                self.reddit_pid,
                self.short_version,
                datetime.now().strftime("%H:%M:%S"),
                self.startup_timer.elapsed_seconds()
            )

        if einhorn.is_worker():
            einhorn.ack_startup()

    def record_repo_version(self, repo_name, git_dir):
        """Get the currently checked out git revision for a given repository,
        record it in g.versions, and return the short version of the hash."""
        try:
            subprocess.check_output
        except AttributeError:
            # python 2.6 compat
            pass
        else:
            try:
                revision = subprocess.check_output(["git",
                                                    "--git-dir", git_dir,
                                                    "rev-parse", "HEAD"])
            except subprocess.CalledProcessError, e:
                self.log.warning("Unable to fetch git revision: %r", e)
            else:
예제 #21
0
class ConfiguratorTests(unittest.TestCase):
    VALID_TOKEN = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0ZXN0X3VzZXJfaWQiLCJleHAiOjQ2NTY1OTM0NTV9.Q8bz2qccFOHLTQ6H3MPdjSh7wDkRQtbBuBwGMzNRKjDFSkCoVF5kiwhBUdwbW8UXO5iZn4Bh7oKdj69lIEOATUxFBblU8Do05EfjECXLYGdbr6ClNmldrB8SsdAtQYQ4Ud-70Z8_75QvkqX_TY5OA4asGJZwH9MC7oHey47-38I"
    TOKEN_SECRET = "-----BEGIN PUBLIC KEY-----\nMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC0Kd3qYtc6zI5tj3iKBux70BhE\nZLLJ7fAKNBUO7h9FCwUcYku+SFigzNOu3AAYt3seNgxl+cvMR2+SNwsa605J9D1v\n9eGmpcITQi85SeJnfR7LJUMu7RieY5wEl0RyuwnSkX3Gkv0+hZISC/XYcWEYolIi\n8725u7u/8HRtUeHoLwIDAQAB\n-----END PUBLIC KEY-----"
    SERIALIZED_REQUEST_HEADER = b"\x0c\x00\x01\x0b\x00\x01\x00\x00\x00\x0bt2_deadbeef\n\x00\x02\x00\x00\x00\x00\x00\x01\x86\xa0\x00\x0c\x00\x02\x0b\x00\x01\x00\x00\x00\x08beefdead\x00\x00"  # noqa

    def setUp(self):
        configurator = Configurator()
        configurator.add_route("example", "/example", request_method="GET")
        configurator.add_route("trace_context", "/trace_context", request_method="GET")

        configurator.add_view(
            example_application, route_name="example", renderer="json")

        configurator.add_view(
            local_tracing_within_context, route_name="trace_context", renderer="json")

        configurator.add_view(
            render_exception_view,
            context=ControlFlowException,
            renderer="json",
        )

        configurator.add_view(
            render_bad_exception_view,
            context=ControlFlowException2,
            renderer="json",
        )

        mock_filewatcher = mock.Mock(spec=FileWatcher)
        mock_filewatcher.get_data.return_value = {
            "secrets": {
                "jwt/authentication/secret": {
                    "type": "simple",
                    "value": self.TOKEN_SECRET,
                },
            },
            "vault": {
                "token": "test",
                "url": "http://vault.example.com:8200/",
            }
        }
        secrets = store.SecretsStore("/secrets")
        secrets._filewatcher = mock_filewatcher

        self.observer = mock.Mock(spec=BaseplateObserver)
        self.server_observer = mock.Mock(spec=ServerSpanObserver)
        def _register_mock(context, server_span):
            server_span.register(self.server_observer)
        self.observer.on_server_span_created.side_effect = _register_mock

        self.baseplate = Baseplate()
        self.baseplate.register(self.observer)
        self.baseplate_configurator = BaseplateConfigurator(
            self.baseplate,
            trust_trace_headers=True,
            auth_factory=AuthenticationContextFactory(secrets),
        )
        configurator.include(self.baseplate_configurator.includeme)
        self.context_init_event_subscriber = mock.Mock()
        configurator.add_subscriber(self.context_init_event_subscriber, ServerSpanInitialized)
        app = configurator.make_wsgi_app()
        self.test_app = webtest.TestApp(app)

    @mock.patch("random.getrandbits")
    def test_no_trace_headers(self, getrandbits):
        getrandbits.return_value = 1234
        self.test_app.get("/example")

        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, None)
        self.assertEqual(server_span.id, 1234)

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)

    def test_trace_headers(self):
        self.test_app.get("/example", headers={
            "X-Trace": "1234",
            "X-Parent": "2345",
            "X-Span": "3456",
            "X-Sampled": "1",
            "X-Flags": "1",
        })

        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, 2345)
        self.assertEqual(server_span.id, 3456)
        self.assertEqual(server_span.sampled, True)
        self.assertEqual(server_span.flags, 1)
        self.assertFalse(context.request_context._authentication_context.defined)

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)

    def test_auth_headers(self):
        self.test_app.get("/example", headers={
            "X-Trace": "1234",
            "X-Authentication": self.VALID_TOKEN,
            "X-Parent": "2345",
            "X-Span": "3456",
            "X-Sampled": "1",
            "X-Flags": "1",
        })
        context, _ = self.observer.on_server_span_created.call_args[0]
        try:
            self.assertEqual(context.request_context.user.id, "test_user_id")
            self.assertEqual(context.request_context.user.roles, set())
            self.assertEqual(context.request_context.user.is_logged_in, True)
            self.assertEqual(context.request_context.oauth_client.id, None)
            self.assertFalse(context.request_context.oauth_client.is_type("third_party"))
        except jwt.exceptions.InvalidAlgorithmError:
            raise unittest.SkipTest("cryptography is not installed")

    def test_edge_request_headers(self):
        self.test_app.get("/example", headers={
            "X-Trace": "1234",
            "X-Authentication": self.VALID_TOKEN,
            "X-Edge-Request": self.SERIALIZED_REQUEST_HEADER,
            "X-Parent": "2345",
            "X-Span": "3456",
            "X-Sampled": "1",
            "X-Flags": "1",
        })
        context, _ = self.observer.on_server_span_created.call_args[0]
        try:
            self.assertEqual(context.request_context.user.id, "test_user_id")
            self.assertEqual(context.request_context.user.roles, set())
            self.assertEqual(context.request_context.user.is_logged_in, True)
            self.assertEqual(context.request_context.user.loid, "t2_deadbeef")
            self.assertEqual(context.request_context.user.cookie_created_ms, 100000)
            self.assertEqual(context.request_context.oauth_client.id, None)
            self.assertFalse(context.request_context.oauth_client.is_type("third_party"))
            self.assertEqual(context.request_context.session.id, "beefdead")
        except jwt.exceptions.InvalidAlgorithmError:
            raise unittest.SkipTest("cryptography is not installed")

    def test_blind_passthrough(self):
        self.test_app.get("/example", headers={
            "X-Trace": "1234",
            "X-Authentication": "invalid_but_doesnt_matter",
            "X-Edge-Request": "also_invalid_but_doesnt_matter",
            "X-Parent": "2345",
            "X-Span": "3456",
            "X-Sampled": "1",
            "X-Flags": "1",
        })
        context, _ = self.observer.on_server_span_created.call_args[0]
        edge_context_headers = context.request_context.header_values()
        self.assertEqual(
            edge_context_headers["Authentication"],
            "invalid_but_doesnt_matter",
        )
        self.assertEqual(
            edge_context_headers["Edge-Request"],
            "also_invalid_but_doesnt_matter",
        )

    def test_not_found(self):
        self.test_app.get("/nope", status=404)

        self.assertFalse(self.observer.on_server_span_created.called)
        self.assertFalse(self.context_init_event_subscriber.called)

    def test_exception_caught(self):
        with self.assertRaises(TestException):
            self.test_app.get("/example?error")

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)
        _, captured_exc, _ = self.server_observer.on_finish.call_args[0][0]
        self.assertIsInstance(captured_exc, TestException)

    def test_control_flow_exception_not_caught(self):
        self.test_app.get("/example?control_flow_exception")

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)
        args, _ = self.server_observer.on_finish.call_args
        self.assertEqual(args[0], None)

    def test_exception_in_exception_view_caught(self):
        with self.assertRaises(ExceptionViewException):
            self.test_app.get("/example?exception_view_exception")

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)
        _, captured_exc, _ = self.server_observer.on_finish.call_args[0][0]
        self.assertIsInstance(captured_exc, ExceptionViewException)

    @mock.patch("random.getrandbits")
    def test_distrust_headers(self, getrandbits):
        getrandbits.return_value = 1234
        self.baseplate_configurator.trust_trace_headers = False

        self.test_app.get("/example", headers={
            "X-Trace": "1234",
            "X-Parent": "2345",
            "X-Span": "3456",
        })

        context, server_span = self.observer.on_server_span_created.call_args[0]
        self.assertEqual(server_span.trace_id, getrandbits.return_value)
        self.assertEqual(server_span.parent_id, None)
        self.assertEqual(server_span.id, getrandbits.return_value)

    def test_local_trace_in_context(self):
        self.test_app.get('/trace_context')
        self.assertEqual(self.server_observer.on_child_span_created.call_count, 1)
        child_span = self.server_observer.on_child_span_created.call_args[0][0]
        context, server_span = self.observer.on_server_span_created.call_args[0]
        self.assertNotEqual(child_span.context, context)
예제 #22
0
class TracingTests(unittest.TestCase):

    def _register_mock(self, context, server_span):
        server_span_observer = TraceServerSpanObserver('test-service',
                                                       'test-hostname',
                                                       server_span,
                                                       NullRecorder())
        server_span.register(server_span_observer)
        self.server_span_observer = server_span_observer

    def setUp(self):
        configurator = Configurator()
        configurator.add_route("example", "/example", request_method="GET")
        configurator.add_view(
            example_application, route_name="example", renderer="json")

        self.client = make_client("test-service")
        self.observer = TraceBaseplateObserver(self.client)

        self.baseplate = Baseplate()
        self.baseplate.register(self.observer)

        self.baseplate_configurator = BaseplateConfigurator(
            self.baseplate,
            trust_trace_headers=True,
        )
        configurator.include(self.baseplate_configurator.includeme)
        app = configurator.make_wsgi_app()
        self.test_app = webtest.TestApp(app)

    def test_trace_on_inbound_request(self):
        with mock.patch.object(TraceBaseplateObserver, 'on_server_span_created',
                               side_effect=self._register_mock) as mocked:

            self.test_app.get('/example')
            span = self.server_span_observer._serialize()
            self.assertEqual(span['name'], 'example')
            self.assertEqual(len(span['annotations']), 2)
            self.assertEqual(span['parentId'], 0)

    def test_configure_tracing_with_defaults_legacy_style(self):
        baseplate = Baseplate()
        self.assertEqual(0, len(baseplate.observers))
        baseplate.configure_tracing('test')
        self.assertEqual(1, len(baseplate.observers))
        tracing_observer = baseplate.observers[0]
        self.assertEqual('test',tracing_observer.service_name)

    def test_configure_tracing_with_defaults_new_style(self):
        baseplate = Baseplate()
        self.assertEqual(0, len(baseplate.observers))
        client = make_client("test")
        baseplate.configure_tracing(client)
        self.assertEqual(1, len(baseplate.observers))
        tracing_observer = baseplate.observers[0]
        self.assertEqual('test',tracing_observer.service_name)

    def test_configure_tracing_with_args(self):
        baseplate = Baseplate()
        self.assertEqual(0, len(baseplate.observers))
        baseplate.configure_tracing('test',
                                    None,
                                    max_span_queue_size=500,
                                    num_span_workers=5,
                                    span_batch_interval=0.5,
                                    num_conns=100,
                                    sample_rate=0.1)
        self.assertEqual(1, len(baseplate.observers))
        tracing_observer = baseplate.observers[0]
        self.assertEqual('test', tracing_observer.service_name)
예제 #23
0
def _create_baseplate_object(timeout: str):
    baseplate = Baseplate()

    observer = TimeoutBaseplateObserver.from_config({"server_timeout.default": timeout})
    baseplate.register(observer)
    return baseplate
예제 #24
0
class TracingTests(unittest.TestCase):
    def _register_server_mock(self, context, server_span):
        server_span_observer = TraceServerSpanObserver('test-service',
                                                       'test-hostname',
                                                       server_span,
                                                       NullRecorder())
        server_span.register(server_span_observer)
        self.server_span_observer = server_span_observer

    def _register_local_mock(self, span):
        local_span_observer = TraceLocalSpanObserver('test-service',
                                                     'test-component',
                                                     'test-hostname', span,
                                                     NullRecorder())
        self.local_span_ids.append(span.id)
        self.local_span_observers.append(local_span_observer)
        span.register(local_span_observer)

    def setUp(self):
        thread_patch = mock.patch("threading.Thread", autospec=True)
        thread_patch.start()
        self.addCleanup(thread_patch.stop)
        configurator = Configurator()
        configurator.add_route("example", "/example", request_method="GET")
        configurator.add_view(example_application,
                              route_name="example",
                              renderer="json")

        configurator.add_route("local_test",
                               "/local_test",
                               request_method="GET")
        configurator.add_view(local_parent_trace_within_context,
                              route_name="local_test",
                              renderer="json")

        self.client = make_client("test-service")
        self.observer = TraceBaseplateObserver(self.client)

        self.baseplate = Baseplate()
        self.baseplate.register(self.observer)

        self.baseplate_configurator = BaseplateConfigurator(
            self.baseplate,
            trust_trace_headers=True,
        )
        configurator.include(self.baseplate_configurator.includeme)
        app = configurator.make_wsgi_app()
        self.local_span_ids = []
        self.local_span_observers = []
        self.test_app = webtest.TestApp(app)

    def test_trace_on_inbound_request(self):
        with mock.patch.object(
                TraceBaseplateObserver,
                'on_server_span_created',
                side_effect=self._register_server_mock) as mocked:
            self.test_app.get('/example')
            span = self.server_span_observer._serialize()
            self.assertEqual(span['name'], 'example')
            self.assertEqual(len(span['annotations']), 2)
            self.assertEqual(span['parentId'], 0)

    def test_local_tracing_embedded(self):
        with mock.patch.object(TraceBaseplateObserver, 'on_server_span_created',
                               side_effect=self._register_server_mock) as mocked, \
             mock.patch.object(TraceServerSpanObserver, 'on_child_span_created',
                               side_effect=self._register_local_mock) as server_child_mocked, \
             mock.patch.object(TraceLocalSpanObserver, 'on_child_span_created',
                               side_effect=self._register_local_mock) as local_child_mocked:

            self.test_app.get('/local_test')
            # Verify that child span can be created within a local span context
            #  and parent IDs are inherited accordingly.
            span = self.local_span_observers[-1]._serialize()
            self.assertEqual(span['name'], 'local-req')
            self.assertEqual(len(span['annotations']), 0)
            self.assertEqual(span['parentId'], self.local_span_ids[-2])

    def test_configure_tracing_with_defaults_legacy_style(self):
        baseplate = Baseplate()
        self.assertEqual(0, len(baseplate.observers))
        baseplate.configure_tracing('test')
        self.assertEqual(1, len(baseplate.observers))
        tracing_observer = baseplate.observers[0]
        self.assertEqual('test', tracing_observer.service_name)

    def test_configure_tracing_with_defaults_new_style(self):
        baseplate = Baseplate()
        self.assertEqual(0, len(baseplate.observers))
        client = make_client("test")
        baseplate.configure_tracing(client)
        self.assertEqual(1, len(baseplate.observers))
        tracing_observer = baseplate.observers[0]
        self.assertEqual('test', tracing_observer.service_name)

    def test_configure_tracing_with_args(self):
        baseplate = Baseplate()
        self.assertEqual(0, len(baseplate.observers))
        baseplate.configure_tracing('test',
                                    None,
                                    max_span_queue_size=500,
                                    num_span_workers=5,
                                    span_batch_interval=0.5,
                                    num_conns=100,
                                    sample_rate=0.1)
        self.assertEqual(1, len(baseplate.observers))
        tracing_observer = baseplate.observers[0]
        self.assertEqual('test', tracing_observer.service_name)
예제 #25
0
class Globals(object):
    spec = {

        ConfigValue.int: [
            'db_pool_size',
            'db_pool_overflow_size',
            'page_cache_time',
            'commentpane_cache_time',
            'num_mc_clients',
            'MAX_CAMPAIGNS_PER_LINK',
            'MIN_DOWN_LINK',
            'MIN_UP_KARMA',
            'MIN_DOWN_KARMA',
            'MIN_RATE_LIMIT_KARMA',
            'MIN_RATE_LIMIT_COMMENT_KARMA',
            'HOT_PAGE_AGE',
            'ADMIN_COOKIE_TTL',
            'ADMIN_COOKIE_MAX_IDLE',
            'OTP_COOKIE_TTL',
            'hsts_max_age',
            'num_comments',
            'max_comments',
            'max_comments_gold',
            'max_comment_parent_walk',
            'max_sr_images',
            'num_serendipity',
            'comment_visits_period',
            'butler_max_mentions',
            'min_membership_create_community',
            'bcrypt_work_factor',
            'cassandra_pool_size',
            'sr_banned_quota',
            'sr_wikibanned_quota',
            'sr_wikicontributor_quota',
            'sr_moderator_invite_quota',
            'sr_contributor_quota',
            'sr_quota_time',
            'sr_invite_limit',
            'thumbnail_hidpi_scaling',
            'wiki_keep_recent_days',
            'wiki_max_page_length_bytes',
            'wiki_max_page_name_length',
            'wiki_max_page_separators',
            'RL_RESET_MINUTES',
            'RL_OAUTH_RESET_MINUTES',
            'comment_karma_display_floor',
            'link_karma_display_floor',
            'mobile_auth_gild_time',
            'default_total_budget_pennies',
            'min_total_budget_pennies',
            'max_total_budget_pennies',
            'default_bid_pennies',
            'min_bid_pennies',
            'max_bid_pennies',
            'frequency_cap_min',
            'frequency_cap_default',
            'eu_cookie_max_attempts',
        ],

        ConfigValue.float: [
            'statsd_sample_rate',
            'querycache_prune_chance',
            'RL_AVG_REQ_PER_SEC',
            'RL_OAUTH_AVG_REQ_PER_SEC',
            'RL_LOGIN_AVG_PER_SEC',
            'RL_LOGIN_IP_AVG_PER_SEC',
            'RL_SHARE_AVG_PER_SEC',
        ],

        ConfigValue.bool: [
            'debug',
            'log_start',
            'sqlprinting',
            'template_debug',
            'reload_templates',
            'uncompressedJS',
            'css_killswitch',
            'db_create_tables',
            'disallow_db_writes',
            'disable_ratelimit',
            'amqp_logging',
            'read_only_mode',
            'disable_wiki',
            'heavy_load_mode',
            'disable_captcha',
            'disable_ads',
            'disable_require_admin_otp',
            'trust_local_proxies',
            'shard_link_vote_queues',
            'shard_commentstree_queues',
            'authnet_validate',
            'ENFORCE_RATELIMIT',
            'RL_SITEWIDE_ENABLED',
            'RL_OAUTH_SITEWIDE_ENABLED',
            'enable_loggedout_experiments',
        ],

        ConfigValue.tuple: [
            'plugins',
            'stalecaches',
            'memcaches',
            'lockcaches',
            'permacache_memcaches',
            'memoizecaches',
            'srmembercaches',
            'relcaches',
            'ratelimitcaches',
            'hardcache_memcaches',
            'cassandra_seeds',
            'automatic_reddits',
            'hardcache_categories',
            'case_sensitive_domains',
            'known_image_domains',
            'reserved_subdomains',
            'offsite_subdomains',
            'TRAFFIC_LOG_HOSTS',
            'exempt_login_user_agents',
            'timed_templates',
            'autoexpand_media_types',
            'multi_icons',
            'hide_subscribers_srs',
            'mcrouter_addr',
        ],

        ConfigValue.tuple_of(ConfigValue.int): [
            'thumbnail_size',
            'mobile_ad_image_size',
        ],

        ConfigValue.tuple_of(ConfigValue.float): [
            'ios_versions',
            'android_versions',
        ],

        ConfigValue.dict(ConfigValue.str, ConfigValue.int): [
            'user_agent_ratelimit_regexes',
        ],

        ConfigValue.str: [
            'wiki_page_registration_info',
            'wiki_page_privacy_policy',
            'wiki_page_user_agreement',
            'wiki_page_gold_bottlecaps',
            'fraud_email',
            'feedback_email',
            'share_reply',
            'nerds_email',
            'community_email',
            'smtp_server',
            'events_collector_url',
            'events_collector_test_url',
            'search_provider',
        ],

        ConfigValue.choice(ONE=CL_ONE, QUORUM=CL_QUORUM): [
             'cassandra_rcl',
             'cassandra_wcl',
        ],

        ConfigValue.timeinterval: [
            'ARCHIVE_AGE',
            "vote_queue_grace_period",
        ],

        config_gold_price: [
            'gold_month_price',
            'gold_year_price',
            'cpm_selfserve',
            'cpm_selfserve_geotarget_metro',
            'cpm_selfserve_geotarget_country',
            'cpm_selfserve_collection',
        ],

        ConfigValue.baseplate(baseplate_config.Optional(baseplate_config.Endpoint)): [
            "activity_endpoint",
        ],

        ConfigValue.dict(ConfigValue.str, ConfigValue.str): [
            'emr_traffic_tags',
        ],
    }

    live_config_spec = {
        ConfigValue.bool: [
            'frontend_logging',
            'mobile_gild_first_login',
            'precomputed_comment_suggested_sort',
        ],
        ConfigValue.int: [
            'captcha_exempt_comment_karma',
            'captcha_exempt_link_karma',
            'create_sr_account_age_days',
            'create_sr_comment_karma',
            'create_sr_link_karma',
            'cflag_min_votes',
            'ads_popularity_threshold',
            'precomputed_comment_sort_min_comments',
            'comment_vote_update_threshold',
            'comment_vote_update_period',
        ],
        ConfigValue.float: [
            'cflag_lower_bound',
            'cflag_upper_bound',
            'spotlight_interest_sub_p',
            'spotlight_interest_nosub_p',
            'gold_revenue_goal',
            'invalid_key_sample_rate',
            'events_collector_vote_sample_rate',
            'events_collector_poison_sample_rate',
            'events_collector_mod_sample_rate',
            'events_collector_quarantine_sample_rate',
            'events_collector_modmail_sample_rate',
            'events_collector_report_sample_rate',
            'events_collector_submit_sample_rate',
            'events_collector_comment_sample_rate',
            'events_collector_use_gzip_chance',
            'https_cert_testing_probability',
            'precomputed_comment_sort_read_chance',
        ],
        ConfigValue.tuple: [
            'fastlane_links',
            'listing_chooser_sample_multis',
            'discovery_srs',
            'proxy_gilding_accounts',
            'mweb_blacklist_expressions',
            'global_loid_experiments',
            'precomputed_comment_sorts',
            'mailgun_domains',
        ],
        ConfigValue.str: [
            'listing_chooser_gold_multi',
            'listing_chooser_explore_sr',
        ],
        ConfigValue.dict(ConfigValue.int, ConfigValue.float): [
            'comment_tree_version_weights',
        ],
        ConfigValue.messages: [
            'welcomebar_messages',
            'sidebar_message',
            'gold_sidebar_message',
        ],
        ConfigValue.dict(ConfigValue.str, ConfigValue.int): [
            'ticket_groups',
            'ticket_user_fields', 
        ],
        ConfigValue.dict(ConfigValue.str, ConfigValue.float): [
            'pennies_per_server_second',
        ],
        ConfigValue.dict(ConfigValue.str, ConfigValue.str): [
            'employee_approved_clients',
            'modmail_forwarding_email',
            'modmail_account_map',
        ],
        ConfigValue.dict(ConfigValue.str, ConfigValue.choice(**PERMISSIONS)): [
            'employees',
        ],
    }

    def __init__(self, config, global_conf, app_conf, paths, **extra):
        """
        Globals acts as a container for objects available throughout
        the life of the application.

        One instance of Globals is created by Pylons during
        application initialization and is available during requests
        via the 'g' variable.

        ``config``
            The PylonsConfig object passed in from ``config/environment.py``

        ``global_conf``
            The same variable used throughout ``config/middleware.py``
            namely, the variables from the ``[DEFAULT]`` section of the
            configuration file.

        ``app_conf``
            The same ``kw`` dictionary used throughout
            ``config/middleware.py`` namely, the variables from the
            section in the config file for your application.

        ``extra``
            The configuration returned from ``load_config`` in 
            ``config/middleware.py`` which may be of use in the setup of
            your global variables.

        """

        global_conf.setdefault("debug", False)

        # reloading site ensures that we have a fresh sys.path to build our
        # working set off of. this means that forked worker processes won't get
        # the sys.path that was current when the master process was spawned
        # meaning that new plugins will be picked up on regular app reload
        # rather than having to restart the master process as well.
        reload(site)
        self.pkg_resources_working_set = pkg_resources.WorkingSet()

        self.config = ConfigValueParser(global_conf)
        self.config.add_spec(self.spec)
        self.plugins = PluginLoader(self.pkg_resources_working_set,
                                    self.config.get("plugins", []))

        self.stats = Stats(self.config.get('statsd_addr'),
                           self.config.get('statsd_sample_rate'))
        self.startup_timer = self.stats.get_timer("app_startup")
        self.startup_timer.start()

        self.baseplate = Baseplate()
        self.baseplate.configure_logging()
        self.baseplate.register(R2BaseplateObserver())

        self.paths = paths

        self.running_as_script = global_conf.get('running_as_script', False)
        
        # turn on for language support
        self.lang = getattr(self, 'site_lang', 'en')
        self.languages, self.lang_name = get_active_langs(
            config, default_lang=self.lang)

        all_languages = self.lang_name.keys()
        all_languages.sort()
        self.all_languages = all_languages
        
        # set default time zone if one is not set
        tz = global_conf.get('timezone', 'UTC')
        self.tz = pytz.timezone(tz)
        
        dtz = global_conf.get('display_timezone', tz)
        self.display_tz = pytz.timezone(dtz)

        self.startup_timer.intermediate("init")

    def __getattr__(self, name):
        if not name.startswith('_') and name in self.config:
            return self.config[name]
        else:
            raise AttributeError("g has no attr %r" % name)

    def setup(self):
        self.env = ''
        if (
            # handle direct invocation of "nosetests"
            "test" in sys.argv[0] or
            # handle "setup.py test" and all permutations thereof.
            "setup.py" in sys.argv[0] and "test" in sys.argv[1:]
        ):
            self.env = "unit_test"

        self.queues = queues.declare_queues(self)

        self.extension_subdomains = dict(
            simple="mobile",
            i="compact",
            api="api",
            rss="rss",
            xml="xml",
            json="json",
        )

        ################# PROVIDERS
        self.auth_provider = select_provider(
            self.config,
            self.pkg_resources_working_set,
            "r2.provider.auth",
            self.authentication_provider,
        )
        self.media_provider = select_provider(
            self.config,
            self.pkg_resources_working_set,
            "r2.provider.media",
            self.media_provider,
        )
        self.cdn_provider = select_provider(
            self.config,
            self.pkg_resources_working_set,
            "r2.provider.cdn",
            self.cdn_provider,
        )
        self.ticket_provider = select_provider(
            self.config,
            self.pkg_resources_working_set,
            "r2.provider.support",
            # TODO: fix this later, it refuses to pick up 
            # g.config['ticket_provider'] value, so hardcoding for now.
            # really, the next uncommented line should be:
            #self.ticket_provider,
            # instead of:
            "zendesk",
        )
        self.image_resizing_provider = select_provider(
            self.config,
            self.pkg_resources_working_set,
            "r2.provider.image_resizing",
            self.image_resizing_provider,
        )
        self.email_provider = select_provider(
            self.config,
            self.pkg_resources_working_set,
            "r2.provider.email",
            self.email_provider,
        )
        self.startup_timer.intermediate("providers")

        ################# CONFIGURATION
        # AMQP is required
        if not self.amqp_host:
            raise ValueError("amqp_host not set in the .ini")

        if not self.cassandra_seeds:
            raise ValueError("cassandra_seeds not set in the .ini")

        # heavy load mode is read only mode with a different infobar
        if self.heavy_load_mode:
            self.read_only_mode = True

        origin_prefix = self.domain_prefix + "." if self.domain_prefix else ""
        self.origin = self.default_scheme + "://" + origin_prefix + self.domain

        self.trusted_domains = set([self.domain])
        if self.https_endpoint:
            https_url = urlparse(self.https_endpoint)
            self.trusted_domains.add(https_url.hostname)

        # load the unique hashed names of files under static
        static_files = os.path.join(self.paths.get('static_files'), 'static')
        names_file_path = os.path.join(static_files, 'names.json')
        if os.path.exists(names_file_path):
            with open(names_file_path) as handle:
                self.static_names = json.load(handle)
        else:
            self.static_names = {}

        # make python warnings go through the logging system
        logging.captureWarnings(capture=True)

        log = logging.getLogger('reddit')

        # when we're a script (paster run) just set up super simple logging
        if self.running_as_script:
            log.setLevel(logging.INFO)
            log.addHandler(logging.StreamHandler())

        # if in debug mode, override the logging level to DEBUG
        if self.debug:
            log.setLevel(logging.DEBUG)

        # attempt to figure out which pool we're in and add that to the
        # LogRecords.
        try:
            with open("/etc/ec2_asg", "r") as f:
                pool = f.read().strip()
            # clean up the pool name since we're putting stuff after "-"
            pool = pool.partition("-")[0]
        except IOError:
            pool = "reddit-app"
        self.log = logging.LoggerAdapter(log, {"pool": pool})

        # set locations
        locations = pkg_resources.resource_stream(__name__,
                                                  "../data/locations.json")
        self.locations = json.loads(locations.read())

        if not self.media_domain:
            self.media_domain = self.domain
        if self.media_domain == self.domain:
            print >> sys.stderr, ("Warning: g.media_domain == g.domain. " +
                   "This may give untrusted content access to user cookies")
        if self.oauth_domain == self.domain:
            print >> sys.stderr, ("Warning: g.oauth_domain == g.domain. "
                    "CORS requests to g.domain will be allowed")

        for arg in sys.argv:
            tokens = arg.split("=")
            if len(tokens) == 2:
                k, v = tokens
                self.log.debug("Overriding g.%s to %s" % (k, v))
                setattr(self, k, v)

        self.reddit_host = socket.gethostname()
        self.reddit_pid  = os.getpid()

        if hasattr(signal, 'SIGUSR1'):
            # not all platforms have user signals
            signal.signal(signal.SIGUSR1, thread_dump)

        locale.setlocale(locale.LC_ALL, self.locale)

        # Pre-calculate ratelimit values
        self.RL_RESET_SECONDS = self.config["RL_RESET_MINUTES"] * 60
        self.RL_MAX_REQS = int(self.config["RL_AVG_REQ_PER_SEC"] *
                                      self.RL_RESET_SECONDS)

        self.RL_OAUTH_RESET_SECONDS = self.config["RL_OAUTH_RESET_MINUTES"] * 60
        self.RL_OAUTH_MAX_REQS = int(self.config["RL_OAUTH_AVG_REQ_PER_SEC"] *
                                     self.RL_OAUTH_RESET_SECONDS)

        self.RL_LOGIN_MAX_REQS = int(self.config["RL_LOGIN_AVG_PER_SEC"] *
                                     self.RL_RESET_SECONDS)
        self.RL_LOGIN_IP_MAX_REQS = int(self.config["RL_LOGIN_IP_AVG_PER_SEC"] *
                                        self.RL_RESET_SECONDS)
        self.RL_SHARE_MAX_REQS = int(self.config["RL_SHARE_AVG_PER_SEC"] *
                                     self.RL_RESET_SECONDS)

        # Compile ratelimit regexs
        user_agent_ratelimit_regexes = {}
        for agent_re, limit in self.user_agent_ratelimit_regexes.iteritems():
            user_agent_ratelimit_regexes[re.compile(agent_re)] = limit
        self.user_agent_ratelimit_regexes = user_agent_ratelimit_regexes

        self.startup_timer.intermediate("configuration")

        ################# ZOOKEEPER
        # for now, zookeeper will be an optional part of the stack.
        # if it's not configured, we will grab the expected config from the
        # [live_config] section of the ini file
        zk_hosts = self.config.get("zookeeper_connection_string")
        if zk_hosts:
            from r2.lib.zookeeper import (connect_to_zookeeper,
                                          LiveConfig, LiveList)
            zk_username = self.config["zookeeper_username"]
            zk_password = self.config["zookeeper_password"]
            self.zookeeper = connect_to_zookeeper(zk_hosts, (zk_username,
                                                             zk_password))
            self.live_config = LiveConfig(self.zookeeper, LIVE_CONFIG_NODE)
            self.secrets = fetch_secrets(self.zookeeper)
            self.throttles = LiveList(self.zookeeper, "/throttles",
                                      map_fn=ipaddress.ip_network,
                                      reduce_fn=ipaddress.collapse_addresses)

            # close our zk connection when the app shuts down
            SHUTDOWN_CALLBACKS.append(self.zookeeper.stop)
        else:
            self.zookeeper = None
            parser = ConfigParser.RawConfigParser()
            parser.optionxform = str
            parser.read([self.config["__file__"]])
            self.live_config = extract_live_config(parser, self.plugins)
            self.secrets = extract_secrets(parser)
            self.throttles = tuple()  # immutable since it's not real

        ################# PRIVILEGED USERS
        self.admins = PermissionFilteredEmployeeList(
            self.live_config, type="admin")
        self.sponsors = PermissionFilteredEmployeeList(
            self.live_config, type="sponsor")
        self.employees = PermissionFilteredEmployeeList(
            self.live_config, type="employee")

        # Store which OAuth clients employees may use, the keys are just for
        # readability.
        self.employee_approved_clients = \
            self.live_config["employee_approved_clients"].values()

        self.startup_timer.intermediate("zookeeper")

        ################# MEMCACHE
        num_mc_clients = self.num_mc_clients

        # the main memcache pool. used for most everything.
        memcaches = CMemcache(
            "main",
            self.memcaches,
            min_compress_len=1400,
            num_clients=num_mc_clients,
            validators=[validate_size_error],
        )

        # a pool just used for @memoize results
        memoizecaches = CMemcache(
            "memoize",
            self.memoizecaches,
            min_compress_len=50 * 1024,
            num_clients=num_mc_clients,
            validators=[validate_size_error],
        )

        # a pool just for srmember rels
        srmembercaches = CMemcache(
            "srmember",
            self.srmembercaches,
            min_compress_len=96,
            num_clients=num_mc_clients,
            validators=[validate_size_error],
        )

        # a pool just for rels
        relcaches = CMemcache(
            "rel",
            self.relcaches,
            min_compress_len=96,
            num_clients=num_mc_clients,
            validators=[validate_size_error],
        )

        ratelimitcaches = CMemcache(
            "ratelimit",
            self.ratelimitcaches,
            min_compress_len=96,
            num_clients=num_mc_clients,
            validators=[validate_size_error],
        )

        # a smaller pool of caches used only for distributed locks.
        self.lock_cache = CMemcache(
            "lock",
            self.lockcaches,
            num_clients=num_mc_clients,
            validators=[validate_size_error],
        )
        self.make_lock = make_lock_factory(self.lock_cache, self.stats)

        # memcaches used in front of the permacache CF in cassandra.
        # XXX: this is a legacy thing; permacache was made when C* didn't have
        # a row cache.
        permacache_memcaches = CMemcache("perma",
                                         self.permacache_memcaches,
                                         min_compress_len=1400,
                                         num_clients=num_mc_clients,
                                         validators=[],)

        # the stalecache is a memcached local to the current app server used
        # for data that's frequently fetched but doesn't need to be fresh.
        if self.stalecaches:
            stalecaches = CMemcache(
                "stale",
                self.stalecaches,
                num_clients=num_mc_clients,
                validators=[validate_size_error],
            )
        else:
            stalecaches = None

        # hardcache memcache pool
        hardcache_memcaches = CMemcache(
            "hardcache",
            self.hardcache_memcaches,
            binary=True,
            min_compress_len=1400,
            num_clients=num_mc_clients,
            validators=[validate_size_error],
        )

        self.startup_timer.intermediate("memcache")

        ################# MCROUTER
        self.mcrouter = Mcrouter(
            "mcrouter",
            self.mcrouter_addr,
            min_compress_len=1400,
            num_clients=1,
        )

        ################# THRIFT-BASED SERVICES
        activity_endpoint = self.config.get("activity_endpoint")
        if activity_endpoint:
            # make ActivityInfo objects rendercache-key friendly
            # TODO: figure out a more general solution for this if
            # we need to do this for other thrift-generated objects
            ActivityInfo.cache_key = lambda self, style: repr(self)

            activity_pool = ThriftConnectionPool(activity_endpoint, timeout=0.1)
            self.baseplate.add_to_context("activity_service",
                ThriftContextFactory(activity_pool, ActivityService.Client))

        self.startup_timer.intermediate("thrift")

        ################# CASSANDRA
        keyspace = "reddit"
        self.cassandra_pools = {
            "main":
                StatsCollectingConnectionPool(
                    keyspace,
                    stats=self.stats,
                    logging_name="main",
                    server_list=self.cassandra_seeds,
                    pool_size=self.cassandra_pool_size,
                    timeout=4,
                    max_retries=3,
                    prefill=False
                ),
        }

        permacache_cf = Permacache._setup_column_family(
            'permacache',
            self.cassandra_pools[self.cassandra_default_pool],
        )

        self.startup_timer.intermediate("cassandra")

        ################# POSTGRES
        self.dbm = self.load_db_params()
        self.startup_timer.intermediate("postgres")

        ################# CHAINS
        # initialize caches. Any cache-chains built here must be added
        # to cache_chains (closed around by reset_caches) so that they
        # can properly reset their local components
        cache_chains = {}
        localcache_cls = (SelfEmptyingCache if self.running_as_script
                          else LocalCache)

        if stalecaches:
            self.cache = StaleCacheChain(
                localcache_cls(),
                stalecaches,
                memcaches,
            )
        else:
            self.cache = CacheChain((localcache_cls(), memcaches))
        cache_chains.update(cache=self.cache)

        if stalecaches:
            self.thingcache = StaleCacheChain(
                localcache_cls(),
                stalecaches,
                self.mcrouter,
            )
        else:
            self.thingcache = CacheChain((localcache_cls(), self.mcrouter))
        cache_chains.update(thingcache=self.thingcache)

        if stalecaches:
            self.memoizecache = StaleCacheChain(
                localcache_cls(),
                stalecaches,
                memoizecaches,
            )
        else:
            self.memoizecache = MemcacheChain(
                (localcache_cls(), memoizecaches))
        cache_chains.update(memoizecache=self.memoizecache)

        if stalecaches:
            self.srmembercache = StaleCacheChain(
                localcache_cls(),
                stalecaches,
                srmembercaches,
            )
        else:
            self.srmembercache = MemcacheChain(
                (localcache_cls(), srmembercaches))
        cache_chains.update(srmembercache=self.srmembercache)

        if stalecaches:
            self.relcache = StaleCacheChain(
                localcache_cls(),
                stalecaches,
                relcaches,
            )
        else:
            self.relcache = MemcacheChain(
                (localcache_cls(), relcaches))
        cache_chains.update(relcache=self.relcache)

        self.ratelimitcache = MemcacheChain(
                (localcache_cls(), ratelimitcaches))
        cache_chains.update(ratelimitcache=self.ratelimitcache)

        # rendercache holds rendered partial templates.
        self.rendercache = MemcacheChain((
            localcache_cls(),
            self.mcrouter,
        ))
        cache_chains.update(rendercache=self.rendercache)

        # pagecaches hold fully rendered pages (includes comment panes)
        self.pagecache = MemcacheChain((
            localcache_cls(),
            self.mcrouter,
        ))
        cache_chains.update(pagecache=self.pagecache)

        # cassandra_local_cache is used for request-local caching in tdb_cassandra
        self.cassandra_local_cache = localcache_cls()
        cache_chains.update(cassandra_local_cache=self.cassandra_local_cache)

        if stalecaches:
            permacache_cache = StaleCacheChain(
                localcache_cls(),
                stalecaches,
                permacache_memcaches,
            )
        else:
            permacache_cache = CacheChain(
                (localcache_cls(), permacache_memcaches),
            )
        cache_chains.update(permacache=permacache_cache)

        self.permacache = Permacache(
            permacache_cache,
            permacache_cf,
            lock_factory=self.make_lock,
        )

        # hardcache is used for various things that tend to expire
        # TODO: replace hardcache w/ cassandra stuff
        self.hardcache = HardcacheChain(
            (localcache_cls(), hardcache_memcaches, HardCache(self)),
            cache_negative_results=True,
        )
        cache_chains.update(hardcache=self.hardcache)

        # I know this sucks, but we need non-request-threads to be
        # able to reset the caches, so we need them be able to close
        # around 'cache_chains' without being able to call getattr on
        # 'g'
        def reset_caches():
            for name, chain in cache_chains.iteritems():
                if isinstance(chain, TransitionalCache):
                    chain = chain.read_chain

                chain.reset()
                if isinstance(chain, LocalCache):
                    continue
                elif isinstance(chain, StaleCacheChain):
                    chain.stats = StaleCacheStats(self.stats, name)
                else:
                    chain.stats = CacheStats(self.stats, name)
        self.cache_chains = cache_chains

        self.reset_caches = reset_caches
        self.reset_caches()

        self.startup_timer.intermediate("cache_chains")

        # try to set the source control revision numbers
        self.versions = {}
        r2_root = os.path.dirname(os.path.dirname(self.paths["root"]))
        r2_gitdir = os.path.join(r2_root, ".git")
        self.short_version = self.record_repo_version("r2", r2_gitdir)

        if I18N_PATH:
            i18n_git_path = os.path.join(os.path.dirname(I18N_PATH), ".git")
            self.record_repo_version("i18n", i18n_git_path)

        # Initialize the amqp module globals, start the worker, etc.
        r2.lib.amqp.initialize(self)

        self.events = EventQueue()

        self.startup_timer.intermediate("revisions")

    def setup_complete(self):
        self.startup_timer.stop()
        self.stats.flush()

        if self.log_start:
            self.log.error(
                "%s:%s started %s at %s (took %.02fs)",
                self.reddit_host,
                self.reddit_pid,
                self.short_version,
                datetime.now().strftime("%H:%M:%S"),
                self.startup_timer.elapsed_seconds()
            )

        if einhorn.is_worker():
            einhorn.ack_startup()

    def record_repo_version(self, repo_name, git_dir):
        """Get the currently checked out git revision for a given repository,
        record it in g.versions, and return the short version of the hash."""
        try:
            subprocess.check_output
        except AttributeError:
            # python 2.6 compat
            pass
        else:
            try:
                revision = subprocess.check_output(["git",
                                                    "--git-dir", git_dir,
                                                    "rev-parse", "HEAD"])
            except subprocess.CalledProcessError, e:
                self.log.warning("Unable to fetch git revision: %r", e)
            else:
예제 #26
0
class ConfiguratorTests(unittest.TestCase):
    def setUp(self):
        configurator = Configurator()
        configurator.add_route("example", "/example", request_method="GET")
        configurator.add_view(example_application,
                              route_name="example",
                              renderer="json")

        self.observer = mock.Mock(spec=BaseplateObserver)
        self.root_observer = mock.Mock(spec=RootSpanObserver)

        def _register_mock(context, root_span):
            root_span.register(self.root_observer)

        self.observer.on_root_span_created.side_effect = _register_mock

        self.baseplate = Baseplate()
        self.baseplate.register(self.observer)
        self.baseplate_configurator = BaseplateConfigurator(
            self.baseplate,
            trust_trace_headers=True,
        )
        configurator.include(self.baseplate_configurator.includeme)
        app = configurator.make_wsgi_app()
        self.test_app = webtest.TestApp(app)

    def test_no_trace_headers(self):
        self.test_app.get("/example")

        self.assertEqual(self.observer.on_root_span_created.call_count, 1)

        context, root_span = self.observer.on_root_span_created.call_args[0]
        self.assertIsInstance(context, Request)
        self.assertEqual(root_span.trace_id, "no-trace")
        self.assertEqual(root_span.parent_id, "no-parent")
        self.assertEqual(root_span.id, "no-span")

        self.assertTrue(self.root_observer.on_start.called)
        self.assertTrue(self.root_observer.on_stop.called)

    def test_trace_headers(self):
        self.test_app.get("/example",
                          headers={
                              "X-Trace": "1234",
                              "X-Parent": "2345",
                              "X-Span": "3456",
                          })

        self.assertEqual(self.observer.on_root_span_created.call_count, 1)

        context, root_span = self.observer.on_root_span_created.call_args[0]
        self.assertIsInstance(context, Request)
        self.assertEqual(root_span.trace_id, "1234")
        self.assertEqual(root_span.parent_id, "2345")
        self.assertEqual(root_span.id, "3456")

        self.assertTrue(self.root_observer.on_start.called)
        self.assertTrue(self.root_observer.on_stop.called)

    def test_not_found(self):
        self.test_app.get("/nope", status=404)

        self.assertFalse(self.observer.on_root_span_created.called)

    @mock.patch("random.getrandbits")
    def test_distrust_headers(self, getrandbits):
        self.baseplate_configurator.trust_trace_headers = False

        self.test_app.get("/example",
                          headers={
                              "X-Trace": "1234",
                              "X-Parent": "2345",
                              "X-Span": "3456",
                          })

        context, root_span = self.observer.on_root_span_created.call_args[0]
        self.assertEqual(root_span.trace_id, getrandbits.return_value)
        self.assertEqual(root_span.parent_id, None)
        self.assertEqual(root_span.id, getrandbits.return_value)
예제 #27
0
class ConfiguratorTests(unittest.TestCase):
    def setUp(self):
        configurator = Configurator()
        configurator.add_route("example", "/example", request_method="GET")
        configurator.add_route("trace_context",
                               "/trace_context",
                               request_method="GET")

        configurator.add_view(example_application,
                              route_name="example",
                              renderer="json")

        configurator.add_view(local_tracing_within_context,
                              route_name="trace_context",
                              renderer="json")

        configurator.add_view(render_exception_view,
                              context=ControlFlowException,
                              renderer="json")

        configurator.add_view(render_bad_exception_view,
                              context=ControlFlowException2,
                              renderer="json")

        self.observer = mock.Mock(spec=BaseplateObserver)
        self.server_observer = mock.Mock(spec=ServerSpanObserver)

        def _register_mock(context, server_span):
            server_span.register(self.server_observer)

        self.observer.on_server_span_created.side_effect = _register_mock

        self.baseplate = Baseplate()
        self.baseplate.register(self.observer)
        self.baseplate_configurator = BaseplateConfigurator(
            self.baseplate,
            edge_context_factory=FakeEdgeContextFactory(),
            header_trust_handler=StaticTrustHandler(trust_headers=True),
        )
        configurator.include(self.baseplate_configurator.includeme)
        self.context_init_event_subscriber = mock.Mock()
        configurator.add_subscriber(self.context_init_event_subscriber,
                                    ServerSpanInitialized)
        app = configurator.make_wsgi_app()
        self.test_app = webtest.TestApp(app)

    @mock.patch("random.getrandbits")
    def test_no_trace_headers(self, getrandbits):
        getrandbits.return_value = 1234
        self.test_app.get("/example")

        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, None)
        self.assertEqual(server_span.id, 1234)

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)

    def test_trace_headers(self):
        self.test_app.get(
            "/example",
            headers={
                "X-Trace":
                "1234",
                "X-Edge-Request":
                base64.b64encode(FakeEdgeContextFactory.RAW_BYTES).decode(),
                "X-Parent":
                "2345",
                "X-Span":
                "3456",
                "X-Sampled":
                "1",
                "X-Flags":
                "1",
            },
        )

        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, 2345)
        self.assertEqual(server_span.id, 3456)
        self.assertEqual(server_span.sampled, True)
        self.assertEqual(server_span.flags, 1)

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)

    def test_edge_request_headers(self):
        self.test_app.get(
            "/example",
            headers={
                "X-Trace":
                "1234",
                "X-Edge-Request":
                base64.b64encode(FakeEdgeContextFactory.RAW_BYTES).decode(),
                "X-Parent":
                "2345",
                "X-Span":
                "3456",
                "X-Sampled":
                "1",
                "X-Flags":
                "1",
            },
        )
        context, _ = self.observer.on_server_span_created.call_args[0]
        assert context.edge_context == FakeEdgeContextFactory.DECODED_CONTEXT

    def test_empty_edge_request_headers(self):
        self.test_app.get(
            "/example",
            headers={
                "X-Trace": "1234",
                "X-Edge-Request": "",
                "X-Parent": "2345",
                "X-Span": "3456",
                "X-Sampled": "1",
                "X-Flags": "1",
            },
        )
        context, _ = self.observer.on_server_span_created.call_args[0]
        self.assertEqual(context.raw_edge_context, b"")

    def test_not_found(self):
        self.test_app.get("/nope", status=404)

        self.assertFalse(self.observer.on_server_span_created.called)
        self.assertFalse(self.context_init_event_subscriber.called)

    def test_exception_caught(self):
        with self.assertRaises(TestException):
            self.test_app.get("/example?error")

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)
        _, captured_exc, _ = self.server_observer.on_finish.call_args[0][0]
        self.assertIsInstance(captured_exc, TestException)

    def test_control_flow_exception_not_caught(self):
        self.test_app.get("/example?control_flow_exception")

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)
        args, _ = self.server_observer.on_finish.call_args
        self.assertEqual(args[0], None)

    def test_exception_in_exception_view_caught(self):
        with self.assertRaises(ExceptionViewException):
            self.test_app.get("/example?exception_view_exception")

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)
        _, captured_exc, _ = self.server_observer.on_finish.call_args[0][0]
        self.assertIsInstance(captured_exc, ExceptionViewException)

    @mock.patch("random.getrandbits")
    def test_distrust_headers(self, getrandbits):
        getrandbits.return_value = 1234
        self.baseplate_configurator.header_trust_handler.trust_headers = False

        self.test_app.get("/example",
                          headers={
                              "X-Trace": "1234",
                              "X-Parent": "2345",
                              "X-Span": "3456"
                          })

        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertEqual(server_span.trace_id, getrandbits.return_value)
        self.assertEqual(server_span.parent_id, None)
        self.assertEqual(server_span.id, getrandbits.return_value)

    def test_local_trace_in_context(self):
        self.test_app.get("/trace_context")
        self.assertEqual(self.server_observer.on_child_span_created.call_count,
                         1)
        child_span = self.server_observer.on_child_span_created.call_args[0][0]
        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertNotEqual(child_span.context, context)

    def test_streaming_response(self):
        class StreamingTestResponse(webtest.TestResponse):
            def decode_content(self):
                # keep your grubby hands off the app_iter, webtest!!!!
                pass

            @property
            def body(self):
                # seriously
                pass

        class StreamingTestRequest(webtest.TestRequest):
            ResponseClass = StreamingTestResponse

        self.test_app.RequestClass = StreamingTestRequest

        response = self.test_app.get("/example?stream")

        # ok, we've returned from the wsgi app but the iterator's not done
        # so... we should have started the span but not finished it yet
        self.assertTrue(self.server_observer.on_start.called)
        self.assertFalse(self.server_observer.on_finish.called)

        self.assertEqual(b"foo", next(response.app_iter))
        self.assertFalse(self.server_observer.on_finish.called)

        self.assertEqual(b"bar", next(response.app_iter))
        self.assertFalse(self.server_observer.on_finish.called)

        with self.assertRaises(StopIteration):
            next(response.app_iter)
        self.assertTrue(self.server_observer.on_finish.called)

        response.app_iter.close()
예제 #28
0
class ConfiguratorTests(unittest.TestCase):
    def setUp(self):
        configurator = Configurator()
        configurator.add_route("example", "/example", request_method="GET")
        configurator.add_route("trace_context", "/trace_context", request_method="GET")

        configurator.add_view(
            example_application, route_name="example", renderer="json")

        configurator.add_view(
            local_tracing_within_context, route_name="trace_context", renderer="json")

        self.observer = mock.Mock(spec=BaseplateObserver)
        self.server_observer = mock.Mock(spec=ServerSpanObserver)
        def _register_mock(context, server_span):
            server_span.register(self.server_observer)
        self.observer.on_server_span_created.side_effect = _register_mock

        self.baseplate = Baseplate()
        self.baseplate.register(self.observer)
        self.baseplate_configurator = BaseplateConfigurator(
            self.baseplate,
            trust_trace_headers=True,
        )
        configurator.include(self.baseplate_configurator.includeme)
        app = configurator.make_wsgi_app()
        self.test_app = webtest.TestApp(app)

    @mock.patch("random.getrandbits")
    def test_no_trace_headers(self, getrandbits):
        getrandbits.return_value = 1234
        self.test_app.get("/example")

        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, None)
        self.assertEqual(server_span.id, 1234)

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)

    def test_trace_headers(self):
        self.test_app.get("/example", headers={
            "X-Trace": "1234",
            "X-Parent": "2345",
            "X-Span": "3456",
            "X-Sampled": "1",
            "X-Flags": "1",
        })

        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, 2345)
        self.assertEqual(server_span.id, 3456)
        self.assertEqual(server_span.sampled, True)
        self.assertEqual(server_span.flags, 1)

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)

    def test_not_found(self):
        self.test_app.get("/nope", status=404)

        self.assertFalse(self.observer.on_server_span_created.called)

    def test_exception_caught(self):
        with self.assertRaises(TestException):
            self.test_app.get("/example?error")

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        _, captured_exc, _ = self.server_observer.on_finish.call_args[0][0]
        self.assertIsInstance(captured_exc, TestException)

    @mock.patch("random.getrandbits")
    def test_distrust_headers(self, getrandbits):
        getrandbits.return_value = 1234
        self.baseplate_configurator.trust_trace_headers = False

        self.test_app.get("/example", headers={
            "X-Trace": "1234",
            "X-Parent": "2345",
            "X-Span": "3456",
        })

        context, server_span = self.observer.on_server_span_created.call_args[0]
        self.assertEqual(server_span.trace_id, getrandbits.return_value)
        self.assertEqual(server_span.parent_id, None)
        self.assertEqual(server_span.id, getrandbits.return_value)

    def test_local_trace_in_context(self):
        self.test_app.get('/trace_context')
        self.assertEqual(self.server_observer.on_child_span_created.call_count, 1)
        child_span = self.server_observer.on_child_span_created.call_args[0][0]
        context, server_span = self.observer.on_server_span_created.call_args[0]
        self.assertNotEqual(child_span.context, context)
예제 #29
0
class TracingTests(unittest.TestCase):
    def _register_mock(self, context, server_span):
        server_span_observer = TraceServerSpanObserver('test-service',
                                                       'test-hostname',
                                                       server_span,
                                                       NullRecorder())
        server_span.register(server_span_observer)
        self.server_span_observer = server_span_observer

    def setUp(self):
        configurator = Configurator()
        configurator.add_route("example", "/example", request_method="GET")
        configurator.add_view(example_application,
                              route_name="example",
                              renderer="json")

        self.observer = TraceBaseplateObserver('test-service')

        self.baseplate = Baseplate()
        self.baseplate.register(self.observer)

        self.baseplate_configurator = BaseplateConfigurator(
            self.baseplate,
            trust_trace_headers=True,
        )
        configurator.include(self.baseplate_configurator.includeme)
        app = configurator.make_wsgi_app()
        self.test_app = webtest.TestApp(app)

    def test_trace_on_inbound_request(self):
        with mock.patch.object(TraceBaseplateObserver,
                               'on_server_span_created',
                               side_effect=self._register_mock) as mocked:

            self.test_app.get('/example')
            span = self.server_span_observer._serialize()
            self.assertEqual(span['name'], 'example')
            self.assertEqual(len(span['annotations']), 2)
            self.assertEqual(span['parentId'], 0)

    def test_configure_tracing_with_defaults(self):
        baseplate = Baseplate()
        self.assertEqual(0, len(baseplate.observers))
        baseplate.configure_tracing('test', None)
        self.assertEqual(1, len(baseplate.observers))
        tracing_observer = baseplate.observers[0]
        self.assertEqual('test', tracing_observer.service_name)

    def test_configure_tracing_with_args(self):
        baseplate = Baseplate()
        self.assertEqual(0, len(baseplate.observers))
        baseplate.configure_tracing('test',
                                    None,
                                    max_span_queue_size=500,
                                    num_span_workers=5,
                                    span_batch_interval=0.5,
                                    num_conns=100,
                                    sample_rate=0.1)
        self.assertEqual(1, len(baseplate.observers))
        tracing_observer = baseplate.observers[0]
        self.assertEqual('test', tracing_observer.service_name)
예제 #30
0
class ConfiguratorTests(unittest.TestCase):
    def setUp(self):
        configurator = Configurator()
        configurator.add_route("example", "/example", request_method="GET")
        configurator.add_route("trace_context",
                               "/trace_context",
                               request_method="GET")

        configurator.add_view(example_application,
                              route_name="example",
                              renderer="json")

        configurator.add_view(local_tracing_within_context,
                              route_name="trace_context",
                              renderer="json")

        self.observer = mock.Mock(spec=BaseplateObserver)
        self.server_observer = mock.Mock(spec=ServerSpanObserver)

        def _register_mock(context, server_span):
            server_span.register(self.server_observer)

        self.observer.on_server_span_created.side_effect = _register_mock

        self.baseplate = Baseplate()
        self.baseplate.register(self.observer)
        self.baseplate_configurator = BaseplateConfigurator(
            self.baseplate,
            trust_trace_headers=True,
        )
        configurator.include(self.baseplate_configurator.includeme)
        app = configurator.make_wsgi_app()
        self.test_app = webtest.TestApp(app)

    @mock.patch("random.getrandbits")
    def test_no_trace_headers(self, getrandbits):
        getrandbits.return_value = 1234
        self.test_app.get("/example")

        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, None)
        self.assertEqual(server_span.id, 1234)

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)

    def test_trace_headers(self):
        self.test_app.get("/example",
                          headers={
                              "X-Trace": "1234",
                              "X-Parent": "2345",
                              "X-Span": "3456",
                              "X-Sampled": "1",
                              "X-Flags": "1",
                          })

        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, 2345)
        self.assertEqual(server_span.id, 3456)
        self.assertEqual(server_span.sampled, True)
        self.assertEqual(server_span.flags, 1)

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)

    def test_not_found(self):
        self.test_app.get("/nope", status=404)

        self.assertFalse(self.observer.on_server_span_created.called)

    def test_exception_caught(self):
        with self.assertRaises(TestException):
            self.test_app.get("/example?error")

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        _, captured_exc, _ = self.server_observer.on_finish.call_args[0][0]
        self.assertIsInstance(captured_exc, TestException)

    @mock.patch("random.getrandbits")
    def test_distrust_headers(self, getrandbits):
        getrandbits.return_value = 1234
        self.baseplate_configurator.trust_trace_headers = False

        self.test_app.get("/example",
                          headers={
                              "X-Trace": "1234",
                              "X-Parent": "2345",
                              "X-Span": "3456",
                          })

        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertEqual(server_span.trace_id, getrandbits.return_value)
        self.assertEqual(server_span.parent_id, None)
        self.assertEqual(server_span.id, getrandbits.return_value)

    def test_local_trace_in_context(self):
        self.test_app.get('/trace_context')
        self.assertEqual(self.server_observer.on_child_span_created.call_count,
                         1)
        child_span = self.server_observer.on_child_span_created.call_args[0][0]
        context, server_span = self.observer.on_server_span_created.call_args[
            0]
        self.assertNotEqual(child_span.context, context)
예제 #31
0
class TracingTests(unittest.TestCase):
    def _register_server_mock(self, context, server_span):
        server_span_observer = TraceServerSpanObserver("test-service",
                                                       "test-hostname",
                                                       server_span,
                                                       NullRecorder())
        server_span.register(server_span_observer)
        self.server_span_observer = server_span_observer

    def _register_local_mock(self, span):
        local_span_observer = TraceLocalSpanObserver("test-service",
                                                     "test-component",
                                                     "test-hostname", span,
                                                     NullRecorder())
        self.local_span_ids.append(span.id)
        self.local_span_observers.append(local_span_observer)
        span.register(local_span_observer)

    def setUp(self):
        thread_patch = mock.patch("threading.Thread", autospec=True)
        thread_patch.start()
        self.addCleanup(thread_patch.stop)
        configurator = Configurator()
        configurator.add_route("example", "/example", request_method="GET")
        configurator.add_view(example_application,
                              route_name="example",
                              renderer="json")

        configurator.add_route("local_test",
                               "/local_test",
                               request_method="GET")
        configurator.add_view(local_parent_trace_within_context,
                              route_name="local_test",
                              renderer="json")

        self.client = make_client("test-service")
        self.observer = TraceBaseplateObserver(self.client)

        self.baseplate = Baseplate()
        self.baseplate.register(self.observer)

        self.baseplate_configurator = BaseplateConfigurator(
            self.baseplate,
            header_trust_handler=StaticTrustHandler(trust_headers=True),
        )
        configurator.include(self.baseplate_configurator.includeme)
        app = configurator.make_wsgi_app()
        self.local_span_ids = []
        self.local_span_observers = []
        self.test_app = webtest.TestApp(app)

    def test_trace_on_inbound_request(self):
        with mock.patch.object(TraceBaseplateObserver,
                               "on_server_span_created",
                               side_effect=self._register_server_mock):
            self.test_app.get("/example")
            span = self.server_span_observer._serialize()
            self.assertEqual(span["name"], "example")
            self.assertEqual(len(span["annotations"]), 2)
            self.assertEqual(span["parentId"], 0)

    def test_local_tracing_embedded(self):
        with mock.patch.object(
                TraceBaseplateObserver,
                "on_server_span_created",
                side_effect=self._register_server_mock), mock.patch.object(
                    TraceServerSpanObserver,
                    "on_child_span_created",
                    side_effect=self._register_local_mock), mock.patch.object(
                        TraceLocalSpanObserver,
                        "on_child_span_created",
                        side_effect=self._register_local_mock):

            self.test_app.get("/local_test")
            # Verify that child span can be created within a local span context
            #  and parent IDs are inherited accordingly.
            span = self.local_span_observers[-1]._serialize()
            self.assertEqual(span["name"], "local-req")
            self.assertEqual(len(span["annotations"]), 0)
            self.assertEqual(span["parentId"], self.local_span_ids[-2])
예제 #32
0
class ConfiguratorTests(unittest.TestCase):
    def setUp(self):
        configurator = Configurator()
        configurator.add_route("example", "/example", request_method="GET")
        configurator.add_route("trace_context", "/trace_context", request_method="GET")

        configurator.add_view(example_application, route_name="example", renderer="json")

        configurator.add_view(
            local_tracing_within_context, route_name="trace_context", renderer="json"
        )

        configurator.add_view(render_exception_view, context=ControlFlowException, renderer="json")

        configurator.add_view(
            render_bad_exception_view, context=ControlFlowException2, renderer="json"
        )

        secrets = FakeSecretsStore(
            {
                "secrets": {
                    "secret/authentication/public-key": {
                        "type": "versioned",
                        "current": AUTH_TOKEN_PUBLIC_KEY,
                    }
                },
            }
        )

        self.observer = mock.Mock(spec=BaseplateObserver)
        self.server_observer = mock.Mock(spec=ServerSpanObserver)

        def _register_mock(context, server_span):
            server_span.register(self.server_observer)

        self.observer.on_server_span_created.side_effect = _register_mock

        self.baseplate = Baseplate()
        self.baseplate.register(self.observer)
        self.baseplate_configurator = BaseplateConfigurator(
            self.baseplate,
            trust_trace_headers=True,
            edge_context_factory=EdgeRequestContextFactory(secrets),
        )
        configurator.include(self.baseplate_configurator.includeme)
        self.context_init_event_subscriber = mock.Mock()
        configurator.add_subscriber(self.context_init_event_subscriber, ServerSpanInitialized)
        app = configurator.make_wsgi_app()
        self.test_app = webtest.TestApp(app)

    @mock.patch("random.getrandbits")
    def test_no_trace_headers(self, getrandbits):
        getrandbits.return_value = 1234
        self.test_app.get("/example")

        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, None)
        self.assertEqual(server_span.id, 1234)

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)

    def test_trace_headers(self):
        self.test_app.get(
            "/example",
            headers={
                "X-Trace": "1234",
                "X-Edge-Request": base64.b64encode(SERIALIZED_EDGECONTEXT_WITH_NO_AUTH).decode(),
                "X-Parent": "2345",
                "X-Span": "3456",
                "X-Sampled": "1",
                "X-Flags": "1",
            },
        )

        self.assertEqual(self.observer.on_server_span_created.call_count, 1)

        context, server_span = self.observer.on_server_span_created.call_args[0]
        self.assertEqual(server_span.trace_id, 1234)
        self.assertEqual(server_span.parent_id, 2345)
        self.assertEqual(server_span.id, 3456)
        self.assertEqual(server_span.sampled, True)
        self.assertEqual(server_span.flags, 1)

        with self.assertRaises(NoAuthenticationError):
            context.request_context.user.id

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)

    def test_edge_request_headers(self):
        self.test_app.get(
            "/example",
            headers={
                "X-Trace": "1234",
                "X-Edge-Request": base64.b64encode(SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH).decode(),
                "X-Parent": "2345",
                "X-Span": "3456",
                "X-Sampled": "1",
                "X-Flags": "1",
            },
        )
        context, _ = self.observer.on_server_span_created.call_args[0]
        try:
            self.assertEqual(context.request_context.user.id, "t2_example")
            self.assertEqual(context.request_context.user.roles, set())
            self.assertEqual(context.request_context.user.is_logged_in, True)
            self.assertEqual(context.request_context.user.loid, "t2_deadbeef")
            self.assertEqual(context.request_context.user.cookie_created_ms, 100000)
            self.assertEqual(context.request_context.oauth_client.id, None)
            self.assertFalse(context.request_context.oauth_client.is_type("third_party"))
            self.assertEqual(context.request_context.session.id, "beefdead")
        except jwt.exceptions.InvalidAlgorithmError:
            raise unittest.SkipTest("cryptography is not installed")

    def test_empty_edge_request_headers(self):
        self.test_app.get(
            "/example",
            headers={
                "X-Trace": "1234",
                "X-Edge-Request": "",
                "X-Parent": "2345",
                "X-Span": "3456",
                "X-Sampled": "1",
                "X-Flags": "1",
            },
        )
        context, _ = self.observer.on_server_span_created.call_args[0]
        self.assertEqual(context.raw_request_context, b"")

    def test_not_found(self):
        self.test_app.get("/nope", status=404)

        self.assertFalse(self.observer.on_server_span_created.called)
        self.assertFalse(self.context_init_event_subscriber.called)

    def test_exception_caught(self):
        with self.assertRaises(TestException):
            self.test_app.get("/example?error")

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)
        _, captured_exc, _ = self.server_observer.on_finish.call_args[0][0]
        self.assertIsInstance(captured_exc, TestException)

    def test_control_flow_exception_not_caught(self):
        self.test_app.get("/example?control_flow_exception")

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)
        args, _ = self.server_observer.on_finish.call_args
        self.assertEqual(args[0], None)

    def test_exception_in_exception_view_caught(self):
        with self.assertRaises(ExceptionViewException):
            self.test_app.get("/example?exception_view_exception")

        self.assertTrue(self.server_observer.on_start.called)
        self.assertTrue(self.server_observer.on_finish.called)
        self.assertTrue(self.context_init_event_subscriber.called)
        _, captured_exc, _ = self.server_observer.on_finish.call_args[0][0]
        self.assertIsInstance(captured_exc, ExceptionViewException)

    @mock.patch("random.getrandbits")
    def test_distrust_headers(self, getrandbits):
        getrandbits.return_value = 1234
        self.baseplate_configurator.header_trust_handler.trust_headers = False

        self.test_app.get(
            "/example", headers={"X-Trace": "1234", "X-Parent": "2345", "X-Span": "3456"}
        )

        context, server_span = self.observer.on_server_span_created.call_args[0]
        self.assertEqual(server_span.trace_id, getrandbits.return_value)
        self.assertEqual(server_span.parent_id, None)
        self.assertEqual(server_span.id, getrandbits.return_value)

    def test_local_trace_in_context(self):
        self.test_app.get("/trace_context")
        self.assertEqual(self.server_observer.on_child_span_created.call_count, 1)
        child_span = self.server_observer.on_child_span_created.call_args[0][0]
        context, server_span = self.observer.on_server_span_created.call_args[0]
        self.assertNotEqual(child_span.context, context)

    def test_streaming_response(self):
        class StreamingTestResponse(webtest.TestResponse):
            def decode_content(self):
                # keep your grubby hands off the app_iter, webtest!!!!
                pass

            @property
            def body(self):
                # seriously
                pass

        class StreamingTestRequest(webtest.TestRequest):
            ResponseClass = StreamingTestResponse

        self.test_app.RequestClass = StreamingTestRequest

        response = self.test_app.get("/example?stream")

        # ok, we've returned from the wsgi app but the iterator's not done
        # so... we should have started the span but not finished it yet
        self.assertTrue(self.server_observer.on_start.called)
        self.assertFalse(self.server_observer.on_finish.called)

        self.assertEqual(b"foo", next(response.app_iter))
        self.assertFalse(self.server_observer.on_finish.called)

        self.assertEqual(b"bar", next(response.app_iter))
        self.assertFalse(self.server_observer.on_finish.called)

        with self.assertRaises(StopIteration):
            next(response.app_iter)
        self.assertTrue(self.server_observer.on_finish.called)

        response.app_iter.close()
예제 #33
0
class ConfiguratorTests(unittest.TestCase):
    def setUp(self):
        configurator = Configurator()
        configurator.add_route("example", "/example", request_method="GET")
        configurator.add_view(
            example_application, route_name="example", renderer="json")

        self.observer = mock.Mock(spec=BaseplateObserver)

        self.baseplate = Baseplate()
        self.baseplate.register(self.observer)
        self.baseplate_configurator = BaseplateConfigurator(
            self.baseplate,
            trust_trace_headers=True,
        )
        configurator.include(self.baseplate_configurator.includeme)
        app = configurator.make_wsgi_app()
        self.test_app = webtest.TestApp(app)

    def test_no_trace_headers(self):
        self.test_app.get("/example")

        self.assertEqual(self.observer.on_root_span_created.call_count, 1)

        context, root_span = self.observer.on_root_span_created.call_args[0]
        self.assertIsInstance(context, Request)
        self.assertEqual(root_span.trace_id, "no-trace")
        self.assertEqual(root_span.parent_id, "no-parent")
        self.assertEqual(root_span.id, "no-span")

        mock_root_observer = self.observer.on_root_span_created.return_value
        self.assertTrue(mock_root_observer.on_start.called)
        self.assertTrue(mock_root_observer.on_stop.called)

    def test_trace_headers(self):
        self.test_app.get("/example", headers={
            "X-Trace": "1234",
            "X-Parent": "2345",
            "X-Span": "3456",
        })

        self.assertEqual(self.observer.on_root_span_created.call_count, 1)

        context, root_span = self.observer.on_root_span_created.call_args[0]
        self.assertIsInstance(context, Request)
        self.assertEqual(root_span.trace_id, "1234")
        self.assertEqual(root_span.parent_id, "2345")
        self.assertEqual(root_span.id, "3456")

        mock_root_observer = self.observer.on_root_span_created.return_value
        self.assertTrue(mock_root_observer.on_start.called)
        self.assertTrue(mock_root_observer.on_stop.called)

    def test_not_found(self):
        self.test_app.get("/nope", status=404)

        self.assertFalse(self.observer.on_root_span_created.called)

    @mock.patch("random.getrandbits")
    def test_distrust_headers(self, getrandbits):
        self.baseplate_configurator.trust_trace_headers = False

        self.test_app.get("/example", headers={
            "X-Trace": "1234",
            "X-Parent": "2345",
            "X-Span": "3456",
        })

        context, root_span = self.observer.on_root_span_created.call_args[0]
        self.assertEqual(root_span.trace_id, getrandbits.return_value)
        self.assertEqual(root_span.parent_id, None)
        self.assertEqual(root_span.id, getrandbits.return_value)
예제 #34
0
def baseplate_app():
    baseplate = Baseplate()
    baseplate.register(SentryBaseplateObserver())
    return baseplate