Пример #1
0
def test_can_read_test_story(default_domain):
    trackers = extract_trackers_from_file("data/test_stories/stories.md",
                                          default_domain,
                                          featurizer=BinaryFeaturizer())
    assert len(trackers) == 7
    # this should be the story simple_story_with_only_end -> show_it_all
    # the generated stories are in a non stable order - therefore we need to
    # do some trickery to find the one we want to test
    tracker = [t for t in trackers if len(t.events) == 5][0]
    assert tracker.events[0] == ActionExecuted("action_listen")
    assert tracker.events[1] == UserUttered("simple",
                                            intent={
                                                "name": "simple",
                                                "confidence": 1.0
                                            },
                                            parse_data={
                                                'text':
                                                'simple',
                                                'intent_ranking': [{
                                                    'confidence':
                                                    1.0,
                                                    'name':
                                                    'simple'
                                                }],
                                                'intent': {
                                                    'confidence': 1.0,
                                                    'name': 'simple'
                                                },
                                                'entities': []
                                            })
    assert tracker.events[2] == ActionExecuted("utter_default")
    assert tracker.events[3] == ActionExecuted("utter_greet")
    assert tracker.events[4] == ActionExecuted("action_listen")
Пример #2
0
def test_persist_and_read_test_story(tmpdir, default_domain):
    graph = extract_story_graph_from_file("data/test_stories/stories.md",
                                          default_domain)
    out_path = tmpdir.join("persisted_story.md")
    Story(graph.story_steps).dump_to_file(out_path.strpath)

    recovered_trackers = extract_trackers_from_file(out_path.strpath,
                                                    default_domain,
                                                    BinaryFeaturizer())
    existing_trackers = extract_trackers_from_file(
        "data/test_stories/stories.md", default_domain, BinaryFeaturizer())
    existing_stories = {t.export_stories() for t in existing_trackers}
    for t in recovered_trackers:
        story_str = t.export_stories()
        assert story_str in existing_stories
        existing_stories.discard(story_str)
Пример #3
0
def test_tracker_write_to_story(tmpdir, default_domain):
    tracker = tracker_from_dialogue_file("data/test_dialogues/enter_name.json",
                                         default_domain)
    p = tmpdir.join("export.md")
    tracker.export_stories_to_file(p.strpath)
    trackers = extract_trackers_from_file(p.strpath, default_domain,
                                          BinaryFeaturizer())
    assert len(trackers) == 1
    recovered = trackers[0]
    assert len(recovered.events) == 8
    assert recovered.events[6] == SlotSet("location", "central")
Пример #4
0
    def test_persist_and_load(self, trained_policy, default_domain, tmpdir):
        trained_policy.persist(tmpdir.strpath)
        loaded = trained_policy.__class__.load(tmpdir.strpath,
                                               trained_policy.featurizer,
                                               trained_policy.max_history)
        trackers = extract_trackers_from_file(DEFAULT_STORIES_FILE,
                                              default_domain,
                                              BinaryFeaturizer())

        for tracker in trackers:
            predicted_probabilities = loaded.predict_action_probabilities(
                tracker, default_domain)
            actual_probabilities = trained_policy.predict_action_probabilities(
                tracker, default_domain)
            assert predicted_probabilities == actual_probabilities
Пример #5
0
def test_dump_and_restore_as_json(default_agent, tmpdir):
    trackers = extract_trackers_from_file(
        DEFAULT_STORIES_FILE, default_agent.domain, default_agent.featurizer,
        default_agent.interpreter, default_agent.policy_ensemble.max_history())

    out_path = tmpdir.join("dumped_tracker.json")

    for tracker in trackers:
        dumped = tracker.current_state(should_include_events=True)
        utils.dump_obj_as_json_to_file(out_path.strpath, dumped)

        tracker_json = json.loads(utils.read_file(out_path.strpath))
        sender_id = tracker_json.get("sender_id",
                                     UserMessage.DEFAULT_SENDER_ID)
        restored_tracker = DialogueStateTracker.from_dict(
            sender_id, tracker_json.get("events", []), default_agent.domain)

        assert restored_tracker == tracker