Пример #1
0
    def train(
        self,
        training_trackers: List[TrackerWithCachedStates],
        domain: Domain,
        interpreter: NaturalLanguageInterpreter,
        **kwargs: Any,
    ) -> None:
        if training_trackers:
            self._emit_rule_policy_warning(training_trackers)

            for policy in self.policies:
                trackers_to_train = SupportedData.trackers_for_policy(
                    policy, training_trackers)
                policy.train(trackers_to_train,
                             domain,
                             interpreter=interpreter,
                             **kwargs)

            self.action_fingerprints = rasa.core.training.training.create_action_fingerprints(
                training_trackers, domain)
        else:
            logger.info(
                "Skipped training, because there are no training samples.")

        self.date_trained = datetime.now().strftime("%Y%m%d-%H%M%S")
Пример #2
0
    def train(
        self,
        training_trackers: List[TrackerWithCachedStates],
        domain: Domain,
        **kwargs: Any,
    ) -> Resource:
        # only considers original trackers (no augmented ones)
        training_trackers = [
            t for t in training_trackers
            if not hasattr(t, "is_augmented") or not t.is_augmented
        ]
        training_trackers = SupportedData.trackers_for_supported_data(
            self.supported_data(), training_trackers)

        (
            trackers_as_states,
            trackers_as_actions,
        ) = self.featurizer.training_states_and_labels(training_trackers,
                                                       domain)
        self.lookup = self._create_lookup_from_states(trackers_as_states,
                                                      trackers_as_actions)
        logger.debug(f"Memorized {len(self.lookup)} unique examples.")

        self.persist()
        return self._resource
Пример #3
0
def test_get_training_trackers_for_policy(
    policy: Policy, n_rule_trackers: int, n_ml_trackers: int
):
    # create five trackers (two rule-based and three ML trackers)
    trackers = [
        DialogueStateTracker("id1", slots=[], is_rule_tracker=True),
        DialogueStateTracker("id2", slots=[], is_rule_tracker=False),
        DialogueStateTracker("id3", slots=[], is_rule_tracker=False),
        DialogueStateTracker("id4", slots=[], is_rule_tracker=True),
        DialogueStateTracker("id5", slots=[], is_rule_tracker=False),
    ]

    trackers = SupportedData.trackers_for_policy(policy, trackers)

    rule_trackers = [tracker for tracker in trackers if tracker.is_rule_tracker]
    ml_trackers = [tracker for tracker in trackers if not tracker.is_rule_tracker]

    assert len(rule_trackers) == n_rule_trackers
    assert len(ml_trackers) == n_ml_trackers