Exemple #1
0
    def test_extract_single_header(self):
        """Test the extraction from a single b3 header."""
        propagator = self.get_propagator()
        child, parent, _ = self.get_child_parent_new_carrier({
            propagator.SINGLE_HEADER_KEY:
            f"{self.serialized_trace_id}-{self.serialized_span_id}"
        })

        self.assertEqual(
            self.serialized_trace_id,
            trace_api.format_trace_id(child.context.trace_id),
        )
        self.assertEqual(
            self.serialized_span_id,
            trace_api.format_span_id(child.parent.span_id),
        )
        self.assertTrue(parent.context.is_remote)
        self.assertTrue(parent.context.trace_flags.sampled)

        child, parent, _ = self.get_child_parent_new_carrier({
            propagator.SINGLE_HEADER_KEY:
            f"{self.serialized_trace_id}-{self.serialized_span_id}-1"
        })

        self.assertEqual(
            self.serialized_trace_id,
            trace_api.format_trace_id(child.context.trace_id),
        )
        self.assertEqual(
            self.serialized_span_id,
            trace_api.format_span_id(child.parent.span_id),
        )

        self.assertTrue(parent.context.is_remote)
        self.assertTrue(parent.context.trace_flags.sampled)
Exemple #2
0
def test_s3_event_handler_trace_propagation(create_event_handler,
                                            s3_get_object_mock,
                                            create_s3_event, mocker):
    trace_export = mocker.patch(
        'cdk_example_app.common.tracing.tracer.logger_span_exporter.export')
    handler, handler_args, _ = create_event_handler()
    event = create_s3_event(
        ['my-key-with-parse-error', 'my-json-key-with-trace'])
    handler(event, Context(function_name='my-lambda'))

    span_1 = trace_export.call_args_list[0].args[0][0]
    assert span_1.attributes == {
        's3Bucket': 'my-bucket',
        's3Key': 'my-key-with-parse-error'
    }
    assert format_trace_id(span_1.context.trace_id) != TRACE_ID
    assert span_1.kind == SpanKind.SERVER
    assert not span_1.status.is_ok
    assert span_1.status.description == 'S3 object my-key-with-parse-error in bucket my-bucket is not valid Json: ' \
                                        'Expecting value: line 1 column 1 (char 0)'

    span_2 = trace_export.call_args_list[1].args[0][0]
    assert span_2.attributes == {
        's3Bucket': 'my-bucket',
        's3Key': 'my-json-key-with-trace',
        'func_key': "{'foo': 'bar'}"
    }
    assert format_trace_id(span_2.context.trace_id) == TRACE_ID
    assert span_2.kind == SpanKind.SERVER
    assert span_2.status.is_ok
Exemple #3
0
    async def test_trace_parent(self):
        id_generator = RandomIdGenerator()
        trace_id = format_trace_id(id_generator.generate_trace_id())
        span_id = format_span_id(id_generator.generate_span_id())
        traceparent_value = f"00-{trace_id}-{span_id}-01"

        await self.async_client.get(
            "/span_name/1234/",
            traceparent=traceparent_value,
        )
        span = self.memory_exporter.get_finished_spans()[0]

        self.assertEqual(
            trace_id,
            format_trace_id(span.get_span_context().trace_id),
        )
        self.assertIsNotNone(span.parent)
        self.assertEqual(
            trace_id,
            format_trace_id(span.parent.trace_id),
        )
        self.assertEqual(
            span_id,
            format_span_id(span.parent.span_id),
        )
        self.memory_exporter.clear()
Exemple #4
0
    def inject(
            self,
            carrier: textmap.CarrierT,
            context: typing.Optional[Context] = None,
            setter: textmap.Setter = default_setter,  # type: ignore
    ) -> None:
        """Injects SpanContext into the HTTP response carrier."""
        span = trace.get_current_span(context)
        span_context = span.get_span_context()
        if span_context == trace.INVALID_SPAN_CONTEXT:
            return

        header_name = "Server-Timing"
        flags = span_context.trace_flags
        trace_id = format_trace_id(span_context.trace_id)
        span_id = format_span_id(span_context.span_id)

        setter.set(
            carrier,
            header_name,
            f'traceparent;desc="00-{trace_id}-{span_id}-{flags:02x}"',
        )
        setter.set(
            carrier,
            _HTTP_HEADER_ACCESS_CONTROL_EXPOSE_HEADERS,
            header_name,
        )
Exemple #5
0
def _format_uber_trace_id(trace_id, span_id, parent_span_id, flags):
    return "{trace_id}:{span_id}:{parent_id}:{:02x}".format(
        flags,
        trace_id=format_trace_id(trace_id),
        span_id=format_span_id(span_id),
        parent_id=format_span_id(parent_span_id),
    )
    def test_trace_response(self):
        orig = get_global_response_propagator()

        set_global_response_propagator(TraceResponsePropagator())
        response = self.client.get("/hello/123")
        headers = response.headers

        span_list = self.memory_exporter.get_finished_spans()
        self.assertEqual(len(span_list), 1)
        span = span_list[0]

        self.assertIn("traceresponse", headers)
        self.assertEqual(
            headers["access-control-expose-headers"],
            "traceresponse",
        )
        self.assertEqual(
            headers["traceresponse"],
            "00-{0}-{1}-01".format(
                trace.format_trace_id(span.get_span_context().trace_id),
                trace.format_span_id(span.get_span_context().span_id),
            ),
        )

        set_global_response_propagator(orig)
Exemple #7
0
    def inject(
        self,
        carrier: CarrierT,
        context: typing.Optional[Context] = None,
        setter: Setter = default_setter,
    ) -> None:
        span = trace.get_current_span(context=context)

        span_context = span.get_span_context()
        if span_context == trace.INVALID_SPAN_CONTEXT:
            return

        sampled = (trace.TraceFlags.SAMPLED & span_context.trace_flags) != 0
        setter.set(
            carrier,
            self.TRACE_ID_KEY,
            format_trace_id(span_context.trace_id),
        )
        setter.set(carrier, self.SPAN_ID_KEY,
                   format_span_id(span_context.span_id))
        span_parent = getattr(span, "parent", None)
        if span_parent is not None:
            setter.set(
                carrier,
                self.PARENT_SPAN_ID_KEY,
                format_span_id(span_parent.span_id),
            )
        setter.set(carrier, self.SAMPLED_KEY, "1" if sampled else "0")
    def inject(
        self,
        carrier: textmap.CarrierT,
        context: typing.Optional[Context] = None,
        setter: textmap.Setter = default_setter,
    ) -> None:
        """Injects SpanContext into the HTTP response carrier."""
        span = trace.get_current_span(context)
        span_context = span.get_span_context()
        if span_context == trace.INVALID_SPAN_CONTEXT:
            return

        header_name = "traceresponse"
        setter.set(
            carrier,
            header_name,
            "00-{trace_id}-{span_id}-{:02x}".format(
                span_context.trace_flags,
                trace_id=format_trace_id(span_context.trace_id),
                span_id=format_span_id(span_context.span_id),
            ),
        )
        setter.set(
            carrier,
            _HTTP_HEADER_ACCESS_CONTROL_EXPOSE_HEADERS,
            header_name,
        )
Exemple #9
0
    def inject(
        self,
        carrier: CarrierT,
        context: typing.Optional[Context] = None,
        setter: Setter = default_setter,
    ) -> None:
        span = trace.get_current_span(context=context)

        span_context = span.get_span_context()
        if span_context == trace.INVALID_SPAN_CONTEXT:
            return

        sampled = (trace.TraceFlags.SAMPLED & span_context.trace_flags) != 0

        fields = [
            format_trace_id(span_context.trace_id),
            format_span_id(span_context.span_id),
            "1" if sampled else "0",
        ]

        span_parent = getattr(span, "parent", None)
        if span_parent:
            fields.append(format_span_id(span_parent.span_id))

        setter.set(carrier, self.SINGLE_HEADER_KEY, "-".join(fields))
    def test_traceresponse_header(self):
        """Test a traceresponse header is sent when a global propagator is set."""

        orig = get_global_response_propagator()
        set_global_response_propagator(TraceResponsePropagator())

        app = otel_asgi.OpenTelemetryMiddleware(simple_asgi)
        self.seed_app(app)
        self.send_default_request()

        span = self.memory_exporter.get_finished_spans()[-1]
        self.assertEqual(trace_api.SpanKind.SERVER, span.kind)

        response_start, response_body, *_ = self.get_all_output()
        self.assertEqual(response_body["body"], b"*")
        self.assertEqual(response_start["status"], 200)

        traceresponse = "00-{0}-{1}-01".format(
            format_trace_id(span.get_span_context().trace_id),
            format_span_id(span.get_span_context().span_id),
        )
        self.assertListEqual(
            response_start["headers"],
            [
                [b"Content-Type", b"text/plain"],
                [b"traceresponse", f"{traceresponse}".encode()],
                [b"access-control-expose-headers", b"traceresponse"],
            ],
        )

        set_global_response_propagator(orig)
Exemple #11
0
    def test_response_headers(self):
        orig = get_global_response_propagator()
        set_global_response_propagator(TraceResponsePropagator())

        response = self.fetch("/")
        headers = response.headers

        spans = self.sorted_spans(self.memory_exporter.get_finished_spans())
        self.assertEqual(len(spans), 3)
        server_span = spans[1]

        self.assertIn("traceresponse", headers)
        self.assertEqual(
            headers["access-control-expose-headers"],
            "traceresponse",
        )
        self.assertEqual(
            headers["traceresponse"],
            "00-{0}-{1}-01".format(
                trace.format_trace_id(server_span.get_span_context().trace_id),
                trace.format_span_id(server_span.get_span_context().span_id),
            ),
        )

        self.memory_exporter.clear()
        set_global_response_propagator(orig)
def _extract_links(links: Sequence[trace_api.Link]) -> trace_pb2.Span.Links:
    """Convert span.links"""
    if not links:
        return None
    extracted_links = []
    dropped_links = 0
    if len(links) > MAX_NUM_LINKS:
        logger.warning(
            "Exporting more then %s links, some will be truncated",
            MAX_NUM_LINKS,
        )
        dropped_links = len(links) - MAX_NUM_LINKS
        links = links[:MAX_NUM_LINKS]
    for link in links:
        link_attributes = link.attributes or {}
        if len(link_attributes) > MAX_LINK_ATTRS:
            logger.warning(
                "Link has more then %s attributes, some will be truncated",
                MAX_LINK_ATTRS,
            )
        trace_id = format_trace_id(link.context.trace_id)
        span_id = format_span_id(link.context.span_id)
        extracted_links.append(
            {
                "trace_id": trace_id,
                "span_id": span_id,
                "type": "TYPE_UNSPECIFIED",
                "attributes": _extract_attributes(
                    link_attributes, MAX_LINK_ATTRS
                ),
            }
        )
    return trace_pb2.Span.Links(
        link=extracted_links, dropped_links_count=dropped_links
    )
Exemple #13
0
        def wrapper(event, context):
            for record in event['Records']:
                s3_bucket = record['s3']['bucket']['name']
                s3_key = record['s3']['object']['key']

                with event_log(s3_bucket, s3_key, context.function_name, logger) as event_log_:
                    s3_obj = s3_client().get_object(Bucket=s3_bucket, Key=s3_key)
                    with start_s3_root_span(context.function_name, s3_obj, s3_bucket, s3_key) as span:
                        trace_id = format_trace_id(span.get_span_context().trace_id)
                        logger.append_keys(traceId=trace_id)
                        event_log_.trace_id = trace_id
                        # TODO: gzip support
                        if parse_json:
                            try:
                                body = json.loads(s3_obj['Body'].read())
                            except JSONDecodeError as e:
                                msg = f'S3 object {s3_key} in bucket {s3_bucket} is not valid Json: {e}'
                                event_log_.mark_failed(msg)
                                span.set_status(Status(StatusCode.ERROR, msg))
                                continue
                        else:
                            body = s3_obj['Body'].read().decode(s3_obj.get('ContentEncoding', 'utf-8'))

                        if functional_key_extractor:
                            functional_key_name, functional_key_value = functional_key_extractor(body)
                            logger.append_keys(**{functional_key_name: functional_key_value})
                            span.set_attribute(functional_key_name, functional_key_value)
                            event_log_.set_functional_key(functional_key_name, functional_key_value)

                        # invoke wrapped function
                        func(body, s3_obj=s3_obj, record=record)
 def setUpClass(cls):
     generator = id_generator.RandomIdGenerator()
     cls.serialized_trace_id = trace_api.format_trace_id(
         generator.generate_trace_id())
     cls.serialized_span_id = trace_api.format_span_id(
         generator.generate_span_id())
     cls.serialized_parent_id = trace_api.format_span_id(
         generator.generate_span_id())
 def _format_context(context):
     x_ctx = OrderedDict()
     x_ctx["trace_id"] = "0x{}".format(
         trace_api.format_trace_id(context.trace_id))
     x_ctx["span_id"] = "0x{}".format(
         trace_api.format_span_id(context.span_id))
     x_ctx["trace_state"] = repr(context.trace_state)
     return x_ctx
def _extract_refs_from_span(span):
    refs = []

    ctx = span.get_context()
    trace_id = ctx.trace_id
    p_span_id = ctx.span_id
    for link in span.links:
        l_trace_id = link.context.trace_id
        l_span_id = link.context.span_id
        ref = {
            'trace.trace_id': trace_api.format_trace_id(trace_id)[2:],
            'trace.parent_id': trace_api.format_span_id(p_span_id)[2:],
            'trace.link.trace_id': trace_api.format_trace_id(l_trace_id)[2:],
            'trace.link.span_id': trace_api.format_span_id(l_span_id)[2:],
            'meta.span_type': 'link',
            'ref_type': 0,
        }
        ref.update(link.attributes)
        refs.append(ref)
    return refs
Exemple #17
0
    def assertTraceResponseHeaderMatchesSpan(self, headers, span):  # pylint: disable=invalid-name
        self.assertIn("traceresponse", headers)
        self.assertEqual(
            headers["access-control-expose-headers"],
            "traceresponse",
        )

        trace_id = trace.format_trace_id(span.get_span_context().trace_id)
        span_id = trace.format_span_id(span.get_span_context().span_id)
        self.assertEqual(
            f"00-{trace_id}-{span_id}-01",
            headers["traceresponse"],
        )
Exemple #18
0
    def test_encode_id_zero_padding(self):
        trace_id = 0x0E0C63257DE34C926F9EFCD03927272E
        span_id = 0x04BF92DEEFC58C92
        parent_id = 0x0AAAAAAAAAAAAAAA
        start_time = 683647322 * 10**9  # in ns
        duration = 50 * 10**6
        end_time = start_time + duration

        otel_span = trace._Span(
            name=TEST_SERVICE_NAME,
            context=trace_api.SpanContext(
                trace_id,
                span_id,
                is_remote=False,
                trace_flags=TraceFlags(TraceFlags.SAMPLED),
            ),
            parent=trace_api.SpanContext(trace_id, parent_id, is_remote=False),
            resource=trace.Resource({}),
        )
        otel_span.start(start_time=start_time)
        otel_span.end(end_time=end_time)

        expected_output = [{
            "traceId":
            format_trace_id(trace_id),
            "id":
            format_span_id(span_id),
            "name":
            TEST_SERVICE_NAME,
            "timestamp":
            JsonV1Encoder._nsec_to_usec_round(start_time),
            "duration":
            JsonV1Encoder._nsec_to_usec_round(duration),
            "debug":
            True,
            "parentId":
            format_span_id(parent_id),
        }]

        self.assertEqual(
            json.dumps(expected_output),
            JsonV1Encoder().serialize([otel_span], NodeEndpoint()),
        )
def _translate_to_hny(spans):
    hny_data = []
    for span in spans:
        ctx = span.get_context()
        trace_id = ctx.trace_id
        span_id = ctx.span_id
        duration_ns = span.end_time - span.start_time
        d = {
            'trace.trace_id':
            trace_api.format_trace_id(trace_id)[2:],
            'trace.span_id':
            trace_api.format_span_id(span_id)[2:],
            'name':
            span.name,
            'start_time':
            datetime.datetime.utcfromtimestamp(span.start_time / float(1e9)),
            'duration_ms':
            duration_ns / float(1e6),  # nanoseconds to ms
            'response.status_code':
            span.status.canonical_code.value,
            'status.message':
            span.status.description,
            'span.kind':
            span.kind.name,  # meta.span_type?
        }
        if isinstance(span.parent, trace_api.Span):
            d['trace.parent_id'] = trace_api.format_span_id(
                span.parent.get_context().span_id)[2:]
        elif isinstance(span.parent, trace_api.SpanContext):
            d['trace.parent_id'] = trace_api.format_span_id(
                span.parent.span_id)[2:]
        # TODO: use sampling_decision attributes for sample rate.
        d.update(span.attributes)

        # Ensure that if Status.Code is not OK, that we set the 'error' tag on the Jaeger span.
        if span.status.canonical_code is not StatusCanonicalCode.OK:
            d['error'] = True
        hny_data.extend(_extract_refs_from_span(span))
        hny_data.extend(_extract_logs_from_span(span))
        hny_data.append(d)
    return hny_data
Exemple #20
0
    def test_trace_response(self):
        orig = get_global_response_propagator()
        set_global_response_propagator(TraceResponsePropagator())

        response = self.client().simulate_get(path="/hello?q=abc")
        headers = response.headers
        span = self.memory_exporter.get_finished_spans()[0]

        self.assertIn("traceresponse", headers)
        self.assertEqual(
            headers["access-control-expose-headers"], "traceresponse",
        )
        self.assertEqual(
            headers["traceresponse"],
            "00-{0}-{1}-01".format(
                format_trace_id(span.get_span_context().trace_id),
                format_span_id(span.get_span_context().span_id),
            ),
        )

        set_global_response_propagator(orig)
Exemple #21
0
    def test_extract_multi_header(self):
        """Test the extraction of B3 headers."""
        propagator = self.get_propagator()
        context = {
            propagator.TRACE_ID_KEY: self.serialized_trace_id,
            propagator.SPAN_ID_KEY: self.serialized_span_id,
            propagator.SAMPLED_KEY: "1",
        }
        child, parent, _ = self.get_child_parent_new_carrier(context)

        self.assertEqual(
            context[propagator.TRACE_ID_KEY],
            trace_api.format_trace_id(child.context.trace_id),
        )

        self.assertEqual(
            context[propagator.SPAN_ID_KEY],
            trace_api.format_span_id(child.parent.span_id),
        )
        self.assertTrue(parent.context.is_remote)
        self.assertTrue(parent.context.trace_flags.sampled)
    def test_websocket_traceresponse_header(self):
        """Test a traceresponse header is set for websocket messages"""

        orig = get_global_response_propagator()
        set_global_response_propagator(TraceResponsePropagator())

        self.scope = {
            "type": "websocket",
            "http_version": "1.1",
            "scheme": "ws",
            "path": "/",
            "query_string": b"",
            "headers": [],
            "client": ("127.0.0.1", 32767),
            "server": ("127.0.0.1", 80),
        }
        app = otel_asgi.OpenTelemetryMiddleware(simple_asgi)
        self.seed_app(app)
        self.send_input({"type": "websocket.connect"})
        self.send_input({"type": "websocket.receive", "text": "ping"})
        self.send_input({"type": "websocket.disconnect"})
        _, socket_send, *_ = self.get_all_output()

        span = self.memory_exporter.get_finished_spans()[-1]
        self.assertEqual(trace_api.SpanKind.SERVER, span.kind)

        traceresponse = "00-{0}-{1}-01".format(
            format_trace_id(span.get_span_context().trace_id),
            format_span_id(span.get_span_context().span_id),
        )
        self.assertListEqual(
            socket_send["headers"],
            [
                [b"traceresponse", f"{traceresponse}".encode()],
                [b"access-control-expose-headers", b"traceresponse"],
            ],
        )

        set_global_response_propagator(orig)
Exemple #23
0
    async def test_trace_response_headers(self):
        response = await self.async_client.get("/span_name/1234/")

        self.assertFalse(response.has_header("Server-Timing"))
        self.memory_exporter.clear()

        set_global_response_propagator(TraceResponsePropagator())

        response = await self.async_client.get("/span_name/1234/")
        span = self.memory_exporter.get_finished_spans()[0]

        self.assertTrue(response.has_header("traceresponse"))
        self.assertEqual(
            response["Access-Control-Expose-Headers"],
            "traceresponse",
        )
        trace_id = format_trace_id(span.get_span_context().trace_id)
        span_id = format_span_id(span.get_span_context().span_id)
        self.assertEqual(
            response["traceresponse"],
            f"00-{trace_id}-{span_id}-01",
        )
        self.memory_exporter.clear()
Exemple #24
0
    async def test_trace_response_headers(self):
        response = await self.async_client.get("/span_name/1234/")

        self.assertNotIn("Server-Timing", response.headers)
        self.memory_exporter.clear()

        set_global_response_propagator(TraceResponsePropagator())

        response = await self.async_client.get("/span_name/1234/")
        span = self.memory_exporter.get_finished_spans()[0]

        self.assertIn("traceresponse", response.headers)
        self.assertEqual(
            response.headers["Access-Control-Expose-Headers"],
            "traceresponse",
        )
        self.assertEqual(
            response.headers["traceresponse"],
            "00-{0}-{1}-01".format(
                format_trace_id(span.get_span_context().trace_id),
                format_span_id(span.get_span_context().span_id),
            ),
        )
        self.memory_exporter.clear()
    def inject(
        self,
        carrier: textmap.CarrierT,
        context: typing.Optional[Context] = None,
        setter: textmap.Setter = textmap.default_setter,
    ) -> None:
        """Injects SpanContext into the carrier.

        See `opentelemetry.propagators.textmap.TextMapPropagator.inject`
        """
        span = trace.get_current_span(context)
        span_context = span.get_span_context()
        if span_context == trace.INVALID_SPAN_CONTEXT:
            return
        traceparent_string = "00-{trace_id}-{span_id}-{:02x}".format(
            span_context.trace_flags,
            trace_id=format_trace_id(span_context.trace_id),
            span_id=format_span_id(span_context.span_id),
        )
        setter.set(carrier, self._TRACEPARENT_HEADER_NAME, traceparent_string)
        if span_context.trace_state:
            tracestate_string = span_context.trace_state.to_header()
            setter.set(carrier, self._TRACESTATE_HEADER_NAME,
                       tracestate_string)
def _extract_logs_from_span(span):
    logs = []

    ctx = span.get_context()
    trace_id = ctx.trace_id
    p_span_id = ctx.span_id
    for event in span.events:
        l = {
            'start_time':
            datetime.datetime.utcfromtimestamp(event.timestamp / float(1e9)),
            'duration_ms':
            0,
            'name':
            event.name,
            'trace.trace_id':
            trace_api.format_trace_id(trace_id)[2:],
            'trace.parent_id':
            trace_api.format_span_id(p_span_id)[2:],
            'meta.span_type':
            'span_event',
        }
        l.update(event.attributes)
        logs.append(l)
    return logs
    def _translate_to_cloud_trace(
        self, spans: Sequence[ReadableSpan]
    ) -> List[Dict[str, Any]]:
        """Translate the spans to Cloud Trace format.

        Args:
            spans: Sequence of spans to convert
        """

        cloud_trace_spans = []

        for span in spans:
            ctx = span.get_span_context()
            trace_id = format_trace_id(ctx.trace_id)
            span_id = format_span_id(ctx.span_id)
            span_name = "projects/{}/traces/{}/spans/{}".format(
                self.project_id, trace_id, span_id
            )

            parent_id = None
            if span.parent:
                parent_id = format_span_id(span.parent.span_id)

            start_time = _get_time_from_ns(span.start_time)
            end_time = _get_time_from_ns(span.end_time)

            if span.attributes and len(span.attributes) > MAX_SPAN_ATTRS:
                logger.warning(
                    "Span has more then %s attributes, some will be truncated",
                    MAX_SPAN_ATTRS,
                )

            # Span does not support a MonitoredResource object. We put the
            # information into attributes instead.
            resources_and_attrs = {
                **(span.attributes or {}),
                **_extract_resources(span.resource),
            }

            cloud_trace_spans.append(
                {
                    "name": span_name,
                    "span_id": span_id,
                    "display_name": _get_truncatable_str_object(
                        span.name, 128
                    ),
                    "start_time": start_time,
                    "end_time": end_time,
                    "parent_span_id": parent_id,
                    "attributes": _extract_attributes(
                        resources_and_attrs,
                        MAX_SPAN_ATTRS,
                        add_agent_attr=True,
                    ),
                    "links": _extract_links(span.links),  # type: ignore[has-type]
                    "status": _extract_status(span.status),  # type: ignore[arg-type]
                    "time_events": _extract_events(span.events),
                    "span_kind": _extract_span_kind(span.kind),
                }
            )
            # TODO: Leverage more of the Cloud Trace API, e.g.
            #  same_process_as_parent_span and child_span_count

        return cloud_trace_spans
 def _encode_trace_id(trace_id: int) -> str:
     return format_trace_id(trace_id)
Exemple #29
0
 def on_start(self, span, parent_context=None):
     trace_id = format_trace_id(span.context.trace_id)
     logger.structure_logs(append=True,
                           traceId=trace_id,
                           correlationId=trace_id)
Exemple #30
0
    def extract(
        self,
        getter: Getter[TextMapPropagatorT],
        carrier: TextMapPropagatorT,
        context: typing.Optional[Context] = None,
    ) -> Context:
        trace_id = format_trace_id(trace.INVALID_TRACE_ID)
        span_id = format_span_id(trace.INVALID_SPAN_ID)
        sampled = "0"
        flags = None

        single_header = _extract_first_element(
            getter.get(carrier, self.SINGLE_HEADER_KEY))
        if single_header:
            # The b3 spec calls for the sampling state to be
            # "deferred", which is unspecified. This concept does not
            # translate to SpanContext, so we set it as recorded.
            sampled = "1"
            fields = single_header.split("-", 4)

            if len(fields) == 1:
                sampled = fields[0]
            elif len(fields) == 2:
                trace_id, span_id = fields
            elif len(fields) == 3:
                trace_id, span_id, sampled = fields
            elif len(fields) == 4:
                trace_id, span_id, sampled, _ = fields
            else:
                return trace.set_span_in_context(trace.INVALID_SPAN)
        else:
            trace_id = (_extract_first_element(
                getter.get(carrier, self.TRACE_ID_KEY)) or trace_id)
            span_id = (_extract_first_element(
                getter.get(carrier, self.SPAN_ID_KEY)) or span_id)
            sampled = (_extract_first_element(
                getter.get(carrier, self.SAMPLED_KEY)) or sampled)
            flags = (_extract_first_element(getter.get(
                carrier, self.FLAGS_KEY)) or flags)

        if (self._trace_id_regex.fullmatch(trace_id) is None
                or self._span_id_regex.fullmatch(span_id) is None):
            id_generator = trace.get_tracer_provider().id_generator
            trace_id = id_generator.generate_trace_id()
            span_id = id_generator.generate_span_id()
            sampled = "0"

        else:
            trace_id = int(trace_id, 16)
            span_id = int(span_id, 16)

        options = 0
        # The b3 spec provides no defined behavior for both sample and
        # flag values set. Since the setting of at least one implies
        # the desire for some form of sampling, propagate if either
        # header is set to allow.
        if sampled in self._SAMPLE_PROPAGATE_VALUES or flags == "1":
            options |= trace.TraceFlags.SAMPLED

        return trace.set_span_in_context(
            trace.NonRecordingSpan(
                trace.SpanContext(
                    # trace an span ids are encoded in hex, so must be converted
                    trace_id=trace_id,
                    span_id=span_id,
                    is_remote=True,
                    trace_flags=trace.TraceFlags(options),
                    trace_state=trace.TraceState(),
                )))