Ejemplo n.º 1
0
            [
                SlotSet("num_people", None),
                ActiveLoop(None),
                SlotSet("num_tables", 5),
                SlotSet(REQUESTED_SLOT, None),
                ActiveLoop(None),
            ],
        ),
        # User rejected manually
        (
            [{
                "event": "action_execution_rejected",
                "name": "my form"
            }],
            [
                ActionExecutionRejected("my form"),
                SlotSet("num_tables", 5),
                SlotSet("num_people", "hi"),
                SlotSet(REQUESTED_SLOT, None),
            ],
        ),
    ],
)
async def test_validate_slots(validate_return_events: List[Dict],
                              expected_events: List[Event]):
    form_name = "my form"
    slot_name = "num_people"
    slot_value = "hi"
    events = [
        ActiveLoop(form_name),
        SlotSet(REQUESTED_SLOT, slot_name),
Ejemplo n.º 2
0
@pytest.mark.parametrize(
    "events, expected_applied_events",
    [
        (
            [
                # Form is triggered and requests slot.
                ActionExecuted(ACTION_LISTEN_NAME),
                user_uttered("greet"),
                ActionExecuted("loop"),
                ActiveLoop("loop"),
                SlotSet(REQUESTED_SLOT, "bla"),
                # User sends chitchat instead of answering form.
                ActionExecuted(ACTION_LISTEN_NAME),
                user_uttered("chitchat"),
                # Form rejected execution.
                ActionExecutionRejected("loop"),
                # Action which deals with unhappy path.
                ActionExecuted("handling chitchat"),
                # We immediately return to form after executing an action to handle it.
                ActionExecuted("loop"),
                # Form happy path continues until all slots are filled.
                SlotSet(REQUESTED_SLOT, "bla"),
                ActionExecuted(ACTION_LISTEN_NAME),
                user_uttered("fill slots"),
                ActionExecuted("loop"),
                SlotSet("slot", "value"),
                SlotSet(REQUESTED_SLOT, None),
                ActiveLoop(None),
            ],
            [
                ActionExecuted(ACTION_LISTEN_NAME),
Ejemplo n.º 3
0
async def test_persist_legacy_form_story():
    domain = Domain.load("data/test_domains/form.yml")

    tracker = DialogueStateTracker("", domain.slots)

    story = ("* greet\n"
             "    - utter_greet\n"
             "* start_form\n"
             "    - some_form\n"
             '    - form{"name": "some_form"}\n'
             "* default\n"
             "    - utter_default\n"
             "    - some_form\n"
             "* stop\n"
             "    - utter_ask_continue\n"
             "* affirm\n"
             "    - some_form\n"
             "* stop\n"
             "    - utter_ask_continue\n"
             "* inform\n"
             "    - some_form\n"
             '    - form{"name": null}\n'
             "* goodbye\n"
             "    - utter_goodbye\n")

    # simulate talking to the form
    events = [
        UserUttered(intent={"name": "greet"}),
        ActionExecuted("utter_greet"),
        ActionExecuted("action_listen"),
        # start the form
        UserUttered(intent={"name": "start_form"}),
        ActionExecuted("some_form"),
        ActiveLoop("some_form"),
        ActionExecuted("action_listen"),
        # out of form input
        UserUttered(intent={"name": "default"}),
        ActionExecutionRejected("some_form"),
        ActionExecuted("utter_default"),
        ActionExecuted("some_form"),
        ActionExecuted("action_listen"),
        # out of form input
        UserUttered(intent={"name": "stop"}),
        ActionExecutionRejected("some_form"),
        ActionExecuted("utter_ask_continue"),
        ActionExecuted("action_listen"),
        # out of form input but continue with the form
        UserUttered(intent={"name": "affirm"}),
        LoopInterrupted(True),
        ActionExecuted("some_form"),
        ActionExecuted("action_listen"),
        # out of form input
        UserUttered(intent={"name": "stop"}),
        ActionExecutionRejected("some_form"),
        ActionExecuted("utter_ask_continue"),
        ActionExecuted("action_listen"),
        # form input
        UserUttered(intent={"name": "inform"}),
        LoopInterrupted(False),
        ActionExecuted("some_form"),
        ActionExecuted("action_listen"),
        ActiveLoop(None),
        UserUttered(intent={"name": "goodbye"}),
        ActionExecuted("utter_goodbye"),
        ActionExecuted("action_listen"),
    ]
    [tracker.update(e) for e in events]

    story = story.replace(f"- {LegacyForm.type_name}",
                          f"- {ActiveLoop.type_name}")

    assert story in tracker.export_stories()
Ejemplo n.º 4
0
async def test_form_unhappy_path_no_validation_from_story():
    form_name = "some_form"
    handle_rejection_action_name = "utter_handle_rejection"

    domain = Domain.from_yaml(f"""
        intents:
        - {GREET_INTENT_NAME}
        actions:
        - {UTTER_GREET_ACTION}
        - {handle_rejection_action_name}
        - some-action
        slots:
          {REQUESTED_SLOT}:
            type: unfeaturized
        forms:
        - {form_name}
    """)

    unhappy_story = TrackerWithCachedStates.from_events(
        "bla",
        domain=domain,
        slots=domain.slots,
        evts=[
            # We are in an active form
            ActionExecuted(form_name),
            ActiveLoop(form_name),
            ActionExecuted(ACTION_LISTEN_NAME),
            # When a user says "hi", and the form is unhappy,
            # we want to run a specific action
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(handle_rejection_action_name),
            ActionExecuted(ACTION_LISTEN_NAME),
            # Next user utterance is an answer to the previous question
            # and shouldn't be validated by the form
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(form_name),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )

    policy = RulePolicy()
    policy.train([unhappy_story], domain, RegexInterpreter())

    # Check that RulePolicy predicts no validation to handle unhappy path
    conversation_events = [
        ActionExecuted(form_name),
        ActiveLoop(form_name),
        SlotSet(REQUESTED_SLOT, "some value"),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("haha", {"name": GREET_INTENT_NAME}),
        ActionExecutionRejected(form_name),
        ActionExecuted(handle_rejection_action_name),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("haha", {"name": GREET_INTENT_NAME}),
    ]

    tracker = DialogueStateTracker.from_events("casd",
                                               evts=conversation_events,
                                               slots=domain.slots)
    action_probabilities = policy.predict_action_probabilities(
        tracker, domain, RegexInterpreter())
    # there is no rule for next action
    assert max(action_probabilities) == policy._core_fallback_threshold
    # check that RulePolicy entered unhappy path based on the training story
    assert tracker.events[-1] == LoopInterrupted(True)
Ejemplo n.º 5
0
    assert action_received_events

    tracker = default_processor.get_tracker(conversation_id)
    # The action was logged on the tracker as well
    expected_events.append(ActionExecuted(ACTION_LISTEN_NAME))

    for event, expected in zip(tracker.events, expected_events):
        assert event == expected


# noinspection PyTypeChecker
@pytest.mark.parametrize(
    "reject_fn",
    [
        lambda: [ActionExecutionRejected(ACTION_LISTEN_NAME)],
        lambda: (_ for _ in ()).throw(ActionExecutionRejection(ACTION_LISTEN_NAME)),
    ],
)
async def test_policy_events_not_applied_if_rejected(
    default_processor: MessageProcessor,
    monkeypatch: MonkeyPatch,
    reject_fn: Callable[[], List[Event]],
):
    expected_action = ACTION_LISTEN_NAME
    expected_events = [LoopInterrupted(True)]
    conversation_id = "test_policy_events_are_applied_to_tracker"
    user_message = "/greet"

    class ConstantEnsemble(PolicyEnsemble):
        def probabilities_using_best_policy(
Ejemplo n.º 6
0
async def test_form_unhappy_path_from_story():
    form_name = "some_form"
    handle_rejection_action_name = "utter_handle_rejection"

    domain = Domain.from_yaml(f"""
        intents:
        - {GREET_INTENT_NAME}
        actions:
        - {UTTER_GREET_ACTION}
        - {handle_rejection_action_name}
        - some-action
        slots:
          {REQUESTED_SLOT}:
            type: unfeaturized
        forms:
        - {form_name}
    """)

    unhappy_story = TrackerWithCachedStates.from_events(
        "bla",
        domain=domain,
        slots=domain.slots,
        evts=[
            # We are in an active form
            ActionExecuted(form_name),
            ActiveLoop(form_name),
            ActionExecuted(ACTION_LISTEN_NAME),
            # in training stories there is either intent or text, never both
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(UTTER_GREET_ACTION),
            # After our bot says "hi", we want to run a specific action
            ActionExecuted(handle_rejection_action_name),
            ActionExecuted(form_name),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )

    policy = RulePolicy()
    policy.train([GREET_RULE, unhappy_story], domain, RegexInterpreter())

    # Check that RulePolicy predicts action to handle unhappy path
    conversation_events = [
        ActionExecuted(form_name),
        ActiveLoop(form_name),
        SlotSet(REQUESTED_SLOT, "some value"),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("haha", {"name": GREET_INTENT_NAME}),
        ActionExecutionRejected(form_name),
    ]

    action_probabilities = policy.predict_action_probabilities(
        DialogueStateTracker.from_events("casd",
                                         evts=conversation_events,
                                         slots=domain.slots),
        domain,
        RegexInterpreter(),
    )
    assert_predicted_action(action_probabilities, domain, UTTER_GREET_ACTION)

    # Check that RulePolicy doesn't trigger form or action_listen
    # after handling unhappy path
    conversation_events.append(ActionExecuted(handle_rejection_action_name))
    action_probabilities = policy.predict_action_probabilities(
        DialogueStateTracker.from_events("casd",
                                         evts=conversation_events,
                                         slots=domain.slots),
        domain,
        RegexInterpreter(),
    )
    assert max(action_probabilities) == policy._core_fallback_threshold
Ejemplo n.º 7
0
async def test_form_unhappy_path_no_validation_from_rule():
    form_name = "some_form"
    handle_rejection_action_name = "utter_handle_rejection"

    domain = Domain.from_yaml(f"""
        intents:
        - {GREET_INTENT_NAME}
        actions:
        - {UTTER_GREET_ACTION}
        - {handle_rejection_action_name}
        - some-action
        slots:
          {REQUESTED_SLOT}:
            type: unfeaturized
        forms:
        - {form_name}
    """)

    unhappy_rule = TrackerWithCachedStates.from_events(
        "bla",
        domain=domain,
        slots=domain.slots,
        evts=[
            # We are in an active form
            ActiveLoop(form_name),
            SlotSet(REQUESTED_SLOT, "bla"),
            ActionExecuted(RULE_SNIPPET_ACTION_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            # When a user says "hi", and the form is unhappy,
            # we want to run a specific action
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(handle_rejection_action_name),
            # Next user utterance is an answer to the previous question
            # and shouldn't be validated by the form
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(form_name),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
        is_rule_tracker=True,
    )
    # unhappy rule is multi user turn rule, therefore remove restriction for policy
    policy = RulePolicy(restrict_rules=False)
    # RulePolicy should memorize that unhappy_rule overrides GREET_RULE
    policy.train([GREET_RULE, unhappy_rule], domain, RegexInterpreter())

    # Check that RulePolicy predicts action to handle unhappy path
    conversation_events = [
        ActionExecuted(form_name),
        ActiveLoop(form_name),
        SlotSet(REQUESTED_SLOT, "some value"),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("haha", {"name": GREET_INTENT_NAME}),
        ActionExecutionRejected(form_name),
    ]

    action_probabilities = policy.predict_action_probabilities(
        DialogueStateTracker.from_events("casd",
                                         evts=conversation_events,
                                         slots=domain.slots),
        domain,
        RegexInterpreter(),
    )
    assert_predicted_action(action_probabilities, domain,
                            handle_rejection_action_name)

    # Check that RulePolicy predicts action_listen
    conversation_events.append(ActionExecuted(handle_rejection_action_name))
    action_probabilities = policy.predict_action_probabilities(
        DialogueStateTracker.from_events("casd",
                                         evts=conversation_events,
                                         slots=domain.slots),
        domain,
        RegexInterpreter(),
    )
    assert_predicted_action(action_probabilities, domain, ACTION_LISTEN_NAME)

    # Check that RulePolicy triggers form again after handling unhappy path
    conversation_events.append(ActionExecuted(ACTION_LISTEN_NAME))
    tracker = DialogueStateTracker.from_events("casd",
                                               evts=conversation_events,
                                               slots=domain.slots)
    action_probabilities = policy.predict_action_probabilities(
        tracker, domain, RegexInterpreter())
    assert_predicted_action(action_probabilities, domain, form_name)
    # check that RulePolicy entered unhappy path based on the training story
    assert tracker.events[-1] == LoopInterrupted(True)
Ejemplo n.º 8
0
         },
         {
             "entity": "count",
             "value": 1
         },
     ],
     timestamp=None,
 ),
 DefinePrevUserUtteredFeaturization(use_text_for_featurization=False,
                                    timestamp=None,
                                    metadata=None),
 ReminderCancelled(timestamp=1621590172.3872123),
 ReminderScheduled(timestamp=None,
                   trigger_date_time=datetime.now(),
                   intent="greet"),
 ActionExecutionRejected(action_name="my_action"),
 LegacyFormValidation(validate=True, timestamp=None),
 LoopInterrupted(timestamp=None, is_interrupted=False),
 ActiveLoop(name="loop"),
 LegacyForm(name="my_form"),
 AllSlotsReset(),
 SlotSet(key="my_slot", value={}),
 SlotSet(key="my slot", value=[]),
 SlotSet(key="test", value=1),
 SlotSet(key="test", value="text"),
 ConversationResumed(),
 ConversationPaused(),
 FollowupAction(name="test"),
 StoryExported(),
 Restarted(),
 ActionReverted(),
Ejemplo n.º 9
0
async def test_form_unhappy_path_from_in_form_rule():
    form_name = "some_form"
    handle_rejection_action_name = "utter_handle_rejection"

    domain = Domain.from_yaml(f"""
        intents:
        - {GREET_INTENT_NAME}
        actions:
        - {UTTER_GREET_ACTION}
        - {handle_rejection_action_name}
        - some-action
        slots:
          {REQUESTED_SLOT}:
            type: unfeaturized
        forms:
        - {form_name}
    """)

    unhappy_rule = TrackerWithCachedStates.from_events(
        "bla",
        domain=domain,
        slots=domain.slots,
        evts=[
            # We are in an active form
            ActiveLoop(form_name),
            SlotSet(REQUESTED_SLOT, "bla"),
            ActionExecuted(RULE_SNIPPET_ACTION_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            # When a user says "hi", and the form is unhappy,
            # we want to run a specific action
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(handle_rejection_action_name),
            ActionExecuted(form_name),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
        is_rule_tracker=True,
    )

    policy = RulePolicy()
    # RulePolicy should memorize that unhappy_rule overrides GREET_RULE
    policy.train([GREET_RULE, unhappy_rule], domain, RegexInterpreter())

    # Check that RulePolicy predicts action to handle unhappy path
    conversation_events = [
        ActionExecuted(form_name),
        ActiveLoop(form_name),
        SlotSet(REQUESTED_SLOT, "some value"),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("haha", {"name": GREET_INTENT_NAME}),
        ActionExecutionRejected(form_name),
    ]

    action_probabilities = policy.predict_action_probabilities(
        DialogueStateTracker.from_events("casd",
                                         evts=conversation_events,
                                         slots=domain.slots),
        domain,
        RegexInterpreter(),
    )
    assert_predicted_action(action_probabilities, domain,
                            handle_rejection_action_name)

    # Check that RulePolicy triggers form again after handling unhappy path
    conversation_events.append(ActionExecuted(handle_rejection_action_name))
    action_probabilities = policy.predict_action_probabilities(
        DialogueStateTracker.from_events("casd",
                                         evts=conversation_events,
                                         slots=domain.slots),
        domain,
        RegexInterpreter(),
    )
    assert_predicted_action(action_probabilities, domain, form_name)
Ejemplo n.º 10
0
class TestUnexpecTEDIntentPolicy(TestTEDPolicy):
    @staticmethod
    def _policy_class_to_test() -> Type[UnexpecTEDIntentPolicy]:
        return UnexpecTEDIntentPolicy

    @pytest.fixture(scope="class")
    def featurizer(self) -> TrackerFeaturizer:
        featurizer = IntentMaxHistoryTrackerFeaturizer(
            IntentTokenizerSingleStateFeaturizer(), max_history=self.max_history
        )
        return featurizer

    @staticmethod
    def persist_and_load_policy(
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
    ):
        return trained_policy.__class__.load(
            trained_policy.config, model_storage, resource, execution_context
        )

    def test_ranking_length(self, trained_policy: UnexpecTEDIntentPolicy):
        assert trained_policy.config[RANKING_LENGTH] == LABEL_RANKING_LENGTH

    def test_ranking_length_and_renormalization(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        tracker: DialogueStateTracker,
        default_domain: Domain,
    ):
        precomputations = None
        prediction_metadata = trained_policy.predict_action_probabilities(
            tracker, default_domain, precomputations,
        ).action_metadata
        assert (
            prediction_metadata is None
            or len(prediction_metadata[RANKING_KEY])
            == trained_policy.config[RANKING_LENGTH]
        )

    def test_label_data_assembly(
        self, trained_policy: UnexpecTEDIntentPolicy, default_domain: Domain
    ):

        # Construct input data
        state_featurizer = trained_policy.featurizer.state_featurizer
        encoded_all_labels = state_featurizer.encode_all_labels(
            default_domain, precomputations=None
        )
        attribute_data, _ = model_data_utils.convert_to_data_format(encoded_all_labels)

        assembled_label_data = trained_policy._assemble_label_data(
            attribute_data, default_domain
        )
        assembled_label_data_signature = assembled_label_data.get_signature()

        assert list(assembled_label_data_signature.keys()) == [
            f"{LABEL}_{INTENT}",
            LABEL,
        ]
        assert assembled_label_data.num_examples == len(default_domain.intents)
        assert list(assembled_label_data_signature[f"{LABEL}_{INTENT}"].keys()) == [
            MASK,
            SENTENCE,
        ]
        assert list(assembled_label_data_signature[LABEL].keys()) == [IDS]
        assert assembled_label_data_signature[f"{LABEL}_{INTENT}"][SENTENCE][
            0
        ].units == len(default_domain.intents)

    def test_training_with_no_intent(
        self,
        featurizer: Optional[TrackerFeaturizer],
        default_domain: Domain,
        tmp_path: Path,
        caplog: LogCaptureFixture,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
    ):
        stories = tmp_path / "stories.yml"
        stories.write_text(
            """
            version: "3.0"
            stories:
            - story: test path
              steps:
              - action: utter_greet
            """
        )
        policy = self.create_policy(
            featurizer=featurizer,
            model_storage=model_storage,
            resource=resource,
            execution_context=execution_context,
        )
        import tests.core.test_policies

        training_trackers = tests.core.test_policies.train_trackers(
            default_domain, str(stories), augmentation_factor=20
        )

        with pytest.warns(UserWarning):
            policy.train(training_trackers, default_domain, precomputations=None)

    def test_prepared_data_for_threshold_prediction(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        default_domain: Domain,
        stories_path: Path,
    ):
        training_trackers = train_trackers(
            default_domain, stories_path, augmentation_factor=0
        )
        training_model_data, _ = trained_policy._prepare_for_training(
            training_trackers, default_domain, precomputations=None,
        )

        data_for_prediction = trained_policy._prepare_data_for_prediction(
            training_model_data
        )

        assert set(data_for_prediction.data.keys()).issubset(PREDICTION_FEATURES)

    def test_similarities_collection_for_label_ids(self):
        label_ids = np.array([[0, 1], [1, -1], [2, -1]])
        outputs = {
            "similarities": np.array(
                [[[1.2, 0.3, 0.2]], [[0.5, 0.2, 1.6]], [[0.01, 0.1, 1.7]],]
            )
        }
        label_id_similarities = UnexpecTEDIntentPolicy._collect_label_id_grouped_scores(
            outputs, label_ids
        )

        # Should contain similarities for all label ids except padding token.
        assert sorted(list(label_id_similarities.keys())) == [0, 1, 2]

        # Cross-check that the collected similarities are correct for each label id.
        assert label_id_similarities[0] == {
            POSITIVE_SCORES_KEY: [1.2],
            NEGATIVE_SCORES_KEY: [0.5, 0.01],
        }
        assert label_id_similarities[1] == {
            POSITIVE_SCORES_KEY: [0.3, 0.2],
            NEGATIVE_SCORES_KEY: [0.1],
        }
        assert label_id_similarities[2] == {
            POSITIVE_SCORES_KEY: [1.7],
            NEGATIVE_SCORES_KEY: [0.2, 1.6],
        }

    def test_label_quantiles_computation(self):
        label_id_scores = {
            0: {
                POSITIVE_SCORES_KEY: [1.3, 0.2],
                NEGATIVE_SCORES_KEY: [
                    -0.1,
                    -1.2,
                    -2.3,
                    -4.1,
                    -0.5,
                    0.2,
                    0.8,
                    0.9,
                    -3.2,
                    -2.7,
                ],
            },
            3: {POSITIVE_SCORES_KEY: [1.3, 0.2], NEGATIVE_SCORES_KEY: [-0.1]},
            6: {POSITIVE_SCORES_KEY: [1.3, 0.2], NEGATIVE_SCORES_KEY: []},
        }
        expected_thresholds = {
            0: [
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                -0.1,
                -0.1,
                -0.5,
                -0.5,
                -1.2,
                -1.2,
                -1.2,
                -2.3,
                -2.3,
                -2.7,
                -2.7,
                -3.2,
                -3.2,
                -4.1,
                -4.1,
            ],
            3: [
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
            ],
            6: [
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
            ],
        }
        thresholds = UnexpecTEDIntentPolicy._compute_label_quantiles(label_id_scores)
        assert sorted(list(thresholds.keys())) == sorted(
            list(expected_thresholds.keys())
        )
        for label_id, tolerance_thresholds in thresholds.items():
            assert expected_thresholds[label_id] == tolerance_thresholds

    def test_post_training_threshold_computation(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        default_domain: Domain,
        stories_path: Path,
    ):
        training_trackers = train_trackers(
            default_domain, stories_path, augmentation_factor=0
        )
        training_model_data, label_ids = trained_policy._prepare_for_training(
            training_trackers, default_domain, precomputations=None,
        )

        trained_policy.compute_label_quantiles_post_training(
            training_model_data, label_ids
        )

        computed_thresholds = trained_policy.label_quantiles

        # -1 is used for padding and hence is not expected in the keys
        expected_keys = list(np.unique(label_ids))
        expected_keys.remove(-1)

        assert sorted(list(computed_thresholds.keys())) == sorted(expected_keys)

    @pytest.mark.parametrize(
        "tolerance, expected_thresholds",
        [
            (0.0, [0.2, -0.1, 0.2]),
            (0.75, [-2.9, -0.1, -4.3]),
            (0.72, [-2.7, -0.1, -4.0]),
            (0.78, [-2.9, -0.1, -4.3]),
            (1.0, [-4.1, -0.1, -5.5]),
        ],
    )
    def test_pick_thresholds_for_labels(
        self, tolerance: float, expected_thresholds: List[float]
    ):
        label_id_tolerance_thresholds = {
            0: [
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                -0.1,
                -0.1,
                -0.5,
                -0.5,
                -1.2,
                -1.2,
                -2.3,
                -2.3,
                -2.7,
                -2.9,
                -3.2,
                -3.2,
                -4.1,
                -4.1,
            ],
            3: [
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
            ],
            4: [0.2 - (index * 0.3) for index in range(20)],
        }
        thresholds = UnexpecTEDIntentPolicy._pick_thresholds(
            label_id_tolerance_thresholds, tolerance
        )
        assert sorted(list(thresholds.keys())) == sorted(
            list(label_id_tolerance_thresholds.keys())
        )
        computed_values = list(thresholds.values())
        assert expected_thresholds == computed_values

    @pytest.mark.parametrize(
        "predicted_similarity, threshold_value, is_unlikely",
        [(1.2, 0.2, False), (0.3, -0.1, False), (-1.5, 0.03, True)],
    )
    def test_unlikely_intent_check(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        predicted_similarity: float,
        threshold_value: float,
        is_unlikely: bool,
        tmp_path: Path,
    ):
        loaded_policy = self.persist_and_load_policy(
            trained_policy, model_storage, resource, execution_context
        )
        # Construct dummy similarities
        similarities = np.array([[0.0] * len(default_domain.intents)])
        dummy_intent_index = 4
        similarities[0, dummy_intent_index] = predicted_similarity

        loaded_policy.label_thresholds[dummy_intent_index] = threshold_value
        query_intent = default_domain.intents[dummy_intent_index]

        unlikely_intent_prediction = loaded_policy._check_unlikely_intent(
            default_domain, similarities, query_intent
        )

        assert is_unlikely == unlikely_intent_prediction

    def test_should_check_for_intent(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        tmp_path: Path,
    ):
        loaded_policy = self.persist_and_load_policy(
            trained_policy, model_storage, resource, execution_context
        )

        intent_index = 0
        assert (
            loaded_policy._should_check_for_intent(
                default_domain.intents[intent_index], default_domain
            )
            is False
        )

        intent_index = 4
        assert loaded_policy._should_check_for_intent(
            default_domain.intents[intent_index], default_domain
        )

        loaded_policy.config[IGNORE_INTENTS_LIST] = [
            default_domain.intents[intent_index]
        ]
        assert (
            loaded_policy._should_check_for_intent(
                default_domain.intents[intent_index], default_domain
            )
            is False
        )

    def test_no_action_unlikely_intent_prediction(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        tmp_path: Path,
    ):
        loaded_policy = self.persist_and_load_policy(
            trained_policy, model_storage, resource, execution_context
        )

        expected_probabilities = [0] * default_domain.num_actions

        precomputations = None
        tracker = DialogueStateTracker(sender_id="init", slots=default_domain.slots)
        prediction = loaded_policy.predict_action_probabilities(
            tracker, default_domain, precomputations
        )

        assert prediction.probabilities == expected_probabilities

        tracker.update_with_events(
            [
                UserUttered(text="hello", intent={"name": "greet"}),
                ActionExecuted(action_name="utter_greet"),
            ],
            default_domain,
        )
        prediction = loaded_policy.predict_action_probabilities(
            tracker, default_domain, precomputations
        )

        assert prediction.probabilities == expected_probabilities

        loaded_policy.model = None

        prediction = loaded_policy.predict_action_probabilities(
            tracker, default_domain, precomputations
        )

        assert prediction.probabilities == expected_probabilities

    @pytest.mark.parametrize(
        "predicted_similarity, threshold_value, is_unlikely",
        [(1.2, 0.2, False), (0.3, -0.1, False), (-1.5, 0.03, True)],
    )
    def test_action_unlikely_intent_prediction(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        predicted_similarity: float,
        threshold_value: float,
        is_unlikely: bool,
        monkeypatch: MonkeyPatch,
        tmp_path: Path,
    ):
        loaded_policy = self.persist_and_load_policy(
            trained_policy, model_storage, resource, execution_context
        )

        similarities = np.array([[[0.0] * len(default_domain.intents)]])

        dummy_intent_index = 4
        similarities[0, 0, dummy_intent_index] = predicted_similarity
        query_intent = default_domain.intents[dummy_intent_index]

        loaded_policy.label_thresholds[dummy_intent_index] = threshold_value

        precomputations = None
        tracker = DialogueStateTracker(sender_id="init", slots=default_domain.slots)

        tracker.update_with_events(
            [UserUttered(text="hello", intent={"name": query_intent})], default_domain,
        )

        # Preset the model predictions to the similarity values
        # so that we don't need to hardcode for particular model predictions.
        monkeypatch.setattr(
            loaded_policy.model,
            "run_inference",
            lambda data: {"similarities": similarities},
        )

        prediction = loaded_policy.predict_action_probabilities(
            tracker, default_domain, precomputations
        )

        if not is_unlikely:
            assert prediction.probabilities == [0.0] * default_domain.num_actions
        else:
            assert (
                prediction.probabilities[
                    default_domain.index_for_action(ACTION_UNLIKELY_INTENT_NAME)
                ]
                == 1.0
            )

            # Make sure metadata is set. The exact structure
            # of the metadata is tested separately and
            # not as part of this test.
            assert prediction.action_metadata is not None
            # Assert metadata is serializable
            assert json.dumps(prediction.action_metadata)

    @pytest.mark.parametrize(
        "tracker_events, should_skip",
        [
            ([], True),
            ([ActionExecuted("action_listen")], True),
            (
                [
                    ActionExecuted("action_listen"),
                    UserUttered("hi", intent={"name": "greet"}),
                ],
                False,
            ),
            (
                [
                    ActionExecuted("action_listen"),
                    UserUttered("hi", intent={"name": "greet"}),
                    EntitiesAdded([{"name": "dummy"}]),
                ],
                False,
            ),
            (
                [
                    ActionExecuted("action_listen"),
                    UserUttered("hi", intent={"name": "greet"}),
                    SlotSet("name"),
                ],
                False,
            ),
            (
                [
                    ActiveLoop("loop"),
                    ActionExecuted("action_listen"),
                    UserUttered("hi", intent={"name": "greet"}),
                    ActionExecutionRejected("loop"),
                ],
                False,
            ),
            (
                [
                    ActionExecuted("action_listen"),
                    UserUttered("hi", intent={"name": "greet"}),
                    ActionExecuted("utter_greet"),
                ],
                True,
            ),
        ],
    )
    def test_skip_predictions_to_prevent_loop(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        caplog: LogCaptureFixture,
        tracker_events: List[Event],
        should_skip: bool,
        tmp_path: Path,
    ):
        """Skips predictions to prevent loop."""
        loaded_policy = self.persist_and_load_policy(
            trained_policy, model_storage, resource, execution_context
        )
        precomputations = None
        tracker = DialogueStateTracker(sender_id="init", slots=default_domain.slots)
        tracker.update_with_events(tracker_events, default_domain)
        with caplog.at_level(logging.DEBUG):
            prediction = loaded_policy.predict_action_probabilities(
                tracker, default_domain, precomputations
            )

        assert (
            "Skipping predictions for UnexpecTEDIntentPolicy" in caplog.text
        ) == should_skip

        if should_skip:
            assert prediction.probabilities == loaded_policy._default_predictions(
                default_domain
            )

    @pytest.mark.parametrize(
        "tracker_events",
        [
            [
                ActionExecuted("action_listen"),
                UserUttered("hi", intent={"name": "inexistent_intent"}),
            ],
            [
                ActionExecuted("action_listen"),
                UserUttered("hi", intent={"name": "inexistent_intent"}),
                EntitiesAdded([{"name": "dummy"}]),
            ],
            [
                ActionExecuted("action_listen"),
                UserUttered("hi", intent={"name": "inexistent_intent"}),
                SlotSet("name"),
            ],
            [
                ActiveLoop("loop"),
                ActionExecuted("action_listen"),
                UserUttered("hi", intent={"name": "inexistent_intent"}),
                ActionExecutionRejected("loop"),
            ],
        ],
    )
    def test_skip_predictions_if_new_intent(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        caplog: LogCaptureFixture,
        tracker_events: List[Event],
    ):
        """Skips predictions if there's a new intent created."""
        loaded_policy = self.persist_and_load_policy(
            trained_policy, model_storage, resource, execution_context
        )
        tracker = DialogueStateTracker(sender_id="init", slots=default_domain.slots)
        tracker.update_with_events(tracker_events, default_domain)

        with caplog.at_level(logging.DEBUG):
            prediction = loaded_policy.predict_action_probabilities(
                tracker, default_domain, precomputations=None,
            )

        assert "Skipping predictions for UnexpecTEDIntentPolicy" in caplog.text

        assert prediction.probabilities == loaded_policy._default_predictions(
            default_domain
        )

    @pytest.mark.parametrize(
        "tracker_events_with_action, tracker_events_without_action",
        [
            (
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                    ActionExecuted("utter_greet"),
                    UserUttered(text="sad", intent={"name": "thank_you"}),
                ],
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    ActionExecuted("utter_greet"),
                    UserUttered(text="sad", intent={"name": "thank_you"}),
                ],
            ),
            (
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    EntitiesAdded(entities=[{"entity": "name", "value": "Peter"},]),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                    ActionExecuted("utter_greet"),
                    UserUttered(text="sad", intent={"name": "thank_you"}),
                ],
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    EntitiesAdded(entities=[{"entity": "name", "value": "Peter"},]),
                    ActionExecuted("utter_greet"),
                    UserUttered(text="sad", intent={"name": "thank_you"}),
                ],
            ),
            (
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                    ActionExecuted("some_form"),
                    ActiveLoop("some_form"),
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="default", intent={"name": "default"}),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                    UserUttered(text="sad", intent={"name": "thank_you"}),
                ],
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                    ActionExecuted("some_form"),
                    ActiveLoop("some_form"),
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="default", intent={"name": "default"}),
                    UserUttered(text="sad", intent={"name": "thank_you"}),
                ],
            ),
        ],
    )
    def test_ignore_action_unlikely_intent(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        tracker_events_with_action: List[Event],
        tracker_events_without_action: List[Event],
        tmp_path: Path,
    ):
        loaded_policy = self.persist_and_load_policy(
            trained_policy, model_storage, resource, execution_context
        )
        precomputations = None
        tracker_with_action = DialogueStateTracker.from_events(
            "test 1", evts=tracker_events_with_action
        )
        tracker_without_action = DialogueStateTracker.from_events(
            "test 2", evts=tracker_events_without_action
        )
        prediction_with_action = loaded_policy.predict_action_probabilities(
            tracker_with_action, default_domain, precomputations
        )
        prediction_without_action = loaded_policy.predict_action_probabilities(
            tracker_without_action, default_domain, precomputations
        )

        # If the weights didn't change then both trackers
        # should result in same prediction. For `UnexpecTEDIntentPolicy`, the real
        # prediction is inside action metadata.
        assert (
            prediction_with_action.action_metadata
            == prediction_without_action.action_metadata
        )

    def test_label_embedding_collection(self, trained_policy: UnexpecTEDIntentPolicy):
        label_ids = tf.constant([[[2], [-1]], [[1], [2]], [[0], [-1]]], dtype=tf.int32)

        all_label_embeddings = np.random.random((10, 20))

        # `-1` is used as padding label id. The embedding for it
        # will be the same as `label_id=0`
        expected_extracted_label_embeddings = tf.constant(
            np.concatenate(
                [
                    all_label_embeddings[2],
                    all_label_embeddings[0],
                    all_label_embeddings[1],
                    all_label_embeddings[2],
                    all_label_embeddings[0],
                    all_label_embeddings[0],
                ]
            ).reshape((3, 2, 20)),
            dtype=tf.float32,
        )

        actual_extracted_label_embeddings = trained_policy.model._get_labels_embed(
            label_ids, tf.constant(all_label_embeddings, dtype=tf.float32)
        )

        assert np.all(
            expected_extracted_label_embeddings == actual_extracted_label_embeddings
        )

    @pytest.mark.parametrize(
        "query_intent_index, ranking_length", [(0, 0), (1, 3), (2, 1), (5, 0)]
    )
    def test_collect_action_metadata(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        tmp_path: Path,
        query_intent_index: int,
        ranking_length: int,
    ):
        loaded_policy = self.persist_and_load_policy(
            trained_policy, model_storage, resource, execution_context
        )

        def test_individual_label_metadata(
            label_metadata: Dict[Text, Optional[float]],
            all_thresholds: Dict[int, float],
            all_similarities: np.array,
            label_index: int,
        ):

            expected_score = all_similarities[0][label_index]
            expected_threshold = (
                all_thresholds[label_index] if label_index in all_thresholds else None
            )
            expected_severity = (
                expected_threshold - expected_score if expected_threshold else None
            )

            assert label_metadata.get(SCORE_KEY) == expected_score
            assert label_metadata.get(THRESHOLD_KEY) == expected_threshold
            assert label_metadata.get(SEVERITY_KEY) == expected_severity

        # Monkey-patch certain attributes of the policy to make the testing easier.
        label_thresholds = {0: 1.2, 1: -0.3, 4: -2.3, 5: 0.2}
        loaded_policy.label_thresholds = label_thresholds
        loaded_policy.config[RANKING_LENGTH] = ranking_length

        # Some dummy similarities
        similarities = np.array([[3.2, 0.2, -1.2, -4.3, -5.1, 2.3]])

        query_intent = default_domain.intents[query_intent_index]

        metadata = loaded_policy._collect_action_metadata(
            default_domain, similarities, query_intent=query_intent
        )

        # Expected outer-most keys
        assert sorted(list(metadata.keys())) == sorted([QUERY_INTENT_KEY, RANKING_KEY])

        # Schema validation for query intent key
        assert sorted(list(metadata[QUERY_INTENT_KEY].keys())) == sorted(
            [NAME, SCORE_KEY, THRESHOLD_KEY, SEVERITY_KEY]
        )

        # Test all elements of metadata for query intent
        assert metadata[QUERY_INTENT_KEY].get(NAME) == query_intent
        test_individual_label_metadata(
            metadata.get(QUERY_INTENT_KEY),
            label_thresholds,
            similarities,
            query_intent_index,
        )

        # Check if ranking is sorted correctly and truncated to `ranking_length`
        sorted_label_similarities = sorted(
            [(index, score) for index, score in enumerate(similarities[0])],
            key=lambda x: -x[1],
        )
        sorted_label_similarities = (
            sorted_label_similarities[:ranking_length]
            if ranking_length
            else sorted_label_similarities
        )
        expected_label_rankings = [
            default_domain.intents[index] for index, _ in sorted_label_similarities
        ]
        collected_label_rankings = [
            label_metadata.get(NAME) for label_metadata in metadata.get(RANKING_KEY)
        ]
        assert collected_label_rankings == expected_label_rankings

        # Test all elements of metadata for all labels in ranking
        for label_metadata in metadata.get(RANKING_KEY):
            label_index = default_domain.intents.index(label_metadata.get(NAME))
            test_individual_label_metadata(
                label_metadata, label_thresholds, similarities, label_index
            )

    @pytest.mark.parametrize(
        "tracker_events_for_training, expected_trackers_with_events",
        [
            # Filter because of no intent and action name
            (
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello"),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="happy to make it work"),
                        ActionExecuted(action_text="Great!"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
            ),
            # Filter because of no action name
            (
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello"),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted(action_text="Great!"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
            ),
            # Filter because of no intent
            (
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello"),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="happy to make it work"),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
            ),
            # No filter needed
            (
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
            ),
            # Filter to return empty list of trackers
            (
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted(action_text="Great!"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
                [],
            ),
        ],
    )
    def test_filter_training_trackers(
        self,
        tracker_events_for_training: List[List[Event]],
        expected_trackers_with_events: List[List[Event]],
        domain: Domain,
    ):
        trackers_for_training = [
            TrackerWithCachedStates.from_events(
                sender_id=f"{tracker_index}", evts=events, domain=domain
            )
            for tracker_index, events in enumerate(tracker_events_for_training)
        ]

        filtered_trackers = UnexpecTEDIntentPolicy._get_trackers_for_training(
            trackers_for_training
        )
        assert len(filtered_trackers) == len(expected_trackers_with_events)
        for collected_tracker, expected_tracker_events in zip(
            filtered_trackers, expected_trackers_with_events
        ):
            collected_tracker_events = list(collected_tracker.events)
            assert collected_tracker_events == expected_tracker_events
Ejemplo n.º 11
0
async def test_request_correct_slots_after_unhappy_path_with_custom_required_slots(
):
    form_name = "some_form"
    slot_name_1 = "slot_1"
    slot_name_2 = "slot_2"

    domain = f"""
        slots:
          {slot_name_1}:
            type: any
          {slot_name_2}:
            type: any
        forms:
          {form_name}:
            {slot_name_1}:
            - type: from_intent
              intent: some_intent
              value: some_value
            {slot_name_2}:
            - type: from_intent
              intent: some_intent
              value: some_value
        actions:
        - validate_{form_name}
        """
    domain = Domain.from_yaml(domain)

    tracker = DialogueStateTracker.from_events(
        "default",
        [
            ActiveLoop(form_name),
            SlotSet(REQUESTED_SLOT, "slot_2"),
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered(
                "hello",
                intent={
                    "name": "greet",
                    "confidence": 1.0
                },
            ),
            ActionExecutionRejected(form_name),
            ActionExecuted("utter_greet"),
        ],
    )

    action_server_url = "http://my-action-server:5055/webhook"

    # Custom form validation action changes the order of the requested slots
    validate_return_events = [
        {
            "event": "slot",
            "name": REQUESTED_SLOT,
            "value": slot_name_2
        },
    ]

    # The form should ask the same slot again when coming back after unhappy path
    expected_events = [SlotSet(REQUESTED_SLOT, slot_name_2)]

    with aioresponses() as mocked:
        mocked.post(action_server_url,
                    payload={"events": validate_return_events})

        action_server = EndpointConfig(action_server_url)
        action = FormAction(form_name, action_server)

        events = await action.run(
            CollectingOutputChannel(),
            TemplatedNaturalLanguageGenerator(domain.responses),
            tracker,
            domain,
        )
        assert events == expected_events
Ejemplo n.º 12
0
async def test_switch_forms_with_same_slot(default_agent: Agent):
    """Tests switching of forms, where the first slot is the same in both forms.

    Tests the fix for issue 7710"""

    # Define two forms in the domain, with same first slot
    slot_a = "my_slot_a"

    form_1 = "my_form_1"
    utter_ask_form_1 = f"Please provide the value for {slot_a} of form 1"

    form_2 = "my_form_2"
    utter_ask_form_2 = f"Please provide the value for {slot_a} of form 2"

    domain = f"""
version: "2.0"
nlu:
- intent: order_status
  examples: |
    - check status of my order
    - when are my shoes coming in
- intent: return
  examples: |
    - start a return
    - I don't want my shoes anymore
forms:
  {form_1}:
    {slot_a}:
    - type: from_entity
      entity: number
  {form_2}:
    {slot_a}:
    - type: from_entity
      entity: number
responses:
    utter_ask_{form_1}_{slot_a}:
    - text: {utter_ask_form_1}
    utter_ask_{form_2}_{slot_a}:
    - text: {utter_ask_form_2}
"""

    domain = Domain.from_yaml(domain)

    # Driving it like rasa/core/processor
    processor = MessageProcessor(
        default_agent.interpreter,
        default_agent.policy_ensemble,
        domain,
        InMemoryTrackerStore(domain),
        InMemoryLockStore(),
        TemplatedNaturalLanguageGenerator(domain.responses),
    )

    # activate the first form
    tracker = DialogueStateTracker.from_events(
        "some-sender",
        evts=[
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered("order status", {
                "name": "form_1",
                "confidence": 1.0
            }),
            DefinePrevUserUtteredFeaturization(False),
        ],
    )
    # rasa/core/processor.predict_next_action
    prediction = PolicyPrediction([], "some_policy")
    action_1 = FormAction(form_1, None)

    await processor._run_action(
        action_1,
        tracker,
        CollectingOutputChannel(),
        TemplatedNaturalLanguageGenerator(domain.responses),
        prediction,
    )

    events_expected = [
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("order status", {
            "name": "form_1",
            "confidence": 1.0
        }),
        DefinePrevUserUtteredFeaturization(False),
        ActionExecuted(form_1),
        ActiveLoop(form_1),
        SlotSet(REQUESTED_SLOT, slot_a),
        BotUttered(
            text=utter_ask_form_1,
            metadata={"utter_action": f"utter_ask_{form_1}_{slot_a}"},
        ),
    ]
    assert tracker.applied_events() == events_expected

    next_events = [
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("return my shoes", {
            "name": "form_2",
            "confidence": 1.0
        }),
        DefinePrevUserUtteredFeaturization(False),
    ]
    tracker.update_with_events(
        next_events,
        domain,
    )
    events_expected.extend(next_events)

    # form_1 is still active, and bot will first validate if the user utterance
    #  provides valid data for the requested slot, which is rejected
    await processor._run_action(
        action_1,
        tracker,
        CollectingOutputChannel(),
        TemplatedNaturalLanguageGenerator(domain.responses),
        prediction,
    )
    events_expected.extend([ActionExecutionRejected(action_name=form_1)])
    assert tracker.applied_events() == events_expected

    # Next, bot predicts form_2
    action_2 = FormAction(form_2, None)
    await processor._run_action(
        action_2,
        tracker,
        CollectingOutputChannel(),
        TemplatedNaturalLanguageGenerator(domain.responses),
        prediction,
    )
    events_expected.extend([
        ActionExecuted(form_2),
        ActiveLoop(form_2),
        SlotSet(REQUESTED_SLOT, slot_a),
        BotUttered(
            text=utter_ask_form_2,
            metadata={"utter_action": f"utter_ask_{form_2}_{slot_a}"},
        ),
    ])
    assert tracker.applied_events() == events_expected
Ejemplo n.º 13
0
async def test_whole_loop():
    expected_activation_events = [
        ActionExecutionRejected("tada"),
        ActionExecuted("test"),
    ]

    expected_do_events = [ActionExecuted("do")]
    expected_deactivation_events = [SlotSet("deactivated")]

    form_name = "my form"

    class MyLoop(LoopAction):
        def name(self) -> Text:
            return form_name

        async def activate(self, *args: Any) -> List[Event]:
            return expected_activation_events

        async def do(self, *args: Any) -> List[Event]:
            events_so_far = args[-1]
            assert events_so_far == [
                ActiveLoop(form_name), *expected_activation_events
            ]

            return expected_do_events

        async def deactivate(self, *args) -> List[Event]:
            events_so_far = args[-1]
            assert events_so_far == [
                ActiveLoop(form_name),
                *expected_activation_events,
                *expected_do_events,
                ActiveLoop(None),
            ]

            return expected_deactivation_events

        async def is_done(self, *args) -> bool:
            events_so_far = args[-1]
            return events_so_far == [
                ActiveLoop(form_name),
                *expected_activation_events,
                *expected_do_events,
            ]

    tracker = DialogueStateTracker.from_events("some sender", [])
    domain = Domain.empty()

    action = MyLoop()
    actual = await action.run(
        CollectingOutputChannel(),
        TemplatedNaturalLanguageGenerator(domain.templates),
        tracker,
        domain,
    )

    assert actual == [
        ActiveLoop(form_name),
        *expected_activation_events,
        *expected_do_events,
        ActiveLoop(None),
        *expected_deactivation_events,
    ]
Ejemplo n.º 14
0
def _emulate_form_rejection(partial_tracker: DialogueStateTracker) -> None:
    from rasa.shared.core.events import ActionExecutionRejected

    rejected_action_name: Text = partial_tracker.active_loop_name
    partial_tracker.update(ActionExecutionRejected(rejected_action_name))