def test_endpoint_can_return_just_body(): # Given this test server: server = TChannel(name='server') @server.json.register def endpoint(request): return {'resp': 'body'} server.listen() # Make a call: tchannel = TChannel(name='client') resp = yield tchannel.json( service='server', endpoint='endpoint', hostport=server.hostport, ) # verify response assert isinstance(resp, Response) assert resp.body == {'resp': 'body'}
def test_routing_delegate_is_propagated_json(): server = TChannel('server') server.listen() @server.json.register('foo') def handler(request): assert request.transport.routing_delegate == 'delegate' return {'success': True} client = TChannel('client', known_peers=[server.hostport]) res = yield client.json('service', 'foo', {}, routing_delegate='delegate') assert res.body == {'success': True}
def test_span_tags(encoding, operation, tracer, thrift_service): server = TChannel('server', tracer=tracer) server.listen() def get_span_baggage(): sp = server.context_provider.get_current_span() baggage = sp.get_baggage_item('bender') if sp else None return {'bender': baggage} @server.json.register('foo') def handler(_): return get_span_baggage() @server.thrift.register(thrift_service.X, method='thrift2') def thrift2(_): return json.dumps(get_span_baggage()) client = TChannel('client', tracer=tracer, trace=True) span = tracer.start_span('root') span.set_baggage_item('bender', 'is great') with span: res = None with client.context_provider.span_in_context(span): if encoding == 'json': res = client.json( service='test-service', # match thrift_service name endpoint='foo', body={}, hostport=server.hostport, ) elif encoding == 'thrift': res = client.thrift( thrift_service.X.thrift2(), hostport=server.hostport, ) else: raise ValueError('Unknown encoding %s' % encoding) res = yield res # cannot yield in StackContext res = res.body if isinstance(res, basestring): res = json.loads(res) assert res == {'bender': 'is great'} for i in range(1000): spans = tracer.reporter.get_spans() if len(spans) == 3: break yield tornado.gen.sleep(0.001) # yield execution and sleep for 1ms spans = tracer.reporter.get_spans() assert len(spans) == 3 trace_ids = set([s.trace_id for s in spans]) assert 1 == len(trace_ids), \ 'all spans must have the same trace_id: %s' % trace_ids
def test_per_request_caller_name_json(): server = TChannel('server') server.listen() @server.json.register('foo') def handler(request): assert request.transport.caller_name == 'bar' return {'success': True} client = TChannel('client', known_peers=[server.hostport]) res = yield client.json('service', 'foo', {}, caller_name='bar') assert res.body == {'success': True}
def test_endpoint_not_found_with_json_request(): server = TChannel(name='server') server.listen() tchannel = TChannel(name='client') with pytest.raises(errors.BadRequestError) as e: yield tchannel.json( service='server', hostport=server.hostport, endpoint='foo', ) assert "Endpoint 'foo' is not defined" in e.value
def test_invalid_headers(): server = TChannel('server') server.listen() client = TChannel('client') with pytest.raises(ValueError) as exc_info: yield client.json( service='foo', endpoint='bar', hostport=server.hostport, headers={'foo': ['bar']}, ) assert 'headers must be a map[string]string' in str(exc_info)
def test_never_choose_ephemeral(): server = TChannel('server') server.listen() @server.json.register('hello') def hello(request): return 'hi' # make a request to set up the connection betweeen the two. client = TChannel('client') yield client.json('server', 'hello', 'world', hostport=server.hostport) assert [client.hostport] == server._dep_tchannel.peers.hosts assert (server._dep_tchannel.peers.choose() is None), ( 'choose() MUST NOT select the ephemeral peer even if that is the only' 'available peer')
def test_never_choose_incoming(): server = TChannel('server') server.listen() client = TChannel('client') client.listen() # client has a non-ephemeral port @server.json.register('hello') def hello(request): return 'hi' # make a request to set up a connection yield client.json('server', 'hello', 'world', hostport=server.hostport) assert [client.hostport] == server._dep_tchannel.peers.hosts assert (server._dep_tchannel.peers.choose() is None), ('server should not know of any peers at this time')
def test_never_choose_ephemeral(): server = TChannel('server') server.listen() @server.json.register('hello') def hello(request): return 'hi' # make a request to set up the connection betweeen the two. client = TChannel('client') yield client.json('server', 'hello', 'world', hostport=server.hostport) assert [client.hostport] == server._dep_tchannel.peers.hosts assert (server._dep_tchannel.peers.choose() is None), ( 'choose() MUST NOT select the ephemeral peer even if that is the only' 'available peer' )
def test_json_server(json_server, sample_json): endpoint = "json_echo" tchannel = TChannel(name='test') header = {'ab': 'bc'} body = sample_json resp = yield tchannel.json( service='endpoint1', hostport=json_server.hostport, endpoint=endpoint, headers=header, body=body, ) # check protocol header assert resp.transport.scheme == JSON # compare header's json assert resp.headers == header # compare body's json assert resp.body == body
def test_json_server(json_server, sample_json): endpoint = "json_echo" tchannel = TChannel(name='test') header = sample_json body = sample_json resp = yield tchannel.json( service='endpoint1', hostport=json_server.hostport, endpoint=endpoint, headers=header, body=body, ) # check protocol header assert resp.transport.scheme == JSON # compare header's json assert resp.headers == header # compare body's json assert resp.body == body
def test_call_should_get_response(): # Given this test server: server = TChannel(name='server') @server.json.register def endpoint(request): assert request.headers == {'req': 'headers'} assert request.body == {'req': 'body'} return Response({'resp': 'body'}, headers={'resp': 'headers'}) server.listen() # Make a call: tchannel = TChannel(name='client') resp = yield tchannel.json( service='server', endpoint='endpoint', headers={'req': 'headers'}, body={'req': 'body'}, hostport=server.hostport, ) # verify response assert isinstance(resp, Response) assert resp.headers == {'resp': 'headers'} assert resp.body == {'resp': 'body'} # verify response transport headers assert isinstance(resp.transport, TransportHeaders) assert resp.transport.scheme == schemes.JSON assert resp.transport.failure_domain is None
def test_span_tags(encoding, operation, tracer, thrift_service): server = TChannel('server', tracer=tracer) server.listen() def get_span_baggage(): sp = server.context_provider.get_current_span() baggage = sp.get_baggage_item('bender') if sp else None return {'bender': baggage} @server.json.register('foo') def handler(_): return get_span_baggage() @server.thrift.register(thrift_service.X, method='thrift2') def thrift2(_): return json.dumps(get_span_baggage()) client = TChannel('client', tracer=tracer, trace=True) span = tracer.start_span('root') span.set_baggage_item('bender', 'is great') with span: res = None with client.context_provider.span_in_context(span): if encoding == 'json': res = client.json( service='test-service', # match thrift_service name endpoint='foo', body={}, hostport=server.hostport, ) elif encoding == 'thrift': res = client.thrift( thrift_service.X.thrift2(), hostport=server.hostport, ) else: raise ValueError('Unknown encoding %s' % encoding) res = yield res # cannot yield in StackContext res = res.body if isinstance(res, basestring): res = json.loads(res) assert res == {'bender': 'is great'} for i in range(1000): spans = tracer.reporter.get_spans() if len(spans) == 3: break yield tornado.gen.sleep(0.001) # yield execution and sleep for 1ms spans = tracer.reporter.get_spans() assert len(spans) == 3 assert 1 == len(set([s.trace_id for s in spans])), \ 'all spans must have the same trace_id' parent = child = None for s in spans: if s.tags is None: continue print('tags %s' % s.tags) # replace list with dictionary s.tags = {tag.key: tag.value for tag in s.tags} if s.kind == tags.SPAN_KIND_RPC_SERVER: child = s elif s.kind == tags.SPAN_KIND_RPC_CLIENT: parent = s assert parent is not None assert child is not None assert parent.operation_name == operation assert child.operation_name == operation assert parent.peer['service_name'] == 'test-service' assert child.peer['service_name'] == 'client' assert parent.peer['ipv4'] is not None assert child.peer['ipv4'] is not None assert parent.tags.get('as') == encoding assert child.tags.get('as') == encoding
def test_span_tags(encoding, operation, tracer, thrift_service): server = TChannel('server', tracer=tracer) server.listen() def get_span_baggage(): sp = server.context_provider.get_current_span() baggage = sp.get_baggage_item('bender') if sp else None return {'bender': baggage} @server.json.register('foo') def handler(_): return get_span_baggage() @server.thrift.register(thrift_service.X, method='thrift2') def thrift2(_): return json.dumps(get_span_baggage()) client = TChannel('client', tracer=tracer, trace=True) span = tracer.start_span('root') span.set_baggage_item('bender', 'is great') with span: res = None with client.context_provider.span_in_context(span): if encoding == 'json': res = client.json( service='test-service', # match thrift_service name endpoint='foo', body={}, hostport=server.hostport, ) elif encoding == 'thrift': res = client.thrift( thrift_service.X.thrift2(), hostport=server.hostport, ) else: raise ValueError('Unknown encoding %s' % encoding) res = yield res # cannot yield in StackContext res = res.body if isinstance(res, basestring): res = json.loads(res) assert res == {'bender': 'is great'} for i in range(1000): spans = tracer.reporter.get_spans() if len(spans) == 3: break yield tornado.gen.sleep(0.001) # yield execution and sleep for 1ms spans = tracer.reporter.get_spans() assert len(spans) == 3 trace_ids = set([s.trace_id for s in spans]) assert 1 == len(trace_ids), \ 'all spans must have the same trace_id: %s' % trace_ids span_ids = set([s.span_id for s in spans]) assert 2 == len(span_ids), \ 'must have two unique span IDs, root span and RPC span: %s' % span_ids parent = child = None for s in spans: if s.tags is None: continue print('tags %s' % s.tags) # replace list with dictionary s.tags = {tag.key: tag.value for tag in s.tags} if s.kind == tags.SPAN_KIND_RPC_SERVER: child = s elif s.kind == tags.SPAN_KIND_RPC_CLIENT: parent = s assert parent is not None assert child is not None assert parent.operation_name == operation assert child.operation_name == operation assert parent.peer['service_name'] == 'test-service' assert child.peer['service_name'] == 'client' assert parent.peer['ipv4'] is not None assert child.peer['ipv4'] is not None assert parent.tags.get('as') == encoding assert child.tags.get('as') == encoding
def test_forwarding(tmpdir): from tornado import gen path = tmpdir.join('keyvalue.thrift') path.write(''' exception ItemDoesNotExist { 1: optional string key } service KeyValue { string getItem(1: string key) throws (1: ItemDoesNotExist doesNotExist) } ''') kv = thrift.load(str(path), service='keyvalue') real_server = TChannel(name='real_server') real_server.listen() items = {} @real_server.thrift.register(kv.KeyValue) def getItem(request): assert request.service == 'keyvalue' key = request.body.key if key in items: assert request.headers == {'expect': 'success'} return items[key] else: assert request.headers == {'expect': 'failure'} raise kv.ItemDoesNotExist(key) @real_server.json.register('putItem') def json_put_item(request): assert request.service == 'keyvalue' assert request.timeout == 0.5 key = request.body['key'] value = request.body['value'] items[key] = value return {'success': True} proxy_server = TChannel(name='proxy_server') proxy_server.listen() # The client that the proxy uses to make requests should be a different # TChannel. That's because TChannel treats all peers (incoming and # outgoing) as the same. So, if the server receives a request and then # uses the same channel to make the request, there's a chance that it gets # forwarded back to the peer that originally made the request. # # This is desirable behavior because we do want to treat all Hyperbahn # nodes as equal. proxy_server_client = TChannel( name='proxy-client', known_peers=[real_server.hostport], ) @proxy_server.register(TChannel.FALLBACK) @gen.coroutine def handler(request): response = yield proxy_server_client.call( scheme=request.transport.scheme, service=request.service, arg1=request.endpoint, arg2=request.headers, arg3=request.body, timeout=request.timeout / 2, retry_on=request.transport.retry_flags, retry_limit=0, shard_key=request.transport.shard_key, routing_delegate=request.transport.routing_delegate, ) raise gen.Return(response) client = TChannel(name='client', known_peers=[proxy_server.hostport]) with pytest.raises(kv.ItemDoesNotExist): response = yield client.thrift( kv.KeyValue.getItem('foo'), headers={'expect': 'failure'}, ) json_response = yield client.json('keyvalue', 'putItem', { 'key': 'hello', 'value': 'world', }, timeout=1.0) assert json_response.body == {'success': True} response = yield client.thrift( kv.KeyValue.getItem('hello'), headers={'expect': 'success'}, ) assert response.body == 'world'
def test_trace_propagation(endpoint, transport, encoding, enabled, expect_spans, expect_baggage, http_patchers, tracer, mock_server, thrift_service, app, http_server, base_url, http_client): """ Main TChannel-OpenTracing integration test, using basictracer as implementation of OpenTracing API. The main logic of this test is as follows: 1. Start a new trace with a root span 2. Store a random value in the baggage 3. Call the first service at the endpoint from `endpoint` parameter. The first service is either tchannel or http, depending on the value if `transport` parameter. 4. The first service calls the second service using pre-defined logic that depends on the endpoint invoked on the first service. 5. The second service accesses the tracing span and returns the value of the baggage item as the response. 6. The first service responds with the value from the second service. 7. The main test validates that the response is equal to the original random value of the baggage, proving trace & baggage propagation. 8. The test also validates that all spans have been finished and recorded, and that they all have the same trace ID. We expect 5 spans to be created from each test run: * top-level (root) span started in the test * client span (calling service-1) * service-1 server span * service-1 client span (calling service-2) * service-2 server span :param endpoint: name of the endpoint to call on the first service :param transport: type of the first service: tchannel or http :param enabled: if False, channels are instructed to disable tracing :param expect_spans: number of spans we expect to be generated :param http_patchers: monkey-patching of tornado AsyncHTTPClient :param tracer: a concrete implementation of OpenTracing Tracer :param mock_server: tchannel server (from conftest.py) :param thrift_service: fixture that creates a Thrift service from fake IDL :param app: tornado.web.Application fixture :param http_server: http server (provided by pytest-tornado) :param base_url: address of http server (provided by pytest-tornado) :param http_client: Tornado's AsyncHTTPClient (provided by pytest-tornado) """ # mock_server is created as a fixture, so we need to set tracer on it mock_server.tchannel._dep_tchannel._tracer = tracer mock_server.tchannel._dep_tchannel._trace = enabled register(tchannel=mock_server.tchannel, thrift_service=thrift_service, http_client=http_client, base_url=base_url) tchannel = TChannel(name='test', tracer=tracer, trace=enabled) app.add_handlers(".*$", [(r"/", HttpHandler, { 'client_channel': tchannel })]) with mock.patch('opentracing.tracer', tracer),\ mock.patch.object(tracing.log, 'exception') as log_exception: assert opentracing.tracer == tracer # sanity check that patch worked span = tracer.start_span('root') baggage = 'from handler3 %d' % time.time() span.set_baggage_item(BAGGAGE_KEY, baggage) if not enabled: span.set_tag('sampling.priority', 0) with span: # use span as context manager so that it's always finished response_future = None with tchannel.context_provider.span_in_context(span): if transport == 'tchannel': if encoding == 'json': response_future = tchannel.json( service='test-client', endpoint=endpoint, hostport=mock_server.hostport, body=mock_server.hostport, ) elif encoding == 'thrift': if endpoint == 'thrift1': response_future = tchannel.thrift( thrift_service.X.thrift1(mock_server.hostport), hostport=mock_server.hostport, ) elif endpoint == 'thrift3': response_future = tchannel.thrift( thrift_service.X.thrift3(mock_server.hostport), hostport=mock_server.hostport, ) elif endpoint == 'thrift4': response_future = tchannel.thrift( thrift_service.X.thrift4(mock_server.hostport), hostport=mock_server.hostport, ) else: raise ValueError('wrong endpoint %s' % endpoint) else: raise ValueError('wrong encoding %s' % encoding) elif transport == 'http': response_future = http_client.fetch(request=HTTPRequest( url='%s%s' % (base_url, endpoint), method='POST', body=mock_server.hostport, )) else: raise NotImplementedError('unknown transport %s' % transport) response = yield response_future assert log_exception.call_count == 0 body = response.body if expect_baggage: assert body == baggage def get_sampled_spans(): return [s for s in tracer.reporter.get_spans() if s.is_sampled] # Sometimes the test runs into weird race condition where the # after_send_response() hook is executed, but the span is not yet # recorded. To prevent flaky test runs we check and wait until # all spans are recorded, for up to 1 second. for i in range(0, 1000): spans = get_sampled_spans() if len(spans) >= expect_spans: break yield tornado.gen.sleep(0.001) # yield execution and sleep for 1ms spans = get_sampled_spans() assert expect_spans == len(spans), 'Unexpected number of spans reported' # We expect all trace IDs in collected spans to be the same if expect_spans > 0: spans = tracer.reporter.get_spans() assert 1 == len(set([s.trace_id for s in spans])), \ 'all spans must have the same trace_id'
def test_trace_propagation( endpoint, transport, encoding, enabled, expect_spans, expect_baggage, http_patchers, tracer, mock_server, thrift_service, app, http_server, base_url, http_client): """ Main TChannel-OpenTracing integration test, using basictracer as implementation of OpenTracing API. The main logic of this test is as follows: 1. Start a new trace with a root span 2. Store a random value in the baggage 3. Call the first service at the endpoint from `endpoint` parameter. The first service is either tchannel or http, depending on the value if `transport` parameter. 4. The first service calls the second service using pre-defined logic that depends on the endpoint invoked on the first service. 5. The second service accesses the tracing span and returns the value of the baggage item as the response. 6. The first service responds with the value from the second service. 7. The main test validates that the response is equal to the original random value of the baggage, proving trace & baggage propagation. 8. The test also validates that all spans have been finished and recorded, and that they all have the same trace ID. We expect 5 spans to be created from each test run: * top-level (root) span started in the test * client span (calling service-1) * service-1 server span * service-1 client span (calling service-2) * service-2 server span :param endpoint: name of the endpoint to call on the first service :param transport: type of the first service: tchannel or http :param enabled: if False, channels are instructed to disable tracing :param expect_spans: number of spans we expect to be generated :param http_patchers: monkey-patching of tornado AsyncHTTPClient :param tracer: a concrete implementation of OpenTracing Tracer :param mock_server: tchannel server (from conftest.py) :param thrift_service: fixture that creates a Thrift service from fake IDL :param app: tornado.web.Application fixture :param http_server: http server (provided by pytest-tornado) :param base_url: address of http server (provided by pytest-tornado) :param http_client: Tornado's AsyncHTTPClient (provided by pytest-tornado) """ # mock_server is created as a fixture, so we need to set tracer on it mock_server.tchannel._dep_tchannel._tracer = tracer mock_server.tchannel._dep_tchannel._trace = enabled register(tchannel=mock_server.tchannel, thrift_service=thrift_service, http_client=http_client, base_url=base_url) tchannel = TChannel(name='test', tracer=tracer, trace=enabled) app.add_handlers(".*$", [ (r"/", HttpHandler, {'client_channel': tchannel}) ]) with mock.patch('opentracing.tracer', tracer),\ mock.patch.object(tracing.log, 'exception') as log_exception: assert opentracing.tracer == tracer # sanity check that patch worked span = tracer.start_span('root') baggage = 'from handler3 %d' % time.time() span.set_baggage_item(BAGGAGE_KEY, baggage) if not enabled: span.set_tag('sampling.priority', 0) with span: # use span as context manager so that it's always finished response_future = None with tchannel.context_provider.span_in_context(span): if transport == 'tchannel': if encoding == 'json': response_future = tchannel.json( service='test-client', endpoint=endpoint, hostport=mock_server.hostport, body=mock_server.hostport, ) elif encoding == 'thrift': if endpoint == 'thrift1': response_future = tchannel.thrift( thrift_service.X.thrift1(mock_server.hostport), hostport=mock_server.hostport, ) elif endpoint == 'thrift3': response_future = tchannel.thrift( thrift_service.X.thrift3(mock_server.hostport), hostport=mock_server.hostport, ) elif endpoint == 'thrift4': response_future = tchannel.thrift( thrift_service.X.thrift4(mock_server.hostport), hostport=mock_server.hostport, ) else: raise ValueError('wrong endpoint %s' % endpoint) else: raise ValueError('wrong encoding %s' % encoding) elif transport == 'http': response_future = http_client.fetch( request=HTTPRequest( url='%s%s' % (base_url, endpoint), method='POST', body=mock_server.hostport, ) ) else: raise NotImplementedError( 'unknown transport %s' % transport) response = yield response_future assert log_exception.call_count == 0 body = response.body if expect_baggage: assert body == baggage def get_sampled_spans(): return [s for s in tracer.reporter.get_spans() if s.is_sampled] # Sometimes the test runs into weird race condition where the # after_send_response() hook is executed, but the span is not yet # recorded. To prevent flaky test runs we check and wait until # all spans are recorded, for up to 1 second. for i in range(0, 1000): spans = get_sampled_spans() if len(spans) >= expect_spans: break yield tornado.gen.sleep(0.001) # yield execution and sleep for 1ms spans = get_sampled_spans() assert expect_spans == len(spans), 'Unexpected number of spans reported' # We expect all trace IDs in collected spans to be the same if expect_spans > 0: spans = tracer.reporter.get_spans() assert 1 == len(set([s.trace_id for s in spans])), \ 'all spans must have the same trace_id'