Exemplo n.º 1
0
def test_revert_action_event(default_domain: Domain):
    tracker = DialogueStateTracker("default", default_domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0

    intent = {"name": "greet", "confidence": 1.0}
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))
    tracker.update(UserUttered("/greet", intent, []))
    tracker.update(ActionExecuted("my_action"))
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))

    # Expecting count of 4:
    #   +3 executed actions
    #   +1 final state
    assert tracker.latest_action.get(ACTION_NAME) == ACTION_LISTEN_NAME
    assert len(list(tracker.generate_all_prior_trackers())) == 4

    tracker.update(ActionReverted())

    # Expecting count of 3:
    #   +3 executed actions
    #   +1 final state
    #   -1 reverted action
    assert tracker.latest_action.get(ACTION_NAME) == "my_action"
    assert len(list(tracker.generate_all_prior_trackers())) == 3

    dialogue = tracker.as_dialogue()

    recovered = DialogueStateTracker("default", default_domain.slots)
    recovered.recreate_from_dialogue(dialogue)

    assert recovered.current_state() == tracker.current_state()
    assert tracker.latest_action.get(ACTION_NAME) == "my_action"
    assert len(list(tracker.generate_all_prior_trackers())) == 3
Exemplo n.º 2
0
def test_restart_event(default_domain: Domain):
    tracker = DialogueStateTracker("default", default_domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0

    intent = {"name": "greet", "confidence": 1.0}
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))
    tracker.update(UserUttered("/greet", intent, []))
    tracker.update(ActionExecuted("my_action"))
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))

    assert len(tracker.events) == 4
    assert tracker.latest_message.text == "/greet"
    assert len(list(tracker.generate_all_prior_trackers())) == 4

    tracker.update(Restarted())

    assert len(tracker.events) == 5
    assert tracker.followup_action is not None
    assert tracker.followup_action == ACTION_LISTEN_NAME
    assert tracker.latest_message.text is None
    assert len(list(tracker.generate_all_prior_trackers())) == 1

    dialogue = tracker.as_dialogue()

    recovered = DialogueStateTracker("default", default_domain.slots)
    recovered.recreate_from_dialogue(dialogue)

    assert recovered.current_state() == tracker.current_state()
    assert len(recovered.events) == 5
    assert recovered.latest_message.text is None
    assert len(list(recovered.generate_all_prior_trackers())) == 1
Exemplo n.º 3
0
    def predict_next_with_tracker(
        self,
        tracker: DialogueStateTracker,
        verbosity: EventVerbosity = EventVerbosity.AFTER_RESTART,
    ) -> Optional[Dict[Text, Any]]:
        """Predict the next action for a given conversation state.

        Args:
            tracker: A tracker representing a conversation state.
            verbosity: Verbosity for the returned conversation state.

        Returns:
            The prediction for the next action. `None` if no domain or policies loaded.
        """
        if self.model_metadata.training_type == TrainingType.NLU:
            rasa.shared.utils.io.raise_warning(
                "No core model. Skipping action prediction and execution.",
                docs=DOCS_URL_POLICIES,
            )
            return None

        prediction = self._predict_next_with_tracker(tracker)

        scores = [{
            "action": a,
            "score": p
        } for a, p in zip(self.domain.action_names_or_texts,
                          prediction.probabilities)]
        return {
            "scores": scores,
            "policy": prediction.policy_name,
            "confidence": prediction.max_confidence,
            "tracker": tracker.current_state(verbosity),
        }
Exemplo n.º 4
0
    def predict_next_with_tracker(
        self,
        tracker: DialogueStateTracker,
        verbosity: EventVerbosity = EventVerbosity.AFTER_RESTART,
    ) -> Optional[Dict[Text, Any]]:
        """Predict the next action for a given conversation state.

        Args:
            tracker: A tracker representing a conversation state.
            verbosity: Verbosity for the returned conversation state.

        Returns:
            The prediction for the next action. `None` if no domain or policies loaded.
        """
        if not self.policy_ensemble or not self.domain:
            # save tracker state to continue conversation from this state
            rasa.shared.utils.io.raise_warning(
                "No policy ensemble or domain set. Skipping action prediction."
                "You should set a policy before training a model.",
                docs=DOCS_URL_POLICIES,
            )
            return None

        prediction = self._get_next_action_probabilities(tracker)

        scores = [{
            "action": a,
            "score": p
        } for a, p in zip(self.domain.action_names, prediction.probabilities)]
        return {
            "scores": scores,
            "policy": prediction.policy_name,
            "confidence": prediction.max_confidence,
            "tracker": tracker.current_state(verbosity),
        }
Exemplo n.º 5
0
    def _current_tracker_state_without_events(tracker: DialogueStateTracker) -> Dict:
        # get current tracker state and remove `events` key from state
        # since events are pushed separately in the `update_one()` operation
        state = tracker.current_state(EventVerbosity.ALL)
        state.pop("events", None)

        return state
Exemplo n.º 6
0
def test_revert_user_utterance_event(default_domain: Domain):
    tracker = DialogueStateTracker("default", default_domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0

    intent1 = {"name": "greet", "confidence": 1.0}
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))
    tracker.update(UserUttered("/greet", intent1, []))
    tracker.update(ActionExecuted("my_action_1"))
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))

    intent2 = {"name": "goodbye", "confidence": 1.0}
    tracker.update(UserUttered("/goodbye", intent2, []))
    tracker.update(ActionExecuted("my_action_2"))
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))

    # Expecting count of 6:
    #   +5 executed actions
    #   +1 final state
    assert tracker.latest_action.get(ACTION_NAME) == ACTION_LISTEN_NAME
    assert len(list(tracker.generate_all_prior_trackers())) == 6

    tracker.update(UserUtteranceReverted())

    # Expecting count of 3:
    #   +5 executed actions
    #   +1 final state
    #   -2 rewound actions associated with the /goodbye
    #   -1 rewound action from the listen right before /goodbye
    assert tracker.latest_action.get(ACTION_NAME) == "my_action_1"
    assert len(list(tracker.generate_all_prior_trackers())) == 3

    dialogue = tracker.as_dialogue()

    recovered = DialogueStateTracker("default", default_domain.slots)
    recovered.recreate_from_dialogue(dialogue)

    assert recovered.current_state() == tracker.current_state()
    assert tracker.latest_action.get(ACTION_NAME) == "my_action_1"
    assert len(list(tracker.generate_all_prior_trackers())) == 3
Exemplo n.º 7
0
def nlg_request_format(
    template_name: Text,
    tracker: DialogueStateTracker,
    output_channel: Text,
    **kwargs: Any,
) -> Dict[Text, Any]:
    """Create the json body for the NLG json body for the request."""

    tracker_state = tracker.current_state(EventVerbosity.ALL)

    return {
        "template": template_name,
        "arguments": kwargs,
        "tracker": tracker_state,
        "channel": {"name": output_channel},
    }
Exemplo n.º 8
0
def nlg_request_format(
    utter_action: Text,
    tracker: DialogueStateTracker,
    output_channel: Text,
    **kwargs: Any,
) -> Dict[Text, Any]:
    """Create the json body for the NLG json body for the request."""
    tracker_state = tracker.current_state(EventVerbosity.ALL)

    # TODO: Remove `template` by Rasa Open Source 3.0.
    return {
        "response": utter_action,
        "template": utter_action,
        "arguments": kwargs,
        "tracker": tracker_state,
        "channel": {
            "name": output_channel
        },
    }