Example #1
0
def test_agent_train(tmpdir, default_domain):
    training_data_file = 'examples/moodbot/data/stories.md'
    agent = Agent("examples/moodbot/domain.yml",
                  policies=[ScoringPolicy()])

    agent.train(training_data_file, max_history=3)
    agent.persist(tmpdir.strpath)

    loaded = Agent.load(tmpdir.strpath)
    # test featurizer
    assert type(loaded.featurizer) is type(agent.featurizer)    # nopep8

    # test domain
    assert [a.name() for a in loaded.domain.actions] == \
           [a.name() for a in agent.domain.actions]
    assert loaded.domain.intents == agent.domain.intents
    assert loaded.domain.entities == agent.domain.entities
    assert loaded.domain.templates == agent.domain.templates
    assert [s.name for s in loaded.domain.slots] == \
           [s.name for s in agent.domain.slots]

    # test policies
    assert type(loaded.policy_ensemble) is type(agent.policy_ensemble)  # nopep8
    assert [type(p) for p in loaded.policy_ensemble.policies] == \
           [type(p) for p in agent.policy_ensemble.policies]
Example #2
0
def default_processor(default_domain):
    ensemble = SimplePolicyEnsemble([ScoringPolicy()])
    interpreter = RegexInterpreter()
    PolicyTrainer(ensemble, default_domain,
                  BinaryFeaturizer()).train(DEFAULT_STORIES_FILE,
                                            max_history=3)
    tracker_store = InMemoryTrackerStore(default_domain)
    return MessageProcessor(interpreter, ensemble, default_domain,
                            tracker_store)
Example #3
0
def core_server(tmpdir_factory):
    model_path = tmpdir_factory.mktemp("model").strpath

    agent = Agent("data/test_domains/default_with_topic.yml",
                  policies=[ScoringPolicy()])

    agent.train(DEFAULT_STORIES_FILE, max_history=3)
    agent.persist(model_path)

    return RasaCoreServer(model_path, interpreter=RegexInterpreter())
Example #4
0
def test_message_processor(default_domain, capsys):
    story_filename = "data/dsl_stories/stories_defaultdomain.md"
    ensemble = SimplePolicyEnsemble([ScoringPolicy()])
    interpreter = RegexInterpreter()

    PolicyTrainer(ensemble, default_domain,
                  BinaryFeaturizer()).train(story_filename, max_history=3)

    tracker_store = InMemoryTrackerStore(default_domain)
    processor = MessageProcessor(interpreter, ensemble, default_domain,
                                 tracker_store)

    processor.handle_message(UserMessage("_greet", ConsoleOutputChannel()))
    out, _ = capsys.readouterr()
    assert "hey there!" in out
Example #5
0
def test_message_processor(default_domain, capsys):
    story_filename = "data/dsl_stories/stories_defaultdomain.md"
    ensemble = SimplePolicyEnsemble([ScoringPolicy()])
    interpreter = RegexInterpreter()

    PolicyTrainer(ensemble, default_domain,
                  BinaryFeaturizer()).train(story_filename, max_history=3)

    tracker_store = InMemoryTrackerStore(default_domain)
    processor = MessageProcessor(interpreter, ensemble, default_domain,
                                 tracker_store)

    out = CollectingOutputChannel()
    processor.handle_message(UserMessage("_greet[name=Core]", out))
    assert ("default", "hey there Core!") == out.latest_output()
 def create_policy(self):
     p = ScoringPolicy()
     return p