def _get_trace_info(self, headers): extracted_values = TraceInfo.extract_upstream_header_values( TRACE_HEADER_NAMES, headers) sampled = bool(extracted_values.get("sampled") == "1") flags = extracted_values.get("flags", None) return TraceInfo.from_upstream( int(extracted_values["trace_id"]), int(extracted_values["parent_span_id"]), int(extracted_values["span_id"]), sampled, int(flags) if flags is not None else None, )
def baseplate_thrift_client(endpoint, client_span_observer=None): baseplate = Baseplate() if client_span_observer: class TestServerSpanObserver(ServerSpanObserver): def on_child_span_created(self, span): span.register(client_span_observer) observer = TestServerSpanObserver() class TestBaseplateObserver(BaseplateObserver): def on_server_span_created(self, context, span): span.register(observer) baseplate.register(TestBaseplateObserver()) context = baseplate.make_context_object() trace_info = TraceInfo.from_upstream( trace_id=1234, parent_id=2345, span_id=3456, flags=4567, sampled=True ) baseplate.configure_context( {"example_service.endpoint": str(endpoint)}, {"example_service": ThriftClient(TestService.Client)}, ) baseplate.make_server_span(context, "example_service.example", trace_info) edge_context_factory = make_edge_context_factory() edge_context = edge_context_factory.from_upstream(SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH) edge_context.attach_context(context) yield context
def test_null_server_observer(self): baseplate = Baseplate() mock_context = baseplate.make_context_object() mock_observer = mock.Mock(spec=BaseplateObserver) mock_observer.on_server_span_created.return_value = None baseplate.register(mock_observer) server_span = baseplate.make_server_span(mock_context, "name", TraceInfo(1, 2, 3, None, 0)) self.assertEqual(server_span.observers, [])
def test_null_root_observer(self): mock_context = mock.Mock() mock_observer = mock.Mock(spec=BaseplateObserver) mock_observer.on_root_span_created.return_value = None baseplate = Baseplate() baseplate.register(mock_observer) root_span = baseplate.make_root_span(mock_context, "name", TraceInfo(1, 2, 3)) self.assertEqual(root_span.observers, [])
def test_server_observer_made(self): mock_context = mock.Mock() mock_observer = mock.Mock(spec=BaseplateObserver) baseplate = Baseplate() baseplate.register(mock_observer) server_span = baseplate.make_server_span(mock_context, "name", TraceInfo(1, 2, 3, None, 0)) self.assertEqual(baseplate.observers, [mock_observer]) self.assertEqual(mock_observer.on_server_span_created.call_count, 1) self.assertEqual(mock_observer.on_server_span_created.call_args, mock.call(mock_context, server_span))
def test_from_upstream_fails_on_invalid_sampled(self): with self.assertRaises(ValueError) as e: TraceInfo.from_upstream(1, 2, 3, 'True', None) self.assertEqual(str(e.exception), "invalid sampled value")
def test_new_does_not_set_sampled(self): new_trace_info = TraceInfo.new() self.assertIsNone(new_trace_info.sampled)
def test_new_does_not_set_flags(self): new_trace_info = TraceInfo.new() self.assertIsNone(new_trace_info.flags)
def test_new_does_not_have_parent_id(self): new_trace_info = TraceInfo.new() self.assertIsNone(new_trace_info.parent_id)
def test_from_upstream_handles_no_sampled_or_flags(self): span = TraceInfo.from_upstream(1, 2, 3, None, None) self.assertIsNone(span.sampled) self.assertIsNone(span.flags)
def test_from_upstream_fails_on_invalid_flags(self): with self.assertRaises(ValueError) as e: TraceInfo.from_upstream(1, 2, 3, True, -1) self.assertEqual(str(e.exception), "invalid flags value")