コード例 #1
0
def test_train_with_e2e_data(
    default_model_storage: ModelStorage,
    default_execution_context: ExecutionContext,
    tracker_events: List[List[Event]],
    skip_training: bool,
    domain: Domain,
):
    policy = UnexpecTEDIntentPolicy(
        UnexpecTEDIntentPolicy.get_default_config(),
        default_model_storage,
        Resource("UnexpecTEDIntentPolicy"),
        default_execution_context,
        featurizer=IntentMaxHistoryTrackerFeaturizer(
            IntentTokenizerSingleStateFeaturizer()
        ),
    )
    trackers_for_training = [
        TrackerWithCachedStates.from_events(
            sender_id=f"{tracker_index}", evts=events, domain=domain
        )
        for tracker_index, events in enumerate(tracker_events)
    ]
    if skip_training:
        with pytest.warns(UserWarning):
            policy.train(trackers_for_training, domain, precomputations=None)
    else:
        policy.train(trackers_for_training, domain, precomputations=None)
コード例 #2
0
 def _standard_featurizer(self) -> TrackerFeaturizer:
     return IntentMaxHistoryTrackerFeaturizer(
         IntentTokenizerSingleStateFeaturizer(),
         max_history=self.config.get(POLICY_MAX_HISTORY),
     )
コード例 #3
0
ファイル: test_policies.py プロジェクト: praneethgb/rasa
class PolicyTestCollection:
    """Tests every policy needs to fulfill.

    Each policy can declare further tests on its own."""
    @staticmethod
    def _policy_class_to_test() -> Type[PolicyGraphComponent]:
        raise NotImplementedError

    max_history = 3  # this is the amount of history we test on

    @pytest.fixture(scope="class")
    def resource(self, ) -> Resource:
        return Resource(uuid.uuid4().hex)

    @pytest.fixture(scope="class")
    def model_storage(self, tmp_path_factory: TempPathFactory) -> ModelStorage:
        return LocalModelStorage(tmp_path_factory.mktemp(uuid.uuid4().hex))

    @pytest.fixture(scope="class")
    def execution_context(self) -> ExecutionContext:
        return ExecutionContext(GraphSchema({}), uuid.uuid4().hex)

    def _config(
            self,
            config_override: Optional[Dict[Text,
                                           Any]] = None) -> Dict[Text, Any]:
        config_override = config_override or {}
        config = self._policy_class_to_test().get_default_config()
        return {**config, **config_override}

    def create_policy(
        self,
        featurizer: Optional[TrackerFeaturizer],
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        config: Optional[Dict[Text, Any]] = None,
    ) -> PolicyGraphComponent:
        return self._policy_class_to_test()(
            config=self._config(config),
            model_storage=model_storage,
            resource=resource,
            execution_context=execution_context,
            featurizer=featurizer,
        )

    @pytest.fixture(scope="class")
    def featurizer(self) -> TrackerFeaturizer:
        featurizer = MaxHistoryTrackerFeaturizer(SingleStateFeaturizer(),
                                                 max_history=self.max_history)
        return featurizer

    @pytest.fixture(scope="class")
    def default_domain(self, domain_path: Text) -> Domain:
        return Domain.load(domain_path)

    @pytest.fixture(scope="class")
    def tracker(self, default_domain: Domain) -> DialogueStateTracker:
        return DialogueStateTracker(DEFAULT_SENDER_ID, default_domain.slots)

    @pytest.fixture(scope="class")
    def trained_policy(
        self,
        featurizer: Optional[TrackerFeaturizer],
        stories_path: Text,
        default_domain: Domain,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
    ) -> PolicyGraphComponent:
        policy = self.create_policy(featurizer, model_storage, resource,
                                    execution_context)
        training_trackers = train_trackers(default_domain,
                                           stories_path,
                                           augmentation_factor=20)
        policy.train(training_trackers, default_domain)
        return policy

    def test_featurizer(
        self,
        trained_policy: PolicyGraphComponent,
        resource: Resource,
        model_storage: ModelStorage,
        tmp_path: Path,
        execution_context: ExecutionContext,
    ):
        assert isinstance(trained_policy.featurizer,
                          MaxHistoryTrackerFeaturizer)
        assert trained_policy.featurizer.max_history == self.max_history
        assert isinstance(trained_policy.featurizer.state_featurizer,
                          SingleStateFeaturizer)

        loaded = trained_policy.__class__.load(
            self._config(trained_policy.config),
            model_storage,
            resource,
            execution_context,
        )

        assert isinstance(loaded.featurizer, MaxHistoryTrackerFeaturizer)
        assert loaded.featurizer.max_history == self.max_history
        assert isinstance(loaded.featurizer.state_featurizer,
                          SingleStateFeaturizer)

    @pytest.mark.parametrize("should_finetune", [False, True])
    def test_persist_and_load(
        self,
        trained_policy: PolicyGraphComponent,
        default_domain: Domain,
        should_finetune: bool,
        stories_path: Text,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
    ):
        loaded = trained_policy.__class__.load(
            self._config(trained_policy.config),
            model_storage,
            resource,
            dataclasses.replace(execution_context,
                                is_finetuning=should_finetune),
        )

        assert loaded.finetune_mode == should_finetune

        trackers = train_trackers(default_domain,
                                  stories_path,
                                  augmentation_factor=20)

        for tracker in trackers:
            predicted_probabilities = loaded.predict_action_probabilities(
                tracker, default_domain)
            actual_probabilities = trained_policy.predict_action_probabilities(
                tracker, default_domain)
            assert predicted_probabilities == actual_probabilities

    def test_prediction_on_empty_tracker(self, trained_policy: Policy,
                                         default_domain: Domain):
        tracker = DialogueStateTracker(DEFAULT_SENDER_ID, default_domain.slots)
        prediction = trained_policy.predict_action_probabilities(
            tracker,
            default_domain,
        )
        assert not prediction.is_end_to_end_prediction
        assert len(prediction.probabilities) == default_domain.num_actions
        assert max(prediction.probabilities) <= 1.0
        assert min(prediction.probabilities) >= 0.0

    @pytest.mark.filterwarnings(
        "ignore:.*without a trained model present.*:UserWarning")
    def test_persist_and_load_empty_policy(
        self,
        default_domain: Domain,
        default_model_storage: ModelStorage,
        execution_context: ExecutionContext,
    ):
        resource = Resource(uuid.uuid4().hex)
        empty_policy = self.create_policy(
            None,
            default_model_storage,
            resource,
            execution_context,
        )

        empty_policy.train([], default_domain)
        loaded = empty_policy.__class__.load(
            self._config(),
            default_model_storage,
            resource,
            execution_context,
        )

        assert loaded is not None

    @staticmethod
    def _get_next_action(policy: PolicyGraphComponent, events: List[Event],
                         domain: Domain) -> Text:
        tracker = get_tracker(events)
        scores = policy.predict_action_probabilities(
            tracker,
            domain,
        ).probabilities
        index = scores.index(max(scores))
        return domain.action_names_or_texts[index]

    @pytest.mark.parametrize(
        "featurizer_config, tracker_featurizer, state_featurizer",
        [
            (
                [{
                    # TODO: remove "2" when migration of policies is done
                    "name": "MaxHistoryTrackerFeaturizer2",
                    "max_history": 12,
                    "state_featurizer": [],
                }],
                MaxHistoryTrackerFeaturizer(max_history=12),
                type(None),
            ),
            (
                # TODO: remove "2" when migration of policies is done
                [{
                    "name": "MaxHistoryTrackerFeaturizer2",
                    "max_history": 12
                }],
                MaxHistoryTrackerFeaturizer(max_history=12),
                type(None),
            ),
            (
                [{
                    # TODO: remove "2" when migration of policies is done
                    "name":
                    "IntentMaxHistoryTrackerFeaturizer2",
                    "max_history":
                    12,
                    "state_featurizer": [{
                        "name":
                        "IntentTokenizerSingleStateFeaturizer2"
                    }],
                }],
                IntentMaxHistoryTrackerFeaturizer(max_history=12),
                IntentTokenizerSingleStateFeaturizer,
            ),
        ],
    )
    def test_different_featurizer_configs(
        self,
        featurizer_config: Optional[Dict[Text, Any]],
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        tracker_featurizer: MaxHistoryTrackerFeaturizer,
        state_featurizer: Type[SingleStateFeaturizer],
    ):
        featurizer_config_override = ({
            "featurizer": featurizer_config
        } if featurizer_config else {})
        policy = self.create_policy(
            None,
            model_storage=model_storage,
            resource=resource,
            execution_context=execution_context,
            config=self._config(featurizer_config_override),
        )

        featurizer = policy.featurizer
        assert isinstance(featurizer, tracker_featurizer.__class__)

        if featurizer_config:
            expected_max_history = featurizer_config[0].get(POLICY_MAX_HISTORY)
        else:
            expected_max_history = self._config().get(POLICY_MAX_HISTORY)

        assert featurizer.max_history == expected_max_history

        assert isinstance(featurizer.state_featurizer, state_featurizer)

    @pytest.mark.parametrize(
        "featurizer_config",
        [
            [
                # TODO: remove "2" when migration of policies is done
                {
                    "name": "MaxHistoryTrackerFeaturizer2",
                    "max_history": 12
                },
                {
                    "name": "MaxHistoryTrackerFeaturizer2",
                    "max_history": 12
                },
            ],
            [{
                # TODO: remove "2" when migration of policies is done
                "name":
                "IntentMaxHistoryTrackerFeaturizer2",
                "max_history":
                12,
                "state_featurizer": [
                    {
                        "name": "IntentTokenizerSingleStateFeaturizer2"
                    },
                    {
                        "name": "IntentTokenizerSingleStateFeaturizer2"
                    },
                ],
            }],
        ],
    )
    def test_different_invalid_featurizer_configs(
        self,
        trained_policy: PolicyGraphComponent,
        featurizer_config: Optional[Dict[Text, Any]],
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
    ):
        with pytest.raises(InvalidPolicyConfig):
            self.create_policy(
                None,
                model_storage=model_storage,
                resource=resource,
                execution_context=execution_context,
                config={"featurizer": featurizer_config},
            )
コード例 #4
0
 def featurizer(self) -> TrackerFeaturizer:
     featurizer = IntentMaxHistoryTrackerFeaturizer(
         IntentTokenizerSingleStateFeaturizer(), max_history=self.max_history
     )
     return featurizer
コード例 #5
0
 def _standard_featurizer(
         max_history: Optional[int] = None) -> TrackerFeaturizer:
     return IntentMaxHistoryTrackerFeaturizer(
         IntentTokenizerSingleStateFeaturizer(), max_history=max_history)