Exemple #1
0
async def test_persist_and_read_test_story(tmp_path: Path, domain: Domain):
    graph = await training.extract_story_graph("data/test_stories/stories.md", domain)
    out_path = tmp_path / "persisted_story.md"
    Story(graph.story_steps).dump_to_file(str(out_path))

    recovered_trackers = await training.load_data(
        str(out_path),
        domain,
        use_story_concatenation=False,
        tracker_limit=1000,
        remove_duplicates=False,
    )
    existing_trackers = await training.load_data(
        "data/test_stories/stories.md",
        domain,
        use_story_concatenation=False,
        tracker_limit=1000,
        remove_duplicates=False,
    )
    existing_stories = {
        t.export_stories(MarkdownStoryWriter()) for t in existing_trackers
    }
    for t in recovered_trackers:
        story_str = t.export_stories(MarkdownStoryWriter())
        assert story_str in existing_stories
        existing_stories.discard(story_str)
Exemple #2
0
    async def test_rephrasing_instead_affirmation(
        self,
        default_channel: OutputChannel,
        default_nlg: NaturalLanguageGenerator,
        default_domain: Domain,
    ):
        events = [
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("greet", 1),
            ActionExecuted("utter_hello"),
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("greet", 0.2),
            ActionExecuted(ACTION_DEFAULT_ASK_AFFIRMATION_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("bye", 1),
        ]

        tracker = await self._get_tracker_after_reverts(
            events, default_channel, default_nlg, default_domain
        )

        assert "bye" == tracker.latest_message.parse_data["intent"][INTENT_NAME_KEY]
        assert tracker.export_stories(MarkdownStoryWriter()) == (
            "## sender\n* greet\n    - utter_hello\n* bye\n"
        )
Exemple #3
0
async def test_tracker_dump_e2e_story(default_agent: Agent):
    sender_id = "test_tracker_dump_e2e_story"

    await default_agent.handle_text("/greet", sender_id=sender_id)
    await default_agent.handle_text("/goodbye", sender_id=sender_id)
    tracker = default_agent.tracker_store.get_or_create_tracker(sender_id)

    story = tracker.export_stories(MarkdownStoryWriter(), e2e=True)
    assert story.strip().split("\n") == [
        "## test_tracker_dump_e2e_story",
        "* greet: /greet",
        "    - utter_greet",
        "* goodbye: /goodbye",
    ]
    async def test_successful_rephrasing(self, default_channel, default_nlg,
                                         default_domain):
        events = [
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("greet", 0.2),
            ActionExecuted(ACTION_DEFAULT_ASK_AFFIRMATION_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("deny", 1),
            ActionExecuted(ACTION_DEFAULT_ASK_REPHRASE_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("bye", 1),
        ]

        tracker = await self._get_tracker_after_reverts(
            events, default_channel, default_nlg, default_domain)

        assert "bye" == tracker.latest_message.parse_data["intent"][
            INTENT_NAME_KEY]
        assert tracker.export_stories(
            MarkdownStoryWriter()) == "## sender\n* bye\n"
Exemple #5
0
async def test_persist_legacy_form_story():
    domain = Domain.load("data/test_domains/form.yml")

    tracker = DialogueStateTracker("", domain.slots)

    story = (
        "* greet\n"
        "    - utter_greet\n"
        "* start_form\n"
        "    - some_form\n"
        '    - form{"name": "some_form"}\n'
        "* default\n"
        "    - utter_default\n"
        "    - some_form\n"
        "* stop\n"
        "    - utter_ask_continue\n"
        "* affirm\n"
        "    - some_form\n"
        "* stop\n"
        "    - utter_ask_continue\n"
        "* inform\n"
        "    - some_form\n"
        '    - form{"name": null}\n'
        "* goodbye\n"
        "    - utter_goodbye\n"
    )

    # simulate talking to the form
    events = [
        UserUttered(intent={"name": "greet"}),
        ActionExecuted("utter_greet"),
        ActionExecuted("action_listen"),
        # start the form
        UserUttered(intent={"name": "start_form"}),
        ActionExecuted("some_form"),
        ActiveLoop("some_form"),
        ActionExecuted("action_listen"),
        # out of form input
        UserUttered(intent={"name": "default"}),
        ActionExecutionRejected("some_form"),
        ActionExecuted("utter_default"),
        ActionExecuted("some_form"),
        ActionExecuted("action_listen"),
        # out of form input
        UserUttered(intent={"name": "stop"}),
        ActionExecutionRejected("some_form"),
        ActionExecuted("utter_ask_continue"),
        ActionExecuted("action_listen"),
        # out of form input but continue with the form
        UserUttered(intent={"name": "affirm"}),
        LoopInterrupted(True),
        ActionExecuted("some_form"),
        ActionExecuted("action_listen"),
        # out of form input
        UserUttered(intent={"name": "stop"}),
        ActionExecutionRejected("some_form"),
        ActionExecuted("utter_ask_continue"),
        ActionExecuted("action_listen"),
        # form input
        UserUttered(intent={"name": "inform"}),
        LoopInterrupted(False),
        ActionExecuted("some_form"),
        ActionExecuted("action_listen"),
        ActiveLoop(None),
        UserUttered(intent={"name": "goodbye"}),
        ActionExecuted("utter_goodbye"),
        ActionExecuted("action_listen"),
    ]
    [tracker.update(e) for e in events]

    story = story.replace(f"- {LegacyForm.type_name}", f"- {ActiveLoop.type_name}")

    assert story in tracker.export_stories(MarkdownStoryWriter())
Exemple #6
0
def test_skip_markdown_writing_deprecation():
    with pytest.warns(None) as warnings:
        MarkdownStoryWriter.dumps([], ignore_deprecation_warning=True)

    assert not warnings
Exemple #7
0
def test_markdown_writing_deprecation():
    with pytest.warns(FutureWarning):
        MarkdownStoryWriter().dumps([])