コード例 #1
0
ファイル: test_trackers.py プロジェクト: alexlana/rasa-1
def test_events_metadata():
    # It should be possible to attach arbitrary metadata to any event and then
    # retrieve it after getting the tracker dict representation.
    events = [
        ActionExecuted("one", metadata={"one": 1}),
        user_uttered("two", 1, metadata={"two": 2}),
        ActionExecuted(ACTION_LISTEN_NAME, metadata={"three": 3}),
    ]

    events = get_tracker(events).current_state(EventVerbosity.ALL)["events"]
    assert events[0]["metadata"] == {"one": 1}
    assert events[1]["metadata"] == {"two": 2}
    assert events[2]["metadata"] == {"three": 3}
コード例 #2
0
    async def test_rephrasing_instead_affirmation(
        self, default_channel, default_nlg, default_domain
    ):
        events = [
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("greet", 1),
            ActionExecuted("utter_hello"),
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("greet", 0.2),
            ActionExecuted(ACTION_DEFAULT_ASK_AFFIRMATION_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("bye", 1),
        ]

        tracker = await self._get_tracker_after_reverts(
            events, default_channel, default_nlg, default_domain
        )

        assert "bye" == tracker.latest_message.parse_data["intent"]["name"]
        assert tracker.export_stories() == (
            "## sender\n* greet\n    - utter_hello\n* bye\n"
        )
コード例 #3
0
 def test_predict_action_listen(self, priority, domain_with_mapping, intent_mapping):
     policy = self.create_policy(None, priority)
     events = [
         ActionExecuted(ACTION_LISTEN_NAME),
         user_uttered(intent_mapping[0], 1),
         ActionExecuted(intent_mapping[1], policy="policy_0_MappingPolicy"),
     ]
     tracker = get_tracker(events)
     scores = policy.predict_action_probabilities(tracker, domain_with_mapping)
     index = scores.index(max(scores))
     action_planned = domain_with_mapping.action_names[index]
     assert action_planned == ACTION_LISTEN_NAME
     assert scores != [0] * domain_with_mapping.num_actions
コード例 #4
0
    async def test_affirmed_rephrasing(self, trained_policy,
                                       default_dispatcher_collecting,
                                       default_domain):
        events = [
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("greet", 0.2),
            ActionExecuted(ACTION_DEFAULT_ASK_AFFIRMATION_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("deny", 1),
            ActionExecuted(ACTION_DEFAULT_ASK_REPHRASE_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("bye", 0.2),
            ActionExecuted(ACTION_DEFAULT_ASK_AFFIRMATION_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("bye", 1),
        ]

        tracker = await self._get_tracker_after_reverts(
            events, default_dispatcher_collecting, default_domain)

        assert "bye" == tracker.latest_message.parse_data["intent"]["name"]
        assert tracker.export_stories() == "## sender\n* bye\n"
コード例 #5
0
 def test_do_not_follow_other_policy(self, priority, domain_with_mapping,
                                     intent_mapping):
     policy = self.create_policy(None, priority)
     events = [
         ActionExecuted(ACTION_LISTEN_NAME),
         user_uttered(intent_mapping[0], 1),
         ActionExecuted(intent_mapping[1], policy="other_policy"),
     ]
     tracker = get_tracker(events)
     scores = policy.predict_action_probabilities(tracker,
                                                  domain_with_mapping,
                                                  RegexInterpreter())
     assert scores == [0] * domain_with_mapping.num_actions
コード例 #6
0
    def test_predict_mapped_action(
        self,
        priority: int,
        domain_with_mapping: Domain,
        intent_mapping: Tuple[Text, Text],
    ):
        policy = self.create_policy(None, priority)
        events = [
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered(intent_mapping[0], 1),
        ]

        assert (self._get_next_action(
            policy, events, domain_with_mapping) == intent_mapping[1])
コード例 #7
0
ファイル: test_policies.py プロジェクト: ysinjab/rasa
    async def test_successful_rephrasing(
        self,
        default_channel: OutputChannel,
        default_nlg: NaturalLanguageGenerator,
        default_domain: Domain,
    ):
        events = [
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("greet", 0.2),
            ActionExecuted(ACTION_DEFAULT_ASK_AFFIRMATION_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("deny", 1),
            ActionExecuted(ACTION_DEFAULT_ASK_REPHRASE_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("bye", 1),
        ]

        tracker = await self._get_tracker_after_reverts(
            events, default_channel, default_nlg, default_domain
        )

        assert "bye" == tracker.latest_message.parse_data["intent"][INTENT_NAME_KEY]
        assert tracker.export_stories(MarkdownStoryWriter()) == "## sender\n* bye\n"
コード例 #8
0
ファイル: test_policies.py プロジェクト: yungliu/rasa_nlu
    async def test_rephrasing_instead_affirmation(self, trained_policy,
                                                  default_dispatcher_collecting,
                                                  default_domain):
        events = [ActionExecuted(ACTION_LISTEN_NAME),
                  user_uttered("greet", 1),
                  ActionExecuted("utter_hello"),
                  ActionExecuted(ACTION_LISTEN_NAME),
                  user_uttered("greet", 0.2),
                  ActionExecuted(ACTION_DEFAULT_ASK_AFFIRMATION_NAME),
                  ActionExecuted(ACTION_LISTEN_NAME),
                  user_uttered("bye", 1),
                  ]

        tracker = await self._get_tracker_after_reverts(
            events,
            default_dispatcher_collecting,
            default_domain
        )

        assert 'bye' == tracker.latest_message.parse_data['intent']['name']
        assert tracker.export_stories() == ("## sender\n"
                                            "* greet\n"
                                            "    - utter_hello\n"
                                            "* bye\n")
コード例 #9
0
    async def test_rephrasing_instead_affirmation(
        self,
        default_channel: OutputChannel,
        default_nlg: NaturalLanguageGenerator,
        default_domain: Domain,
    ):
        events = [
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("greet", 1),
            ActionExecuted("utter_hello"),
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("greet", 0.2),
            ActionExecuted(ACTION_DEFAULT_ASK_AFFIRMATION_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            user_uttered("bye", 1),
        ]

        tracker = await self._get_tracker_after_reverts(
            events, default_channel, default_nlg, default_domain)

        assert "bye" == tracker.latest_message.parse_data["intent"][
            INTENT_NAME_KEY]
        assert tracker.export_stories(MarkdownStoryWriter(), e2e=True) == (
            "## sender\n* greet: Random\n    - utter_hello\n* bye: Random\n")
コード例 #10
0
 def test_do_not_follow_other_policy(
     self,
     priority: int,
     domain_with_mapping: Domain,
     intent_mapping: Tuple[Text, Text],
 ):
     policy = self.create_policy(None, priority)
     events = [
         ActionExecuted(ACTION_LISTEN_NAME),
         user_uttered(intent_mapping[0], 1),
         ActionExecuted(intent_mapping[1], policy="other_policy"),
     ]
     tracker = get_tracker(events)
     prediction = policy.predict_action_probabilities(
         tracker, domain_with_mapping, RegexInterpreter())
     assert prediction.probabilities == [0
                                         ] * domain_with_mapping.num_actions
     assert not prediction.is_end_to_end_prediction
コード例 #11
0
def test_fallback_wins_over_mapping():
    domain = Domain.load("data/test_domains/default.yml")
    events = [
        ActionExecuted(ACTION_LISTEN_NAME),
        # Low confidence should trigger fallback
        utilities.user_uttered(USER_INTENT_RESTART, 0.0001),
    ]
    tracker = DialogueStateTracker.from_events("test", events, [])

    ensemble = SimplePolicyEnsemble([FallbackPolicy(), MappingPolicy()])

    result, best_policy = ensemble.probabilities_using_best_policy(
        tracker, domain)
    max_confidence_index = result.index(max(result))
    index_of_fallback_policy = 0
    next_action = domain.action_for_index(max_confidence_index, None)

    assert best_policy == f"policy_{index_of_fallback_policy}_{FallbackPolicy.__name__}"
    assert next_action.name() == ACTION_DEFAULT_FALLBACK_NAME
コード例 #12
0
 def test_predict_action_listen(
     self,
     priority: int,
     domain_with_mapping: Domain,
     intent_mapping: Tuple[Text, Text],
 ):
     policy = self.create_policy(None, priority)
     events = [
         ActionExecuted(ACTION_LISTEN_NAME),
         user_uttered(intent_mapping[0], 1),
         ActionExecuted(intent_mapping[1], policy="policy_0_MappingPolicy"),
     ]
     tracker = get_tracker(events)
     prediction = policy.predict_action_probabilities(
         tracker, domain_with_mapping, RegexInterpreter())
     index = prediction.probabilities.index(max(prediction.probabilities))
     action_planned = domain_with_mapping.action_names_or_texts[index]
     assert not prediction.is_end_to_end_prediction
     assert action_planned == ACTION_LISTEN_NAME
     assert prediction.probabilities != [0
                                         ] * domain_with_mapping.num_actions
コード例 #13
0
def test_fallback_mapping_restart():
    domain = Domain.load("data/test_domains/default.yml")
    events = [
        ActionExecuted(ACTION_DEFAULT_FALLBACK_NAME),
        utilities.user_uttered(USER_INTENT_RESTART, 1),
    ]
    tracker = DialogueStateTracker.from_events("test", events, [])

    two_stage_fallback_policy = TwoStageFallbackPolicy(
        priority=2, deny_suggestion_intent_name="deny")
    mapping_policy = MappingPolicy(priority=1)

    mapping_fallback_ensemble = SimplePolicyEnsemble(
        [two_stage_fallback_policy, mapping_policy])

    result, best_policy = mapping_fallback_ensemble.probabilities_using_best_policy(
        tracker, domain)
    max_confidence_index = result.index(max(result))
    index_of_mapping_policy = 1
    next_action = domain.action_for_index(max_confidence_index, None)

    assert best_policy == f"policy_{index_of_mapping_policy}_{MappingPolicy.__name__}"
    assert next_action.name() == ACTION_RESTART_NAME
コード例 #14
0
def test_form_wins_over_everything_else(ensemble: SimplePolicyEnsemble):
    form_name = "test-form"
    domain = f"""
    forms:
    - {form_name}
    """
    domain = Domain.from_yaml(domain)

    events = [
        Form("test-form"),
        ActionExecuted(ACTION_LISTEN_NAME),
        utilities.user_uttered("test", 1),
    ]
    tracker = DialogueStateTracker.from_events("test", events, [])
    result, best_policy = ensemble.probabilities_using_best_policy(
        tracker, domain)

    max_confidence_index = result.index(max(result))
    next_action = domain.action_for_index(max_confidence_index, None)

    index_of_form_policy = 0
    assert best_policy == f"policy_{index_of_form_policy}_{FormPolicy.__name__}"
    assert next_action.name() == form_name
コード例 #15
0
def test_form_wins_over_everything_else(ensemble: SimplePolicyEnsemble):
    form_name = "test-form"
    domain = f"""
    forms:
    - {form_name}
    """
    domain = Domain.from_yaml(domain)

    events = [
        ActiveLoop("test-form"),
        ActionExecuted(ACTION_LISTEN_NAME),
        utilities.user_uttered("test", 1),
    ]
    tracker = DialogueStateTracker.from_events("test", events, [])
    prediction = ensemble.probabilities_using_best_policy(
        tracker, domain, RegexInterpreter())

    next_action = rasa.core.actions.action.action_for_index(
        prediction.max_confidence_index, domain, None)

    index_of_form_policy = 0
    assert (prediction.policy_name ==
            f"policy_{index_of_form_policy}_{FormPolicy.__name__}")
    assert next_action.name() == form_name
コード例 #16
0
def test_get_last_event_with_reverted():
    events = [ActionExecuted("one"), ActionReverted(), user_uttered("two", 1)]

    tracker = get_tracker(events)

    assert tracker.get_last_event_for(ActionExecuted) is None
コード例 #17
0
def test_get_last_event_for():
    events = [ActionExecuted("one"), user_uttered("two", 1)]

    tracker = get_tracker(events)

    assert tracker.get_last_event_for(ActionExecuted).action_name == "one"
コード例 #18
0
ファイル: test_ensemble.py プロジェクト: xeronith/rasa
    )
    max_confidence_index = result.index(max(result))
    index_of_mapping_policy = 1
    next_action = domain.action_for_index(max_confidence_index, None)

    assert best_policy == f"policy_{index_of_mapping_policy}_{MappingPolicy.__name__}"
    assert next_action.name() == ACTION_RESTART_NAME


@pytest.mark.parametrize(
    "events",
    [
        [
            Form("test-form"),
            ActionExecuted(ACTION_LISTEN_NAME),
            utilities.user_uttered(USER_INTENT_RESTART, 1),
        ],
        [
            ActionExecuted(ACTION_LISTEN_NAME),
            utilities.user_uttered(USER_INTENT_RESTART, 1),
        ],
    ],
)
def test_mapping_wins_over_form(events: List[Event]):
    domain = """
    forms:
    - test-form
    """
    domain = Domain.from_yaml(domain)
    tracker = DialogueStateTracker.from_events("test", events, [])
コード例 #19
0
ファイル: test_policies.py プロジェクト: sanaayakurup/rasa-1
    def test_restart_if_paused(self, priority, domain_with_mapping):
        policy = self.create_policy(None, priority)
        events = [ConversationPaused(), user_uttered(USER_INTENT_RESTART, 1)]

        assert (self._get_next_action(
            policy, events, domain_with_mapping) == ACTION_RESTART_NAME)
コード例 #20
0
ファイル: test_trackers.py プロジェクト: malhotra1432/rasa-1
    # assert that the tracker contains the slot with the modified value
    assert tracker.get_slot(slot_name) == value_to_set

    # assert that the initial slot has not been affected
    assert slot.value == initial_value


@pytest.mark.parametrize(
    "events, expected_applied_events",
    [
        (
            [
                # Form gets triggered.
                ActionExecuted(ACTION_LISTEN_NAME),
                user_uttered("fill_whole_form"),
                # Form executes and fills slots.
                ActionExecuted("loop"),
                ActiveLoop("loop"),
                SlotSet("slot1", "value"),
                SlotSet("slot2", "value2"),
            ],
            [
                ActionExecuted(ACTION_LISTEN_NAME),
                user_uttered("fill_whole_form"),
                ActionExecuted("loop"),
                ActiveLoop("loop"),
                SlotSet("slot1", "value"),
                SlotSet("slot2", "value2"),
            ],
        ),
コード例 #21
0
ファイル: test_trackers.py プロジェクト: pranavdurai10/rasa
def test_get_last_event_for_with_skip():
    events = [ActionExecuted("one"), user_uttered("two", 1), ActionExecuted("three")]

    tracker = get_tracker(events)

    assert tracker.get_last_event_for(ActionExecuted, skip=1).action_name == "one"
コード例 #22
0
ファイル: test_policies.py プロジェクト: ysinjab/rasa
    def test_ask_affirmation(self, trained_policy: Policy, default_domain: Domain):
        events = [ActionExecuted(ACTION_LISTEN_NAME), user_uttered("Hi", 0.2)]

        next_action = self._get_next_action(trained_policy, events, default_domain)

        assert next_action == ACTION_DEFAULT_ASK_AFFIRMATION_NAME
コード例 #23
0
def test_get_last_event_for():
    events = [ActionExecuted('one'), user_uttered('two', 1)]

    tracker = get_tracker(events)

    assert tracker.get_last_event_for(ActionExecuted).action_name == 'one'