示例#1
0
def test_policy_predictions_dont_change_persistence():
    original_user_message = UserUttered("hi", intent={"name": "greet"})
    tracker = DialogueStateTracker.from_events(
        "Vova",
        evts=[
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered("hi", intent={"name": "greet"}),
            DefinePrevUserUtteredFeaturization(True),
            EntitiesAdded(entities=[{
                "entity": "entity1",
                "value": "value1"
            }]),
        ],
    )

    user_message: UserUttered = list(tracker.events)[1]
    # The entities from the policy predictions are accessible
    assert user_message.entities

    actual_serialized = user_message.as_dict()

    # Assert entities predicted by policies are not persisted
    assert not actual_serialized["parse_data"]["entities"]

    expected_serialized = original_user_message.as_dict()
    # don't compare timestamps
    expected_serialized.pop("timestamp")
    actual_serialized.pop("timestamp")

    assert actual_serialized == expected_serialized
示例#2
0
def test_policy_prediction_reflected_in_tracker_state():
    entities_predicted_by_policy = [{"entity": "entity1", "value": "value1"}]
    nlu_entities = [{"entity": "entityNLU", "value": "value100"}]

    tracker = DialogueStateTracker.from_events(
        "Tester",
        evts=[
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered(
                "hi",
                intent={"name": "greet"},
                entities=nlu_entities.copy(),
                message_id="unique",
                metadata={"some": "data"},
            ),
            DefinePrevUserUtteredFeaturization(True),
            EntitiesAdded(entities=entities_predicted_by_policy),
        ],
    )

    tracker_state = tracker.current_state()

    expected_state = {
        "sender_id": "Tester",
        "slots": {},
        "latest_message": {
            "intent": {
                "name": "greet"
            },
            "entities": nlu_entities + entities_predicted_by_policy,
            "text": "hi",
            "message_id": "unique",
            "metadata": {
                "some": "data"
            },
        },
        "latest_event_time": 1514764800.0,
        "followup_action": None,
        "paused": False,
        "events": None,
        "latest_input_channel": None,
        "active_loop": {},
        "latest_action": {
            "action_name": "action_listen"
        },
        "latest_action_name": "action_listen",
    }

    assert tracker_state == expected_state

    # Make sure we didn't change the actual event
    assert tracker.latest_message.parse_data["entities"] == nlu_entities
示例#3
0
    def _create_optional_event_for_entities(
        self,
        prediction_output: Dict[Text, tf.Tensor],
        is_e2e_prediction: bool,
        interpreter: NaturalLanguageInterpreter,
        tracker: DialogueStateTracker,
    ) -> Optional[List[Event]]:
        if tracker.latest_action_name != ACTION_LISTEN_NAME or not is_e2e_prediction:
            # entities belong only to the last user message
            # and only if user text was used for prediction,
            # a user message always comes after action listen
            return None

        if not self.config[ENTITY_RECOGNITION]:
            # entity recognition is not turned on, no entities can be predicted
            return None

        # The batch dimension of entity prediction is not the same as batch size,
        # rather it is the number of last (if max history featurizer else all)
        # text inputs in the batch
        # therefore, in order to pick entities from the latest user message
        # we need to pick entities from the last batch dimension of entity prediction
        predicted_tags, confidence_values = rasa.utils.train_utils.entity_label_to_tags(
            prediction_output,
            self._entity_tag_specs,
            self.config[BILOU_FLAG],
            prediction_index=-1,
        )

        if ENTITY_ATTRIBUTE_TYPE not in predicted_tags:
            # no entities detected
            return None

        # entities belong to the last message of the tracker
        # convert the predicted tags to actual entities
        text = tracker.latest_message.text
        parsed_message = interpreter.featurize_message(
            Message(data={TEXT: text}))
        tokens = parsed_message.get(TOKENS_NAMES[TEXT])
        entities = EntityExtractor.convert_predictions_into_entities(
            text,
            tokens,
            predicted_tags,
            self.split_entities_config,
            confidences=confidence_values,
        )

        # add the extractor name
        for entity in entities:
            entity[EXTRACTOR] = "TEDPolicy"

        return [EntitiesAdded(entities)]
示例#4
0
    events: List[Event],
    comparison_result: bool,
):
    result = all(event == events[0] for event in events)
    assert result == comparison_result


tested_events = [
    EntitiesAdded(
        entities=[
            {
                "entity": "city",
                "value": "London",
                "role": "destination",
                "group": "test",
            },
            {
                "entity": "count",
                "value": 1
            },
        ],
        timestamp=None,
    ),
    DefinePrevUserUtteredFeaturization(use_text_for_featurization=False,
                                       timestamp=None,
                                       metadata=None),
    ReminderCancelled(timestamp=1621590172.3872123),
    ReminderScheduled(timestamp=None,
                      trigger_date_time=datetime.now(),
                      intent="greet"),
    ActionExecutionRejected(action_name="my_action"),
示例#5
0
def test_autofill_slots_for_policy_entities():
    policy_entity, policy_entity_value = "policy_entity", "end-to-end"
    nlu_entity, nlu_entity_value = "nlu_entity", "nlu rocks"
    domain = Domain.from_yaml(
        textwrap.dedent(f"""
    entities:
    - {nlu_entity}
    - {policy_entity}

    slots:
        {nlu_entity}:
            type: text
        {policy_entity}:
            type: text
    """))

    tracker = DialogueStateTracker.from_events(
        "some sender",
        evts=[
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered(
                "hi",
                intent={"name": "greet"},
                entities=[{
                    "entity": nlu_entity,
                    "value": nlu_entity_value
                }],
            ),
            DefinePrevUserUtteredFeaturization(True),
            EntitiesAdded(entities=[
                {
                    "entity": policy_entity,
                    "value": policy_entity_value
                },
                {
                    "entity": nlu_entity,
                    "value": nlu_entity_value
                },
            ]),
        ],
        domain=domain,
        slots=domain.slots,
    )

    # Slots are correctly set
    assert tracker.slots[nlu_entity].value == nlu_entity_value
    assert tracker.slots[policy_entity].value == policy_entity_value

    expected_events = [
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered(
            "hi",
            intent={"name": "greet"},
            entities=[
                {
                    "entity": nlu_entity,
                    "value": nlu_entity_value
                },
                # Added by `DefinePrevUserUtteredEntities`
                {
                    "entity": policy_entity,
                    "value": policy_entity_value
                },
            ],
        ),
        # SlotSet event added for entity predicted by NLU
        SlotSet(nlu_entity, nlu_entity_value),
        DefinePrevUserUtteredFeaturization(True),
        EntitiesAdded(entities=[
            {
                "entity": policy_entity,
                "value": policy_entity_value
            },
            {
                "entity": nlu_entity,
                "value": nlu_entity_value
            },
        ]),
        # SlotSet event added for entity predicted by policies
        # This event is somewhat duplicate. We don't deduplicate as this is a true
        # reflection of the given events and it doesn't change the actual state.
        SlotSet(nlu_entity, nlu_entity_value),
        SlotSet(policy_entity, policy_entity_value),
    ]

    for actual, expected in zip(tracker.events, expected_events):
        assert actual == expected
示例#6
0
class TestMemoizationPolicy(PolicyTestCollection):
    @staticmethod
    def _policy_class_to_test() -> Type[PolicyGraphComponent]:
        return MemoizationPolicy

    @pytest.fixture(scope="class")
    def featurizer(self) -> TrackerFeaturizer:
        featurizer = MaxHistoryTrackerFeaturizer(None,
                                                 max_history=self.max_history)
        return featurizer

    def test_featurizer(
        self,
        trained_policy: PolicyGraphComponent,
        resource: Resource,
        model_storage: ModelStorage,
        tmp_path: Path,
        execution_context: ExecutionContext,
    ) -> None:
        assert isinstance(trained_policy.featurizer,
                          MaxHistoryTrackerFeaturizer)
        assert trained_policy.featurizer.state_featurizer is None
        loaded = trained_policy.__class__.load(
            self._config(trained_policy.config),
            model_storage,
            resource,
            execution_context,
        )
        assert isinstance(loaded.featurizer, MaxHistoryTrackerFeaturizer)
        assert loaded.featurizer.state_featurizer is None

    def test_memorise(
        self,
        trained_policy: MemoizationPolicy,
        default_domain: Domain,
        stories_path: Text,
    ):
        trackers = train_trackers(default_domain,
                                  stories_path,
                                  augmentation_factor=20)

        trained_policy.train(trackers, default_domain)
        lookup_with_augmentation = trained_policy.lookup

        trackers = [
            t for t in trackers
            if not hasattr(t, "is_augmented") or not t.is_augmented
        ]

        (
            all_states,
            all_actions,
        ) = trained_policy.featurizer.training_states_and_labels(
            trackers, default_domain)

        for tracker, states, actions in zip(trackers, all_states, all_actions):
            recalled = trained_policy.recall(states, tracker, default_domain,
                                             None)
            assert recalled == actions[0]

        nums = np.random.randn(default_domain.num_states)
        random_states = [{
            f: num
            for f, num in zip(default_domain.input_states, nums)
        }]
        assert trained_policy._recall_states(random_states) is None

        # compare augmentation for augmentation_factor of 0 and 20:
        trackers_no_augmentation = train_trackers(default_domain,
                                                  stories_path,
                                                  augmentation_factor=0)

        trained_policy.train(trackers_no_augmentation, default_domain)
        lookup_no_augmentation = trained_policy.lookup

        assert lookup_no_augmentation == lookup_with_augmentation

    def test_memorise_with_nlu(self, trained_policy: MemoizationPolicy,
                               default_domain: Domain):
        tracker = tracker_from_dialogue(TEST_DEFAULT_DIALOGUE, default_domain)
        states = trained_policy._prediction_states(tracker, default_domain)

        recalled = trained_policy.recall(states, tracker, default_domain, None)
        assert recalled is not None

    def test_finetune_after_load(
        self,
        trained_policy: MemoizationPolicy,
        resource: Resource,
        model_storage: ModelStorage,
        execution_context: ExecutionContext,
        default_domain: Domain,
        stories_path: Text,
    ):

        execution_context = dataclasses.replace(execution_context,
                                                is_finetuning=True)
        loaded_policy = MemoizationPolicy.load(trained_policy.config,
                                               model_storage, resource,
                                               execution_context)

        assert loaded_policy.finetune_mode

        new_story = TrackerWithCachedStates.from_events(
            "channel",
            domain=default_domain,
            slots=default_domain.slots,
            evts=[
                ActionExecuted(ACTION_LISTEN_NAME),
                UserUttered(intent={"name": "why"}),
                ActionExecuted("utter_channel"),
                ActionExecuted(ACTION_LISTEN_NAME),
            ],
        )
        original_train_data = train_trackers(default_domain,
                                             stories_path,
                                             augmentation_factor=20)

        loaded_policy.train(
            original_train_data + [new_story],
            default_domain,
        )

        # Get the hash of the tracker state of new story
        new_story_states, _ = loaded_policy.featurizer.training_states_and_labels(
            [new_story], default_domain)

        # Feature keys for each new state should be present in the lookup
        for states in new_story_states:
            state_key = loaded_policy._create_feature_key(states)
            assert state_key in loaded_policy.lookup

    @pytest.mark.parametrize(
        "tracker_events_with_action, tracker_events_without_action",
        [
            (
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                ],
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                ],
            ),
            (
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    EntitiesAdded(entities=[{
                        "entity": "name",
                        "value": "Peter"
                    }]),
                    SlotSet("name", "Peter"),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                ],
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    SlotSet("name", "Peter"),
                    EntitiesAdded(entities=[{
                        "entity": "name",
                        "value": "Peter"
                    }]),
                ],
            ),
        ],
    )
    def test_ignore_action_unlikely_intent(
        self,
        trained_policy: MemoizationPolicy,
        default_domain: Domain,
        tracker_events_with_action: List[Event],
        tracker_events_without_action: List[Event],
    ):
        tracker_with_action = DialogueStateTracker.from_events(
            "test 1",
            evts=tracker_events_with_action,
            slots=default_domain.slots)
        tracker_without_action = DialogueStateTracker.from_events(
            "test 2",
            evts=tracker_events_without_action,
            slots=default_domain.slots)
        prediction_with_action = trained_policy.predict_action_probabilities(
            tracker_with_action,
            default_domain,
        )
        prediction_without_action = trained_policy.predict_action_probabilities(
            tracker_without_action,
            default_domain,
        )

        # Memoization shouldn't be affected with the
        # presence of action_unlikely_intent.
        assert (prediction_with_action.probabilities ==
                prediction_without_action.probabilities)

    @pytest.mark.parametrize(
        "featurizer_config, tracker_featurizer, state_featurizer",
        [
            (None, MaxHistoryTrackerFeaturizer(), type(None)),
            ([], MaxHistoryTrackerFeaturizer(), type(None)),
        ],
    )
    def test_empty_featurizer_configs(
        self,
        featurizer_config: Optional[Dict[Text, Any]],
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        tracker_featurizer: MaxHistoryTrackerFeaturizer,
        state_featurizer: Type[SingleStateFeaturizer],
    ):
        featurizer_config_override = ({
            "featurizer": featurizer_config
        } if featurizer_config else {})
        policy = self.create_policy(
            None,
            model_storage=model_storage,
            resource=resource,
            execution_context=execution_context,
            config=self._config(featurizer_config_override),
        )

        featurizer = policy.featurizer
        assert isinstance(featurizer, tracker_featurizer.__class__)

        if featurizer_config:
            expected_max_history = featurizer_config[0].get(POLICY_MAX_HISTORY)
        else:
            expected_max_history = self._config().get(POLICY_MAX_HISTORY)

        assert featurizer.max_history == expected_max_history

        assert isinstance(featurizer.state_featurizer, state_featurizer)

    @pytest.mark.parametrize("max_history", [1, 2, 3, 4, None])
    def test_prediction(
        self,
        max_history: Optional[int],
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
    ):
        policy = self.create_policy(
            featurizer=MaxHistoryTrackerFeaturizer(max_history=max_history),
            model_storage=model_storage,
            resource=resource,
            execution_context=execution_context,
        )

        GREET_INTENT_NAME = "greet"
        UTTER_GREET_ACTION = "utter_greet"
        UTTER_BYE_ACTION = "utter_goodbye"
        domain = Domain.from_yaml(f"""
            intents:
            - {GREET_INTENT_NAME}
            actions:
            - {UTTER_GREET_ACTION}
            - {UTTER_BYE_ACTION}
            slots:
                slot_1:
                    type: bool
                slot_2:
                    type: bool
                slot_3:
                    type: bool
                slot_4:
                    type: bool
            """)
        events = [
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(UTTER_GREET_ACTION),
            SlotSet("slot_1", True),
            ActionExecuted(UTTER_GREET_ACTION),
            SlotSet("slot_2", True),
            SlotSet("slot_3", True),
            ActionExecuted(UTTER_GREET_ACTION),
            ActionExecuted(UTTER_GREET_ACTION),
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(UTTER_GREET_ACTION),
            SlotSet("slot_4", True),
            ActionExecuted(UTTER_BYE_ACTION),
        ]
        training_story = TrackerWithCachedStates.from_events(
            "training story",
            evts=events,
            domain=domain,
            slots=domain.slots,
        )
        test_story = TrackerWithCachedStates.from_events(
            "training story",
            events[:-1],
            domain=domain,
            slots=domain.slots,
        )
        policy.train([training_story], domain)
        prediction = policy.predict_action_probabilities(test_story, domain)
        assert (domain.action_names_or_texts[prediction.probabilities.index(
            max(prediction.probabilities))] == UTTER_BYE_ACTION)
class TestUnexpecTEDIntentPolicy(TestTEDPolicy):
    @staticmethod
    def _policy_class_to_test() -> Type[UnexpecTEDIntentPolicy]:
        return UnexpecTEDIntentPolicy

    @pytest.fixture(scope="class")
    def featurizer(self) -> TrackerFeaturizer:
        featurizer = IntentMaxHistoryTrackerFeaturizer(
            IntentTokenizerSingleStateFeaturizer(), max_history=self.max_history
        )
        return featurizer

    @staticmethod
    def persist_and_load_policy(
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
    ):
        return trained_policy.__class__.load(
            trained_policy.config, model_storage, resource, execution_context
        )

    def test_ranking_length(self, trained_policy: UnexpecTEDIntentPolicy):
        assert trained_policy.config[RANKING_LENGTH] == LABEL_RANKING_LENGTH

    def test_ranking_length_and_renormalization(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        tracker: DialogueStateTracker,
        default_domain: Domain,
    ):
        precomputations = None
        prediction_metadata = trained_policy.predict_action_probabilities(
            tracker, default_domain, precomputations,
        ).action_metadata
        assert (
            prediction_metadata is None
            or len(prediction_metadata[RANKING_KEY])
            == trained_policy.config[RANKING_LENGTH]
        )

    def test_label_data_assembly(
        self, trained_policy: UnexpecTEDIntentPolicy, default_domain: Domain
    ):

        # Construct input data
        state_featurizer = trained_policy.featurizer.state_featurizer
        encoded_all_labels = state_featurizer.encode_all_labels(
            default_domain, precomputations=None
        )
        attribute_data, _ = model_data_utils.convert_to_data_format(encoded_all_labels)

        assembled_label_data = trained_policy._assemble_label_data(
            attribute_data, default_domain
        )
        assembled_label_data_signature = assembled_label_data.get_signature()

        assert list(assembled_label_data_signature.keys()) == [
            f"{LABEL}_{INTENT}",
            LABEL,
        ]
        assert assembled_label_data.num_examples == len(default_domain.intents)
        assert list(assembled_label_data_signature[f"{LABEL}_{INTENT}"].keys()) == [
            MASK,
            SENTENCE,
        ]
        assert list(assembled_label_data_signature[LABEL].keys()) == [IDS]
        assert assembled_label_data_signature[f"{LABEL}_{INTENT}"][SENTENCE][
            0
        ].units == len(default_domain.intents)

    def test_training_with_no_intent(
        self,
        featurizer: Optional[TrackerFeaturizer],
        default_domain: Domain,
        tmp_path: Path,
        caplog: LogCaptureFixture,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
    ):
        stories = tmp_path / "stories.yml"
        stories.write_text(
            """
            version: "3.0"
            stories:
            - story: test path
              steps:
              - action: utter_greet
            """
        )
        policy = self.create_policy(
            featurizer=featurizer,
            model_storage=model_storage,
            resource=resource,
            execution_context=execution_context,
        )
        import tests.core.test_policies

        training_trackers = tests.core.test_policies.train_trackers(
            default_domain, str(stories), augmentation_factor=20
        )

        with pytest.warns(UserWarning):
            policy.train(training_trackers, default_domain, precomputations=None)

    def test_prepared_data_for_threshold_prediction(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        default_domain: Domain,
        stories_path: Path,
    ):
        training_trackers = train_trackers(
            default_domain, stories_path, augmentation_factor=0
        )
        training_model_data, _ = trained_policy._prepare_for_training(
            training_trackers, default_domain, precomputations=None,
        )

        data_for_prediction = trained_policy._prepare_data_for_prediction(
            training_model_data
        )

        assert set(data_for_prediction.data.keys()).issubset(PREDICTION_FEATURES)

    def test_similarities_collection_for_label_ids(self):
        label_ids = np.array([[0, 1], [1, -1], [2, -1]])
        outputs = {
            "similarities": np.array(
                [[[1.2, 0.3, 0.2]], [[0.5, 0.2, 1.6]], [[0.01, 0.1, 1.7]],]
            )
        }
        label_id_similarities = UnexpecTEDIntentPolicy._collect_label_id_grouped_scores(
            outputs, label_ids
        )

        # Should contain similarities for all label ids except padding token.
        assert sorted(list(label_id_similarities.keys())) == [0, 1, 2]

        # Cross-check that the collected similarities are correct for each label id.
        assert label_id_similarities[0] == {
            POSITIVE_SCORES_KEY: [1.2],
            NEGATIVE_SCORES_KEY: [0.5, 0.01],
        }
        assert label_id_similarities[1] == {
            POSITIVE_SCORES_KEY: [0.3, 0.2],
            NEGATIVE_SCORES_KEY: [0.1],
        }
        assert label_id_similarities[2] == {
            POSITIVE_SCORES_KEY: [1.7],
            NEGATIVE_SCORES_KEY: [0.2, 1.6],
        }

    def test_label_quantiles_computation(self):
        label_id_scores = {
            0: {
                POSITIVE_SCORES_KEY: [1.3, 0.2],
                NEGATIVE_SCORES_KEY: [
                    -0.1,
                    -1.2,
                    -2.3,
                    -4.1,
                    -0.5,
                    0.2,
                    0.8,
                    0.9,
                    -3.2,
                    -2.7,
                ],
            },
            3: {POSITIVE_SCORES_KEY: [1.3, 0.2], NEGATIVE_SCORES_KEY: [-0.1]},
            6: {POSITIVE_SCORES_KEY: [1.3, 0.2], NEGATIVE_SCORES_KEY: []},
        }
        expected_thresholds = {
            0: [
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                -0.1,
                -0.1,
                -0.5,
                -0.5,
                -1.2,
                -1.2,
                -1.2,
                -2.3,
                -2.3,
                -2.7,
                -2.7,
                -3.2,
                -3.2,
                -4.1,
                -4.1,
            ],
            3: [
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
            ],
            6: [
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
            ],
        }
        thresholds = UnexpecTEDIntentPolicy._compute_label_quantiles(label_id_scores)
        assert sorted(list(thresholds.keys())) == sorted(
            list(expected_thresholds.keys())
        )
        for label_id, tolerance_thresholds in thresholds.items():
            assert expected_thresholds[label_id] == tolerance_thresholds

    def test_post_training_threshold_computation(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        default_domain: Domain,
        stories_path: Path,
    ):
        training_trackers = train_trackers(
            default_domain, stories_path, augmentation_factor=0
        )
        training_model_data, label_ids = trained_policy._prepare_for_training(
            training_trackers, default_domain, precomputations=None,
        )

        trained_policy.compute_label_quantiles_post_training(
            training_model_data, label_ids
        )

        computed_thresholds = trained_policy.label_quantiles

        # -1 is used for padding and hence is not expected in the keys
        expected_keys = list(np.unique(label_ids))
        expected_keys.remove(-1)

        assert sorted(list(computed_thresholds.keys())) == sorted(expected_keys)

    @pytest.mark.parametrize(
        "tolerance, expected_thresholds",
        [
            (0.0, [0.2, -0.1, 0.2]),
            (0.75, [-2.9, -0.1, -4.3]),
            (0.72, [-2.7, -0.1, -4.0]),
            (0.78, [-2.9, -0.1, -4.3]),
            (1.0, [-4.1, -0.1, -5.5]),
        ],
    )
    def test_pick_thresholds_for_labels(
        self, tolerance: float, expected_thresholds: List[float]
    ):
        label_id_tolerance_thresholds = {
            0: [
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                0.2,
                -0.1,
                -0.1,
                -0.5,
                -0.5,
                -1.2,
                -1.2,
                -2.3,
                -2.3,
                -2.7,
                -2.9,
                -3.2,
                -3.2,
                -4.1,
                -4.1,
            ],
            3: [
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
                -0.1,
            ],
            4: [0.2 - (index * 0.3) for index in range(20)],
        }
        thresholds = UnexpecTEDIntentPolicy._pick_thresholds(
            label_id_tolerance_thresholds, tolerance
        )
        assert sorted(list(thresholds.keys())) == sorted(
            list(label_id_tolerance_thresholds.keys())
        )
        computed_values = list(thresholds.values())
        assert expected_thresholds == computed_values

    @pytest.mark.parametrize(
        "predicted_similarity, threshold_value, is_unlikely",
        [(1.2, 0.2, False), (0.3, -0.1, False), (-1.5, 0.03, True)],
    )
    def test_unlikely_intent_check(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        predicted_similarity: float,
        threshold_value: float,
        is_unlikely: bool,
        tmp_path: Path,
    ):
        loaded_policy = self.persist_and_load_policy(
            trained_policy, model_storage, resource, execution_context
        )
        # Construct dummy similarities
        similarities = np.array([[0.0] * len(default_domain.intents)])
        dummy_intent_index = 4
        similarities[0, dummy_intent_index] = predicted_similarity

        loaded_policy.label_thresholds[dummy_intent_index] = threshold_value
        query_intent = default_domain.intents[dummy_intent_index]

        unlikely_intent_prediction = loaded_policy._check_unlikely_intent(
            default_domain, similarities, query_intent
        )

        assert is_unlikely == unlikely_intent_prediction

    def test_should_check_for_intent(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        tmp_path: Path,
    ):
        loaded_policy = self.persist_and_load_policy(
            trained_policy, model_storage, resource, execution_context
        )

        intent_index = 0
        assert (
            loaded_policy._should_check_for_intent(
                default_domain.intents[intent_index], default_domain
            )
            is False
        )

        intent_index = 4
        assert loaded_policy._should_check_for_intent(
            default_domain.intents[intent_index], default_domain
        )

        loaded_policy.config[IGNORE_INTENTS_LIST] = [
            default_domain.intents[intent_index]
        ]
        assert (
            loaded_policy._should_check_for_intent(
                default_domain.intents[intent_index], default_domain
            )
            is False
        )

    def test_no_action_unlikely_intent_prediction(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        tmp_path: Path,
    ):
        loaded_policy = self.persist_and_load_policy(
            trained_policy, model_storage, resource, execution_context
        )

        expected_probabilities = [0] * default_domain.num_actions

        precomputations = None
        tracker = DialogueStateTracker(sender_id="init", slots=default_domain.slots)
        prediction = loaded_policy.predict_action_probabilities(
            tracker, default_domain, precomputations
        )

        assert prediction.probabilities == expected_probabilities

        tracker.update_with_events(
            [
                UserUttered(text="hello", intent={"name": "greet"}),
                ActionExecuted(action_name="utter_greet"),
            ],
            default_domain,
        )
        prediction = loaded_policy.predict_action_probabilities(
            tracker, default_domain, precomputations
        )

        assert prediction.probabilities == expected_probabilities

        loaded_policy.model = None

        prediction = loaded_policy.predict_action_probabilities(
            tracker, default_domain, precomputations
        )

        assert prediction.probabilities == expected_probabilities

    @pytest.mark.parametrize(
        "predicted_similarity, threshold_value, is_unlikely",
        [(1.2, 0.2, False), (0.3, -0.1, False), (-1.5, 0.03, True)],
    )
    def test_action_unlikely_intent_prediction(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        predicted_similarity: float,
        threshold_value: float,
        is_unlikely: bool,
        monkeypatch: MonkeyPatch,
        tmp_path: Path,
    ):
        loaded_policy = self.persist_and_load_policy(
            trained_policy, model_storage, resource, execution_context
        )

        similarities = np.array([[[0.0] * len(default_domain.intents)]])

        dummy_intent_index = 4
        similarities[0, 0, dummy_intent_index] = predicted_similarity
        query_intent = default_domain.intents[dummy_intent_index]

        loaded_policy.label_thresholds[dummy_intent_index] = threshold_value

        precomputations = None
        tracker = DialogueStateTracker(sender_id="init", slots=default_domain.slots)

        tracker.update_with_events(
            [UserUttered(text="hello", intent={"name": query_intent})], default_domain,
        )

        # Preset the model predictions to the similarity values
        # so that we don't need to hardcode for particular model predictions.
        monkeypatch.setattr(
            loaded_policy.model,
            "run_inference",
            lambda data: {"similarities": similarities},
        )

        prediction = loaded_policy.predict_action_probabilities(
            tracker, default_domain, precomputations
        )

        if not is_unlikely:
            assert prediction.probabilities == [0.0] * default_domain.num_actions
        else:
            assert (
                prediction.probabilities[
                    default_domain.index_for_action(ACTION_UNLIKELY_INTENT_NAME)
                ]
                == 1.0
            )

            # Make sure metadata is set. The exact structure
            # of the metadata is tested separately and
            # not as part of this test.
            assert prediction.action_metadata is not None
            # Assert metadata is serializable
            assert json.dumps(prediction.action_metadata)

    @pytest.mark.parametrize(
        "tracker_events, should_skip",
        [
            ([], True),
            ([ActionExecuted("action_listen")], True),
            (
                [
                    ActionExecuted("action_listen"),
                    UserUttered("hi", intent={"name": "greet"}),
                ],
                False,
            ),
            (
                [
                    ActionExecuted("action_listen"),
                    UserUttered("hi", intent={"name": "greet"}),
                    EntitiesAdded([{"name": "dummy"}]),
                ],
                False,
            ),
            (
                [
                    ActionExecuted("action_listen"),
                    UserUttered("hi", intent={"name": "greet"}),
                    SlotSet("name"),
                ],
                False,
            ),
            (
                [
                    ActiveLoop("loop"),
                    ActionExecuted("action_listen"),
                    UserUttered("hi", intent={"name": "greet"}),
                    ActionExecutionRejected("loop"),
                ],
                False,
            ),
            (
                [
                    ActionExecuted("action_listen"),
                    UserUttered("hi", intent={"name": "greet"}),
                    ActionExecuted("utter_greet"),
                ],
                True,
            ),
        ],
    )
    def test_skip_predictions_to_prevent_loop(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        caplog: LogCaptureFixture,
        tracker_events: List[Event],
        should_skip: bool,
        tmp_path: Path,
    ):
        """Skips predictions to prevent loop."""
        loaded_policy = self.persist_and_load_policy(
            trained_policy, model_storage, resource, execution_context
        )
        precomputations = None
        tracker = DialogueStateTracker(sender_id="init", slots=default_domain.slots)
        tracker.update_with_events(tracker_events, default_domain)
        with caplog.at_level(logging.DEBUG):
            prediction = loaded_policy.predict_action_probabilities(
                tracker, default_domain, precomputations
            )

        assert (
            "Skipping predictions for UnexpecTEDIntentPolicy" in caplog.text
        ) == should_skip

        if should_skip:
            assert prediction.probabilities == loaded_policy._default_predictions(
                default_domain
            )

    @pytest.mark.parametrize(
        "tracker_events",
        [
            [
                ActionExecuted("action_listen"),
                UserUttered("hi", intent={"name": "inexistent_intent"}),
            ],
            [
                ActionExecuted("action_listen"),
                UserUttered("hi", intent={"name": "inexistent_intent"}),
                EntitiesAdded([{"name": "dummy"}]),
            ],
            [
                ActionExecuted("action_listen"),
                UserUttered("hi", intent={"name": "inexistent_intent"}),
                SlotSet("name"),
            ],
            [
                ActiveLoop("loop"),
                ActionExecuted("action_listen"),
                UserUttered("hi", intent={"name": "inexistent_intent"}),
                ActionExecutionRejected("loop"),
            ],
        ],
    )
    def test_skip_predictions_if_new_intent(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        caplog: LogCaptureFixture,
        tracker_events: List[Event],
    ):
        """Skips predictions if there's a new intent created."""
        loaded_policy = self.persist_and_load_policy(
            trained_policy, model_storage, resource, execution_context
        )
        tracker = DialogueStateTracker(sender_id="init", slots=default_domain.slots)
        tracker.update_with_events(tracker_events, default_domain)

        with caplog.at_level(logging.DEBUG):
            prediction = loaded_policy.predict_action_probabilities(
                tracker, default_domain, precomputations=None,
            )

        assert "Skipping predictions for UnexpecTEDIntentPolicy" in caplog.text

        assert prediction.probabilities == loaded_policy._default_predictions(
            default_domain
        )

    @pytest.mark.parametrize(
        "tracker_events_with_action, tracker_events_without_action",
        [
            (
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                    ActionExecuted("utter_greet"),
                    UserUttered(text="sad", intent={"name": "thank_you"}),
                ],
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    ActionExecuted("utter_greet"),
                    UserUttered(text="sad", intent={"name": "thank_you"}),
                ],
            ),
            (
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    EntitiesAdded(entities=[{"entity": "name", "value": "Peter"},]),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                    ActionExecuted("utter_greet"),
                    UserUttered(text="sad", intent={"name": "thank_you"}),
                ],
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    EntitiesAdded(entities=[{"entity": "name", "value": "Peter"},]),
                    ActionExecuted("utter_greet"),
                    UserUttered(text="sad", intent={"name": "thank_you"}),
                ],
            ),
            (
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                    ActionExecuted("some_form"),
                    ActiveLoop("some_form"),
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="default", intent={"name": "default"}),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                    UserUttered(text="sad", intent={"name": "thank_you"}),
                ],
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                    ActionExecuted("some_form"),
                    ActiveLoop("some_form"),
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="default", intent={"name": "default"}),
                    UserUttered(text="sad", intent={"name": "thank_you"}),
                ],
            ),
        ],
    )
    def test_ignore_action_unlikely_intent(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        tracker_events_with_action: List[Event],
        tracker_events_without_action: List[Event],
        tmp_path: Path,
    ):
        loaded_policy = self.persist_and_load_policy(
            trained_policy, model_storage, resource, execution_context
        )
        precomputations = None
        tracker_with_action = DialogueStateTracker.from_events(
            "test 1", evts=tracker_events_with_action
        )
        tracker_without_action = DialogueStateTracker.from_events(
            "test 2", evts=tracker_events_without_action
        )
        prediction_with_action = loaded_policy.predict_action_probabilities(
            tracker_with_action, default_domain, precomputations
        )
        prediction_without_action = loaded_policy.predict_action_probabilities(
            tracker_without_action, default_domain, precomputations
        )

        # If the weights didn't change then both trackers
        # should result in same prediction. For `UnexpecTEDIntentPolicy`, the real
        # prediction is inside action metadata.
        assert (
            prediction_with_action.action_metadata
            == prediction_without_action.action_metadata
        )

    def test_label_embedding_collection(self, trained_policy: UnexpecTEDIntentPolicy):
        label_ids = tf.constant([[[2], [-1]], [[1], [2]], [[0], [-1]]], dtype=tf.int32)

        all_label_embeddings = np.random.random((10, 20))

        # `-1` is used as padding label id. The embedding for it
        # will be the same as `label_id=0`
        expected_extracted_label_embeddings = tf.constant(
            np.concatenate(
                [
                    all_label_embeddings[2],
                    all_label_embeddings[0],
                    all_label_embeddings[1],
                    all_label_embeddings[2],
                    all_label_embeddings[0],
                    all_label_embeddings[0],
                ]
            ).reshape((3, 2, 20)),
            dtype=tf.float32,
        )

        actual_extracted_label_embeddings = trained_policy.model._get_labels_embed(
            label_ids, tf.constant(all_label_embeddings, dtype=tf.float32)
        )

        assert np.all(
            expected_extracted_label_embeddings == actual_extracted_label_embeddings
        )

    @pytest.mark.parametrize(
        "query_intent_index, ranking_length", [(0, 0), (1, 3), (2, 1), (5, 0)]
    )
    def test_collect_action_metadata(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        tmp_path: Path,
        query_intent_index: int,
        ranking_length: int,
    ):
        loaded_policy = self.persist_and_load_policy(
            trained_policy, model_storage, resource, execution_context
        )

        def test_individual_label_metadata(
            label_metadata: Dict[Text, Optional[float]],
            all_thresholds: Dict[int, float],
            all_similarities: np.array,
            label_index: int,
        ):

            expected_score = all_similarities[0][label_index]
            expected_threshold = (
                all_thresholds[label_index] if label_index in all_thresholds else None
            )
            expected_severity = (
                expected_threshold - expected_score if expected_threshold else None
            )

            assert label_metadata.get(SCORE_KEY) == expected_score
            assert label_metadata.get(THRESHOLD_KEY) == expected_threshold
            assert label_metadata.get(SEVERITY_KEY) == expected_severity

        # Monkey-patch certain attributes of the policy to make the testing easier.
        label_thresholds = {0: 1.2, 1: -0.3, 4: -2.3, 5: 0.2}
        loaded_policy.label_thresholds = label_thresholds
        loaded_policy.config[RANKING_LENGTH] = ranking_length

        # Some dummy similarities
        similarities = np.array([[3.2, 0.2, -1.2, -4.3, -5.1, 2.3]])

        query_intent = default_domain.intents[query_intent_index]

        metadata = loaded_policy._collect_action_metadata(
            default_domain, similarities, query_intent=query_intent
        )

        # Expected outer-most keys
        assert sorted(list(metadata.keys())) == sorted([QUERY_INTENT_KEY, RANKING_KEY])

        # Schema validation for query intent key
        assert sorted(list(metadata[QUERY_INTENT_KEY].keys())) == sorted(
            [NAME, SCORE_KEY, THRESHOLD_KEY, SEVERITY_KEY]
        )

        # Test all elements of metadata for query intent
        assert metadata[QUERY_INTENT_KEY].get(NAME) == query_intent
        test_individual_label_metadata(
            metadata.get(QUERY_INTENT_KEY),
            label_thresholds,
            similarities,
            query_intent_index,
        )

        # Check if ranking is sorted correctly and truncated to `ranking_length`
        sorted_label_similarities = sorted(
            [(index, score) for index, score in enumerate(similarities[0])],
            key=lambda x: -x[1],
        )
        sorted_label_similarities = (
            sorted_label_similarities[:ranking_length]
            if ranking_length
            else sorted_label_similarities
        )
        expected_label_rankings = [
            default_domain.intents[index] for index, _ in sorted_label_similarities
        ]
        collected_label_rankings = [
            label_metadata.get(NAME) for label_metadata in metadata.get(RANKING_KEY)
        ]
        assert collected_label_rankings == expected_label_rankings

        # Test all elements of metadata for all labels in ranking
        for label_metadata in metadata.get(RANKING_KEY):
            label_index = default_domain.intents.index(label_metadata.get(NAME))
            test_individual_label_metadata(
                label_metadata, label_thresholds, similarities, label_index
            )

    @pytest.mark.parametrize(
        "tracker_events_for_training, expected_trackers_with_events",
        [
            # Filter because of no intent and action name
            (
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello"),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="happy to make it work"),
                        ActionExecuted(action_text="Great!"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
            ),
            # Filter because of no action name
            (
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello"),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted(action_text="Great!"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
            ),
            # Filter because of no intent
            (
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello"),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="happy to make it work"),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
            ),
            # No filter needed
            (
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted("utter_goodbye"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
            ),
            # Filter to return empty list of trackers
            (
                [
                    [
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(text="hello", intent={"name": "greet"}),
                        ActionExecuted("utter_greet"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                        UserUttered(
                            text="happy to make it work", intent={"name": "goodbye"}
                        ),
                        ActionExecuted(action_text="Great!"),
                        ActionExecuted(ACTION_LISTEN_NAME),
                    ],
                ],
                [],
            ),
        ],
    )
    def test_filter_training_trackers(
        self,
        tracker_events_for_training: List[List[Event]],
        expected_trackers_with_events: List[List[Event]],
        domain: Domain,
    ):
        trackers_for_training = [
            TrackerWithCachedStates.from_events(
                sender_id=f"{tracker_index}", evts=events, domain=domain
            )
            for tracker_index, events in enumerate(tracker_events_for_training)
        ]

        filtered_trackers = UnexpecTEDIntentPolicy._get_trackers_for_training(
            trackers_for_training
        )
        assert len(filtered_trackers) == len(expected_trackers_with_events)
        for collected_tracker, expected_tracker_events in zip(
            filtered_trackers, expected_trackers_with_events
        ):
            collected_tracker_events = list(collected_tracker.events)
            assert collected_tracker_events == expected_tracker_events
示例#8
0
class TestTEDPolicy(PolicyTestCollection):
    @staticmethod
    def _policy_class_to_test() -> Type[TEDPolicy]:
        return TEDPolicy

    def test_train_model_checkpointing(self, tmp_path: Path,
                                       tmp_path_factory: TempPathFactory):
        train_core(
            domain="data/test_domains/default.yml",
            stories="data/test_yaml_stories/stories_defaultdomain.yml",
            output=str(tmp_path),
            fixed_model_name="my_model.tar.gz",
            config="data/test_config/config_ted_policy_model_checkpointing.yml",
        )

        storage_dir = tmp_path_factory.mktemp("storage dir")
        storage, _ = LocalModelStorage.from_model_archive(
            storage_dir, tmp_path / "my_model.tar.gz")

        checkpoint_dir = get_checkpoint_dir_path(storage_dir)
        assert checkpoint_dir.is_dir()

    def test_doesnt_checkpoint_with_no_checkpointing(
            self, tmp_path: Path, tmp_path_factory: TempPathFactory):
        train_core(
            domain="data/test_domains/default.yml",
            stories="data/test_yaml_stories/stories_defaultdomain.yml",
            output=str(tmp_path),
            fixed_model_name="my_model.tar.gz",
            config=
            "data/test_config/config_ted_policy_no_model_checkpointing.yml",
        )

        storage_dir = tmp_path_factory.mktemp("storage dir")
        storage, _ = LocalModelStorage.from_model_archive(
            storage_dir, tmp_path / "my_model.tar.gz")

        checkpoint_dir = get_checkpoint_dir_path(storage_dir)
        assert not checkpoint_dir.is_dir()

    def test_doesnt_checkpoint_with_zero_eval_num_examples(
            self, tmp_path: Path, tmp_path_factory: TempPathFactory):
        checkpoint_dir = get_checkpoint_dir_path(tmp_path)
        assert not checkpoint_dir.is_dir()
        config_file = "config_ted_policy_model_checkpointing_zero_eval_num_examples.yml"
        with pytest.warns(UserWarning) as warning:
            train_core(
                domain="data/test_domains/default.yml",
                stories="data/test_yaml_stories/stories_defaultdomain.yml",
                output=str(tmp_path),
                fixed_model_name="my_model.tar.gz",
                config=f"data/test_config/{config_file}",
            )
        warn_text = (
            f"You have opted to save the best model, but the value of "
            f"'{EVAL_NUM_EXAMPLES}' is not greater than 0. No checkpoint model will be "
            f"saved.")

        assert len([w for w in warning if warn_text in str(w.message)]) == 1

        storage_dir = tmp_path_factory.mktemp("storage dir")
        storage, _ = LocalModelStorage.from_model_archive(
            storage_dir, tmp_path / "my_model.tar.gz")

        checkpoint_dir = get_checkpoint_dir_path(storage_dir)
        assert not checkpoint_dir.is_dir()

    @pytest.mark.parametrize(
        "should_finetune, epoch_override, expected_epoch_value",
        [
            (
                True,
                TEDPolicy.get_default_config()[EPOCHS] + 1,
                TEDPolicy.get_default_config()[EPOCHS] + 1,
            ),
            (
                False,
                TEDPolicy.get_default_config()[EPOCHS] + 1,
                TEDPolicy.get_default_config()[EPOCHS],
            ),  # trained_policy uses default epochs during training
        ],
    )
    def test_epoch_override_when_loaded(
        self,
        trained_policy: TEDPolicy,
        should_finetune: bool,
        epoch_override: int,
        expected_epoch_value: int,
        resource: Resource,
        model_storage: ModelStorage,
        execution_context: ExecutionContext,
    ):
        execution_context.is_finetuning = should_finetune
        loaded_policy = trained_policy.__class__.load(
            {
                **self._config(), EPOCH_OVERRIDE: epoch_override
            },
            model_storage,
            resource,
            execution_context,
        )

        assert loaded_policy.config[EPOCHS] == expected_epoch_value

    def test_train_fails_with_checkpoint_zero_eval_num_epochs(
            self, tmp_path: Path):
        config_file = "config_ted_policy_model_checkpointing_zero_every_num_epochs.yml"
        match_string = ("Only values either equal to -1 or greater"
                        " than 0 are allowed for this parameter.")
        with pytest.raises(
                InvalidConfigException,
                match=match_string,
        ):
            train_core(
                domain="data/test_domains/default.yml",
                stories="data/test_yaml_stories/stories_defaultdomain.yml",
                output=str(tmp_path),
                config=f"data/test_config/{config_file}",
            )

        assert not (tmp_path / "my_model.tar.gz").is_file()

    def test_training_with_no_intent(
        self,
        featurizer: Optional[TrackerFeaturizer],
        default_domain: Domain,
        tmp_path: Path,
        caplog: LogCaptureFixture,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
    ):
        stories = tmp_path / "stories.yml"
        stories.write_text("""
            version: "3.0"
            stories:
            - story: test path
              steps:
              - action: utter_greet
            """)
        policy = self.create_policy(
            featurizer=featurizer,
            model_storage=model_storage,
            resource=resource,
            execution_context=execution_context,
        )
        import tests.core.test_policies

        training_trackers = tests.core.test_policies.train_trackers(
            default_domain, str(stories), augmentation_factor=20)
        with pytest.raises(RasaException) as e:
            policy.train(training_trackers,
                         default_domain,
                         precomputations=None)

        assert "No user features specified. Cannot train 'TED' model." == str(
            e.value)

    def test_similarity_type(self, trained_policy: TEDPolicy):
        assert trained_policy.config[SIMILARITY_TYPE] == "inner"

    def test_ranking_length(self, trained_policy: TEDPolicy):
        assert trained_policy.config[RANKING_LENGTH] == 0

    def test_ranking_length_and_renormalization(
        self,
        trained_policy: TEDPolicy,
        tracker: DialogueStateTracker,
        default_domain: Domain,
        monkeypatch: MonkeyPatch,
    ):
        precomputations = None
        prediction = trained_policy.predict_action_probabilities(
            tracker,
            default_domain,
            precomputations,
        )

        # first check the output is what we expect
        assert not prediction.is_end_to_end_prediction

        # check that ranking length is applied - without normalization
        if trained_policy.config[RANKING_LENGTH] == 0:
            assert sum([confidence for confidence in prediction.probabilities
                        ]) == pytest.approx(1)
            assert all(confidence > 0
                       for confidence in prediction.probabilities)
        else:
            assert (sum([
                confidence > 0 for confidence in prediction.probabilities
            ]) == trained_policy.config[RANKING_LENGTH])
            assert sum([confidence for confidence in prediction.probabilities
                        ]) != pytest.approx(1)

    def test_label_data_assembly(self, trained_policy: TEDPolicy,
                                 default_domain: Domain):
        state_featurizer = trained_policy.featurizer.state_featurizer
        encoded_all_labels = state_featurizer.encode_all_labels(
            default_domain, precomputations=None)

        attribute_data, _ = model_data_utils.convert_to_data_format(
            encoded_all_labels)
        assembled_label_data = trained_policy._assemble_label_data(
            attribute_data, default_domain)
        assembled_label_data_signature = assembled_label_data.get_signature()

        assert list(assembled_label_data_signature.keys()) == [
            f"{LABEL}_{ACTION_NAME}",
            f"{LABEL}",
        ]
        assert assembled_label_data.num_examples == default_domain.num_actions
        assert list(assembled_label_data_signature[f"{LABEL}_{ACTION_NAME}"].
                    keys()) == [
                        MASK,
                        SENTENCE,
                    ]
        assert list(assembled_label_data_signature[LABEL].keys()) == [IDS]
        assert (assembled_label_data_signature[f"{LABEL}_{ACTION_NAME}"]
                [SENTENCE][0].units == default_domain.num_actions)

    def test_gen_batch(self, trained_policy: TEDPolicy, default_domain: Domain,
                       stories_path: Path):
        training_trackers = tests.core.test_policies.train_trackers(
            default_domain, stories_path, augmentation_factor=0)
        precomputations = None
        training_data, label_ids, entity_tags = trained_policy._featurize_for_training(
            training_trackers,
            default_domain,
            precomputations,
        )

        _, all_labels = trained_policy._create_label_data(
            default_domain, precomputations)
        model_data = trained_policy._create_model_data(training_data,
                                                       label_ids, entity_tags,
                                                       all_labels)
        batch_size = 2
        data_generator = RasaBatchDataGenerator(model_data,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                batch_strategy="sequence")
        iterator = iter(data_generator)
        # model data keys were sorted, so the order is alphabetical
        (
            (
                batch_action_name_mask,
                _,
                _,
                batch_action_name_sentence_shape,
                batch_dialogue_length,
                batch_entities_mask,
                _,
                _,
                batch_entities_sentence_shape,
                batch_intent_mask,
                _,
                _,
                batch_intent_sentence_shape,
                batch_label_ids,
                batch_slots_mask,
                _,
                _,
                batch_slots_sentence_shape,
            ),
            _,
        ) = next(iterator)

        assert (batch_label_ids.shape[0] == batch_size
                and batch_dialogue_length.shape[0] == batch_size)
        # batch and dialogue dimensions are NOT combined for masks
        assert (batch_slots_mask.shape[0] == batch_size
                and batch_intent_mask.shape[0] == batch_size
                and batch_entities_mask.shape[0] == batch_size
                and batch_action_name_mask.shape[0] == batch_size)
        # some features might be "fake" so there sequence is `0`
        seq_len = max([
            batch_intent_sentence_shape[1],
            batch_action_name_sentence_shape[1],
            batch_entities_sentence_shape[1],
            batch_slots_sentence_shape[1],
        ])
        assert (batch_intent_sentence_shape[1] == seq_len
                or batch_intent_sentence_shape[1] == 0)
        assert (batch_action_name_sentence_shape[1] == seq_len
                or batch_action_name_sentence_shape[1] == 0)
        assert (batch_entities_sentence_shape[1] == seq_len
                or batch_entities_sentence_shape[1] == 0)
        assert (batch_slots_sentence_shape[1] == seq_len
                or batch_slots_sentence_shape[1] == 0)

        data_generator = RasaBatchDataGenerator(model_data,
                                                batch_size=batch_size,
                                                shuffle=True,
                                                batch_strategy="balanced")
        iterator = iter(data_generator)

        (
            (
                batch_action_name_mask,
                _,
                _,
                batch_action_name_sentence_shape,
                batch_dialogue_length,
                batch_entities_mask,
                _,
                _,
                batch_entities_sentence_shape,
                batch_intent_mask,
                _,
                _,
                batch_intent_sentence_shape,
                batch_label_ids,
                batch_slots_mask,
                _,
                _,
                batch_slots_sentence_shape,
            ),
            _,
        ) = next(iterator)

        assert (batch_label_ids.shape[0] == batch_size
                and batch_dialogue_length.shape[0] == batch_size)
        # some features might be "fake" so there sequence is `0`
        seq_len = max([
            batch_intent_sentence_shape[1],
            batch_action_name_sentence_shape[1],
            batch_entities_sentence_shape[1],
            batch_slots_sentence_shape[1],
        ])
        assert (batch_intent_sentence_shape[1] == seq_len
                or batch_intent_sentence_shape[1] == 0)
        assert (batch_action_name_sentence_shape[1] == seq_len
                or batch_action_name_sentence_shape[1] == 0)
        assert (batch_entities_sentence_shape[1] == seq_len
                or batch_entities_sentence_shape[1] == 0)
        assert (batch_slots_sentence_shape[1] == seq_len
                or batch_slots_sentence_shape[1] == 0)

    @pytest.mark.parametrize(
        "tracker_events_with_action, tracker_events_without_action",
        [
            (
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                ],
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                ],
            ),
            (
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    EntitiesAdded(entities=[
                        {
                            "entity": "name",
                            "value": "Peter"
                        },
                    ]),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                    ActionExecuted("utter_greet"),
                ],
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    EntitiesAdded(entities=[
                        {
                            "entity": "name",
                            "value": "Peter"
                        },
                    ]),
                    ActionExecuted("utter_greet"),
                ],
            ),
            (
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                    ActionExecuted("some_form"),
                    ActiveLoop("some_form"),
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="default", intent={"name": "default"}),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                ],
                [
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="hello", intent={"name": "greet"}),
                    ActionExecuted(ACTION_UNLIKELY_INTENT_NAME),
                    ActionExecuted("some_form"),
                    ActiveLoop("some_form"),
                    ActionExecuted(ACTION_LISTEN_NAME),
                    UserUttered(text="default", intent={"name": "default"}),
                ],
            ),
        ],
    )
    def test_ignore_action_unlikely_intent(
        self,
        trained_policy: TEDPolicy,
        default_domain: Domain,
        tracker_events_with_action: List[Event],
        tracker_events_without_action: List[Event],
    ):
        precomputations = None
        tracker_with_action = DialogueStateTracker.from_events(
            "test 1", evts=tracker_events_with_action)
        tracker_without_action = DialogueStateTracker.from_events(
            "test 2", evts=tracker_events_without_action)
        prediction_with_action = trained_policy.predict_action_probabilities(
            tracker_with_action,
            default_domain,
            precomputations,
        )
        prediction_without_action = trained_policy.predict_action_probabilities(
            tracker_without_action,
            default_domain,
            precomputations,
        )

        # If the weights didn't change then both trackers
        # should result in same prediction.
        assert (prediction_with_action.probabilities ==
                prediction_without_action.probabilities)

    @pytest.mark.parametrize(
        "featurizer_config, tracker_featurizer, state_featurizer",
        [
            (None, MaxHistoryTrackerFeaturizer(), SingleStateFeaturizer),
            ([], MaxHistoryTrackerFeaturizer(), SingleStateFeaturizer),
        ],
    )
    def test_empty_featurizer_configs(
        self,
        featurizer_config: Optional[Dict[Text, Any]],
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        tracker_featurizer: MaxHistoryTrackerFeaturizer,
        state_featurizer: Type[SingleStateFeaturizer],
    ):
        featurizer_config_override = ({
            "featurizer": featurizer_config
        } if featurizer_config else {})
        policy = self.create_policy(
            None,
            model_storage=model_storage,
            resource=resource,
            execution_context=execution_context,
            config=self._config(featurizer_config_override),
        )

        featurizer = policy.featurizer
        assert isinstance(featurizer, tracker_featurizer.__class__)

        if featurizer_config:
            expected_max_history = featurizer_config[0].get(POLICY_MAX_HISTORY)
        else:
            expected_max_history = self._config().get(POLICY_MAX_HISTORY)

        assert featurizer.max_history == expected_max_history

        assert isinstance(featurizer.state_featurizer, state_featurizer)
示例#9
0
                   metadata={"utter_action": "utter_ask_cuisine"}),
        ActiveLoop("restaurant_form"),
        SlotSet("requested_slot", "cuisine"),
    ]


@pytest.mark.parametrize(
    "event",
    (
        EntitiesAdded(
            entities=[
                {
                    "entity": "city",
                    "value": "London"
                },
                {
                    "entity": "count",
                    "value": 1
                },
            ],
            timestamp=None,
        ),
        EntitiesAdded(entities=[]),
        EntitiesAdded(entities=[{
            "entity": "name",
            "value": "John",
            "role": "contact",
            "group": "test"
        }]),
        DefinePrevUserUtteredFeaturization(
            use_text_for_featurization=False, timestamp=None, metadata=None),