コード例 #1
0
def test_policy_loading_load_returns_wrong_type(tmpdir):
    original_policy_ensemble = PolicyEnsemble([LoadReturnsWrongTypePolicy()])
    original_policy_ensemble.train([], None)
    original_policy_ensemble.persist(str(tmpdir))

    with pytest.raises(Exception):
        PolicyEnsemble.load(str(tmpdir))
コード例 #2
0
    def load(cls,
             path,
             interpreter=None,
             tracker_store=None,
             action_factory=None,
             rules_file=None,
             generator=None,
             create_dispatcher=None):
        # type: (Text, Any, Optional[TrackerStore]) -> Agent

        if path is None:
            raise ValueError("No domain path specified.")
        domain = TemplateDomain.load(os.path.join(path, "domain.yml"),
                                     action_factory)
        # ensures the domain hasn't changed between test and train
        domain.compare_with_specification(path)
        ensemble = PolicyEnsemble.load(path)
        _interpreter = NaturalLanguageInterpreter.create(interpreter)
        _tracker_store = cls.create_tracker_store(tracker_store, domain)
        return cls(domain=domain,
                   policies=ensemble,
                   interpreter=_interpreter,
                   tracker_store=_tracker_store,
                   rules_file=rules_file,
                   generator=generator,
                   create_dispatcher=create_dispatcher)
コード例 #3
0
ファイル: my_agent.py プロジェクト: JCourt1/bot_creation_tool
    def load(
        cls,
        path,  # type: Text
        interpreter=None,  # type: Union[NLI, Text, None]
        tracker_store=None,  # type: Optional[TrackerStore]
        action_factory=None  # type: Optional[Text]
    ):
        # type: (Text, Any, Optional[TrackerStore]) -> Agent
        """Load a persisted model from the passed path."""

        if path is None:
            raise ValueError("No domain path specified.")

        if os.path.isfile(path):
            raise ValueError("You are trying to load a MODEL from a file "
                             "('{}'), which is not possible. \n"
                             "The persisted path should be a directory "
                             "containing the various model files. \n\n"
                             "If you want to load training data instead of "
                             "a model, use `agent.load_data(...)` "
                             "instead.".format(path))

        ensemble = PolicyEnsemble.load(path)
        domain = TemplateDomain.load(os.path.join(path, "domain.yml"),
                                     action_factory)
        # ensures the domain hasn't changed between test and train
        domain.compare_with_specification(path)
        _interpreter = NaturalLanguageInterpreter.create(interpreter)
        _tracker_store = cls.create_tracker_store(tracker_store, domain)

        return cls(domain, ensemble, _interpreter, _tracker_store)
コード例 #4
0
def _load_and_set_updated_model(agent: 'Agent', model_directory: Text,
                                fingerprint: Text):
    """Load the persisted model into memory and set the model on the agent."""

    logger.debug("Found new model with fingerprint {}. Loading..."
                 "".format(fingerprint))

    stack_model_directory = _get_stack_model_directory(model_directory)
    if stack_model_directory:
        from rasa_core.interpreter import RasaNLUInterpreter
        nlu_model = os.path.join(stack_model_directory, "nlu")
        core_model = os.path.join(stack_model_directory, "core")
        interpreter = RasaNLUInterpreter(model_directory=nlu_model)
    else:
        interpreter = agent.interpreter
        core_model = model_directory

    domain_path = os.path.join(os.path.abspath(core_model), "domain.yml")
    domain = Domain.load(domain_path)

    # noinspection PyBroadException
    try:
        policy_ensemble = PolicyEnsemble.load(core_model)
        agent.update_model(domain, policy_ensemble, fingerprint, interpreter)
        logger.debug("Finished updating agent to new model.")
    except Exception:
        logger.exception("Failed to load policy and update agent. "
                         "The previous model will stay loaded instead.")
コード例 #5
0
ファイル: agent.py プロジェクト: githubclj/rasa_core
    def load(cls,
             path,  # type: Text
             interpreter=None,  # type: Union[NLI, Text, None]
             tracker_store=None,  # type: Optional[TrackerStore]
             action_factory=None,  # type: Optional[Text]
             generator=None  # type: Union[EndpointConfig, NLG]
             ):
        # type: (Text, Any, Optional[TrackerStore]) -> Agent
        """Load a persisted model from the passed path."""

        if path is None:
            raise ValueError("No domain path specified.")

        if os.path.isfile(path):
            raise ValueError("You are trying to load a MODEL from a file "
                             "('{}'), which is not possible. \n"
                             "The persisted path should be a directory "
                             "containing the various model files. \n\n"
                             "If you want to load training data instead of "
                             "a model, use `agent.load_data(...)` "
                             "instead.".format(path))

        ensemble = PolicyEnsemble.load(path)
        domain = TemplateDomain.load(os.path.join(path, "domain.yml"),
                                     action_factory)
        # ensures the domain hasn't changed between test and train
        domain.compare_with_specification(path)
        _tracker_store = cls.create_tracker_store(tracker_store, domain)

        return cls(domain, ensemble, interpreter, generator, _tracker_store)
コード例 #6
0
ファイル: agent.py プロジェクト: punitcs81/chatbot
    def load(cls,
             path: Text,
             interpreter: Optional[NaturalLanguageInterpreter] = None,
             generator: Union[EndpointConfig, 'NLG'] = None,
             tracker_store: Optional['TrackerStore'] = None,
             action_endpoint: Optional[EndpointConfig] = None,
             ) -> 'Agent':
        """Load a persisted model from the passed path."""

        if not path:
            raise ValueError("You need to provide a valid directory where "
                             "to load the agent from when calling "
                             "`Agent.load`.")

        if os.path.isfile(path):
            raise ValueError("You are trying to load a MODEL from a file "
                             "('{}'), which is not possible. \n"
                             "The persisted path should be a directory "
                             "containing the various model files. \n\n"
                             "If you want to load training data instead of "
                             "a model, use `agent.load_data(...)` "
                             "instead.".format(path))

        domain = Domain.load(os.path.join(path, "domain.yml"))
        ensemble = PolicyEnsemble.load(path) if path else None

        # ensures the domain hasn't changed between test and train
        domain.compare_with_specification(path)

        return cls(domain=domain,
                   policies=ensemble,
                   interpreter=interpreter,
                   generator=generator,
                   tracker_store=tracker_store,
                   action_endpoint=action_endpoint)
コード例 #7
0
def test_policy_loading_simple(tmpdir):
    original_policy_ensemble = PolicyEnsemble([WorkingPolicy()])
    original_policy_ensemble.train([], None)
    original_policy_ensemble.persist(str(tmpdir))

    loaded_policy_ensemble = PolicyEnsemble.load(str(tmpdir))
    assert original_policy_ensemble.policies == loaded_policy_ensemble.policies
コード例 #8
0
 def load(cls, path, interpreter=None, tracker_store=None):
     # type: (Text, Any, Optional[TrackerStore]) -> Agent
     domain = TemplateDomain.load(os.path.join(path, "domain.yml"))
     # ensures the domain hasn't changed between test and train
     domain.compare_with_specification(path)
     featurizer = Featurizer.load(path)
     ensemble = PolicyEnsemble.load(path, featurizer)
     _interpreter = NaturalLanguageInterpreter.create(interpreter)
     _tracker_store = cls._create_tracker_store(tracker_store, domain)
     return cls(domain, ensemble, featurizer, _interpreter, _tracker_store)
コード例 #9
0
def _update_model_from_server(model_server: EndpointConfig,
                              agent: 'Agent') -> None:
    """Load a zipped Rasa Core model from a URL and update the passed agent."""

    if not is_url(model_server.url):
        raise InvalidURL(model_server.url)

    model_directory = tempfile.mkdtemp()

    new_model_fingerprint = _pull_model_and_fingerprint(
        model_server, model_directory, agent.fingerprint)
    if new_model_fingerprint:
        domain_path = os.path.join(os.path.abspath(model_directory),
                                   "domain.yml")
        domain = Domain.load(domain_path)
        policy_ensemble = PolicyEnsemble.load(model_directory)
        agent.update_model(domain, policy_ensemble, new_model_fingerprint)
    else:
        logger.debug("No new model found at "
                     "URL {}".format(model_server.url))
コード例 #10
0
 def loadAgent(path,
               interpreter=None,
               tracker_store=None,
               action_factory=None,
               core_server=None):
     # type: (Text, Any, Optional[TrackerStore]) -> Agent
     if path is None:
         raise ValueError("No domain path specified.")
     domain = SnipsDomain.load(os.path.join(path, "domain.yml"),
                               action_factory, core_server)
     # ensures the domain hasn't changed between test and train
     domain.compare_with_specification(path)
     featurizer = Featurizer.load(path)
     ensemble = PolicyEnsemble.load(path, featurizer)
     _interpreter = NaturalLanguageInterpreter.create(interpreter)
     _tracker_store = SnipsMqttAgent.create_tracker_store(
         tracker_store, domain)
     print("CREATED SNIPS AGENT")
     return SnipsMqttAgent(domain, ensemble, featurizer, _interpreter,
                           _tracker_store)
コード例 #11
0
ファイル: agent.py プロジェクト: rohitjun08/rasa_core
def _update_model_from_server(
        model_server,  # type: EndpointConfig
        agent,  # type: Agent
):
    # type: (...) -> None
    """Load a zipped Rasa Core model from a URL and update the passed agent."""

    if not is_url(model_server.url):
        raise InvalidURL(model_server.url)

    model_directory = tempfile.mkdtemp()

    new_model_fingerprint = _pull_model_and_fingerprint(
            model_server, model_directory, agent.fingerprint)
    if new_model_fingerprint:
        domain_path = os.path.join(os.path.abspath(model_directory),
                                   "domain.yml")
        domain = Domain.load(domain_path)
        policy_ensemble = PolicyEnsemble.load(model_directory)
        agent.update_model(domain, policy_ensemble, new_model_fingerprint)
    else:
        logger.debug("No new model found at "
                     "URL {}".format(model_server.url))
コード例 #12
0
ファイル: agent.py プロジェクト: rohitjun08/rasa_core
    def load(cls,
             path,  # type: Text
             interpreter=None,  # type: Optional[NaturalLanguageInterpreter]
             generator=None,  # type: Union[EndpointConfig, NLG]
             tracker_store=None,  # type: Optional[TrackerStore]
             action_endpoint=None,  # type: Optional[EndpointConfig]
             ):
        # type: (...) -> Agent
        """Load a persisted model from the passed path."""

        if not path:
            raise ValueError("You need to provide a valid directory where "
                             "to load the agent from when calling "
                             "`Agent.load`.")

        if os.path.isfile(path):
            raise ValueError("You are trying to load a MODEL from a file "
                             "('{}'), which is not possible. \n"
                             "The persisted path should be a directory "
                             "containing the various model files. \n\n"
                             "If you want to load training data instead of "
                             "a model, use `agent.load_data(...)` "
                             "instead.".format(path))

        domain = Domain.load(os.path.join(path, "domain.yml"))
        ensemble = PolicyEnsemble.load(path) if path else None

        # ensures the domain hasn't changed between test and train
        domain.compare_with_specification(path)

        return cls(domain=domain,
                   policies=ensemble,
                   interpreter=interpreter,
                   generator=generator,
                   tracker_store=tracker_store,
                   action_endpoint=action_endpoint)
コード例 #13
0
ファイル: conftest.py プロジェクト: stevenshichn/rasa_core
def moodbot_metadata():
    return PolicyEnsemble.load_metadata(MOODBOT_MODEL_PATH)
コード例 #14
0
def test_domain_spec_dm():
    model_path = 'examples/babi/models/policy/current'
    policy = PolicyEnsemble.load(model_path, BinaryFeaturizer())
    policy.persist(model_path)
コード例 #15
0
def test_valid_policy_configurations(valid_config):
    assert PolicyEnsemble.from_dict(valid_config)
コード例 #16
0
def test_invalid_policy_configurations(invalid_config):
    with pytest.raises(InvalidPolicyConfig):
        PolicyEnsemble.from_dict(invalid_config)
コード例 #17
0
def test_ensemble_from_dict():
    def check_memoization(p):
        assert p.max_history == 5

    def check_keras(p):
        featurizer = p.featurizer
        state_featurizer = featurizer.state_featurizer
        # Assert policy
        assert p.epochs == 50
        # Assert featurizer
        assert isinstance(featurizer, MaxHistoryTrackerFeaturizer)
        assert featurizer.max_history == 5
        # Assert state_featurizer
        assert isinstance(state_featurizer, BinarySingleStateFeaturizer)

    def check_fallback(p):
        assert p.fallback_action_name == 'action_default_fallback'
        assert p.nlu_threshold == 0.7
        assert p.core_threshold == 0.7

    ensemble_dict = {
        'policies': [{
            'epochs':
            50,
            'name':
            'KerasPolicy',
            'featurizer': [{
                'max_history':
                5,
                'name':
                'MaxHistoryTrackerFeaturizer',
                'state_featurizer': [{
                    'name': 'BinarySingleStateFeaturizer'
                }]
            }]
        }, {
            'max_history': 5,
            'name': 'MemoizationPolicy'
        }, {
            'core_threshold': 0.7,
            'name': 'FallbackPolicy',
            'nlu_threshold': 0.7,
            'fallback_action_name': 'action_default_fallback'
        }, {
            'name': 'FormPolicy'
        }]
    }
    ensemble = PolicyEnsemble.from_dict(ensemble_dict)

    # Check if all policies are present
    assert len(ensemble) == 4
    # MemoizationPolicy is parent of FormPolicy
    assert any([
        isinstance(p, MemoizationPolicy) and not isinstance(p, FormPolicy)
        for p in ensemble
    ])
    assert any([isinstance(p, KerasPolicy) for p in ensemble])
    assert any([isinstance(p, FallbackPolicy) for p in ensemble])
    assert any([isinstance(p, FormPolicy) for p in ensemble])

    # Verify policy configurations
    for policy in ensemble:
        if isinstance(policy, MemoizationPolicy) \
                and not isinstance(policy, FormPolicy):
            check_memoization(policy)
        elif isinstance(policy, KerasPolicy):
            check_keras(policy)
        elif isinstance(policy, FallbackPolicy):
            check_fallback(policy)
コード例 #18
0
ファイル: superagent.py プロジェクト: 13927729580/rasa-addons
    def load(
            cls,
            path,
            domain=None,
            policies=None,
            interpreter=None,
            generator=None,
            tracker_store=None,
            action_endpoint=None,
            rules=None,
            create_dispatcher=None,
            model_server=None,  # type: Optional[EndpointConfig]
            wait_time_between_pulls=None,  # type: Optional[int]
            create_nlg=None):
        # type: (Text, Any, Optional[TrackerStore]) -> Agent

        if not path and not domain:
            raise ValueError("You need to provide a valid directory where "
                             "to load the agent from when calling "
                             "`Agent.load`.")

        if os.path.isfile(path):
            raise ValueError("You are trying to load a MODEL from a file "
                             "('{}'), which is not possible. \n"
                             "The persisted path should be a directory "
                             "containing the various model files. \n\n"
                             "If you want to load training data instead of "
                             "a model, use `agent.load_data(...)` "
                             "instead.".format(path))

        # We don't want to block if the path doesn't exist but a model_server is supplied
        if domain is None and model_server is None:
            domain = TemplateDomain.load(os.path.join(path, "domain.yml"))
        if policies is None and model_server is None:
            policies = PolicyEnsemble.load(path)

        # ensures the domain hasn't changed between test and train
        if model_server is None:
            domain.compare_with_specification(path)
        #
        # _interpreter = NaturalLanguageInterpreter.create(interpreter)
        # _tracker_store = cls.create_tracker_store(tracker_store, domain)

        agent = cls(domain=domain,
                    policies=policies,
                    interpreter=interpreter,
                    tracker_store=tracker_store,
                    generator=generator,
                    action_endpoint=action_endpoint,
                    rules=rules,
                    create_dispatcher=create_dispatcher,
                    create_nlg=create_nlg)
        if model_server:
            if wait_time_between_pulls:
                # continuously pull the model every `wait_time_between_pulls` seconds
                start_model_pulling_in_worker(model_server,
                                              wait_time_between_pulls, agent)
            else:
                # just pull the model once
                _update_model_from_server(model_server, agent)

        if rules:
            # Start worker if `wait_time_between_pulls` is set
            if isinstance(rules, EndpointConfig) and wait_time_between_pulls:
                # continuously pull the rules every `wait_time_between_pulls` seconds
                start_rules_pulling_in_worker(rules, wait_time_between_pulls,
                                              agent)
            # In all other cases we only want to load the rules once
            else:
                agent.rules = SuperAgent.get_rules(rules)
        return agent