Beispiel #1
0
    def _find_action_from_rules(self, tracker: DialogueStateTracker,
                                domain: Domain) -> Optional[Text]:
        tracker_as_states = self.featurizer.prediction_states([tracker],
                                                              domain)
        states = tracker_as_states[0]

        logger.debug(f"Current tracker state: {states}")

        rule_keys = self._get_possible_keys(self.lookup[RULES], states)
        predicted_action_name = None
        best_rule_key = ""
        if rule_keys:
            # TODO check that max is correct
            # if there are several rules,
            # it should mean that some rule is a subset of another rule
            best_rule_key = max(rule_keys, key=len)
            predicted_action_name = self.lookup[RULES].get(best_rule_key)

        active_form_name = tracker.active_form_name()
        if active_form_name:
            # find rules for unhappy path of the form
            form_unhappy_keys = self._get_possible_keys(
                self.lookup[RULES_FOR_FORM_UNHAPPY_PATH], states)
            # there could be several unhappy path conditions
            unhappy_path_conditions = [
                self.lookup[RULES_FOR_FORM_UNHAPPY_PATH].get(key)
                for key in form_unhappy_keys
            ]

            # Check if a rule that predicted action_listen
            # was applied inside the form.
            # Rules might not explicitly switch back to the `Form`.
            # Hence, we have to take care of that.
            predicted_listen_from_general_rule = (
                predicted_action_name == ACTION_LISTEN_NAME
                and ACTIVE_FORM_PREFIX + active_form_name not in best_rule_key)
            if predicted_listen_from_general_rule:
                if DO_NOT_PREDICT_FORM_ACTION not in unhappy_path_conditions:
                    # negative rules don't contain a key that corresponds to
                    # the fact that active_form shouldn't be predicted
                    logger.debug(
                        f"Predicted form '{active_form_name}' by overwriting "
                        f"'{ACTION_LISTEN_NAME}' predicted by general rule.")
                    return active_form_name

                # do not predict anything
                predicted_action_name = None

            if DO_NOT_VALIDATE_FORM in unhappy_path_conditions:
                logger.debug("Added `FormValidation(False)` event.")
                tracker.update(FormValidation(False))

        if predicted_action_name is not None:
            logger.debug(
                f"There is a rule for the next action '{predicted_action_name}'."
            )
        else:
            logger.debug("There is no applicable rule.")

        return predicted_action_name
Beispiel #2
0
    def _find_action_from_form_happy_path(
        tracker: DialogueStateTracker, ) -> Optional[Text]:

        active_form_name = tracker.active_form_name()
        active_form_rejected = tracker.active_loop.get("rejected")
        should_predict_form = (active_form_name and not active_form_rejected
                               and
                               tracker.latest_action_name != active_form_name)
        should_predict_listen = (active_form_name and not active_form_rejected
                                 and tracker.latest_action_name
                                 == active_form_name)

        if should_predict_form:
            logger.debug(f"Predicted form '{active_form_name}'.")
            return active_form_name

        # predict `action_listen` if form action was run successfully
        if should_predict_listen:
            logger.debug(
                f"Predicted '{ACTION_LISTEN_NAME}' after form '{active_form_name}'."
            )
            return ACTION_LISTEN_NAME
Beispiel #3
0
    def predict_action_probabilities(
        self,
        tracker: DialogueStateTracker,
        domain: Domain,
        interpreter: NaturalLanguageInterpreter = RegexInterpreter(),
        **kwargs: Any,
    ) -> List[float]:
        """Predicts the next action the bot should take after seeing the tracker.

        Returns the list of probabilities for the next actions.
        If memorized action was found returns 1 for its index,
        else returns 0 for all actions.
        """
        result = self._default_predictions(domain)

        if not self.is_enabled:
            return result

        # Rasa Open Source default actions overrule anything. If users want to achieve
        # the same, they need to a rule or make sure that their form rejects
        # accordingly.
        rasa_default_action_name = _should_run_rasa_default_action(tracker)
        if rasa_default_action_name:
            result[domain.index_for_action(rasa_default_action_name)] = 1
            return result

        active_form_name = tracker.active_form_name()
        active_form_rejected = tracker.active_loop.get("rejected")
        should_predict_form = (active_form_name and not active_form_rejected
                               and
                               tracker.latest_action_name != active_form_name)
        should_predict_listen = (active_form_name and not active_form_rejected
                                 and tracker.latest_action_name
                                 == active_form_name)

        # A form has priority over any other rule.
        # The rules or any other prediction will be applied only if a form was rejected.
        # If we are in a form, and the form didn't run previously or rejected, we can
        # simply force predict the form.
        if should_predict_form:
            logger.debug(f"Predicted form '{active_form_name}'.")
            result[domain.index_for_action(active_form_name)] = 1
            return result

        # predict `action_listen` if form action was run successfully
        if should_predict_listen:
            logger.debug(
                f"Predicted '{ACTION_LISTEN_NAME}' after form '{active_form_name}'."
            )
            result[domain.index_for_action(ACTION_LISTEN_NAME)] = 1
            return result

        possible_keys = set(self.lookup.keys())

        tracker_as_states = self.featurizer.prediction_states([tracker],
                                                              domain)
        states = tracker_as_states[0]

        logger.debug(f"Current tracker state: {states}")

        for i, state in enumerate(reversed(states)):
            possible_keys = set(
                filter(lambda _key: self._rule_is_good(_key, i, state),
                       possible_keys))

        if possible_keys:
            # TODO rethink that
            key = max(possible_keys, key=len)

            recalled = self.lookup.get(key)

            if active_form_name:
                # Check if a rule that predicted action_listen
                # was applied inside the form.
                # Rules might not explicitly switch back to the `Form`.
                # Hence, we have to take care of that.
                predicted_listen_from_general_rule = recalled is None or (
                    domain.action_names[recalled] == ACTION_LISTEN_NAME
                    and f"active_form_{active_form_name}" not in key)
                if predicted_listen_from_general_rule:
                    logger.debug(f"Predicted form '{active_form_name}'.")
                    result[domain.index_for_action(active_form_name)] = 1
                    return result

                # Since rule snippets inside the form contain only unhappy paths,
                # notify the form that
                # it was predicted after an answer to a different question and
                # therefore it should not validate user input for requested slot
                predicted_form_from_form_rule = (
                    domain.action_names[recalled] == active_form_name
                    and f"active_form_{active_form_name}" in key)
                if predicted_form_from_form_rule:
                    logger.debug("Added `FormValidation(False)` event.")
                    tracker.update(FormValidation(False))

            if recalled is not None:
                logger.debug(f"There is a rule for next action "
                             f"'{domain.action_names[recalled]}'.")

                result[recalled] = 1
            else:
                logger.debug("There is no applicable rule.")

        return result