def test_tracker_store_remembers_max_history(default_domain: Domain): store = InMemoryTrackerStore(default_domain) tr = store.get_or_create_tracker("myuser", max_event_history=42) tr.update(Restarted()) store.save(tr) tr2 = store.retrieve("myuser") assert tr._max_event_history == tr2._max_event_history == 42
async def _tracker_store_and_tracker_with_slot_set() -> Tuple[ InMemoryTrackerStore, DialogueStateTracker ]: # returns an InMemoryTrackerStore containing a tracker with a slot set slot_key = "cuisine" slot_val = "French" store = InMemoryTrackerStore(test_domain) tracker = await store.get_or_create_tracker(DEFAULT_SENDER_ID) ev = SlotSet(slot_key, slot_val) tracker.update(ev) return store, tracker
async def test_update_conversation_with_events( rasa_app: SanicASGITestClient, monkeypatch: MonkeyPatch, initial_tracker_events: List[Event], events_to_append: List[Event], expected_events: List[Event], ): conversation_id = "some-conversation-ID" domain = Domain.empty() tracker_store = InMemoryTrackerStore(domain) monkeypatch.setattr(rasa_app.app.agent, "tracker_store", tracker_store) if initial_tracker_events: tracker = DialogueStateTracker.from_events( conversation_id, initial_tracker_events ) tracker_store.save(tracker) fetched_tracker = await rasa.server.update_conversation_with_events( conversation_id, rasa_app.app.agent.create_processor(), domain, events_to_append ) assert list(fetched_tracker.events) == expected_events
async def test_restart_after_retrieval_from_tracker_store(domain: Domain): store = InMemoryTrackerStore(domain) tr = await store.get_or_create_tracker("myuser") synth = [ActionExecuted("action_listen") for _ in range(4)] for e in synth: tr.update(e) tr.update(Restarted()) latest_restart = tr.idx_after_latest_restart() await store.save(tr) tr2 = await store.retrieve("myuser") latest_restart_after_loading = tr2.idx_after_latest_restart() assert latest_restart == latest_restart_after_loading
def test_get_next_action_probabilities_pass_policy_predictions_without_interpreter_arg( predict_function: Callable, ): policy = TEDPolicy() policy.predict_action_probabilities = predict_function ensemble = SimplePolicyEnsemble(policies=[policy]) interpreter = Mock() domain = Domain.empty() processor = MessageProcessor(interpreter, ensemble, domain, InMemoryTrackerStore(domain), Mock()) with pytest.warns(DeprecationWarning): processor._get_next_action_probabilities( DialogueStateTracker.from_events( "lala", [ActionExecuted(ACTION_LISTEN_NAME)]))
async def default_processor(default_domain, default_nlg): agent = Agent( default_domain, SimplePolicyEnsemble([AugmentedMemoizationPolicy()]), interpreter=RegexInterpreter(), ) training_data = await agent.load_data(DEFAULT_STORIES_FILE) agent.train(training_data) tracker_store = InMemoryTrackerStore(default_domain) return MessageProcessor( agent.interpreter, agent.policy_ensemble, default_domain, tracker_store, default_nlg, )
def test_get_or_create(): slot_key = 'location' slot_val = 'Easter Island' store = InMemoryTrackerStore(domain) tracker = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID) ev = SlotSet(slot_key, slot_val) tracker.update(ev) assert tracker.get_slot(slot_key) == slot_val store.save(tracker) again = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID) assert again.get_slot(slot_key) == slot_val
async def test_markers_cli_results_save_correctly(tmp_path: Path): domain = Domain.empty() store = InMemoryTrackerStore(domain) for i in range(5): tracker = DialogueStateTracker(str(i), None) tracker.update_with_events([SlotSet(str(j), "slot") for j in range(5)], domain) tracker.update(ActionExecuted(ACTION_SESSION_START_NAME)) tracker.update(UserUttered("hello")) tracker.update_with_events( [SlotSet(str(5 + j), "slot") for j in range(5)], domain) await store.save(tracker) tracker_loader = MarkerTrackerLoader(store, "all") results_path = tmp_path / "results.csv" markers = OrMarker(markers=[ SlotSetMarker("2", name="marker1"), SlotSetMarker("7", name="marker2") ]) await markers.evaluate_trackers(tracker_loader.load(), results_path) with open(results_path, "r") as results: result_reader = csv.DictReader(results) senders = set() for row in result_reader: senders.add(row["sender_id"]) if row["marker"] == "marker1": assert row["session_idx"] == "0" assert int(row["event_idx"]) >= 2 assert row["num_preceding_user_turns"] == "0" if row["marker"] == "marker2": assert row["session_idx"] == "1" assert int(row["event_idx"]) >= 3 assert row["num_preceding_user_turns"] == "1" assert len(senders) == 5
async def test_processor_logs_text_tokens_in_tracker(mood_agent: Agent): text = "Hello there" tokenizer = WhitespaceTokenizer() tokens = tokenizer.tokenize(Message(data={"text": text}), "text") indices = [(t.start, t.end) for t in tokens] message = UserMessage(text) tracker_store = InMemoryTrackerStore(mood_agent.domain) lock_store = InMemoryLockStore() processor = MessageProcessor( mood_agent.interpreter, mood_agent.policy_ensemble, mood_agent.domain, tracker_store, lock_store, TemplatedNaturalLanguageGenerator(mood_agent.domain.responses), ) tracker = await processor.log_message(message) event = tracker.get_last_event_for(event_type=UserUttered) event_tokens = event.as_dict().get("parse_data").get("text_tokens") assert event_tokens == indices
async def default_agent(_default_agent: Agent) -> Agent: # Clean tracker store after each test so tests don't affect each other _default_agent.tracker_store = InMemoryTrackerStore(_default_agent.domain) _default_agent.domain.session_config = SessionConfig.default() return _default_agent
def test_inmemory_tracker_store_with_token_serialisation( domain: Domain, response_selector_interpreter: Interpreter): tracker_store = InMemoryTrackerStore(domain) prepare_token_serialisation(tracker_store, response_selector_interpreter, "inmemory")
def test_get_or_create(): get_or_create_tracker_store(InMemoryTrackerStore(domain))
def reset_conversation_state(agent: Agent) -> Agent: # Clean tracker store after each test so tests don't affect each other agent.tracker_store = InMemoryTrackerStore(agent.domain) agent.domain.session_config = SessionConfig.default() return agent
class RasaServiceLocal(MqttService): """ Load RASA model and tracker directly and use to handle intent and routing messages""" def __init__(self, config, loop): """constructor""" super(RasaServiceLocal, self).__init__(config, loop) self.config = config self.subscribe_to = 'hermod/+/rasa/get_domain,hermod/+/rasa/set_slots' \ + ',hermod/+/dialog/ended,hermod/+/dialog/init,hermod/+/nlu/externalparse,' \ + 'hermod/+/nlu/parse,hermod/+/intent,hermod/+/intent,hermod/+/dialog/started' model_path = get_model( config['services']['RasaServiceLocal'].get('model_path')) endpoint = EndpointConfig( config['services']['RasaServiceLocal'].get('rasa_actions_url')) domain = 'domain.yml' self.tracker_store = InMemoryTrackerStore(domain) regex_interpreter = RegexInterpreter() self.text_interpreter = RasaNLUInterpreter(model_path + '/nlu') self.agent = Agent.load(model_path, action_endpoint=endpoint, tracker_store=self.tracker_store, interpreter=regex_interpreter) async def connect_hook(self): """mqtt connected callback""" # SUBSCRIBE for sub in self.subscribe_to.split(","): await self.client.subscribe(sub) await self.client.publish('hermod/rasa/ready', json.dumps({})) async def on_message(self, message): """handle mqtt message""" topic = "{}".format(message.topic) parts = topic.split("/") site = parts[1] payload_string = str(message.payload, encoding='utf-8') payload = {} text = '' try: payload = json.loads(payload_string) except json.JSONDecodeError: pass if topic == 'hermod/' + site + '/rasa/set_slots': if payload: await self.set_slots(site, payload) elif topic == 'hermod/' + site + '/nlu/parse': if payload: await self.client.publish('hermod/' + site \ + '/display/startwaiting', json.dumps({})) text = payload.get('query') await self.nlu_parse_request(site, text, payload) await self.client.publish( 'hermod/' + site + '/display/stopwaiting', json.dumps({})) elif topic == 'hermod/' + site + '/nlu/externalparse': if payload: text = payload.get('query') await self.nlu_external_parse_request(site, text, payload) elif topic == 'hermod/' + site + '/intent': if payload: await self.client.publish('hermod/' + site \ + '/display/startwaiting', json.dumps({})) await self.handle_intent(site, payload) await self.client.publish('hermod/' + \ site + '/display/stopwaiting', json.dumps({})) elif topic == 'hermod/' + site + '/tts/finished': await self.client.unsubscribe('hermod/' + site + '/tts/finished') await self.finish(site, payload) elif topic == 'hermod/' + site + '/dialog/started': await self.reset_tracker(site) elif topic == 'hermod/' + site + '/ ': # save dialog init data to slots for custom actions tracker = self.tracker_store.get_or_create_tracker(site) tracker.update(SlotSet("hermod_client", json.dumps(payload))) self.tracker_store.save(tracker) elif topic == 'hermod/' + site + '/rasa/get_domain': await self.send_domain(site) elif topic == 'hermod/' + site + '/core/ended': await self.send_story(site, payload) async def send_story(self, site, payload): """send conversation history for a site""" # text = payload.get('text', '') tracker = self.tracker_store.get_or_create_tracker(site) response = tracker.export_stories() await self.client.publish('hermod/' + site + \ '/rasa/story', json.dumps({'id': payload.get('id', ''), 'story': response})) async def send_domain(self, site): """send domain for a site""" await self.client.publish('hermod/' + site + \ '/rasa/domain', json.dumps(self.agent.domain.as_dict())) async def reset_tracker(self, site): """reset conversation history for a site""" pass # self.log('RESSET tracker '+site) # tracker = self.tracker_store.get_or_create_tracker(site) # tracker._reset() async def handle_intent(self, site, payload): """handle intent message""" await self.client.publish('hermod/' + site + '/core/started', json.dumps(payload)) if payload: intent_name = payload.get('intent', {}).get('name', '') entities_json = {} entities = payload.get('entities', []) for entity in entities: entities_json[entity.get('entity')] = entity.get('value') intent_json = "/" + intent_name + json.dumps(entities_json) messages = [] responses = await self.agent.handle_text(intent_json, sender_id=site, \ output_channel=None) for response in responses: messages.append(response.get("text")) if messages: message = '. '.join(messages) await self.client.subscribe('hermod/' + site + '/tts/finished') await self.client.publish( 'hermod/' + site + '/tts/say', json.dumps({ "text": message, "id": payload.get('id', '') })) else: await self.finish(site, payload) else: await self.finish(site, payload) async def set_slots(self, site, payload): """set tracker slots""" tracker = self.tracker_store.get_or_create_tracker(site) if payload: for slot in payload.get('slots', []): tracker.update(SlotSet(slot.get('slot'), slot.get('value'))) self.tracker_store.save(tracker) await self.client.publish('hermod/' + site + '/dialog/slots', \ json.dumps(tracker.current_slot_values())) async def send_slots(self, site): """send a message with current tracker slots for site""" tracker = self.tracker_store.get_or_create_tracker(site) slots = tracker.current_slot_values() await self.client.publish('hermod/' + site + '/dialog/slots', json.dumps(slots)) async def finish(self, site, payload): """ finish intent callback """ tracker = self.tracker_store.get_or_create_tracker(site) slots = tracker.current_slot_values() if slots.get('hermod_force_continue', False) == 'true': tracker.update(SlotSet("hermod_force_continue", None)) tracker.update(SlotSet("hermod_force_end", None)) self.tracker_store.save(tracker) await self.client.publish( 'hermod/' + site + '/dialog/continue', json.dumps({"id": payload.get("id", "")})) elif slots.get('hermod_force_end', False) == 'true': tracker.update(SlotSet("hermod_force_continue", None)) tracker.update(SlotSet("hermod_force_end", None)) self.tracker_store.save(tracker) await self.client.publish( 'hermod/' + site + '/dialog/end', json.dumps({"id": payload.get("id", "")})) else: if self.config.get('keep_listening') == "true": await self.client.publish( 'hermod/' + site + '/dialog/continue', json.dumps({"id": payload.get("id", "")})) else: await self.client.publish( 'hermod/' + site + '/dialog/end', json.dumps({"id": payload.get("id", "")})) await self.send_slots(site) await self.client.publish('hermod/' + site + '/core/ended', json.dumps(payload)) async def nlu_parse_request(self, site, text, payload): """ parse text into NLU json and send as message""" response = await self.text_interpreter.parse(text) response['id'] = payload.get('id', '') await self.client.publish('hermod/' + site + '/nlu/intent', json.dumps(response)) async def nlu_external_parse_request(self, site, text, payload): """ parse text into NLU json and send as message without invoking hermod flow""" response = await self.text_interpreter.parse(text) response['id'] = payload.get('id', '') await self.client.publish('hermod/' + site + '/nlu/externalintent', json.dumps(response))
def test_dialogue_from_parameters(domain: Domain): tracker = tracker_from_dialogue(TEST_DEFAULT_DIALOGUE, domain) serialised_dialogue = InMemoryTrackerStore.serialise_tracker(tracker) deserialised_dialogue = Dialogue.from_parameters( json.loads(serialised_dialogue)) assert tracker.as_dialogue().as_dict() == deserialised_dialogue.as_dict()
def test_inmemory_tracker_store_with_token_serialisation( domain: Domain, response_selector_agent: Agent ): tracker_store = InMemoryTrackerStore(domain) prepare_token_serialisation(tracker_store, response_selector_agent, "inmemory")
async def test_switch_forms_with_same_slot(default_agent: Agent): """Tests switching of forms, where the first slot is the same in both forms. Tests the fix for issue 7710""" # Define two forms in the domain, with same first slot slot_a = "my_slot_a" form_1 = "my_form_1" utter_ask_form_1 = f"Please provide the value for {slot_a} of form 1" form_2 = "my_form_2" utter_ask_form_2 = f"Please provide the value for {slot_a} of form 2" domain = f""" version: "2.0" nlu: - intent: order_status examples: | - check status of my order - when are my shoes coming in - intent: return examples: | - start a return - I don't want my shoes anymore forms: {form_1}: {slot_a}: - type: from_entity entity: number {form_2}: {slot_a}: - type: from_entity entity: number responses: utter_ask_{form_1}_{slot_a}: - text: {utter_ask_form_1} utter_ask_{form_2}_{slot_a}: - text: {utter_ask_form_2} """ domain = Domain.from_yaml(domain) # Driving it like rasa/core/processor processor = MessageProcessor( default_agent.interpreter, default_agent.policy_ensemble, domain, InMemoryTrackerStore(domain), TemplatedNaturalLanguageGenerator(domain.templates), ) # activate the first form tracker = DialogueStateTracker.from_events( "some-sender", evts=[ ActionExecuted(ACTION_LISTEN_NAME), UserUttered("order status", { "name": "form_1", "confidence": 1.0 }), DefinePrevUserUtteredFeaturization(False), ], ) # rasa/core/processor.predict_next_action prediction = PolicyPrediction([], "some_policy") action_1 = FormAction(form_1, None) await processor._run_action( action_1, tracker, CollectingOutputChannel(), TemplatedNaturalLanguageGenerator(domain.templates), prediction, ) events_expected = [ ActionExecuted(ACTION_LISTEN_NAME), UserUttered("order status", { "name": "form_1", "confidence": 1.0 }), DefinePrevUserUtteredFeaturization(False), ActionExecuted(form_1), ActiveLoop(form_1), SlotSet(REQUESTED_SLOT, slot_a), BotUttered( text=utter_ask_form_1, metadata={"template_name": f"utter_ask_{form_1}_{slot_a}"}, ), ] assert tracker.applied_events() == events_expected next_events = [ ActionExecuted(ACTION_LISTEN_NAME), UserUttered("return my shoes", { "name": "form_2", "confidence": 1.0 }), DefinePrevUserUtteredFeaturization(False), ] tracker.update_with_events( next_events, domain, ) events_expected.extend(next_events) # form_1 is still active, and bot will first validate if the user utterance # provides valid data for the requested slot, which is rejected await processor._run_action( action_1, tracker, CollectingOutputChannel(), TemplatedNaturalLanguageGenerator(domain.templates), prediction, ) events_expected.extend([ActionExecutionRejected(action_name=form_1)]) assert tracker.applied_events() == events_expected # Next, bot predicts form_2 action_2 = FormAction(form_2, None) await processor._run_action( action_2, tracker, CollectingOutputChannel(), TemplatedNaturalLanguageGenerator(domain.templates), prediction, ) events_expected.extend([ ActionExecuted(form_2), ActiveLoop(form_2), SlotSet(REQUESTED_SLOT, slot_a), BotUttered( text=utter_ask_form_2, metadata={"template_name": f"utter_ask_{form_2}_{slot_a}"}, ), ]) assert tracker.applied_events() == events_expected
async def test_logging_of_end_to_end_action(): end_to_end_action = "hi, how are you?" domain = Domain( intents=["greet"], entities=[], slots=[], templates={}, action_names=[], forms={}, action_texts=[end_to_end_action], ) conversation_id = "test_logging_of_end_to_end_action" user_message = "/greet" class ConstantEnsemble(PolicyEnsemble): def __init__(self) -> None: super().__init__([]) self.number_of_calls = 0 def probabilities_using_best_policy( self, tracker: DialogueStateTracker, domain: Domain, interpreter: NaturalLanguageInterpreter, **kwargs: Any, ) -> PolicyPrediction: if self.number_of_calls == 0: prediction = PolicyPrediction.for_action_name( domain, end_to_end_action, "some policy" ) prediction.is_end_to_end_prediction = True self.number_of_calls += 1 return prediction else: return PolicyPrediction.for_action_name(domain, ACTION_LISTEN_NAME) tracker_store = InMemoryTrackerStore(domain) lock_store = InMemoryLockStore() processor = MessageProcessor( RegexInterpreter(), ConstantEnsemble(), domain, tracker_store, lock_store, NaturalLanguageGenerator.create(None, domain), ) await processor.handle_message(UserMessage(user_message, sender_id=conversation_id)) tracker = tracker_store.retrieve(conversation_id) expected_events = [ ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME), UserUttered(user_message, intent={"name": "greet"}), ActionExecuted(action_text=end_to_end_action), BotUttered("hi, how are you?", {}, {}, 123), ActionExecuted(ACTION_LISTEN_NAME), ] for event, expected in zip(tracker.events, expected_events): assert event == expected
def stores_to_be_tested(): temp = tempfile.mkdtemp() return [MockRedisTrackerStore(domain), InMemoryTrackerStore(domain), SQLTrackerStore(domain, db=os.path.join(temp, 'rasa.db'))]
def test_predict_next_action_with_hidden_rules(): rule_intent = "rule_intent" rule_action = "rule_action" story_intent = "story_intent" story_action = "story_action" rule_slot = "rule_slot" story_slot = "story_slot" domain = Domain.from_yaml(f""" version: "2.0" intents: - {rule_intent} - {story_intent} actions: - {rule_action} - {story_action} slots: {rule_slot}: type: text {story_slot}: type: text """) rule = TrackerWithCachedStates.from_events( "rule", domain=domain, slots=domain.slots, evts=[ ActionExecuted(RULE_SNIPPET_ACTION_NAME), ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": rule_intent}), ActionExecuted(rule_action), SlotSet(rule_slot, rule_slot), ActionExecuted(ACTION_LISTEN_NAME), ], is_rule_tracker=True, ) story = TrackerWithCachedStates.from_events( "story", domain=domain, slots=domain.slots, evts=[ ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": story_intent}), ActionExecuted(story_action), SlotSet(story_slot, story_slot), ActionExecuted(ACTION_LISTEN_NAME), ], ) interpreter = RegexInterpreter() ensemble = SimplePolicyEnsemble( policies=[RulePolicy(), MemoizationPolicy()]) ensemble.train([rule, story], domain, interpreter) tracker_store = InMemoryTrackerStore(domain) lock_store = InMemoryLockStore() processor = MessageProcessor( interpreter, ensemble, domain, tracker_store, lock_store, TemplatedNaturalLanguageGenerator(domain.responses), ) tracker = DialogueStateTracker.from_events( "casd", evts=[ ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": rule_intent}), ], slots=domain.slots, ) action, prediction = processor.predict_next_action(tracker) assert action._name == rule_action assert prediction.hide_rule_turn processor._log_action_on_tracker(tracker, action, [SlotSet(rule_slot, rule_slot)], prediction) action, prediction = processor.predict_next_action(tracker) assert isinstance(action, ActionListen) assert prediction.hide_rule_turn processor._log_action_on_tracker(tracker, action, None, prediction) tracker.events.append(UserUttered(intent={"name": story_intent})) # rules are hidden correctly if memo policy predicts next actions correctly action, prediction = processor.predict_next_action(tracker) assert action._name == story_action assert not prediction.hide_rule_turn processor._log_action_on_tracker(tracker, action, [SlotSet(story_slot, story_slot)], prediction) action, prediction = processor.predict_next_action(tracker) assert isinstance(action, ActionListen) assert not prediction.hide_rule_turn