示例#1
0
    def log_bot_utterances_on_tracker(tracker: DialogueStateTracker,
                                      dispatcher: Dispatcher) -> None:

        if dispatcher.latest_bot_messages:
            for m in dispatcher.latest_bot_messages:
                bot_utterance = BotUttered(text=m.text, data=m.data)
                logger.debug("Bot utterance '{}'".format(bot_utterance))
                tracker.update(bot_utterance)

            dispatcher.latest_bot_messages = []
示例#2
0
def test_session_start(default_domain: Domain):
    tracker = DialogueStateTracker("default", default_domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0

    # add a SessionStarted event
    tracker.update(SessionStarted())

    # tracker has one event
    assert len(tracker.events) == 1
示例#3
0
文件: test.py 项目: cxy115566/rasa-1
def _collect_action_executed_predictions(
    processor: "MessageProcessor",
    partial_tracker: DialogueStateTracker,
    event: ActionExecuted,
    fail_on_prediction_errors: bool,
    circuit_breaker_tripped: bool,
) -> Tuple[EvaluationStore, Optional[Text], Optional[float]]:
    from rasa.core.policies.form_policy import FormPolicy

    action_executed_eval_store = EvaluationStore()

    gold = event.action_name

    if circuit_breaker_tripped:
        predicted = "circuit breaker tripped"
        policy = None
        confidence = None
    else:
        action, policy, confidence = processor.predict_next_action(
            partial_tracker)
        predicted = action.name()

        if policy and predicted != gold and FormPolicy.__name__ in policy:
            # FormPolicy predicted wrong action
            # but it might be Ok if form action is rejected
            _emulate_form_rejection(processor, partial_tracker)
            # try again
            action, policy, confidence = processor.predict_next_action(
                partial_tracker)
            predicted = action.name()

    action_executed_eval_store.add_to_store(action_predictions=predicted,
                                            action_targets=gold)

    if action_executed_eval_store.has_prediction_target_mismatch():
        partial_tracker.update(
            WronglyPredictedAction(gold, predicted, event.policy,
                                   event.confidence, event.timestamp))
        if fail_on_prediction_errors:
            error_msg = ("Model predicted a wrong action. Failed Story: "
                         "\n\n{}".format(partial_tracker.export_stories()))
            if FormPolicy.__name__ in policy:
                error_msg += ("FormAction is not run during "
                              "evaluation therefore it is impossible to know "
                              "if validation failed or this story is wrong. "
                              "If the story is correct, add it to the "
                              "training stories and retrain.")
            raise ValueError(error_msg)
    else:
        partial_tracker.update(event)

    return action_executed_eval_store, policy, confidence
示例#4
0
def test_traveling_back_in_time(default_domain):
    tracker = DialogueStateTracker("default", default_domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0

    intent = {"name": "greet", "confidence": 1.0}
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))
    tracker.update(UserUttered("/greet", intent, []))

    import time

    time.sleep(1)
    time_for_timemachine = time.time()
    time.sleep(1)

    tracker.update(ActionExecuted("my_action"))
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))

    # Expecting count of 4:
    #   +3 executed actions
    #   +1 final state
    assert tracker.latest_action_name == ACTION_LISTEN_NAME
    assert len(tracker.events) == 4
    assert len(list(tracker.generate_all_prior_trackers())) == 4

    tracker = tracker.travel_back_in_time(time_for_timemachine)

    # Expecting count of 2:
    #   +1 executed actions
    #   +1 final state
    assert tracker.latest_action_name == ACTION_LISTEN_NAME
    assert len(tracker.events) == 2
    assert len(list(tracker.generate_all_prior_trackers())) == 2
示例#5
0
    def prepare(self, text):
        tracker = DialogueStateTracker("default", self.domain.slots)
        parse_data = self.interpreter.parse(text)
        # print(parse_data)
        tracker.update(
            UserUttered(text, parse_data["intent"], parse_data["entities"],
                        parse_data))
        # store all entities as slots
        for e in self.domain.slots_for_entities(parse_data["entities"]):
            tracker.update(e)

        print("Logged UserUtterance - "
              "tracker now has {} events".format(len(tracker.events)))
        # print(tracker.latest_message)
        return tracker
示例#6
0
def test_tracker_entity_retrieval(default_domain):
    tracker = DialogueStateTracker("default", default_domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0
    assert list(tracker.get_latest_entity_values("entity_name")) == []

    intent = {"name": "greet", "confidence": 1.0}
    tracker.update(UserUttered("/greet", intent, [{
        "start": 1,
        "end": 5,
        "value": "greet",
        "entity": "entity_name",
        "extractor": "manual"
    }]))
    assert list(tracker.get_latest_entity_values("entity_name")) == ["greet"]
    assert list(tracker.get_latest_entity_values("unknown")) == []
示例#7
0
async def cancel_reminder_and_check(
    tracker: DialogueStateTracker,
    default_processor: MessageProcessor,
    reminder_canceled_event: ReminderCancelled,
    num_jobs_before: int,
    num_jobs_after: int,
) -> None:
    # cancel the sixth reminder
    tracker.update(reminder_canceled_event)

    # check that the jobs were added
    assert len((await jobs.scheduler()).get_jobs()) == num_jobs_before

    await default_processor._cancel_reminders(tracker.events, tracker)

    # check that only one job was removed
    assert len((await jobs.scheduler()).get_jobs()) == num_jobs_after
示例#8
0
文件: test.py 项目: cxy115566/rasa-1
def _emulate_form_rejection(processor: "MessageProcessor",
                            partial_tracker: DialogueStateTracker) -> None:
    from rasa.core.policies.form_policy import FormPolicy
    from rasa.core.events import ActionExecutionRejected

    if partial_tracker.active_form.get("name"):
        for p in processor.policy_ensemble.policies:
            if isinstance(p, FormPolicy):
                # emulate form rejection
                partial_tracker.update(
                    ActionExecutionRejected(
                        partial_tracker.active_form["name"]))
                # check if unhappy path is covered by the train stories
                if not p.state_is_unhappy(partial_tracker, processor.domain):
                    # this state is not covered by the stories
                    del partial_tracker.events[-1]
                    partial_tracker.active_form["rejected"] = False
示例#9
0
    async def trigger_external_user_uttered(
        self,
        intent_name: Text,
        entities: Optional[Union[List[Dict[Text, Any]], Dict[Text, Text]]],
        tracker: DialogueStateTracker,
        output_channel: OutputChannel,
    ) -> None:
        """Triggers an external message.

        Triggers an external message (like a user message, but invisible;
        used, e.g., by a reminder or the trigger_intent endpoint).

        Args:
            intent_name: Name of the intent to be triggered.
            entities: Entities to be passed on.
            tracker: The tracker to which the event should be added.
            output_channel: The output channel.
        """
        if isinstance(entities, list):
            entity_list = entities
        elif isinstance(entities, dict):
            # Allow for a short-hand notation {"ent1": "val1", "ent2": "val2", ...}.
            # Useful if properties like 'start', 'end', or 'extractor' are not given,
            # e.g. for external events.
            entity_list = [
                {"entity": ent, "value": val} for ent, val in entities.items()
            ]
        elif not entities:
            entity_list = []
        else:
            rasa.shared.utils.io.raise_warning(
                f"Invalid entity specification: {entities}. Assuming no entities."
            )
            entity_list = []

        # Set the new event's input channel to the latest input channel, so
        # that we don't lose this property.
        input_channel = tracker.get_latest_input_channel()

        tracker.update(
            UserUttered.create_external(intent_name, entity_list, input_channel)
        )
        await self._predict_and_execute_next_action(output_channel, tracker)
        # save tracker state to continue conversation from this state
        self._save_tracker(tracker)
示例#10
0
def _collect_user_uttered_predictions(
    event: UserUttered,
    partial_tracker: DialogueStateTracker,
    fail_on_prediction_errors: bool,
) -> EvaluationStore:
    user_uttered_eval_store = EvaluationStore()

    intent_gold = event.parse_data.get("true_intent")
    predicted_intent = event.parse_data.get("intent", {}).get("name")

    if not predicted_intent:
        predicted_intent = [None]

    user_uttered_eval_store.add_to_store(
        intent_predictions=predicted_intent, intent_targets=intent_gold
    )

    entity_gold = event.parse_data.get("true_entities")
    predicted_entities = event.parse_data.get("entities")

    if entity_gold or predicted_entities:
        user_uttered_eval_store.add_to_store(
            entity_targets=_clean_entity_results(event.text, entity_gold),
            entity_predictions=_clean_entity_results(event.text, predicted_entities),
        )

    if user_uttered_eval_store.has_prediction_target_mismatch():
        partial_tracker.update(
            WronglyClassifiedUserUtterance(event, user_uttered_eval_store)
        )
        if fail_on_prediction_errors:
            raise ValueError(
                "NLU model predicted a wrong intent. Failed Story:"
                " \n\n{}".format(partial_tracker.export_stories())
            )
    else:
        end_to_end_user_utterance = EndToEndUserUtterance(
            event.text, event.intent, event.entities
        )
        partial_tracker.update(end_to_end_user_utterance)

    return user_uttered_eval_store
示例#11
0
文件: test.py 项目: zuiwanting/rasa
def _collect_user_uttered_predictions(
    event: UserUttered,
    predicted: Dict[Text, Any],
    partial_tracker: DialogueStateTracker,
    fail_on_prediction_errors: bool,
) -> EvaluationStore:
    user_uttered_eval_store = EvaluationStore()

    intent_gold = event.intent.get("name")
    predicted_intent = predicted.get(INTENT, {}).get("name")

    user_uttered_eval_store.add_to_store(
        intent_predictions=[predicted_intent], intent_targets=[intent_gold]
    )

    entity_gold = event.entities
    predicted_entities = predicted.get(ENTITIES)

    if entity_gold or predicted_entities:
        user_uttered_eval_store.add_to_store(
            entity_targets=_clean_entity_results(event.text, entity_gold),
            entity_predictions=_clean_entity_results(event.text, predicted_entities),
        )

    if user_uttered_eval_store.has_prediction_target_mismatch():
        partial_tracker.update(
            WronglyClassifiedUserUtterance(event, user_uttered_eval_store)
        )
        if fail_on_prediction_errors:
            raise ValueError(
                "NLU model predicted a wrong intent. Failed Story:"
                " \n\n{}".format(
                    YAMLStoryWriter().dumps(partial_tracker.as_story().story_steps)
                )
            )
    else:
        end_to_end_user_utterance = EndToEndUserUtterance(
            event.text, event.intent, event.entities
        )
        partial_tracker.update(end_to_end_user_utterance)

    return user_uttered_eval_store
示例#12
0
def test_get_latest_entity_values(entities: List[Dict[Text, Any]],
                                  expected_values: List[Text],
                                  default_domain: Domain):
    entity_type = entities[0].get("entity")
    entity_role = entities[0].get("role")
    entity_group = entities[0].get("group")

    tracker = DialogueStateTracker("default", default_domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0
    assert list(tracker.get_latest_entity_values(entity_type)) == []

    intent = {"name": "greet", "confidence": 1.0}
    tracker.update(UserUttered("/greet", intent, entities))

    assert (list(
        tracker.get_latest_entity_values(
            entity_type, entity_role=entity_role,
            entity_group=entity_group)) == expected_values)
    assert list(tracker.get_latest_entity_values("unknown")) == []
示例#13
0
    def test_missing_classes_filled_correctly(
        self, default_domain, trackers, tracker, featurizer, priority
    ):
        # Pretend that a couple of classes are missing and check that
        # those classes are predicted as 0, while the other class
        # probabilities are predicted normally.
        policy = self.create_policy(featurizer=featurizer, priority=priority, cv=None)

        classes = [1, 3]
        new_trackers = []
        for tr in trackers:
            new_tracker = DialogueStateTracker(
                UserMessage.DEFAULT_SENDER_ID, default_domain.slots
            )
            for e in tr.applied_events():
                if isinstance(e, ActionExecuted):
                    new_action = default_domain.action_for_index(
                        np.random.choice(classes), action_endpoint=None
                    ).name()
                    new_tracker.update(ActionExecuted(new_action))
                else:
                    new_tracker.update(e)

            new_trackers.append(new_tracker)

        policy.train(
            new_trackers, domain=default_domain, interpreter=RegexInterpreter()
        )
        predicted_probabilities = policy.predict_action_probabilities(
            tracker, default_domain
        )

        assert len(predicted_probabilities) == default_domain.num_actions
        assert np.allclose(sum(predicted_probabilities), 1.0)
        for i, prob in enumerate(predicted_probabilities):
            if i in classes:
                assert prob >= 0.0
            else:
                assert prob == 0.0
示例#14
0
    def predict_action_probabilities(self, tracker: DialogueStateTracker,
                                     domain: Domain) -> List[float]:
        """Predicts the corresponding form action if there is an active form"""
        result = [0.0] * domain.num_actions

        if tracker.active_form.get('name'):
            logger.debug("There is an active form '{}'"
                         "".format(tracker.active_form['name']))
            if tracker.latest_action_name == ACTION_LISTEN_NAME:
                # predict form action after user utterance

                if tracker.active_form.get('rejected'):
                    # since it is assumed that training stories contain
                    # only unhappy paths, notify the form that
                    # it should not be validated if predicted by other policy
                    tracker_as_states = self.featurizer.prediction_states(
                        [tracker], domain)
                    states = tracker_as_states[0]
                    memorized_form = self.recall(states, tracker, domain)

                    if memorized_form == tracker.active_form['name']:
                        logger.debug("There is a memorized tracker state {}, "
                                     "added `FormValidation(False)` event"
                                     "".format(self._modified_states(states)))
                        tracker.update(FormValidation(False))
                        return result

                idx = domain.index_for_action(tracker.active_form['name'])
                result[idx] = 1.0

            elif tracker.latest_action_name == tracker.active_form.get('name'):
                # predict action_listen after form action
                idx = domain.index_for_action(ACTION_LISTEN_NAME)
                result[idx] = 1.0
        else:
            logger.debug("There is no active form")

        return result
示例#15
0
    async def _handle_message_with_tracker(
            self, message: UserMessage, tracker: DialogueStateTracker) -> None:

        if message.parse_data:
            parse_data = message.parse_data
        else:
            parse_data = await self._parse_message(message)

        # don't ever directly mutate the tracker
        # - instead pass its events to log
        tracker.update(
            UserUttered(message.text,
                        parse_data["intent"],
                        parse_data["entities"],
                        parse_data,
                        input_channel=message.input_channel,
                        message_id=message.message_id))
        # store all entities as slots
        for e in self.domain.slots_for_entities(parse_data["entities"]):
            tracker.update(e)

        logger.debug("Logged UserUtterance - "
                     "tracker now has {} events".format(len(tracker.events)))
示例#16
0
def test_tracker_update_slots_with_entity(default_domain):
    tracker = DialogueStateTracker("default", default_domain.slots)

    test_entity = default_domain.entities[0]
    expected_slot_value = "test user"

    intent = {"name": "greet", "confidence": 1.0}
    tracker.update(
        UserUttered(
            "/greet",
            intent,
            [{
                "start": 1,
                "end": 5,
                "value": expected_slot_value,
                "entity": test_entity,
                "extractor": "manual",
            }],
        ),
        default_domain,
    )

    assert tracker.get_slot(test_entity) == expected_slot_value
示例#17
0
async def test_action_session_start_with_slots(
    default_channel: CollectingOutputChannel,
    template_nlg: TemplatedNaturalLanguageGenerator,
    template_sender_tracker: DialogueStateTracker,
    default_domain: Domain,
    session_config: SessionConfig,
    expected_events: List[Event],
):
    # set a few slots on tracker
    slot_set_event_1 = SlotSet("my_slot", "value")
    slot_set_event_2 = SlotSet("another-slot", "value2")
    for event in [slot_set_event_1, slot_set_event_2]:
        template_sender_tracker.update(event)

    default_domain.session_config = session_config

    events = await ActionSessionStart().run(default_channel, template_nlg,
                                            template_sender_tracker,
                                            default_domain)

    assert events == expected_events

    # make sure that the list of events has ascending timestamps
    assert sorted(events, key=lambda x: x.timestamp) == events
示例#18
0
def test_session_start_is_not_serialised(default_domain: Domain):
    tracker = DialogueStateTracker("default", default_domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0

    # add SlotSet event
    tracker.update(SlotSet("slot", "value"))

    # add the two SessionStarted events and a user event
    tracker.update(ActionExecuted(ACTION_SESSION_START_NAME))
    tracker.update(SessionStarted())
    tracker.update(UserUttered("say something"))

    # make sure session start is not serialised
    story = Story.from_events(tracker.events, "some-story01")

    expected = """## some-story01
    - slot{"slot": "value"}
* say something
"""

    assert story.as_story_string(flat=True) == expected
示例#19
0
    def _find_action_from_rules(
        self, tracker: DialogueStateTracker, domain: Domain
    ) -> Optional[Text]:
        tracker_as_states = self.featurizer.prediction_states([tracker], domain)
        states = tracker_as_states[0]

        logger.debug(f"Current tracker state: {states}")

        rule_keys = self._get_possible_keys(self.lookup[RULES], states)
        predicted_action_name = None
        best_rule_key = ""
        if rule_keys:
            # if there are several rules,
            # it should mean that some rule is a subset of another rule
            # therefore we pick a rule of maximum length
            best_rule_key = max(rule_keys, key=len)
            predicted_action_name = self.lookup[RULES].get(best_rule_key)

        active_loop_name = tracker.active_loop_name
        if active_loop_name:
            # find rules for unhappy path of the loop
            loop_unhappy_keys = self._get_possible_keys(
                self.lookup[RULES_FOR_LOOP_UNHAPPY_PATH], states
            )
            # there could be several unhappy path conditions
            unhappy_path_conditions = [
                self.lookup[RULES_FOR_LOOP_UNHAPPY_PATH].get(key)
                for key in loop_unhappy_keys
            ]

            # Check if a rule that predicted action_listen
            # was applied inside the loop.
            # Rules might not explicitly switch back to the loop.
            # Hence, we have to take care of that.
            predicted_listen_from_general_rule = (
                predicted_action_name == ACTION_LISTEN_NAME
                and not get_active_loop_name(self._rule_key_to_state(best_rule_key)[-1])
            )
            if predicted_listen_from_general_rule:
                if DO_NOT_PREDICT_LOOP_ACTION not in unhappy_path_conditions:
                    # negative rules don't contain a key that corresponds to
                    # the fact that active_loop shouldn't be predicted
                    logger.debug(
                        f"Predicted loop '{active_loop_name}' by overwriting "
                        f"'{ACTION_LISTEN_NAME}' predicted by general rule."
                    )
                    return active_loop_name

                # do not predict anything
                predicted_action_name = None

            if DO_NOT_VALIDATE_LOOP in unhappy_path_conditions:
                logger.debug("Added `FormValidation(False)` event.")
                tracker.update(FormValidation(False))

        if predicted_action_name is not None:
            logger.debug(
                f"There is a rule for the next action '{predicted_action_name}'."
            )
        else:
            logger.debug("There is no applicable rule.")

        return predicted_action_name
示例#20
0
    def predict_action_probabilities(
        self,
        tracker: DialogueStateTracker,
        domain: Domain,
        interpreter: NaturalLanguageInterpreter = RegexInterpreter(),
        **kwargs: Any,
    ) -> List[float]:
        """Predicts the next action the bot should take after seeing the tracker.

        Returns the list of probabilities for the next actions.
        If memorized action was found returns 1 for its index,
        else returns 0 for all actions.
        """
        result = self._default_predictions(domain)

        if not self.is_enabled:
            return result

        # Rasa Open Source default actions overrule anything. If users want to achieve
        # the same, they need to a rule or make sure that their form rejects
        # accordingly.
        rasa_default_action_name = _should_run_rasa_default_action(tracker)
        if rasa_default_action_name:
            result[domain.index_for_action(rasa_default_action_name)] = 1
            return result

        active_form_name = tracker.active_form_name()
        active_form_rejected = tracker.active_loop.get("rejected")
        should_predict_form = (active_form_name and not active_form_rejected
                               and
                               tracker.latest_action_name != active_form_name)
        should_predict_listen = (active_form_name and not active_form_rejected
                                 and tracker.latest_action_name
                                 == active_form_name)

        # A form has priority over any other rule.
        # The rules or any other prediction will be applied only if a form was rejected.
        # If we are in a form, and the form didn't run previously or rejected, we can
        # simply force predict the form.
        if should_predict_form:
            logger.debug(f"Predicted form '{active_form_name}'.")
            result[domain.index_for_action(active_form_name)] = 1
            return result

        # predict `action_listen` if form action was run successfully
        if should_predict_listen:
            logger.debug(
                f"Predicted '{ACTION_LISTEN_NAME}' after form '{active_form_name}'."
            )
            result[domain.index_for_action(ACTION_LISTEN_NAME)] = 1
            return result

        possible_keys = set(self.lookup.keys())

        tracker_as_states = self.featurizer.prediction_states([tracker],
                                                              domain)
        states = tracker_as_states[0]

        logger.debug(f"Current tracker state: {states}")

        for i, state in enumerate(reversed(states)):
            possible_keys = set(
                filter(lambda _key: self._rule_is_good(_key, i, state),
                       possible_keys))

        if possible_keys:
            # TODO rethink that
            key = max(possible_keys, key=len)

            recalled = self.lookup.get(key)

            if active_form_name:
                # Check if a rule that predicted action_listen
                # was applied inside the form.
                # Rules might not explicitly switch back to the `Form`.
                # Hence, we have to take care of that.
                predicted_listen_from_general_rule = recalled is None or (
                    domain.action_names[recalled] == ACTION_LISTEN_NAME
                    and f"active_form_{active_form_name}" not in key)
                if predicted_listen_from_general_rule:
                    logger.debug(f"Predicted form '{active_form_name}'.")
                    result[domain.index_for_action(active_form_name)] = 1
                    return result

                # Since rule snippets inside the form contain only unhappy paths,
                # notify the form that
                # it was predicted after an answer to a different question and
                # therefore it should not validate user input for requested slot
                predicted_form_from_form_rule = (
                    domain.action_names[recalled] == active_form_name
                    and f"active_form_{active_form_name}" in key)
                if predicted_form_from_form_rule:
                    logger.debug("Added `FormValidation(False)` event.")
                    tracker.update(FormValidation(False))

            if recalled is not None:
                logger.debug(f"There is a rule for next action "
                             f"'{domain.action_names[recalled]}'.")

                result[recalled] = 1
            else:
                logger.debug("There is no applicable rule.")

        return result
示例#21
0
def test_revert_user_utterance_event(default_domain):
    tracker = DialogueStateTracker("default", default_domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0

    intent1 = {"name": "greet", "confidence": 1.0}
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))
    tracker.update(UserUttered("/greet", intent1, []))
    tracker.update(ActionExecuted("my_action_1"))
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))

    intent2 = {"name": "goodbye", "confidence": 1.0}
    tracker.update(UserUttered("/goodbye", intent2, []))
    tracker.update(ActionExecuted("my_action_2"))
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))

    # Expecting count of 6:
    #   +5 executed actions
    #   +1 final state
    assert tracker.latest_action_name == ACTION_LISTEN_NAME
    assert len(list(tracker.generate_all_prior_trackers())) == 6

    tracker.update(UserUtteranceReverted())

    # Expecting count of 3:
    #   +5 executed actions
    #   +1 final state
    #   -2 rewound actions associated with the /goodbye
    #   -1 rewound action from the listen right before /goodbye
    assert tracker.latest_action_name == "my_action_1"
    assert len(list(tracker.generate_all_prior_trackers())) == 3

    dialogue = tracker.as_dialogue()

    recovered = DialogueStateTracker("default", default_domain.slots)
    recovered.recreate_from_dialogue(dialogue)

    assert recovered.current_state() == tracker.current_state()
    assert tracker.latest_action_name == "my_action_1"
    assert len(list(tracker.generate_all_prior_trackers())) == 3
示例#22
0
文件: test.py 项目: zuiwanting/rasa
def _collect_action_executed_predictions(
    processor: "MessageProcessor",
    partial_tracker: DialogueStateTracker,
    event: ActionExecuted,
    fail_on_prediction_errors: bool,
    circuit_breaker_tripped: bool,
) -> Tuple[EvaluationStore, Optional[Text], Optional[float]]:
    from rasa.core.policies.form_policy import FormPolicy

    action_executed_eval_store = EvaluationStore()

    gold = event.action_name

    if circuit_breaker_tripped:
        predicted = "circuit breaker tripped"
        policy = None
        confidence = None
    else:
        action, policy, confidence = processor.predict_next_action(partial_tracker)
        predicted = action.name()

        if (
            policy
            and predicted != gold
            and _form_might_have_been_rejected(
                processor.domain, partial_tracker, predicted
            )
        ):
            # Wrong action was predicted,
            # but it might be Ok if form action is rejected.
            _emulate_form_rejection(partial_tracker)
            # try again
            action, policy, confidence = processor.predict_next_action(partial_tracker)

            # Even if the prediction is also wrong, we don't have to undo the emulation
            # of the action rejection as we know that the user explicitly specified
            # that something else than the form was supposed to run.
            predicted = action.name()

    action_executed_eval_store.add_to_store(
        action_predictions=[predicted], action_targets=[gold]
    )

    if action_executed_eval_store.has_prediction_target_mismatch():
        partial_tracker.update(
            WronglyPredictedAction(
                gold, predicted, event.policy, event.confidence, event.timestamp
            )
        )
        if fail_on_prediction_errors:
            error_msg = (
                "Model predicted a wrong action. Failed Story: "
                "\n\n{}".format(
                    YAMLStoryWriter().dumps(partial_tracker.as_story().story_steps)
                )
            )
            if FormPolicy.__name__ in policy:
                error_msg += (
                    "FormAction is not run during "
                    "evaluation therefore it is impossible to know "
                    "if validation failed or this story is wrong. "
                    "If the story is correct, add it to the "
                    "training stories and retrain."
                )
            raise ValueError(error_msg)
    else:
        partial_tracker.update(event)

    return action_executed_eval_store, policy, confidence
示例#23
0
async def test_persist_form_story(tmpdir):
    domain = Domain.load("data/test_domains/form.yml")

    tracker = DialogueStateTracker("", domain.slots)

    story = ("* greet\n"
             "    - utter_greet\n"
             "* start_form\n"
             "    - some_form\n"
             '    - form{"name": "some_form"}\n'
             "* default\n"
             "    - utter_default\n"
             "    - some_form\n"
             "* stop\n"
             "    - utter_ask_continue\n"
             "* affirm\n"
             "    - some_form\n"
             "* stop\n"
             "    - utter_ask_continue\n"
             "    - action_listen\n"
             "* form: inform\n"
             "    - some_form\n"
             '    - form{"name": null}\n'
             "* goodbye\n"
             "    - utter_goodbye\n")

    # simulate talking to the form
    events = [
        UserUttered(intent={"name": "greet"}),
        ActionExecuted("utter_greet"),
        ActionExecuted("action_listen"),
        # start the form
        UserUttered(intent={"name": "start_form"}),
        ActionExecuted("some_form"),
        Form("some_form"),
        ActionExecuted("action_listen"),
        # out of form input
        UserUttered(intent={"name": "default"}),
        ActionExecutionRejected("some_form"),
        ActionExecuted("utter_default"),
        ActionExecuted("some_form"),
        ActionExecuted("action_listen"),
        # out of form input
        UserUttered(intent={"name": "stop"}),
        ActionExecutionRejected("some_form"),
        ActionExecuted("utter_ask_continue"),
        ActionExecuted("action_listen"),
        # out of form input but continue with the form
        UserUttered(intent={"name": "affirm"}),
        FormValidation(False),
        ActionExecuted("some_form"),
        ActionExecuted("action_listen"),
        # out of form input
        UserUttered(intent={"name": "stop"}),
        ActionExecutionRejected("some_form"),
        ActionExecuted("utter_ask_continue"),
        ActionExecuted("action_listen"),
        # form input
        UserUttered(intent={"name": "inform"}),
        FormValidation(True),
        ActionExecuted("some_form"),
        ActionExecuted("action_listen"),
        Form(None),
        UserUttered(intent={"name": "goodbye"}),
        ActionExecuted("utter_goodbye"),
        ActionExecuted("action_listen"),
    ]
    [tracker.update(e) for e in events]

    assert story in tracker.export_stories()
示例#24
0
def _emulate_form_rejection(partial_tracker: DialogueStateTracker) -> None:
    from rasa.core.events import ActionExecutionRejected

    rejected_action_name: Text = partial_tracker.active_loop["name"]
    partial_tracker.update(ActionExecutionRejected(rejected_action_name))