Exemplo n.º 1
0
    def train(self, training_trackers: List[DialogueStateTracker],
              **kwargs: Any) -> None:
        """Train the policies / policy ensemble using dialogue data from file.

        Args:
            training_trackers: trackers to train on
            **kwargs: additional arguments passed to the underlying ML
                           trainer (e.g. keras parameters)
        """

        if not self.is_ready():
            raise AgentNotReady("Can't train without a policy ensemble.")

        # deprecation tests
        if kwargs.get('featurizer') or kwargs.get('max_history'):
            raise Exception("Passing `featurizer` and `max_history` "
                            "to `agent.train(...)` is not supported anymore. "
                            "Pass appropriate featurizer "
                            "directly to the policy instead. More info "
                            "https://rasa.com/docs/core/migrations.html#x-to"
                            "-0-9-0")

        if isinstance(training_trackers, str):
            # the user most likely passed in a file name to load training
            # data from
            raise Exception("Passing a file name to `agent.train(...)` is "
                            "not supported anymore. Rather load the data with "
                            "`data = agent.load_data(file_name)` and pass it "
                            "to `agent.train(data)`.")

        logger.debug("Agent trainer got kwargs: {}".format(kwargs))
        check_domain_sanity(self.domain)

        self.policy_ensemble.train(training_trackers, self.domain, **kwargs)
        self._set_fingerprint()
Exemplo n.º 2
0
    def train(
            self,
            training_trackers,  # type: List[DialogueStateTracker]
            **kwargs  # type: **Any
    ):
        # type: (...) -> None
        """Train the policies / policy ensemble using dialogue data from file.

            :param training_trackers: trackers to train on
            :param kwargs: additional arguments passed to the underlying ML
                           trainer (e.g. keras parameters)
        """

        # deprecation tests
        if kwargs.get('featurizer') or kwargs.get('max_history'):
            raise Exception("Passing `featurizer` and `max_history` "
                            "to `agent.train(...)` is not supported anymore. "
                            "Pass appropriate featurizer "
                            "directly to the policy instead. More info "
                            "https://core.rasa.com/migrations.html#x-to-0-9-0")

        # TODO: DEPRECATED - remove in version 0.10
        if isinstance(training_trackers, string_types):
            # the user most likely passed in a file name to load training
            # data from
            logger.warning("Passing a file name to `agent.train(...)` is "
                           "deprecated. Rather load the data with "
                           "`data = agent.load_data(file_name)` and pass it "
                           "to `agent.train(data)`.")
            training_trackers = self.load_data(training_trackers)

        logger.debug("Agent trainer got kwargs: {}".format(kwargs))
        check_domain_sanity(self.domain)

        self.policy_ensemble.train(training_trackers, self.domain, **kwargs)
Exemplo n.º 3
0
    def train(self,
              training_trackers,  # type: List[DialogueStateTracker]
              **kwargs  # type: **Any
              ):
        # type: (...) -> None
        """Train the policies / policy ensemble using dialogue data from file.

            :param training_trackers: trackers to train on
            :param kwargs: additional arguments passed to the underlying ML
                           trainer (e.g. keras parameters)
        """

        # deprecation tests
        if kwargs.get('featurizer') or kwargs.get('max_history'):
            raise Exception("Passing `featurizer` and `max_history` "
                            "to `agent.train(...)` is not supported anymore. "
                            "Pass appropriate featurizer "
                            "directly to the policy instead. More info "
                            "https://core.rasa.com/migrations.html#x-to-0-9-0")

        # TODO: DEPRECATED - remove in version 0.10
        if isinstance(training_trackers, string_types):
            # the user most likely passed in a file name to load training
            # data from
            logger.warning("Passing a file name to `agent.train(...)` is "
                           "deprecated. Rather load the data with "
                           "`data = agent.load_data(file_name)` and pass it "
                           "to `agent.train(data)`.")
            training_trackers = self.load_data(training_trackers)

        logger.debug("Agent trainer got kwargs: {}".format(kwargs))
        check_domain_sanity(self.domain)

        self.policy_ensemble.train(training_trackers, self.domain,
                                   **kwargs)
Exemplo n.º 4
0
    def train(self,
              resource_name=None,
              interpreter=None,
              input_channel=None,
              max_history=3,
              augmentation_factor=20,
              max_training_samples=None,
              max_number_of_trackers=2000,
              **kwargs):
        logger.debug("Policy trainer got kwargs: {}".format(kwargs))
        check_domain_sanity(self.domain)

        training_data = self._prepare_training_data(resource_name, max_history,
                                                    augmentation_factor,
                                                    max_training_samples,
                                                    max_number_of_trackers)

        self.ensemble.train(training_data, self.domain, self.featurizer,
                            **kwargs)

        training_data.reset_metadata(
        )  # online learning doesn't support it yet

        ensemble = OnlinePolicyEnsemble(self.ensemble, self.featurizer,
                                        max_history, training_data)
        self.run_online_training(ensemble, self.domain, interpreter,
                                 input_channel)
Exemplo n.º 5
0
    def train_online(
            self,
            training_trackers,  # type: List[DialogueStateTracker]
            input_channel=None,  # type: Optional[InputChannel]
            max_visual_history=3,  # type: int
            **kwargs  # type: **Any
    ):
        # type: (...) -> None
        from rasa_core.policies.online_trainer import OnlinePolicyEnsemble
        """Train a policy ensemble in online learning mode."""

        if not self.interpreter:
            raise ValueError("When using online learning, you need to specify "
                             "an interpreter for the agent to use.")

        # TODO: DEPRECATED - remove in version 0.10
        if isinstance(training_trackers, string_types):
            # the user most likely passed in a file name to load training
            # data from
            logger.warning("Passing a file name to `agent.train_online(...)` "
                           "is deprecated. Rather load the data with "
                           "`data = agent.load_data(file_name)` and pass it "
                           "to `agent.train_online(data)`.")
            training_trackers = self.load_data(training_trackers)

        logger.debug("Agent online trainer got kwargs: {}".format(kwargs))
        check_domain_sanity(self.domain)

        self.policy_ensemble.train(training_trackers, self.domain, **kwargs)

        ensemble = OnlinePolicyEnsemble(self.policy_ensemble,
                                        training_trackers, max_visual_history)

        ensemble.run_online_training(self.domain, self.interpreter,
                                     input_channel)
Exemplo n.º 6
0
    def train(self, filename=None, max_history=3,
              augmentation_factor=20, max_training_samples=None,
              max_number_of_trackers=2000, **kwargs):
        """Trains a policy on a domain using training data from a file.

        :param augmentation_factor: how many stories should be created by
                                    randomly concatenating stories
        :param filename: story file containing the training conversations
        :param max_history: number of past actions to consider for the
                            prediction of the next action
        :param max_training_samples: specifies how many training samples to
                                     train on - `None` to use all examples
        :param max_number_of_trackers: limits the tracker generation during
                                       story file parsing - `None` for unlimited
        :param kwargs: additional arguments passed to the underlying ML trainer
                       (e.g. keras parameters)
        :return: trained policy
        """

        logger.debug("Policy trainer got kwargs: {}".format(kwargs))
        check_domain_sanity(self.domain)

        X, y = self._prepare_training_data(filename, max_history,
                                           augmentation_factor,
                                           max_training_samples,
                                           max_number_of_trackers)

        self.ensemble.train(X, y, self.domain, self.featurizer, **kwargs)
Exemplo n.º 7
0
    def train(self,
              filename=None, interpreter=None, input_channel=None,
              max_history=3, augmentation_factor=20, max_training_samples=None,
              max_number_of_trackers=2000, **kwargs):
        logger.debug("Policy trainer got kwargs: {}".format(kwargs))
        check_domain_sanity(self.domain)

        X, y = self._prepare_training_data(filename, max_history,
                                           augmentation_factor,
                                           max_training_samples,
                                           max_number_of_trackers)

        self.ensemble.train(X, y, self.domain, self.featurizer, **kwargs)

        ensemble = OnlinePolicyEnsemble(self.ensemble, self.featurizer,
                                        max_history, (X, y))
        self.run_online_training(ensemble, self.domain, interpreter,
                                 input_channel)
Exemplo n.º 8
0
    def train_online(self,
                     training_trackers,  # type: List[DialogueStateTracker]
                     input_channel=None,  # type: Optional[InputChannel]
                     max_visual_history=3,  # type: int
                     **kwargs  # type: **Any
                     ):
        # type: (...) -> None
        from rasa_core.policies.online_trainer import OnlinePolicyEnsemble
        """Train a policy ensemble in online learning mode."""

        if not self.interpreter:
            raise ValueError(
                    "When using online learning, you need to specify "
                    "an interpreter for the agent to use.")

        # TODO: DEPRECATED - remove in version 0.10
        if isinstance(training_trackers, string_types):
            # the user most likely passed in a file name to load training
            # data from
            logger.warning("Passing a file name to `agent.train_online(...)` "
                           "is deprecated. Rather load the data with "
                           "`data = agent.load_data(file_name)` and pass it "
                           "to `agent.train_online(data)`.")
            training_trackers = self.load_data(training_trackers)

        logger.debug("Agent online trainer got kwargs: {}".format(kwargs))
        check_domain_sanity(self.domain)

        self.policy_ensemble.train(training_trackers, self.domain, **kwargs)

        ensemble = OnlinePolicyEnsemble(self.policy_ensemble,
                                        training_trackers,
                                        max_visual_history)

        ensemble.run_online_training(self.domain, self.interpreter,
                                     input_channel)