def _get_next_action_probabilities( self, tracker: DialogueStateTracker) -> PolicyPrediction: """Collect predictions from ensemble and return action and predictions.""" followup_action = tracker.followup_action if followup_action: tracker.clear_followup_action() if followup_action in self.domain.action_names: return PolicyPrediction.for_action_name( self.domain, followup_action, FOLLOWUP_ACTION) logger.error( f"Trying to run unknown follow-up action '{followup_action}'. " "Instead of running that, Rasa Open Source will ignore the action " "and predict the next action.") prediction = self.policy_ensemble.probabilities_using_best_policy( tracker, self.domain, self.interpreter) if isinstance(prediction, PolicyPrediction): return prediction rasa.shared.utils.io.raise_deprecation_warning( f"Returning a tuple of probabilities and policy name for " f"`{PolicyEnsemble.probabilities_using_best_policy.__name__}` is " f"deprecated and will be removed in Rasa Open Source 3.0.0. Please return " f"a `{PolicyPrediction.__name__}` object instead.") probabilities, policy_name = prediction return PolicyPrediction(probabilities, policy_name)
async def test_action_unlikely_intent_metadata( default_processor: MessageProcessor): tracker = DialogueStateTracker.from_events( "some-sender", evts=[ ActionExecuted(ACTION_LISTEN_NAME), ], ) domain = Domain.empty() metadata = {"key1": 1, "key2": "2"} await default_processor._run_action( ActionUnlikelyIntent(), tracker, CollectingOutputChannel(), TemplatedNaturalLanguageGenerator(domain.responses), PolicyPrediction([], "some policy", action_metadata=metadata), ) applied_events = tracker.applied_events() assert applied_events == [ ActionExecuted(ACTION_LISTEN_NAME), ActionExecuted(ACTION_UNLIKELY_INTENT_NAME, metadata=metadata), ] assert applied_events[1].metadata == metadata
def probabilities_using_best_policy( self, tracker: DialogueStateTracker, domain: Domain, interpreter: RegexInterpreter, **kwargs: Any, ) -> PolicyPrediction: latest_event = tracker.events[-1] if (isinstance(latest_event, UserUttered) and latest_event.parse_data["intent"]["name"] in intent_names): intent_name = latest_event.parse_data["intent"]["name"] # Here we return `action_unlikely_intent` if the name of the # latest intent is present in `intent_names`. Accompanying # metadata is fetched from `metadata_for_intent` if it is present. # We need to do it because every time the tests are run, # training will result in different model weights which might # result in different predictions of `action_unlikely_intent`. # Because we're not testing `UnexpecTEDIntentPolicy`, # here we simply trigger it by # predicting `action_unlikely_intent` in a specified moment # to make the tests deterministic. return PolicyPrediction.for_action_name( domain, ACTION_UNLIKELY_INTENT_NAME, action_metadata=metadata_for_intent.get(intent_name) if metadata_for_intent else None, ) return _original(self, tracker, domain, interpreter, **kwargs)
async def _update_tracker_session( self, tracker: DialogueStateTracker, output_channel: OutputChannel, metadata: Optional[Dict] = None, ) -> None: """Check the current session in `tracker` and update it if expired. An 'action_session_start' is run if the latest tracker session has expired, or if the tracker does not yet contain any events (only those after the last restart are considered). Args: metadata: Data sent from client associated with the incoming user message. tracker: Tracker to inspect. output_channel: Output channel for potential utterances in a custom `ActionSessionStart`. """ if not tracker.applied_events() or self._has_session_expired(tracker): logger.debug( f"Starting a new session for conversation ID '{tracker.sender_id}'." ) await self._run_action( action=self._get_action(ACTION_SESSION_START_NAME), tracker=tracker, output_channel=output_channel, nlg=self.nlg, metadata=metadata, prediction=PolicyPrediction.for_action_name( self.domain, ACTION_SESSION_START_NAME), )
def test_default_predict_excludes_rejected_action( default_ensemble: DefaultPolicyPredictionEnsemble, ): domain = Domain.load("data/test_domains/default.yml") excluded_action = domain.action_names_or_texts[0] tracker = DialogueStateTracker.from_events( sender_id="arbitrary", evts=[ UserUttered("hi"), ActionExecuted(excluded_action), ActionExecutionRejected(excluded_action), # not "Rejection" ], ) num_actions = len(domain.action_names_or_texts) predictions = [ PolicyPrediction(policy_name=str(idx), probabilities=[1.0] * num_actions, policy_priority=idx) for idx in range(2) ] index_of_excluded_action = domain.index_for_action(excluded_action) prediction = default_ensemble.combine_predictions_from_kwargs( domain=domain, tracker=tracker, **{prediction.policy_name: prediction for prediction in predictions}, ) assert prediction.probabilities[index_of_excluded_action] == 0.0
def _fallback_after_listen( self, domain: Domain, prediction: PolicyPrediction) -> PolicyPrediction: """Triggers fallback if `action_listen` is predicted after a user utterance. This is done on the condition that: - a fallback policy is present, - we received a user message and the predicted action is `action_listen` by a policy other than the `MemoizationPolicy` or one of its subclasses. Args: domain: the :class:`rasa.shared.core.domain.Domain` prediction: The winning prediction. Returns: The prediction for the next action. """ fallback_idx_policy = [(i, p) for i, p in enumerate(self.policies) if isinstance(p, FallbackPolicy)] if not fallback_idx_policy: return prediction fallback_idx, fallback_policy = fallback_idx_policy[0] logger.debug( f"Action '{ACTION_LISTEN_NAME}' was predicted after " f"a user message using {prediction.policy_name}. Predicting " f"fallback action: {fallback_policy.fallback_action_name}") return PolicyPrediction( fallback_policy.fallback_scores(domain), f"policy_{fallback_idx}_{type(fallback_policy).__name__}", FALLBACK_POLICY_PRIORITY, )
def _predict_next_with_tracker( self, tracker: DialogueStateTracker) -> PolicyPrediction: """Collect predictions from ensemble and return action and predictions.""" followup_action = tracker.followup_action if followup_action: tracker.clear_followup_action() if followup_action in self.domain.action_names_or_texts: prediction = PolicyPrediction.for_action_name( self.domain, followup_action, FOLLOWUP_ACTION) return prediction logger.error( f"Trying to run unknown follow-up action '{followup_action}'. " "Instead of running that, Rasa Open Source will ignore the action " "and predict the next action.") target = self.model_metadata.core_target if not target: raise ValueError( "Cannot predict next action if there is no core target.") results = self.graph_runner.run( inputs={PLACEHOLDER_TRACKER: tracker}, targets=[target], ) policy_prediction = results[target] return policy_prediction
def test_predict_next_action_with_deprecated_ensemble( default_processor: MessageProcessor, monkeypatch: MonkeyPatch ): expected_confidence = 2.0 expected_action = "utter_greet" expected_probabilities = rasa.core.policies.policy.confidence_scores_for( expected_action, expected_confidence, default_processor.domain ) expected_policy_name = "deprecated ensemble" class DeprecatedEnsemble(PolicyEnsemble): def probabilities_using_best_policy( self, tracker: DialogueStateTracker, domain: Domain, interpreter: NaturalLanguageInterpreter, **kwargs: Any, ) -> Tuple[List[float], Optional[Text]]: return expected_probabilities, expected_policy_name monkeypatch.setattr(default_processor, "policy_ensemble", DeprecatedEnsemble([])) tracker = DialogueStateTracker.from_events( "some sender", [ActionExecuted(ACTION_LISTEN_NAME)] ) with pytest.warns(FutureWarning): action, prediction = default_processor.predict_next_action(tracker) assert action.name() == expected_action assert prediction == PolicyPrediction(expected_probabilities, expected_policy_name)
def probabilities_using_best_policy( self, tracker: DialogueStateTracker, domain: Domain, interpreter: NaturalLanguageInterpreter, **kwargs: Any, ) -> PolicyPrediction: if self.number_of_calls == 0: prediction = PolicyPrediction.for_action_name( domain, end_to_end_action, "some policy" ) prediction.is_end_to_end_prediction = True self.number_of_calls += 1 return prediction else: return PolicyPrediction.for_action_name(domain, ACTION_LISTEN_NAME)
def predict_action_probabilities( tracker: DialogueStateTracker, domain: Domain, interpreter: NaturalLanguageInterpreter, **kwargs, ) -> PolicyPrediction: assert interpreter == test_interpreter return PolicyPrediction([1, 0], "some-policy", policy_priority=1)
def probabilities_using_best_policy( self, tracker: DialogueStateTracker, domain: Domain, interpreter: NaturalLanguageInterpreter, **kwargs: Any, ) -> PolicyPrediction: prediction = PolicyPrediction.for_action_name( default_processor.domain, expected_action, "some policy") prediction.events = expected_events return prediction
async def test_2nd_affirm_successful(default_processor: MessageProcessor): tracker = DialogueStateTracker.from_events( "some-sender", evts=[ ActionExecuted(ACTION_LISTEN_NAME), UserUttered("my name is John", { "name": "say_name", "confidence": 1.0 }), SlotSet("some_slot", "example_value"), # User sends message with low NLU confidence *_message_requiring_fallback(), ActiveLoop(ACTION_TWO_STAGE_FALLBACK_NAME), # Action asks user to affirm *_two_stage_clarification_request(), ActionExecuted(ACTION_LISTEN_NAME), # User denies suggested intents UserUttered("hi", {"name": USER_INTENT_OUT_OF_SCOPE}), # Action asks user to rephrase *_two_stage_clarification_request(), # User rephrased with low confidence *_message_requiring_fallback(), *_two_stage_clarification_request(), # Actions asks user to affirm for the last time ActionExecuted(ACTION_LISTEN_NAME), # User affirms successfully UserUttered("hi", {"name": "greet"}), ], ) domain = Domain.empty() action = TwoStageFallbackAction() await default_processor._run_action( action, tracker, CollectingOutputChannel(), TemplatedNaturalLanguageGenerator(domain.responses), PolicyPrediction([], "some policy"), ) applied_events = tracker.applied_events() assert applied_events == [ ActionExecuted(ACTION_LISTEN_NAME), UserUttered("my name is John", { "name": "say_name", "confidence": 1.0 }), SlotSet("some_slot", "example_value"), ActionExecuted(ACTION_LISTEN_NAME), UserUttered("hi", {"name": "greet"}), ]
def predict_action_probabilities( self, tracker, domain, interpreter, **kwargs, ) -> PolicyPrediction: latest_event = tracker.events[-1] if (isinstance(latest_event, UserUttered) and latest_event.parse_data["intent"]["name"] == intent_name): return PolicyPrediction.for_action_name( domain, ACTION_UNLIKELY_INTENT_NAME) return _original(self, tracker, domain, interpreter, **kwargs)
async def execute_action( self, sender_id: Text, action: Text, output_channel: OutputChannel, policy: Optional[Text], confidence: Optional[float], ) -> Optional[DialogueStateTracker]: """Executes an action.""" prediction = PolicyPrediction.for_action_name(self.domain, action, policy, confidence or 0.0) return await self.processor.execute_action( # type: ignore[union-attr] sender_id, action, output_channel, self.nlg, prediction)
async def execute_action( self, sender_id: Text, action: Text, output_channel: OutputChannel, policy: Optional[Text], confidence: Optional[float], ) -> Optional[DialogueStateTracker]: """Handle a single message.""" processor = self.create_processor() prediction = PolicyPrediction.for_action_name(self.domain, action, policy, confidence or 0.0) return await processor.execute_action(sender_id, action, output_channel, self.nlg, prediction)
async def _update_tracker_session( self, tracker: DialogueStateTracker, output_channel: OutputChannel, metadata: Optional[Dict] = None, ) -> None: """Check the current session in `tracker` and update it if expired. An 'action_session_start' is run if the latest tracker session has expired, or if the tracker does not yet contain any events (only those after the last restart are considered). Args: metadata: Data sent from client associated with the incoming user message. tracker: Tracker to inspect. output_channel: Output channel for potential utterances in a custom `ActionSessionStart`. """ if not tracker.applied_events() or self._has_session_expired(tracker): logger.debug( f"Starting a new session for conversation ID '{tracker.sender_id}'." ) action_session_start = self._get_action(ACTION_SESSION_START_NAME) # TODO: Remove in 3.0.0 and describe migration to `session_start_metadata` # slot in migration guide. if isinstance( action_session_start, rasa.core.actions.action.ActionSessionStart ): # Here we set optional metadata to the ActionSessionStart, which will # then be passed to the SessionStart event. action_session_start.metadata = metadata if metadata: tracker.update( SlotSet(SESSION_START_METADATA_SLOT, metadata), self.domain ) await self._run_action( action=action_session_start, tracker=tracker, output_channel=output_channel, nlg=self.nlg, prediction=PolicyPrediction.for_action_name( self.domain, ACTION_SESSION_START_NAME ), )
def _get_next_action_probabilities( self, tracker: DialogueStateTracker) -> PolicyPrediction: """Collect predictions from ensemble and return action and predictions.""" followup_action = tracker.followup_action if followup_action: tracker.clear_followup_action() if followup_action in self.domain.action_names_or_texts: return PolicyPrediction.for_action_name( self.domain, followup_action, FOLLOWUP_ACTION) logger.error( f"Trying to run unknown follow-up action '{followup_action}'. " "Instead of running that, Rasa Open Source will ignore the action " "and predict the next action.") return self.policy_ensemble.probabilities_using_best_policy( tracker, self.domain, self.interpreter)
def predict_action_probabilities( self, tracker: DialogueStateTracker, domain: Domain, interpreter: NaturalLanguageInterpreter, **kwargs: Any, ) -> PolicyPrediction: result = [0.0] * domain.num_actions result[self.predict_index] = self.confidence return PolicyPrediction( result, self.__class__.__name__, policy_priority=self.priority, is_end_to_end_prediction=self.is_end_to_end_prediction, events=self.events, optional_events=self.optional_events, )
def test_default_predict_ignores_other_kwargs( default_ensemble: DefaultPolicyPredictionEnsemble, ): domain = Domain.load("data/test_domains/default.yml") tracker = DialogueStateTracker.from_events(sender_id="arbitrary", evts=[]) prediction = PolicyPrediction(policy_name="arbitrary", probabilities=[1.0], policy_priority=1) final_prediction = default_ensemble.combine_predictions_from_kwargs( domain=domain, tracker=tracker, **{ "policy-graph-component-1": prediction, "another-random-component": domain, "yet-another-component": tracker, }, ) assert final_prediction.policy_name == prediction.policy_name
def _rule_prediction( self, probabilities: List[float], prediction_source: Text, returning_from_unhappy_path: bool = False, is_end_to_end_prediction: bool = False, is_no_user_prediction: bool = False, ) -> PolicyPrediction: return PolicyPrediction( probabilities, self.__class__.__name__, self.priority, events=[LoopInterrupted(True)] if returning_from_unhappy_path else [], is_end_to_end_prediction=is_end_to_end_prediction, is_no_user_prediction=is_no_user_prediction, hide_rule_turn=(True if prediction_source in self.lookup.get( RULES_NOT_IN_STORIES, []) else False), )
def _get_prediction( policy: Policy, tracker: DialogueStateTracker, domain: Domain, interpreter: NaturalLanguageInterpreter, ) -> PolicyPrediction: number_of_arguments_in_rasa_1_0 = 2 arguments = rasa.shared.utils.common.arguments_of( policy.predict_action_probabilities ) if ( len(arguments) > number_of_arguments_in_rasa_1_0 and "interpreter" in arguments ): prediction = policy.predict_action_probabilities( tracker, domain, interpreter ) else: rasa.shared.utils.io.raise_warning( "The function `predict_action_probabilities` of " "the `Policy` interface was changed to support " "additional parameters. Please make sure to " "adapt your custom `Policy` implementation.", category=DeprecationWarning, ) prediction = policy.predict_action_probabilities( tracker, domain, RegexInterpreter() ) if isinstance(prediction, list): rasa.shared.utils.io.raise_deprecation_warning( f"The function `predict_action_probabilities` of " f"the `{Policy.__name__}` interface was changed to return " f"a `{PolicyPrediction.__name__}` object. Please make sure to " f"adapt your custom `{Policy.__name__}` implementation. Support for " f"returning a list of floats will be removed in Rasa Open Source 3.0.0" ) prediction = PolicyPrediction( prediction, policy.__class__.__name__, policy_priority=policy.priority ) return prediction
def _prediction( self, probabilities: List[float], events: Optional[List[Event]] = None, optional_events: Optional[List[Event]] = None, is_end_to_end_prediction: bool = False, is_no_user_prediction: bool = False, diagnostic_data: Optional[Dict[Text, Any]] = None, action_metadata: Optional[Dict[Text, Any]] = None, ) -> "PolicyPrediction": from rasa.core.policies.policy import PolicyPrediction return PolicyPrediction( probabilities, self.__class__.__name__, self.priority, events, optional_events, is_end_to_end_prediction, is_no_user_prediction, diagnostic_data, action_metadata=action_metadata, )
def _collect_action_executed_predictions( processor: "MessageProcessor", partial_tracker: DialogueStateTracker, event: ActionExecuted, fail_on_prediction_errors: bool, circuit_breaker_tripped: bool, ) -> Tuple[EvaluationStore, PolicyPrediction]: from rasa.core.policies.form_policy import FormPolicy action_executed_eval_store = EvaluationStore() gold = event.action_name or event.action_text if circuit_breaker_tripped: prediction = PolicyPrediction([], policy_name=None) predicted = "circuit breaker tripped" else: action, prediction = processor.predict_next_action(partial_tracker) predicted = action.name() if (prediction.policy_name and predicted != gold and _form_might_have_been_rejected( processor.domain, partial_tracker, predicted)): # Wrong action was predicted, # but it might be Ok if form action is rejected. emulate_loop_rejection(partial_tracker) # try again action, prediction = processor.predict_next_action(partial_tracker) # Even if the prediction is also wrong, we don't have to undo the emulation # of the action rejection as we know that the user explicitly specified # that something else than the form was supposed to run. predicted = action.name() action_executed_eval_store.add_to_store(action_predictions=[predicted], action_targets=[gold]) if action_executed_eval_store.has_prediction_target_mismatch(): partial_tracker.update( WronglyPredictedAction( gold, predicted, prediction.policy_name, prediction.max_confidence, event.timestamp, )) if fail_on_prediction_errors: story_dump = YAMLStoryWriter().dumps( partial_tracker.as_story().story_steps) error_msg = (f"Model predicted a wrong action. Failed Story: " f"\n\n{story_dump}") if FormPolicy.__name__ in prediction.policy_name: error_msg += ("FormAction is not run during " "evaluation therefore it is impossible to know " "if validation failed or this story is wrong. " "If the story is correct, add it to the " "training stories and retrain.") raise WrongPredictionException(error_msg) else: partial_tracker.update( ActionExecuted( predicted, prediction.policy_name, prediction.max_confidence, event.timestamp, )) return action_executed_eval_store, prediction
domain, InMemoryTrackerStore(domain), InMemoryLockStore(), Mock(), ) # This should not raise processor._get_next_action_probabilities( DialogueStateTracker.from_events("lala", [ActionExecuted(ACTION_LISTEN_NAME)])) @pytest.mark.parametrize( "predict_function", [ lambda tracker, domain, _: PolicyPrediction([1, 0, 2, 3], "some-policy" ), lambda tracker, domain, _=True: PolicyPrediction([1, 0], "some-policy" ), ], ) def test_get_next_action_probabilities_pass_policy_predictions_without_interpreter_arg( predict_function: Callable, ): policy = TEDPolicy() policy.predict_action_probabilities = predict_function ensemble = SimplePolicyEnsemble(policies=[policy]) interpreter = Mock() domain = Domain.empty() processor = MessageProcessor(
def _collect_action_executed_predictions( processor: "MessageProcessor", partial_tracker: DialogueStateTracker, event: ActionExecuted, fail_on_prediction_errors: bool, ) -> Tuple[EvaluationStore, PolicyPrediction, Optional[EntityEvaluationResult]]: action_executed_eval_store = EvaluationStore() expected_action_name = event.action_name expected_action_text = event.action_text expected_action = expected_action_name or expected_action_text policy_entity_result = None prev_action_unlikely_intent = False try: predicted_action, prediction, policy_entity_result = _run_action_prediction( processor, partial_tracker, expected_action) except ActionLimitReached: prediction = PolicyPrediction([], policy_name=None) predicted_action = "circuit breaker tripped" predicted_action_unlikely_intent = predicted_action == ACTION_UNLIKELY_INTENT_NAME if predicted_action_unlikely_intent and predicted_action != expected_action: partial_tracker.update( WronglyPredictedAction( predicted_action, expected_action_text, predicted_action, prediction.policy_name, prediction.max_confidence, event.timestamp, metadata=prediction.action_metadata, )) prev_action_unlikely_intent = True try: predicted_action, prediction, policy_entity_result = _run_action_prediction( processor, partial_tracker, expected_action) except ActionLimitReached: prediction = PolicyPrediction([], policy_name=None) predicted_action = "circuit breaker tripped" action_executed_eval_store.add_to_store( action_predictions=[predicted_action], action_targets=[expected_action]) if action_executed_eval_store.has_prediction_target_mismatch(): partial_tracker.update( WronglyPredictedAction( expected_action_name, expected_action_text, predicted_action, prediction.policy_name, prediction.max_confidence, event.timestamp, metadata=prediction.action_metadata, predicted_action_unlikely_intent=prev_action_unlikely_intent, )) if (fail_on_prediction_errors and predicted_action != ACTION_UNLIKELY_INTENT_NAME and predicted_action != expected_action): story_dump = YAMLStoryWriter().dumps( partial_tracker.as_story().story_steps) error_msg = (f"Model predicted a wrong action. Failed Story: " f"\n\n{story_dump}") raise WrongPredictionException(error_msg) elif prev_action_unlikely_intent: partial_tracker.update( WarningPredictedAction( ACTION_UNLIKELY_INTENT_NAME, predicted_action, prediction.policy_name, prediction.max_confidence, event.timestamp, prediction.action_metadata, )) else: partial_tracker.update( ActionExecuted( predicted_action, prediction.policy_name, prediction.max_confidence, event.timestamp, metadata=prediction.action_metadata, )) return action_executed_eval_store, prediction, policy_entity_result
**{prediction.policy_name: prediction for prediction in predictions}, ) assert prediction.probabilities[index_of_excluded_action] == 0.0 @pytest.mark.parametrize( "predictions_and_expected_winner_idx, last_action_was_action_listen", itertools.product( [ ( # highest probability and highest priority [ PolicyPrediction( policy_name=str(idx), probabilities=[idx] * 3, policy_priority=idx, ) for idx in range(4) ], 3, ), ( # highest probability wins even if priority is low [ PolicyPrediction( policy_name=str(idx), probabilities=[idx] * 3, policy_priority=idx, ) for idx in reversed(range(4)) ], 0,
def _pick_best_policy( self, predictions: Dict[Text, PolicyPrediction]) -> PolicyPrediction: """Picks the best policy prediction based on probabilities and policy priority. Args: predictions: the dictionary containing policy name as keys and predictions as values Returns: The best prediction. """ best_confidence = (-1, -1) best_policy_name = None # form and mapping policies are special: # form should be above fallback # mapping should be below fallback # mapping is above form if it wins over fallback # therefore form predictions are stored separately form_confidence = None form_policy_name = None # different type of predictions have different priorities # No user predictions overrule all other predictions. is_no_user_prediction = any(prediction.is_no_user_prediction for prediction in predictions.values()) # End-to-end predictions overrule all other predictions based on user input. is_end_to_end_prediction = any(prediction.is_end_to_end_prediction for prediction in predictions.values()) policy_events = [] for policy_name, prediction in predictions.items(): policy_events += prediction.events # No user predictions (e.g. happy path loop predictions) # overrule all other predictions. if prediction.is_no_user_prediction != is_no_user_prediction: continue # End-to-end predictions overrule all other predictions based on user input. if (not is_no_user_prediction and prediction.is_end_to_end_prediction != is_end_to_end_prediction): continue confidence = (prediction.max_confidence, prediction.policy_priority) if self._is_form_policy(policy_name): # store form prediction separately form_confidence = confidence form_policy_name = policy_name elif confidence > best_confidence: # pick the best policy best_confidence = confidence best_policy_name = policy_name if form_confidence is not None and self._is_not_mapping_policy( best_policy_name, best_confidence[0]): # if mapping didn't win, check form policy predictions if form_confidence > best_confidence: best_policy_name = form_policy_name best_prediction = predictions[best_policy_name] policy_events += best_prediction.optional_events return PolicyPrediction( best_prediction.probabilities, best_policy_name, best_prediction.policy_priority, policy_events, is_end_to_end_prediction=best_prediction.is_end_to_end_prediction, is_no_user_prediction=best_prediction.is_no_user_prediction, diagnostic_data=best_prediction.diagnostic_data, )
async def test_switch_forms_with_same_slot(default_agent: Agent): """Tests switching of forms, where the first slot is the same in both forms. Tests the fix for issue 7710""" # Define two forms in the domain, with same first slot slot_a = "my_slot_a" form_1 = "my_form_1" utter_ask_form_1 = f"Please provide the value for {slot_a} of form 1" form_2 = "my_form_2" utter_ask_form_2 = f"Please provide the value for {slot_a} of form 2" domain = f""" version: "2.0" nlu: - intent: order_status examples: | - check status of my order - when are my shoes coming in - intent: return examples: | - start a return - I don't want my shoes anymore forms: {form_1}: {slot_a}: - type: from_entity entity: number {form_2}: {slot_a}: - type: from_entity entity: number responses: utter_ask_{form_1}_{slot_a}: - text: {utter_ask_form_1} utter_ask_{form_2}_{slot_a}: - text: {utter_ask_form_2} """ domain = Domain.from_yaml(domain) # Driving it like rasa/core/processor processor = MessageProcessor( default_agent.interpreter, default_agent.policy_ensemble, domain, InMemoryTrackerStore(domain), TemplatedNaturalLanguageGenerator(domain.templates), ) # activate the first form tracker = DialogueStateTracker.from_events( "some-sender", evts=[ ActionExecuted(ACTION_LISTEN_NAME), UserUttered("order status", { "name": "form_1", "confidence": 1.0 }), DefinePrevUserUtteredFeaturization(False), ], ) # rasa/core/processor.predict_next_action prediction = PolicyPrediction([], "some_policy") action_1 = FormAction(form_1, None) await processor._run_action( action_1, tracker, CollectingOutputChannel(), TemplatedNaturalLanguageGenerator(domain.templates), prediction, ) events_expected = [ ActionExecuted(ACTION_LISTEN_NAME), UserUttered("order status", { "name": "form_1", "confidence": 1.0 }), DefinePrevUserUtteredFeaturization(False), ActionExecuted(form_1), ActiveLoop(form_1), SlotSet(REQUESTED_SLOT, slot_a), BotUttered( text=utter_ask_form_1, metadata={"template_name": f"utter_ask_{form_1}_{slot_a}"}, ), ] assert tracker.applied_events() == events_expected next_events = [ ActionExecuted(ACTION_LISTEN_NAME), UserUttered("return my shoes", { "name": "form_2", "confidence": 1.0 }), DefinePrevUserUtteredFeaturization(False), ] tracker.update_with_events( next_events, domain, ) events_expected.extend(next_events) # form_1 is still active, and bot will first validate if the user utterance # provides valid data for the requested slot, which is rejected await processor._run_action( action_1, tracker, CollectingOutputChannel(), TemplatedNaturalLanguageGenerator(domain.templates), prediction, ) events_expected.extend([ActionExecutionRejected(action_name=form_1)]) assert tracker.applied_events() == events_expected # Next, bot predicts form_2 action_2 = FormAction(form_2, None) await processor._run_action( action_2, tracker, CollectingOutputChannel(), TemplatedNaturalLanguageGenerator(domain.templates), prediction, ) events_expected.extend([ ActionExecuted(form_2), ActiveLoop(form_2), SlotSet(REQUESTED_SLOT, slot_a), BotUttered( text=utter_ask_form_2, metadata={"template_name": f"utter_ask_{form_2}_{slot_a}"}, ), ]) assert tracker.applied_events() == events_expected
def _pick_best_policy( predictions: List[PolicyPrediction]) -> PolicyPrediction: """Picks the best policy prediction based on probabilities and policy priority. Args: predictions: a list containing policy predictions Returns: The index of the best prediction """ best_confidence = (-1.0, -1) best_index = -1 # different type of predictions have different priorities # No user predictions overrule all other predictions. is_no_user_prediction = any(prediction.is_no_user_prediction for prediction in predictions) # End-to-end predictions overrule all other predictions based on user input. is_end_to_end_prediction = any(prediction.is_end_to_end_prediction for prediction in predictions) policy_events = [] for idx, prediction in enumerate(predictions): policy_events += prediction.events # No user predictions (e.g. happy path loop predictions) # overrule all other predictions. if prediction.is_no_user_prediction != is_no_user_prediction: continue # End-to-end predictions overrule all other predictions based on user input. if (not is_no_user_prediction and prediction.is_end_to_end_prediction != is_end_to_end_prediction): continue confidence = (prediction.max_confidence, prediction.policy_priority) if confidence > best_confidence: # pick the best policy best_confidence = confidence best_index = idx if best_index < 0: raise InvalidConfigException( "No best prediction found. Please check your model configuration." ) best_prediction = predictions[best_index] policy_events += best_prediction.optional_events return PolicyPrediction( best_prediction.probabilities, best_prediction.policy_name, best_prediction.policy_priority, policy_events, is_end_to_end_prediction=best_prediction.is_end_to_end_prediction, is_no_user_prediction=best_prediction.is_no_user_prediction, diagnostic_data=best_prediction.diagnostic_data, hide_rule_turn=best_prediction.hide_rule_turn, action_metadata=best_prediction.action_metadata, )
InMemoryTrackerStore(domain), InMemoryLockStore(), Mock(), ) # This should not raise processor._get_next_action_probabilities( DialogueStateTracker.from_events("lala", [ActionExecuted(ACTION_LISTEN_NAME)]) ) @pytest.mark.parametrize( "predict_function", [ lambda tracker, domain, something_else: PolicyPrediction( [1, 0, 2, 3], "some-policy" ), lambda tracker, domain, some_bool=True: PolicyPrediction([1, 0], "some-policy"), ], ) def test_get_next_action_probabilities_pass_policy_predictions_without_interpreter_arg( predict_function: Callable, ): policy = TEDPolicy() policy.predict_action_probabilities = predict_function ensemble = SimplePolicyEnsemble(policies=[policy]) interpreter = Mock() domain = Domain.empty()