Exemple #1
0
def baseplate_thrift_client(endpoint, client_spec, client_span_observer=None):
    baseplate = Baseplate(
        app_config={
            "baseplate.service_name": "fancy test client",
            "example_service.endpoint": str(endpoint),
        }
    )

    if client_span_observer:

        class TestServerSpanObserver(ServerSpanObserver):
            def on_child_span_created(self, span):
                span.register(client_span_observer)

        observer = TestServerSpanObserver()

        class TestBaseplateObserver(BaseplateObserver):
            def on_server_span_created(self, context, span):
                span.register(observer)

        baseplate.register(TestBaseplateObserver())

    context = baseplate.make_context_object()
    trace_info = TraceInfo.from_upstream(
        trace_id="1234", parent_id="2345", span_id="3456", flags=4567, sampled=True
    )

    baseplate.configure_context({"example_service": ThriftClient(client_spec.Client)})

    baseplate.make_server_span(context, "example_service.example", trace_info)

    context.raw_edge_context = FakeEdgeContextFactory.RAW_BYTES

    yield context
Exemple #2
0
    def test_null_server_observer(self):
        baseplate = Baseplate()
        mock_context = baseplate.make_context_object()
        mock_observer = mock.Mock(spec=BaseplateObserver)
        mock_observer.on_server_span_created.return_value = None
        baseplate.register(mock_observer)
        server_span = baseplate.make_server_span(mock_context, "name", TraceInfo(1, 2, 3, None, 0))

        self.assertEqual(server_span.observers, [])
Exemple #3
0
 def _get_trace_info(self, headers: Mapping[str, str]) -> TraceInfo:
     sampled = bool(headers.get("X-Sampled") == "1")
     flags = headers.get("X-Flags", None)
     return TraceInfo.from_upstream(
         int(headers["X-Trace"]),
         int(headers["X-Parent"]),
         int(headers["X-Span"]),
         sampled,
         int(flags) if flags is not None else None,
     )
Exemple #4
0
        def call_processor_with_span_context(self: Any, seqid: int,
                                             iprot: TProtocolBase,
                                             oprot: TProtocolBase) -> Any:
            context = baseplate.make_context_object()

            # Allow case-insensitivity for THeader headers
            headers: Mapping[bytes,
                             bytes] = CaseInsensitiveDict(  # type: ignore
                                 data=iprot.get_headers())

            trace_info: Optional[TraceInfo]
            try:
                sampled = bool(headers.get(b"Sampled") == b"1")
                flags = headers.get(b"Flags", None)
                trace_info = TraceInfo.from_upstream(
                    headers[b"Trace"].decode(),
                    headers[b"Parent"].decode(),
                    headers[b"Span"].decode(),
                    sampled,
                    int(flags) if flags is not None else None,
                )
            except (KeyError, ValueError):
                trace_info = None

            edge_payload = headers.get(b"Edge-Request", None)
            context.raw_edge_context = edge_payload
            if edge_context_factory:
                context.edge_context = edge_context_factory.from_upstream(
                    edge_payload)

            try:
                raw_deadline_budget = headers[b"Deadline-Budget"].decode()
                context.deadline_budget = float(raw_deadline_budget) / 1000
            except (KeyError, ValueError):
                context.deadline_budget = None

            span = baseplate.make_server_span(context,
                                              name=fn_name,
                                              trace_info=trace_info)

            try:
                service_name = headers[b"User-Agent"].decode()
            except (KeyError, UnicodeDecodeError):
                pass
            else:
                span.set_tag("peer.service", service_name)

            context.headers = headers

            handler = processor._handler
            context_aware_handler = _ContextAwareHandler(
                handler, context, logger)
            context_aware_processor = processor.__class__(
                context_aware_handler)
            return processor_fn(context_aware_processor, seqid, iprot, oprot)
Exemple #5
0
    def test_server_observer_made(self):
        baseplate = Baseplate()
        mock_context = baseplate.make_context_object()
        mock_observer = mock.Mock(spec=BaseplateObserver)
        baseplate.register(mock_observer)
        server_span = baseplate.make_server_span(mock_context, "name",
                                                 TraceInfo(1, 2, 3, None, 0))

        self.assertEqual(baseplate.observers, [mock_observer])
        self.assertEqual(mock_observer.on_server_span_created.call_count, 1)
        self.assertEqual(mock_observer.on_server_span_created.call_args,
                         mock.call(mock_context, server_span))
Exemple #6
0
        def call_processor_with_span_context(self: Any, seqid: int,
                                             iprot: TProtocolBase,
                                             oprot: TProtocolBase) -> Any:
            context = baseplate.make_context_object()

            # Allow case-insensitivity for THeader headers
            headers: Mapping[bytes,
                             bytes] = CaseInsensitiveDict(  # type: ignore
                                 data=iprot.get_headers())

            trace_info: Optional[TraceInfo]
            try:
                sampled = bool(headers.get(b"Sampled") == b"1")
                flags = headers.get(b"Flags", None)
                trace_info = TraceInfo.from_upstream(
                    int(headers[b"Trace"]),
                    int(headers[b"Parent"]),
                    int(headers[b"Span"]),
                    sampled,
                    int(flags) if flags is not None else None,
                )
            except (KeyError, ValueError):
                trace_info = None

            edge_payload = headers.get(b"Edge-Request", None)
            if edge_context_factory:
                edge_context = edge_context_factory.from_upstream(edge_payload)
                edge_context.attach_context(context)
            else:
                # just attach the raw context so it gets passed on
                # downstream even if we don't know how to handle it.
                context.raw_request_context = edge_payload

            baseplate.make_server_span(context,
                                       name=fn_name,
                                       trace_info=trace_info)

            context.headers = headers

            handler = processor._handler
            context_aware_handler = _ContextAwareHandler(
                handler, context, logger)
            context_aware_processor = processor.__class__(
                context_aware_handler)
            return processor_fn(context_aware_processor, seqid, iprot, oprot)
Exemple #7
0
def baseplate_thrift_client(endpoint, client_spec, client_span_observer=None):
    baseplate = Baseplate(
        app_config={
            "baseplate.service_name": "fancy test client",
            "example_service.endpoint": str(endpoint),
        })

    if client_span_observer:

        class TestServerSpanObserver(ServerSpanObserver):
            def on_child_span_created(self, span):
                span.register(client_span_observer)

        observer = TestServerSpanObserver()

        class TestBaseplateObserver(BaseplateObserver):
            def on_server_span_created(self, context, span):
                span.register(observer)

        baseplate.register(TestBaseplateObserver())

    context = baseplate.make_context_object()
    trace_info = TraceInfo.from_upstream(trace_id=1234,
                                         parent_id=2345,
                                         span_id=3456,
                                         flags=4567,
                                         sampled=True)

    baseplate.configure_context(
        {"example_service": ThriftClient(client_spec.Client)})

    baseplate.make_server_span(context, "example_service.example", trace_info)

    edge_context_factory = make_edge_context_factory()
    edge_context = edge_context_factory.from_upstream(
        SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH)
    edge_context.attach_context(context)

    yield context
Exemple #8
0
 def test_from_upstream_handles_no_sampled_or_flags(self):
     span = TraceInfo.from_upstream(1, 2, 3, None, None)
     self.assertIsNone(span.sampled)
     self.assertIsNone(span.flags)
Exemple #9
0
 def test_from_upstream_fails_on_invalid_flags(self):
     with self.assertRaises(ValueError) as e:
         TraceInfo.from_upstream(1, 2, 3, True, -1)
     self.assertEqual(str(e.exception), "invalid flags value")
Exemple #10
0
 def test_from_upstream_fails_on_invalid_sampled(self):
     with self.assertRaises(ValueError) as e:
         TraceInfo.from_upstream(1, 2, 3, "True", None)
     self.assertEqual(str(e.exception), "invalid sampled value")
Exemple #11
0
 def test_new_does_not_set_sampled(self):
     new_trace_info = TraceInfo.new()
     self.assertIsNone(new_trace_info.sampled)
Exemple #12
0
 def test_new_does_not_set_flags(self):
     new_trace_info = TraceInfo.new()
     self.assertIsNone(new_trace_info.flags)
Exemple #13
0
 def test_new_does_not_have_parent_id(self):
     new_trace_info = TraceInfo.new()
     self.assertIsNone(new_trace_info.parent_id)