def test_multiple_registrations(self): service_registration_count = 10 request_received_count = [0] topic = UuidGenerator.generate_id_as_string() with self.create_client() as service_client: service_client.connect() def my_request(request): request_received_count[0] += 1 response = Response(request) service_client.send_response(response) reg_info = ServiceRegistrationInfo(service_client, "multiple_registrations_test") callback = RequestCallback() callback.on_request = my_request reg_info.add_topic(topic, callback) with self.create_client() as request_client: request_client.connect() for _ in range(0, service_registration_count): service_client.register_service_sync(reg_info, self.DEFAULT_TIMEOUT) request = Request(topic) response = request_client.sync_request( request, timeout=self.RESPONSE_WAIT) self.assertNotIsInstance(response, ErrorResponse) self.assertEqual(request.message_id, response.request_message_id) service_client.unregister_service_sync(reg_info, self.DEFAULT_TIMEOUT) self.assertEqual(service_registration_count, request_received_count[0])
def test_client_register_service_subscribes_client_to_channel(self): channel = '/mcafee/service/unittest' # Create dummy service service_info = dxlclient.service.ServiceRegistrationInfo( service_type='/mcafee/service/unittest', client=self.client) # Add topics to the service service_info.add_topic(channel + "1", RequestCallback()) service_info.add_topic(channel + "2", RequestCallback()) service_info.add_topics( {channel + str(i): RequestCallback() for i in range(3, 6)}) subscriptions_before_registration = self.client.subscriptions expected_subscriptions_after_registration = \ sorted(subscriptions_before_registration + tuple(channel + str(i) for i in range(1, 6))) # Register service in client self.client.register_service_async(service_info) # Check subscribed channels subscriptions_after_registration = self.client.subscriptions self.assertEqual(expected_subscriptions_after_registration, sorted(subscriptions_after_registration))
def test_client_register_service_unsuscribes_from_channel_by_guid(self): channel1 = '/mcafee/service/unittest/one' channel2 = '/mcafee/service/unittest/two' # Create dummy service service_info = dxlclient.service.ServiceRegistrationInfo( service_type='/mcafee/service/unittest', client=self.client) service_info.add_topic(channel1, RequestCallback()) service_info.add_topic(channel2, RequestCallback()) # Create same dummy service - different object service_info2 = service_info = dxlclient.service.ServiceRegistrationInfo( service_type='/mcafee/service/unittest', client=self.client) service_info._service_id = service_info.service_id service_info.add_topic(channel1, RequestCallback()) service_info.add_topic(channel2, RequestCallback()) # Register service in client self.client.register_service_async(service_info) # Check subscribed channels subscriptions = self.client.subscriptions assert channel1 in subscriptions, "Client wasn't subscribed to service channel" assert channel2 in subscriptions, "Client wasn't subscribed to service channel" self.client.unregister_service_async(service_info2) subscriptions = self.client.subscriptions assert channel1 not in subscriptions, "Client wasn't unsubscribed to service channel" assert channel2 not in subscriptions, "Client wasn't unsubscribed to service channel"
def add_client_callbacks(self, client, on_client_request_callback=None): request_callback = RequestCallback() def on_request(request): logging.info(request.destination_topic) logging.info(request.payload) if on_client_request_callback: on_client_request_callback() response = Response(request) response.payload = "Ok" try: client.send_response(response) except DxlException as ex: print("Failed to send response" + str(ex)) request_callback.on_request = on_request self.info = ServiceRegistrationInfo(client, "/mcafee/service/JTI") self.info_registered = False self.info_registrations = 0 self.info_unregistrations = 0 service_id = self.info.service_id self.info.add_topic( "/mcafee/service/JTI/file/reputation/" + service_id, request_callback) self.info.add_topic( "/mcafee/service/JTI/cert/reputation/" + service_id, request_callback) def is_event_for_service(event): return json.loads(event.payload.decode("utf8").rstrip( "\0"))["serviceGuid"] == service_id class ServiceRegisteredCallback(EventCallback): def __init__(self, test): self.test = test super(ServiceRegisteredCallback, self).__init__() def on_event(self, event): if is_event_for_service(event): self.test.info_registrations += 1 self.test.info_registered = True class ServiceUnregisteredCallback(EventCallback): def __init__(self, test): self.test = test super(ServiceUnregisteredCallback, self).__init__() def on_event(self, event): if is_event_for_service(event): self.test.info_unregistrations += 1 self.test.info_registered = False client.add_event_callback(_ServiceManager.DXL_SERVICE_REGISTER_CHANNEL, ServiceRegisteredCallback(self)) client.add_event_callback( _ServiceManager.DXL_SERVICE_UNREGISTER_CHANNEL, ServiceUnregisteredCallback(self))
def test_client_handle_message_with_request_calls_request_callback(self): req_callback = RequestCallback() req_callback.on_request = Mock() self.client.add_request_callback(self.test_channel, req_callback) # Create and process Request req = Request(destination_topic=self.test_channel)._to_bytes() self.client._handle_message(self.test_channel, req) # Check that callback was called self.assertEqual(req_callback.on_request.call_count, 1)
def create_service_reg_info(): reg_info = ServiceRegistrationInfo( service_client, "registry_specified_service_id_test") callback = RequestCallback() callback.on_request = \ lambda request: my_request(reg_info.service_id, request) reg_info.add_topic(topic, callback) service_client.register_service_sync(reg_info, self.DEFAULT_TIMEOUT) return reg_info
def test_response_service_not_found_no_service_id_at_client(self): request_received = [False] topic = UuidGenerator.generate_id_as_string() with self.create_client() as service_client: service_client.connect() def my_request(request): request_received[0] = True service_client.send_response(Response(request)) reg_info = ServiceRegistrationInfo( service_client, "response_service_not_found_no_service_id_at_client_test") callback = RequestCallback() callback.on_request = my_request reg_info.add_topic(topic, callback) reg_info.add_topic(topic, callback) service_client.register_service_sync(reg_info, self.DEFAULT_TIMEOUT) self.assertIsNotNone( self.query_service_registry_by_service( service_client, reg_info)) with self.create_client() as request_client: request_client.connect() # Remove the service's registration with the client-side # ServiceManager, avoiding unregistration of the service from # the broker. This should allow the broker to forward the # request on to the service client. registered_services = service_client._service_manager.services service = registered_services[reg_info.service_id] del registered_services[reg_info.service_id] request = Request(topic) response = request_client.sync_request( request, timeout=self.RESPONSE_WAIT) # Re-register the service with the internal ServiceManager so # that its resources (TTL timeout, etc.) can be cleaned up # properly at shutdown. registered_services[reg_info.service_id] = service # The request should receive an 'unavailable service' error # response because the service client should be unable to route # the request to an internally registered service. self.assertFalse(request_received[0]) self.assertIsInstance(response, ErrorResponse) self.assertEqual(reg_info.service_id, response.service_id) self.assertEqual( self.DXL_SERVICE_UNAVAILABLE_ERROR_CODE, BrokerServiceRegistryTest.normalized_error_code(response)) self.assertEqual(self.DXL_SERVICE_UNAVAILABLE_ERROR_MESSAGE, response.error_message) self.assertIsNone(self.query_service_registry_by_service( service_client, reg_info))
def register_test_service(self, client, service_type=None): topic = "broker_service_registry_test_service_" + \ UuidGenerator.generate_id_as_string() reg_info = ServiceRegistrationInfo( client, service_type or "broker_service_registry_test_service_" + UuidGenerator.generate_id_as_string()) callback = RequestCallback() callback.on_request = \ lambda request: client.send_response(Response(request)) reg_info.add_topic(topic, callback) client.register_service_sync(reg_info, self.DEFAULT_TIMEOUT) return reg_info
def test_execute_message_payload(self): # Create a server that handles a request, unpacks the payload, and # asserts that the information in the payload was delivered successfully. with self.create_client(max_retries=0) as service_client: service_client.connect() topic = UuidGenerator.generate_id_as_string() reg_info = ServiceRegistrationInfo( service_client, "message_payload_runner_service") # callback definition def on_request(request): with self.request_complete_condition: try: self.request_received = request except Exception as ex: # pylint: disable=broad-except print(ex) self.request_complete_condition.notify_all() request_callback = RequestCallback() request_callback.on_request = on_request reg_info.add_topic(topic, request_callback) # Register the service service_client.register_service_sync(reg_info, self.DEFAULT_TIMEOUT) with self.create_client() as request_client: request_client.connect() packer = msgpack.Packer() # Send a request to the server with information contained # in the payload request = Request(destination_topic=topic) request.payload = packer.pack(self.TEST_STRING) request.payload += packer.pack(self.TEST_BYTE) request.payload += packer.pack(self.TEST_INT) request_client.async_request(request, request_callback) start = time.time() # Wait until the request has been processed with self.request_complete_condition: while (time.time() - start < self.MAX_WAIT) and \ not self.request_received: self.request_complete_condition.wait(self.MAX_WAIT) self.assertIsNotNone(self.request_received) unpacker = msgpack.Unpacker(file_like=BytesIO(request.payload)) self.assertEqual( next(unpacker).decode('utf8'), self.TEST_STRING) self.assertEqual(next(unpacker), self.TEST_BYTE) self.assertEqual(next(unpacker), self.TEST_INT)
def test_execute_message_payload(self): # Create a server that handles a request, unpacks the payload, and # asserts that the information in the payload was delivered successfully. with self.create_client(max_retries=0) as server: test_service = TestService(server, 1) server.connect() topic = UuidGenerator.generate_id_as_string() reg_info = ServiceRegistrationInfo( server, "message_payload_runner_service") # callback definition def on_request(request): unpacker = Unpacker( file_like=StringIO.StringIO(request.payload)) with self.request_complete_condition: try: self.assertEquals(unpacker.next(), self.TEST_STRING) self.assertEquals(unpacker.next(), self.TEST_BYTE) self.assertEquals(unpacker.next(), self.TEST_INT) self.received_request = True except Exception, e: print e.message self.request_complete_condition.notify_all() request_callback = RequestCallback() request_callback.on_request = on_request reg_info.add_topic(topic, request_callback) # Register the service server.register_service_sync(reg_info, self.DEFAULT_TIMEOUT) with self.create_client() as client: client.connect() packer = Packer() # Send a request to the server with information contained # in the payload request = Request(destination_topic=topic) request.payload = packer.pack(self.TEST_STRING) request.payload += packer.pack(self.TEST_BYTE) request.payload += packer.pack(self.TEST_INT) client.async_request(request, request_callback) with self.request_complete_condition: if not self.received_request: # Wait until the request has been processed self.request_complete_condition.wait(self.MAX_WAIT) if not self.received_request: self.fail("Request not received.")
def test_client_register_service_subscribes_client_to_channel(self): channel1 = '/mcafee/service/unittest/one' channel2 = '/mcafee/service/unittest/two' # Create dummy service service_info = dxlclient.service.ServiceRegistrationInfo( service_type='/mcafee/service/unittest', client=self.client) service_info.add_topic(channel1, RequestCallback()) service_info.add_topic(channel2, RequestCallback()) # Register service in client self.client.register_service_async(service_info) # Check subscribed channels subscriptions = self.client.subscriptions assert channel1 in subscriptions, "Client wasn't subscribed to service channel" assert channel2 in subscriptions, "Client wasn't subscribed to service channel"
def setUp(self): mqtt_client_patch = patch('pahoproxy.client.Client') mqtt_client_patch.start() self.config = DxlClientConfig(broker_ca_bundle=get_ca_bundle_pem(), cert_file=get_cert_file_pem(), private_key=get_dxl_private_key(), brokers=[]) self.req_callback = RequestCallback() self.req_callback.on_request = Mock()
def add_client_callbacks(self, client): self.request_callback = RequestCallback() def on_request(request): logging.info(request.destination_topic) logging.info(request.payload) response = Response(request) response.payload = bytes("Ok") try: client.send_response(response) except DxlException, ex: print "Failed to send response" + str(ex)
def test_response_service_not_found_no_channel(self): request_received = [False] topic = UuidGenerator.generate_id_as_string() with self.create_client() as service_client: service_client.connect() def my_request(request): request_received[0] = True service_client.send_response(Response(request)) reg_info = ServiceRegistrationInfo( service_client, "response_service_not_found_no_channel_test") callback = RequestCallback() callback.on_request = my_request reg_info.add_topic(topic, callback) service_client.register_service_sync(reg_info, self.DEFAULT_TIMEOUT) service_client.unsubscribe(topic) self.assertIsNotNone( self.query_service_registry_by_service( service_client, reg_info)) with self.create_client() as request_client: request_client.connect() request = Request(topic) response = request_client.sync_request( request, timeout=self.RESPONSE_WAIT) self.assertFalse(request_received[0]) self.assertIsInstance(response, ErrorResponse) self.assertEqual(reg_info.service_id, response.service_id) self.assertEqual( self.DXL_SERVICE_UNAVAILABLE_ERROR_CODE, BrokerServiceRegistryTest.normalized_error_code(response)) self.assertEqual(self.DXL_SERVICE_UNAVAILABLE_ERROR_MESSAGE, response.error_message) self.assertIsNone(self.query_service_registry_by_service( service_client, reg_info))
def test_multiple_services(self): with self.create_client() as service_client: service_client.connect() reg_info_topic_1 = "multiple_services_test_1_" + \ UuidGenerator.generate_id_as_string() reg_info_1 = ServiceRegistrationInfo( service_client, "multiple_services_test_1") def reg_info_request_1(request): response = Response(request) response.payload = "service1" service_client.send_response(response) reg_info_callback_1 = RequestCallback() reg_info_callback_1.on_request = reg_info_request_1 reg_info_1.add_topic(reg_info_topic_1, reg_info_callback_1) service_client.register_service_sync(reg_info_1, self.DEFAULT_TIMEOUT) reg_info_topic_2 = "multiple_services_test_2_" + \ UuidGenerator.generate_id_as_string() reg_info_2 = ServiceRegistrationInfo( service_client, "multiple_services_test_2") def reg_info_request_2(request): response = Response(request) response.payload = "service2" service_client.send_response(response) reg_info_callback_2 = RequestCallback() reg_info_callback_2.on_request = reg_info_request_2 reg_info_2.add_topic(reg_info_topic_2, reg_info_callback_2) service_client.register_service_sync(reg_info_2, self.DEFAULT_TIMEOUT) with self.create_client() as request_client: request_client.connect() response = request_client.sync_request( Request(reg_info_topic_1), self.DEFAULT_TIMEOUT) self.assertIsInstance(response, Response) self.assertEqual(response.payload.decode("utf8"), "service1") response = request_client.sync_request( Request(reg_info_topic_2), self.DEFAULT_TIMEOUT) self.assertIsInstance(response, Response) self.assertEqual(response.payload.decode("utf8"), "service2")
def test_async_flood(self): channel = UuidGenerator.generate_id_as_string() with self.create_client() as client: self.m_info = ServiceRegistrationInfo(client, channel) client.connect() client.subscribe(channel) def my_request_callback(request): try: time.sleep(0.05) resp = Response(request) resp.payload = request.payload client.send_response(resp) except Exception as ex: # pylint: disable=broad-except print(ex) req_callback = RequestCallback() req_callback.on_request = my_request_callback self.m_info.add_topic(channel, req_callback) client.register_service_sync(self.m_info, 10) with self.create_client() as client2: client2.connect() def my_response_callback(response): if response.message_type == Message.MESSAGE_TYPE_ERROR: print("Received error response: " + response._error_response) with self.resp_condition: self.error_count += 1 self.resp_condition.notify_all() else: with self.resp_condition: if self.response_count % 10 == 0: print("Received request " + str(self.response_count)) self.response_count += 1 self.resp_condition.notify_all() callback = ResponseCallback() callback.on_response = my_response_callback client2.add_response_callback("", callback) for i in range(0, self.REQUEST_COUNT): if i % 100 == 0: print("Sent: " + str(i)) request = Request(channel) request.payload = str(i) client2.async_request(request) if self.error_count > 0: break # Wait for all responses, an error to occur, or we timeout start_time = time.time() with self.resp_condition: while (self.response_count != self.REQUEST_COUNT) and ( self.error_count == 0) and ( time.time() - start_time < self.WAIT_TIME): self.resp_condition.wait(5) if self.error_count != 0: raise Exception("Received an error response!") self.assertEqual(self.REQUEST_COUNT, self.response_count, "Did not receive all messages!")
def test_register_service_call_from_request_callback(self): # While in the request callback for a service invocation, attempt to # register a second service. Confirm that a call to the second service # is successful. This test ensures that concurrent add/remove service # calls and processing of incoming messages do not produce deadlocks. with self.create_client(self.DEFAULT_RETRIES, 2) as client: expected_second_service_response_payload = \ "Second service request okay too" second_service_callback = RequestCallback() def on_second_service_request(request): response = Response(request) response.payload = expected_second_service_response_payload try: client.send_response(response) except DxlException as ex: print("Failed to send response" + str(ex)) second_service_callback.on_request = on_second_service_request second_service_info = ServiceRegistrationInfo( client, "/mcafee/service/JTI2") second_service_info.add_topic( "/mcafee/service/JTI2/file/reputation/" + second_service_info.service_id, second_service_callback) def register_second_service(): client.register_service_sync(second_service_info, self.REG_DELAY) register_second_service_thread = threading.Thread( target=register_second_service) register_second_service_thread.daemon = True # Perform the second service registration from a separate thread # in order to ensure that locks taken by the callback and # service managers do not produce deadlocks between the # thread from which the service registration request is made and # any threads on which response messages are received from the # broker. def on_first_service_request(): register_second_service_thread.start() register_second_service_thread.join() self.add_client_callbacks(client, on_first_service_request) client.connect() client.register_service_sync(self.info, self.REG_DELAY) first_service_request = Request( "/mcafee/service/JTI/file/reputation/" + self.info.service_id) first_service_request.payload = "Test" first_service_response = client.sync_request( first_service_request, self.POST_OP_DELAY) first_service_response_payload = first_service_response. \ payload.decode("utf8") logging.info("First service response payload: %s", first_service_response_payload) self.assertEqual("Ok", first_service_response_payload) second_service_request = Request( "/mcafee/service/JTI2/file/reputation/" + second_service_info.service_id) second_service_request.payload = "Test" second_service_response = client.sync_request( second_service_request, self.POST_OP_DELAY) actual_second_service_response_payload = second_service_response. \ payload.decode("utf8") logging.info("Second service response payload: %s", actual_second_service_response_payload) self.assertEqual(expected_second_service_response_payload, actual_second_service_response_payload)
def test_wildcard_services(self): max_wait = 10 with self.create_client() as client: # The request message that the service receives service_request_message = [] # The request message corresponding to the response received by the client client_response_message_request = [] # The event that we received client_event_message = [] client_event_message_condition = Condition() # The payload that the service receives service_request_message_receive_payload = [] client.connect() info = ServiceRegistrationInfo(client, "myWildcardService") meta = {} # Transform events mapped to "test/#/" to "request/test/..." meta["EventToRequestTopic"] = "/test/#" meta["EventToRequestPrefix"] = "/request" info.metadata = meta rcb = RequestCallback() def on_request(request): print("## Request in service: " + request.destination_topic + ", " + str(request.message_id)) print("## Request in service - payload: " + request.payload) service_request_message.append(request.message_id) service_request_message_receive_payload.append(request.payload) response = Response(request) response.payload = "Request response - Event payload: " + request.payload client.send_response(response) rcb.on_request = on_request info.add_topic("/request/test/#", rcb) client.register_service_sync(info, 10) evt = Event("/test/bar") rcb = ResponseCallback() def on_response(response): # Only handle the response corresponding to the event we sent if response.request_message_id == evt.message_id: print("## received_response: " + response.request_message_id + ", " + response.__class__.__name__) print("## received_response_payload: " + response.payload) client_response_message_request[0] = response.request_message_id rcb.on_response = on_response client.add_response_callback("", rcb) ecb = EventCallback() def on_event(event): print("## received event: " + event.destination_topic + ", " + event.message_id) with client_event_message_condition: client_event_message.append(event.message_id) client_event_message_condition.notify_all() ecb.on_event = on_event client.add_event_callback("/test/#", ecb) # Send our event print("## Sending event: " + evt.destination_topic + ", " + evt.message_id) evt.payload = "Unit test payload" client.send_event(evt) start = time.time() with client_event_message_condition: while (time.time() - start < max_wait) and \ not client_event_message: client_event_message_condition.wait(max_wait) # # Make sure the service received the request properly # self.assertEqual(evt.message_id, service_request_message[0]) # # Make sure the service received the request payload from the event properly # self.assertEqual(evt.payload, service_request_message_receive_payload[0]) # Make sure the response we received was for the request message # self.assertEqual(evt.message_id, client_response_message_request[0]) # Make sure we received the correct event self.assertGreater(len(client_event_message), 0) self.assertEqual(evt.message_id, client_event_message[0])