def predict_action_probabilities( self, tracker: DialogueStateTracker, domain: Domain, interpreter: NaturalLanguageInterpreter, **kwargs: Any, ) -> PolicyPrediction: """Predicts the corresponding form action if there is an active form.""" result = self._default_predictions(domain) if tracker.active_loop_name: logger.debug("There is an active form '{}'".format( tracker.active_loop_name)) if tracker.latest_action_name == ACTION_LISTEN_NAME: # predict form action after user utterance if tracker.active_loop.get(LOOP_REJECTED): if self.state_is_unhappy(tracker, domain): return self._prediction(result, events=[LoopInterrupted(True)]) result = self._prediction_result(tracker.active_loop_name, tracker, domain) elif tracker.latest_action_name == tracker.active_loop_name: # predict action_listen after form action result = self._prediction_result(ACTION_LISTEN_NAME, tracker, domain) else: logger.debug("There is no active form") return self._prediction(result)
async def test_policy_events_are_applied_to_tracker( default_processor: MessageProcessor, monkeypatch: MonkeyPatch): expected_action = ACTION_LISTEN_NAME policy_events = [LoopInterrupted(True)] conversation_id = "test_policy_events_are_applied_to_tracker" user_message = "/greet" expected_events = [ ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME), UserUttered(user_message, intent={"name": "greet"}), *policy_events, ] class ConstantEnsemble(PolicyEnsemble): def probabilities_using_best_policy( self, tracker: DialogueStateTracker, domain: Domain, interpreter: NaturalLanguageInterpreter, **kwargs: Any, ) -> PolicyPrediction: prediction = PolicyPrediction.for_action_name( default_processor.domain, expected_action, "some policy") prediction.events = policy_events return prediction monkeypatch.setattr(default_processor, "policy_ensemble", ConstantEnsemble([])) action_received_events = False async def mocked_run( self, output_channel: "OutputChannel", nlg: "NaturalLanguageGenerator", tracker: "DialogueStateTracker", domain: "Domain", ) -> List[Event]: # The action already has access to the policy events nonlocal action_received_events action_received_events = list(tracker.events) == expected_events return [] monkeypatch.setattr(ActionListen, ActionListen.run.__name__, mocked_run) await default_processor.handle_message( UserMessage(user_message, sender_id=conversation_id)) assert action_received_events tracker = default_processor.get_tracker(conversation_id) # The action was logged on the tracker as well expected_events.append(ActionExecuted(ACTION_LISTEN_NAME)) for event, expected in zip(tracker.events, expected_events): assert event == expected
def _prediction_with_unhappy_path( self, probabilities: List[float], returning_from_unhappy_path: bool, is_end_to_end_prediction: bool, ) -> "PolicyPrediction": return self._prediction( probabilities, events=[LoopInterrupted(True)] if returning_from_unhappy_path else [], is_end_to_end_prediction=is_end_to_end_prediction, )
def predict_action_probabilities( self, tracker: DialogueStateTracker, domain: Domain, interpreter: NaturalLanguageInterpreter, **kwargs: Any, ) -> PolicyPrediction: """Predicts the next action (see parent class for more information).""" result = self._default_predictions(domain) # Rasa Open Source default actions overrule anything. If users want to achieve # the same, they need to write a rule or make sure that their loop rejects # accordingly. default_action_name = self._find_action_from_default_actions(tracker) if default_action_name: self._prediction_source = DEFAULT_RULES return self._prediction( self._prediction_result(default_action_name, tracker, domain)) # A loop has priority over any other rule. # The rules or any other prediction will be applied only if a loop was rejected. # If we are in a loop, and the loop didn't run previously or rejected, we can # simply force predict the loop. loop_happy_path_action_name = self._find_action_from_loop_happy_path( tracker) if loop_happy_path_action_name: self._prediction_source = LOOP_RULES return self._prediction( self._prediction_result(loop_happy_path_action_name, tracker, domain)) ( rules_action_name, source, returning_from_unhappy_path, ) = self._find_action_from_rules(tracker, domain) # we want to remember the source even if rules didn't predict any action self._prediction_source = source policy_events = [LoopInterrupted(True) ] if returning_from_unhappy_path else [] if rules_action_name: result = self._prediction_result(rules_action_name, tracker, domain) return self._prediction(result, events=policy_events)
async def test_policy_events_not_applied_if_rejected( default_processor: MessageProcessor, monkeypatch: MonkeyPatch, reject_fn: Callable[[], List[Event]], ): expected_action = ACTION_LISTEN_NAME expected_events = [LoopInterrupted(True)] conversation_id = "test_policy_events_are_applied_to_tracker" user_message = "/greet" class ConstantEnsemble(PolicyEnsemble): def probabilities_using_best_policy( self, tracker: DialogueStateTracker, domain: Domain, interpreter: NaturalLanguageInterpreter, **kwargs: Any, ) -> PolicyPrediction: prediction = PolicyPrediction.for_action_name( default_processor.domain, expected_action, "some policy" ) prediction.events = expected_events return prediction monkeypatch.setattr(default_processor, "policy_ensemble", ConstantEnsemble([])) async def mocked_run(*args: Any, **kwargs: Any) -> List[Event]: return reject_fn() monkeypatch.setattr(ActionListen, ActionListen.run.__name__, mocked_run) await default_processor.handle_message( UserMessage(user_message, sender_id=conversation_id) ) tracker = default_processor.get_tracker(conversation_id) expected_events = [ ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME), UserUttered(user_message, intent={"name": "greet"}), ActionExecutionRejected(ACTION_LISTEN_NAME), ] for event, expected in zip(tracker.events, expected_events): assert event == expected
def _rule_prediction( self, probabilities: List[float], prediction_source: Text, returning_from_unhappy_path: bool = False, is_end_to_end_prediction: bool = False, is_no_user_prediction: bool = False, ) -> PolicyPrediction: return PolicyPrediction( probabilities, self.__class__.__name__, self.priority, events=[LoopInterrupted(True)] if returning_from_unhappy_path else [], is_end_to_end_prediction=is_end_to_end_prediction, is_no_user_prediction=is_no_user_prediction, hide_rule_turn=(True if prediction_source in self.lookup.get( RULES_NOT_IN_STORIES, []) else False), )
async def test_persist_legacy_form_story(): domain = Domain.load("data/test_domains/form.yml") tracker = DialogueStateTracker("", domain.slots) story = ("* greet\n" " - utter_greet\n" "* start_form\n" " - some_form\n" ' - form{"name": "some_form"}\n' "* default\n" " - utter_default\n" " - some_form\n" "* stop\n" " - utter_ask_continue\n" "* affirm\n" " - some_form\n" "* stop\n" " - utter_ask_continue\n" "* inform\n" " - some_form\n" ' - form{"name": null}\n' "* goodbye\n" " - utter_goodbye\n") # simulate talking to the form events = [ UserUttered(intent={"name": "greet"}), ActionExecuted("utter_greet"), ActionExecuted("action_listen"), # start the form UserUttered(intent={"name": "start_form"}), ActionExecuted("some_form"), ActiveLoop("some_form"), ActionExecuted("action_listen"), # out of form input UserUttered(intent={"name": "default"}), ActionExecutionRejected("some_form"), ActionExecuted("utter_default"), ActionExecuted("some_form"), ActionExecuted("action_listen"), # out of form input UserUttered(intent={"name": "stop"}), ActionExecutionRejected("some_form"), ActionExecuted("utter_ask_continue"), ActionExecuted("action_listen"), # out of form input but continue with the form UserUttered(intent={"name": "affirm"}), LoopInterrupted(True), ActionExecuted("some_form"), ActionExecuted("action_listen"), # out of form input UserUttered(intent={"name": "stop"}), ActionExecutionRejected("some_form"), ActionExecuted("utter_ask_continue"), ActionExecuted("action_listen"), # form input UserUttered(intent={"name": "inform"}), LoopInterrupted(False), ActionExecuted("some_form"), ActionExecuted("action_listen"), ActiveLoop(None), UserUttered(intent={"name": "goodbye"}), ActionExecuted("utter_goodbye"), ActionExecuted("action_listen"), ] [tracker.update(e) for e in events] story = story.replace(f"- {LegacyForm.type_name}", f"- {ActiveLoop.type_name}") assert story in tracker.export_stories()
async def test_form_unhappy_path_no_validation_from_story(): form_name = "some_form" handle_rejection_action_name = "utter_handle_rejection" domain = Domain.from_yaml(f""" intents: - {GREET_INTENT_NAME} actions: - {UTTER_GREET_ACTION} - {handle_rejection_action_name} - some-action slots: {REQUESTED_SLOT}: type: unfeaturized forms: - {form_name} """) unhappy_story = TrackerWithCachedStates.from_events( "bla", domain=domain, slots=domain.slots, evts=[ # We are in an active form ActionExecuted(form_name), ActiveLoop(form_name), ActionExecuted(ACTION_LISTEN_NAME), # When a user says "hi", and the form is unhappy, # we want to run a specific action UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(handle_rejection_action_name), ActionExecuted(ACTION_LISTEN_NAME), # Next user utterance is an answer to the previous question # and shouldn't be validated by the form UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(form_name), ActionExecuted(ACTION_LISTEN_NAME), ], ) policy = RulePolicy() policy.train([unhappy_story], domain, RegexInterpreter()) # Check that RulePolicy predicts no validation to handle unhappy path conversation_events = [ ActionExecuted(form_name), ActiveLoop(form_name), SlotSet(REQUESTED_SLOT, "some value"), ActionExecuted(ACTION_LISTEN_NAME), UserUttered("haha", {"name": GREET_INTENT_NAME}), ActionExecutionRejected(form_name), ActionExecuted(handle_rejection_action_name), ActionExecuted(ACTION_LISTEN_NAME), UserUttered("haha", {"name": GREET_INTENT_NAME}), ] tracker = DialogueStateTracker.from_events("casd", evts=conversation_events, slots=domain.slots) action_probabilities = policy.predict_action_probabilities( tracker, domain, RegexInterpreter()) # there is no rule for next action assert max(action_probabilities) == policy._core_fallback_threshold # check that RulePolicy entered unhappy path based on the training story assert tracker.events[-1] == LoopInterrupted(True)
async def test_form_unhappy_path_no_validation_from_rule(): form_name = "some_form" handle_rejection_action_name = "utter_handle_rejection" domain = Domain.from_yaml(f""" intents: - {GREET_INTENT_NAME} actions: - {UTTER_GREET_ACTION} - {handle_rejection_action_name} - some-action slots: {REQUESTED_SLOT}: type: unfeaturized forms: - {form_name} """) unhappy_rule = TrackerWithCachedStates.from_events( "bla", domain=domain, slots=domain.slots, evts=[ # We are in an active form ActiveLoop(form_name), SlotSet(REQUESTED_SLOT, "bla"), ActionExecuted(RULE_SNIPPET_ACTION_NAME), ActionExecuted(ACTION_LISTEN_NAME), # When a user says "hi", and the form is unhappy, # we want to run a specific action UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(handle_rejection_action_name), # Next user utterance is an answer to the previous question # and shouldn't be validated by the form ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(form_name), ActionExecuted(ACTION_LISTEN_NAME), ], is_rule_tracker=True, ) # unhappy rule is multi user turn rule, therefore remove restriction for policy policy = RulePolicy(restrict_rules=False) # RulePolicy should memorize that unhappy_rule overrides GREET_RULE policy.train([GREET_RULE, unhappy_rule], domain, RegexInterpreter()) # Check that RulePolicy predicts action to handle unhappy path conversation_events = [ ActionExecuted(form_name), ActiveLoop(form_name), SlotSet(REQUESTED_SLOT, "some value"), ActionExecuted(ACTION_LISTEN_NAME), UserUttered("haha", {"name": GREET_INTENT_NAME}), ActionExecutionRejected(form_name), ] action_probabilities = policy.predict_action_probabilities( DialogueStateTracker.from_events("casd", evts=conversation_events, slots=domain.slots), domain, RegexInterpreter(), ) assert_predicted_action(action_probabilities, domain, handle_rejection_action_name) # Check that RulePolicy predicts action_listen conversation_events.append(ActionExecuted(handle_rejection_action_name)) action_probabilities = policy.predict_action_probabilities( DialogueStateTracker.from_events("casd", evts=conversation_events, slots=domain.slots), domain, RegexInterpreter(), ) assert_predicted_action(action_probabilities, domain, ACTION_LISTEN_NAME) # Check that RulePolicy triggers form again after handling unhappy path conversation_events.append(ActionExecuted(ACTION_LISTEN_NAME)) tracker = DialogueStateTracker.from_events("casd", evts=conversation_events, slots=domain.slots) action_probabilities = policy.predict_action_probabilities( tracker, domain, RegexInterpreter()) assert_predicted_action(action_probabilities, domain, form_name) # check that RulePolicy entered unhappy path based on the training story assert tracker.events[-1] == LoopInterrupted(True)
"entity": "count", "value": 1 }, ], timestamp=None, ), DefinePrevUserUtteredFeaturization(use_text_for_featurization=False, timestamp=None, metadata=None), ReminderCancelled(timestamp=1621590172.3872123), ReminderScheduled(timestamp=None, trigger_date_time=datetime.now(), intent="greet"), ActionExecutionRejected(action_name="my_action"), LegacyFormValidation(validate=True, timestamp=None), LoopInterrupted(timestamp=None, is_interrupted=False), ActiveLoop(name="loop"), LegacyForm(name="my_form"), AllSlotsReset(), SlotSet(key="my_slot", value={}), SlotSet(key="my slot", value=[]), SlotSet(key="test", value=1), SlotSet(key="test", value="text"), ConversationResumed(), ConversationPaused(), FollowupAction(name="test"), StoryExported(), Restarted(), ActionReverted(), UserUtteranceReverted(), BotUttered(text="Test bot utterance"),
def _find_action_from_rules(self, tracker: DialogueStateTracker, domain: Domain) -> Optional[Text]: tracker_as_states = self.featurizer.prediction_states([tracker], domain) states = tracker_as_states[0] logger.debug(f"Current tracker state: {states}") rule_keys = self._get_possible_keys(self.lookup[RULES], states) predicted_action_name = None best_rule_key = "" if rule_keys: # if there are several rules, # it should mean that some rule is a subset of another rule # therefore we pick a rule of maximum length best_rule_key = max(rule_keys, key=len) predicted_action_name = self.lookup[RULES].get(best_rule_key) active_loop_name = tracker.active_loop_name if active_loop_name: # find rules for unhappy path of the loop loop_unhappy_keys = self._get_possible_keys( self.lookup[RULES_FOR_LOOP_UNHAPPY_PATH], states) # there could be several unhappy path conditions unhappy_path_conditions = [ self.lookup[RULES_FOR_LOOP_UNHAPPY_PATH].get(key) for key in loop_unhappy_keys ] # Check if a rule that predicted action_listen # was applied inside the loop. # Rules might not explicitly switch back to the loop. # Hence, we have to take care of that. predicted_listen_from_general_rule = ( predicted_action_name == ACTION_LISTEN_NAME and not get_active_loop_name( self._rule_key_to_state(best_rule_key)[-1])) if predicted_listen_from_general_rule: if DO_NOT_PREDICT_LOOP_ACTION not in unhappy_path_conditions: # negative rules don't contain a key that corresponds to # the fact that active_loop shouldn't be predicted logger.debug( f"Predicted loop '{active_loop_name}' by overwriting " f"'{ACTION_LISTEN_NAME}' predicted by general rule.") return active_loop_name # do not predict anything predicted_action_name = None if DO_NOT_VALIDATE_LOOP in unhappy_path_conditions: logger.debug("Added `FormValidation(False)` event.") tracker.update(LoopInterrupted(True)) if predicted_action_name is not None: logger.debug( f"There is a rule for the next action '{predicted_action_name}'." ) else: logger.debug("There is no applicable rule.") return predicted_action_name