async def test_http_interpreter(): with aioresponses() as mocked: mocked.post("https://example.com/parse") endpoint = EndpointConfig('https://example.com') interpreter = RasaNLUHttpInterpreter(endpoint=endpoint) await interpreter.parse(text='message_text', message_id='1134') r = latest_request(mocked, "POST", "https://example.com/parse") query = json_of_latest_request(r) response = { 'project': 'default', 'q': 'message_text', 'message_id': '1134', 'model': None, 'token': None } assert query == response
def test_file_broker_properly_logs_newlines(tmpdir): fname = tmpdir.join("events.log").strpath actual = broker_utils.from_endpoint_config( EndpointConfig(**{ "type": "file", "path": fname })) event_with_newline = UserUttered("hello \n there") actual.publish(event_with_newline.as_dict()) # reading the events from the file one event per line recovered = [] with open(fname, "r") as f: for l in f: recovered.append(Event.from_parameters(json.loads(l))) assert recovered == [event_with_newline]
def test_file_broker_properly_logs_newlines(tmpdir): log_file_path = tmpdir.join("events.log").strpath actual = EventBroker.create( EndpointConfig(**{ "type": "file", "path": log_file_path })) event_with_newline = UserUttered("hello \n there") actual.publish(event_with_newline.as_dict()) # reading the events from the file one event per line recovered = [] with open(log_file_path, "r") as log_file: for line in log_file: recovered.append(Event.from_parameters(json.loads(line))) assert recovered == [event_with_newline]
async def test_http_parsing(): message = UserMessage("lunch?") endpoint = EndpointConfig("https://interpreter.com") with aioresponses() as mocked: mocked.post("https://interpreter.com/model/parse", repeat=True, status=200) inter = RasaNLUHttpInterpreter(endpoint_config=endpoint) try: await MessageProcessor(inter, None, None, None, None).parse_message(message) except KeyError: pass # logger looks for intent and entities, so we except r = latest_request(mocked, "POST", "https://interpreter.com/model/parse") assert r
async def test_remote_action_endpoint_responds_400( default_channel, default_nlg, default_tracker, default_domain ): endpoint = EndpointConfig("https://example.com/webhooks/actions") remote_action = action.RemoteAction("my_action", endpoint) with aioresponses() as mocked: # noinspection PyTypeChecker mocked.post( "https://example.com/webhooks/actions", exception=ClientResponseError(400, None, '{"action_name": "my_action"}'), ) with pytest.raises(Exception) as execinfo: await remote_action.run( default_channel, default_nlg, default_tracker, default_domain ) assert execinfo.type == ActionExecutionRejection assert "Custom action 'my_action' rejected to run" in str(execinfo.value)
async def test_agent_with_model_server_in_thread(model_server: TestClient, domain: Domain): model_endpoint_config = EndpointConfig.from_dict({ "url": model_server.make_url("/model"), "wait_time_between_pulls": 2 }) agent = Agent() agent = await rasa.core.agent.load_from_server( agent, model_server=model_endpoint_config) await asyncio.sleep(5) assert agent.fingerprint == "somehash" assert agent.domain.as_dict() == domain.as_dict() assert agent.processor.graph_runner assert model_server.app.number_of_model_requests == 1 jobs.kill_scheduler()
async def test_project_with_model_server(trained_nlu_model): fingerprint = "somehash" model_endpoint = EndpointConfig("http://server.com/models/nlu/tags/latest") zip_path = zip_folder(trained_nlu_model) # mock a response that returns a zipped model with io.open(zip_path, "rb") as f: responses.add( responses.GET, model_endpoint.url, headers={ "ETag": fingerprint, "filename": "my_model_xyz.zip" }, body=f.read(), content_type="application/zip", stream=True, ) project = await load_from_server(model_server=model_endpoint) assert project.fingerprint == fingerprint
async def get_response(self, request): """Train the engine. """ if self.config.get('domain') is None: self.config.setdefault( 'domain', Domain.from_file("data/" + self.config['skill-id'] + "/core/model")) self.config.setdefault( 'tracker_store', ArcusTrackerStore(self.config.get('domain'), self.asm)) domain = self.config.get('domain') tracker_store = self.config.get('tracker_store') nlg = NaturalLanguageGenerator.create(None, domain) policy_ensemble = SimplePolicyEnsemble.load("data/" + self.config['skill-id'] + "/core") interpreter = LocalNLUInterpreter(request) url = 'http://localhost:8080/api/v1/skill/generic_action' processor = MessageProcessor(interpreter, policy_ensemble, domain, tracker_store, nlg, action_endpoint=EndpointConfig(url), message_preprocessor=None) message_nlu = UserMessage(request['text'], None, request['user'], input_channel=request['channel']) result = await processor.handle_message(message_nlu) if result is not None and len(result) > 0: return {"text": result[0]['text']} else: _LOGGER.info(result) return {"text": "error"}
def _load_from_module_name_in_endpoint_config( domain: Domain, store: EndpointConfig, event_broker: Optional[EventBroker] = None ) -> "TrackerStore": """Initializes a custom tracker. Defaults to the InMemoryTrackerStore if the module path can not be found. Args: domain: defines the universe in which the assistant operates store: the specific tracker store event_broker: an event broker to publish events Returns: a tracker store from a specified type in a stores endpoint configuration """ try: tracker_store_class = rasa.shared.utils.common.class_from_module_path( store.type ) init_args = rasa.shared.utils.common.arguments_of(tracker_store_class.__init__) if "url" in init_args and "host" not in init_args: # DEPRECATION EXCEPTION - remove in 2.1 raise Exception( "The `url` initialization argument for custom tracker stores has " "been removed. Your custom tracker store should take a `host` " "argument in its `__init__()` instead." ) else: store.kwargs["host"] = store.url return tracker_store_class( domain=domain, event_broker=event_broker, **store.kwargs ) except (AttributeError, ImportError): rasa.shared.utils.io.raise_warning( f"Tracker store with type '{store.type}' not found. " f"Using `InMemoryTrackerStore` instead." ) return InMemoryTrackerStore(domain)
async def test_parse_with_http_interpreter(trained_default_agent_model: Text): endpoints = AvailableEndpoints(nlu=EndpointConfig("https://interpreter.com")) agent = await load_agent( model_path=trained_default_agent_model, endpoints=endpoints ) response_body = { "intent": {INTENT_NAME_KEY: "some_intent", "confidence": 1.0}, "entities": [], "text": "lunch?", } with aioresponses() as mocked: mocked.post( "https://interpreter.com/model/parse", repeat=True, status=HTTPStatus.OK, body=json.dumps(response_body), ) # mock the parse function with the one defined for this test result = await agent.parse_message("lunch?") assert result == response_body
async def test_remote_action_invalid_entities_payload( default_channel: OutputChannel, default_nlg: NaturalLanguageGenerator, default_tracker: DialogueStateTracker, domain: Domain, event: Event, ): endpoint = EndpointConfig("https://example.com/webhooks/actions") remote_action = action.RemoteAction("my_action", endpoint) response = { "events": [event], "responses": [], } with aioresponses() as mocked: mocked.post("https://example.com/webhooks/actions", payload=response) with pytest.raises(ValidationError) as e: await remote_action.run(default_channel, default_nlg, default_tracker, domain) assert "Failed to validate Action server response from API" in str(e.value)
def test_callback_channel(): # START DOC INCLUDE from rasa.core.channels.callback import CallbackInput from rasa.core.agent import Agent from rasa.core.interpreter import RegexInterpreter # load your trained agent agent = Agent.load(MODEL_PATH, interpreter=RegexInterpreter()) input_channel = CallbackInput( # URL Core will call to send the bot responses endpoint=EndpointConfig("http://localhost:5004")) s = agent.handle_channels([input_channel], 5004) # END DOC INCLUDE # the above marker marks the end of the code snipped included # in the docs routes_list = utils.list_routes(s) assert routes_list.get("callback_webhook.health").startswith( "/webhooks/callback") assert routes_list.get("callback_webhook.webhook").startswith( "/webhooks/callback/webhook")
async def test_no_slots_extracted_with_custom_slot_mappings( custom_events: List[Event]): form_name = "my form" events = [ ActiveLoop(form_name), SlotSet(REQUESTED_SLOT, "num_tables"), ActionExecuted(ACTION_LISTEN_NAME), UserUttered("off topic"), ] tracker = DialogueStateTracker.from_events(sender_id="bla", evts=events) domain = f""" slots: num_tables: type: any forms: {form_name}: num_tables: - type: from_entity entity: num_tables actions: - validate_{form_name} """ domain = Domain.from_yaml(domain) action_server_url = "http:/my-action-server:5055/webhook" with aioresponses() as mocked: mocked.post(action_server_url, payload={"events": custom_events}) action_server = EndpointConfig(action_server_url) action = FormAction(form_name, action_server) with pytest.raises(ActionExecutionRejection): await action.run( CollectingOutputChannel(), TemplatedNaturalLanguageGenerator(domain.templates), tracker, domain, )
def _serve_application(app, stories, finetune, skip_visualization): """Start a core server and attach the interactive learning IO.""" endpoint = EndpointConfig(url=DEFAULT_SERVER_URL) async def run_interactive_io(running_app: Sanic): """Small wrapper to shut down the server once cmd io is done.""" await record_messages( endpoint=endpoint, stories=stories, finetune=finetune, skip_visualization=skip_visualization, sender_id=uuid.uuid4().hex) logger.info("Killing Sanic server now.") running_app.stop() # kill the sanic server app.add_task(run_interactive_io) app.run(host='0.0.0.0', port=DEFAULT_SERVER_PORT, access_log=True) return app
def __init__(self, project: str = "GDA", model: str = "AOD", user_type: str = "user"): self.project = project self.model = model self.user_type = user_type self.agent_path = os.path.join(here, DEFAULT_MODELS_PATH, self.project, self.model) if model == "smalltalk": self.action_endpoint = None else: self.action_endpoint = EndpointConfig( url=config.ACTION_URL_ENDPOINT) if os.path.exists(self.agent_path): self.agent = Agent.load(get_latest_model(self.agent_path), action_endpoint=self.action_endpoint) else: raise NotADirectoryError( "NLU or dialogue model not found, make sure training succeeded" )
async def load_model(request: Request): validate_request_body(request, "No path to model file defined in request_body.") model_path = request.json.get("model_file", None) model_server = request.json.get("model_server", None) remote_storage = request.json.get("remote_storage", None) if model_server: try: model_server = EndpointConfig.from_dict(model_server) except TypeError as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 400, "BadRequest", "Supplied 'model_server' is not valid. Error: {}".format(e), {"parameter": "model_server", "in": "body"}, ) app.agent = await _load_agent( model_path, model_server, remote_storage, endpoints ) logger.debug("Successfully loaded model '{}'.".format(model_path)) return response.json(None, status=204)
async def test_remote_action_endpoint_responds_400( default_dispatcher_collecting, default_domain): tracker = DialogueStateTracker("default", default_domain.slots) endpoint = EndpointConfig("https://example.com/webhooks/actions") remote_action = action.RemoteAction("my_action", endpoint) with aioresponses() as mocked: # noinspection PyTypeChecker mocked.post( 'https://example.com/webhooks/actions', exception=ClientResponseError( 400, None, '{"action_name": "my_action"}')) with pytest.raises(Exception) as execinfo: await remote_action.run(default_dispatcher_collecting, tracker, default_domain) assert execinfo.type == ActionExecutionRejection assert "Custom action 'my_action' rejected to run" in str(execinfo.value)
async def test_pull_model_with_invalid_domain(model_server: TestClient, monkeypatch: MonkeyPatch, caplog: LogCaptureFixture): # mock `Domain.load()` as if the domain contains invalid YAML error_message = "domain is invalid" mock_load = Mock(side_effect=InvalidDomain(error_message)) monkeypatch.setattr(Domain, "load", mock_load) model_endpoint_config = EndpointConfig.from_dict({ "url": model_server.make_url("/model"), "wait_time_between_pulls": None }) agent = Agent() await rasa.core.agent.load_from_server(agent, model_server=model_endpoint_config) # `Domain.load()` was called mock_load.assert_called_once() # error was logged assert error_message in caplog.text
async def test_agent_with_model_server_in_thread( model_server: TestClient, moodbot_domain: Domain, moodbot_metadata: Any ): model_endpoint_config = EndpointConfig.from_dict( {"url": model_server.make_url("/model"), "wait_time_between_pulls": 2} ) agent = Agent() agent = await rasa.core.agent.load_from_server( agent, model_server=model_endpoint_config ) await asyncio.sleep(5) assert agent.fingerprint == "somehash" assert hash(agent.domain) == hash(moodbot_domain) agent_policies = { utils.module_path_from_instance(p) for p in agent.policy_ensemble.policies } moodbot_policies = set(moodbot_metadata["policy_names"]) assert agent_policies == moodbot_policies assert model_server.app.number_of_model_requests == 1 jobs.kill_scheduler()
def test_get_number_of_sanic_workers( env_value: Optional[Text], lock_store: Union[LockStore, Text, None], expected: Optional[int], ): # remember pre-test value of SANIC_WORKERS env var pre_test_value = os.environ.get(ENV_SANIC_WORKERS) # set env var to desired value and make assertion if env_value is not None: os.environ[ENV_SANIC_WORKERS] = str(env_value) # lock_store may be string or LockStore object # create EndpointConfig if it's a string, otherwise pass the object if isinstance(lock_store, str): lock_store = EndpointConfig(type=lock_store) assert utils.number_of_sanic_workers(lock_store) == expected # reset env var to pre-test value os.environ.pop(ENV_SANIC_WORKERS, None) if pre_test_value is not None: os.environ[ENV_SANIC_WORKERS] = pre_test_value
async def test_remote_action_runs(default_dispatcher_collecting, default_domain): tracker = DialogueStateTracker("default", default_domain.slots) endpoint = EndpointConfig("https://example.com/webhooks/actions") remote_action = action.RemoteAction("my_action", endpoint) with aioresponses() as mocked: mocked.post( "https://example.com/webhooks/actions", payload={"events": [], "responses": []}, ) await remote_action.run(default_dispatcher_collecting, tracker, default_domain) r = latest_request(mocked, "post", "https://example.com/webhooks/actions") assert r assert json_of_latest_request(r) == { "domain": default_domain.as_dict(), "next_action": "my_action", "sender_id": "default", "version": rasa.__version__, "tracker": { "latest_message": {"entities": [], "intent": {}, "text": None}, "active_form": {}, "latest_action_name": None, "sender_id": "default", "paused": False, "latest_event_time": None, "followup_action": "action_listen", "slots": {"name": None}, "events": [], "latest_input_channel": None, }, }
async def test_load_model_from_model_server( rasa_app: SanicASGITestClient, trained_core_model: Text ): _, response = await rasa_app.get("/status") assert response.status == 200 assert "fingerprint" in response.json() old_fingerprint = response.json()["fingerprint"] endpoint = EndpointConfig("https://example.com/model/trained_core_model") with open(trained_core_model, "rb") as f: with aioresponses(passthrough=["http://127.0.0.1"]) as mocked: headers = {} fs = os.fstat(f.fileno()) headers["Content-Length"] = str(fs[6]) mocked.get( "https://example.com/model/trained_core_model", content_type="application/x-tar", body=f.read(), ) data = {"model_server": {"url": endpoint.url}} _, response = await rasa_app.put("/model", json=data) assert response.status == 204 _, response = await rasa_app.get("/status") assert response.status == 200 assert "fingerprint" in response.json() assert old_fingerprint != response.json()["fingerprint"] import rasa.core.jobs rasa.core.jobs.__scheduler = None
async def test_remote_action_logs_events(default_channel, default_nlg, default_tracker, default_domain): endpoint = EndpointConfig("https://example.com/webhooks/actions") remote_action = action.RemoteAction("my_action", endpoint) response = { "events": [{ "event": "slot", "value": "rasa", "name": "name" }], "responses": [ { "text": "test text", "template": None, "buttons": [{ "title": "cheap", "payload": "cheap" }], }, { "template": "utter_greet" }, ], } with aioresponses() as mocked: mocked.post("https://example.com/webhooks/actions", payload=response) events = await remote_action.run(default_channel, default_nlg, default_tracker, default_domain) r = latest_request(mocked, "post", "https://example.com/webhooks/actions") assert r assert json_of_latest_request(r) == { "domain": default_domain.as_dict(), "next_action": "my_action", "sender_id": "my-sender", "version": rasa.__version__, "tracker": { "latest_message": { "entities": [], "intent": {}, "text": None, "message_id": None, "metadata": {}, }, ACTIVE_LOOP: {}, "latest_action": {}, "latest_action_name": None, "sender_id": "my-sender", "paused": False, "followup_action": "action_listen", "latest_event_time": None, "slots": { "name": None }, "events": [], "latest_input_channel": None, }, } assert len(events) == 3 # first two events are bot utterances assert events[0] == BotUttered( "test text", {"buttons": [{ "title": "cheap", "payload": "cheap" }]}) assert events[1] == BotUttered("hey there None!", metadata={"template_name": "utter_greet"}) assert events[2] == SlotSet("name", "rasa")
async def _pull_model_and_fingerprint(model_server: EndpointConfig, fingerprint: Optional[Text], model_directory: Text) -> Optional[Text]: """Queries the model server. Args: model_server: Model server endpoint information. fingerprint: Current model fingerprint. model_directory: Directory where to download model to. Returns: Value of the response's <ETag> header which contains the model hash. Returns `None` if no new model is found. """ headers = {"If-None-Match": fingerprint} logger.debug(f"Requesting model from server {model_server.url}...") async with model_server.session() as session: try: params = model_server.combine_parameters() async with session.request( "GET", model_server.url, timeout=DEFAULT_REQUEST_TIMEOUT, headers=headers, params=params, ) as resp: if resp.status in [204, 304]: logger.debug("Model server returned {} status code, " "indicating that no new model is available. " "Current fingerprint: {}" "".format(resp.status, fingerprint)) return None elif resp.status == 404: logger.debug( "Model server could not find a model at the requested " "endpoint '{}'. It's possible that no model has been " "trained, or that the requested tag hasn't been " "assigned.".format(model_server.url)) return None elif resp.status != 200: logger.debug( "Tried to fetch model from server, but server response " "status code is {}. We'll retry later..." "".format(resp.status)) return None rasa.utils.io.unarchive(await resp.read(), model_directory) logger.debug("Unzipped model to '{}'".format( os.path.abspath(model_directory))) # return the new fingerprint return resp.headers.get("ETag") except aiohttp.ClientError as e: logger.debug("Tried to fetch model from server, but " "couldn't reach server. We'll retry later... " "Error: {}.".format(e)) return None
async def test_remote_action_logs_events(default_dispatcher_collecting, default_domain): tracker = DialogueStateTracker("default", default_domain.slots) endpoint = EndpointConfig("https://example.com/webhooks/actions") remote_action = action.RemoteAction("my_action", endpoint) response = { "events": [{ "event": "slot", "value": "rasa", "name": "name" }], "responses": [ { "text": "test text", "buttons": [{ "title": "cheap", "payload": "cheap" }] }, { "template": "utter_greet" }, ], } with aioresponses() as mocked: mocked.post("https://example.com/webhooks/actions", payload=response) events = await remote_action.run(default_dispatcher_collecting, tracker, default_domain) r = latest_request(mocked, "post", "https://example.com/webhooks/actions") assert r assert json_of_latest_request(r) == { "domain": default_domain.as_dict(), "next_action": "my_action", "sender_id": "default", "version": rasa.__version__, "tracker": { "latest_message": { "entities": [], "intent": {}, "text": None }, "active_form": {}, "latest_action_name": None, "sender_id": "default", "paused": False, "followup_action": "action_listen", "latest_event_time": None, "slots": { "name": None }, "events": [], "latest_input_channel": None, }, } assert events == [SlotSet("name", "rasa")] channel = default_dispatcher_collecting.output_channel assert channel.messages == [ { "text": "test text", "recipient_id": "my-sender", "buttons": [{ "title": "cheap", "payload": "cheap" }], }, { "text": "hey there None!", "recipient_id": "my-sender" }, ]
def get_endpoint(self, bot: Text): from sagas.conf.runtime import runtime bot_endpoint = f"{bot}_actions" if runtime.is_docker() else 'localhost' return EndpointConfig(f"http://{bot_endpoint}:5055/webhook")
def test_load_non_existent_custom_broker_name(): config = EndpointConfig(**{"type": "rasa.core.broker.MyProducer"}) assert broker.from_endpoint_config(config) is None
def test_load_custom_broker_name(): config = EndpointConfig(**{"type": "rasa.core.broker.FileProducer"}) assert broker.from_endpoint_config(config)
def test_custom_token_name(): test_data = {"url": "http://test", "token": "token", "token_name": "test_token"} actual = EndpointConfig.from_dict(test_data) assert actual.token_name == "test_token"
from rasa.core.channels.socketio import SocketIOInput from rasa.core.agent import Agent from rasa.core.interpreter import RasaNLUInterpreter from rasa.utils.endpoints import EndpointConfig #from MyIo import RestInput action_endpoint = EndpointConfig(url="http://localhost:5055/webhook") nlu_interpreter = RasaNLUInterpreter('/data/xingyang/Documents/Hangzhou Dianzi University - Chatbot/capstone-master/models/20190717-114901/nlu') agent = Agent.load('/data/xingyang/Documents/Hangzhou Dianzi University - Chatbot/capstone-master/models/20190717-114901/core', interpreter=nlu_interpreter, action_endpoint=action_endpoint) #input_channel = RestInput() input_channel = SocketIOInput( # event name for messages sent from the user user_message_evt="user_uttered", # event name for messages sent from the bot bot_message_evt="bot_uttered", # socket.io namespace to use for the messages namespace=None ) #s = agent.handle_channels([input_channel],5005, serve_forever=True) agent.handle_channels([input_channel], http_port=5005,route="/webhooks/",cors="*")