コード例 #1
0
def test_form_submit_rule():
    form_name = "some_form"
    submit_action_name = "utter_submit"
    domain = Domain.from_yaml(f"""
        intents:
        - {GREET_INTENT_NAME}
        actions:
        - {UTTER_GREET_ACTION}
        - some-action
        - {submit_action_name}
        slots:
          {REQUESTED_SLOT}:
            type: unfeaturized
        forms:
        - {form_name}
    """)

    form_submit_rule = _form_submit_rule(domain, submit_action_name, form_name)

    policy = RulePolicy()
    policy.train([GREET_RULE, form_submit_rule], domain, RegexInterpreter())

    form_conversation = DialogueStateTracker.from_events(
        "in a form",
        evts=[
            # Form was activated
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered("haha", {"name": GREET_INTENT_NAME}),
            ActionExecuted(form_name),
            ActiveLoop(form_name),
            SlotSet(REQUESTED_SLOT, "some value"),
            ActionExecuted(ACTION_LISTEN_NAME),
            # User responds and fills requested slot
            UserUttered("haha", {"name": GREET_INTENT_NAME}),
            ActionExecuted(form_name),
            # Form get's deactivated
            ActiveLoop(None),
            SlotSet(REQUESTED_SLOT, None),
        ],
        slots=domain.slots,
    )

    # RulePolicy predicts action which handles submit
    action_probabilities = policy.predict_action_probabilities(
        form_conversation, domain, RegexInterpreter())
    assert_predicted_action(action_probabilities, domain, submit_action_name)
コード例 #2
0
ファイル: action.py プロジェクト: sony2/rasa-for-botfront
 async def run(
     self,
     output_channel: "OutputChannel",
     nlg: "NaturalLanguageGenerator",
     tracker: "DialogueStateTracker",
     domain: "Domain",
 ) -> List[Event]:
     return [ActiveLoop(None), SlotSet(REQUESTED_SLOT, None)]
コード例 #3
0
async def test_import_nlu_training_data_from_e2e_stories(
    default_importer: TrainingDataImporter,
):
    # The `E2EImporter` correctly wraps the underlying `CombinedDataImporter`
    assert isinstance(default_importer, E2EImporter)
    importer_without_e2e = default_importer.importer

    stories = StoryGraph(
        [
            StoryStep(
                events=[
                    SlotSet("some slot", "doesn't matter"),
                    UserUttered("greet_from_stories", {"name": "greet_from_stories"}),
                    ActionExecuted("utter_greet_from_stories"),
                ]
            ),
            StoryStep(
                events=[
                    UserUttered("how are you doing?"),
                    ActionExecuted("utter_greet_from_stories", action_text="Hi Joey."),
                ]
            ),
        ]
    )

    async def mocked_stories(*_: Any, **__: Any) -> StoryGraph:
        return stories

    # Patch to return our test stories
    importer_without_e2e.get_stories = mocked_stories

    # The wrapping `E2EImporter` simply forwards these method calls
    assert (await importer_without_e2e.get_stories()).as_story_string() == (
        await default_importer.get_stories()
    ).as_story_string()
    assert (await importer_without_e2e.get_config()) == (
        await default_importer.get_config()
    )

    # Check additional NLU training data from stories was added
    nlu_data = await default_importer.get_nlu_data()

    # The `E2EImporter` adds NLU training data based on our training stories
    assert len(nlu_data.training_examples) > len(
        (await importer_without_e2e.get_nlu_data()).training_examples
    )

    # Check if the NLU training data was added correctly from the story training data
    expected_additional_messages = [
        Message(data={TEXT: "greet_from_stories", INTENT_NAME: "greet_from_stories"}),
        Message(data={ACTION_NAME: "utter_greet_from_stories", ACTION_TEXT: ""}),
        Message(data={TEXT: "how are you doing?", INTENT_NAME: None}),
        Message(
            data={ACTION_NAME: "utter_greet_from_stories", ACTION_TEXT: "Hi Joey."}
        ),
    ]

    assert all(m in nlu_data.training_examples for m in expected_additional_messages)
コード例 #4
0
async def test_dont_predict_form_if_already_finished():
    form_name = "some_form"

    domain = Domain.from_yaml(f"""
    intents:
    - {GREET_INTENT_NAME}
    actions:
    - {UTTER_GREET_ACTION}
    - some-action
    slots:
      {REQUESTED_SLOT}:
        type: unfeaturized
    forms:
    - {form_name}
""")

    policy = RulePolicy()
    policy.train([GREET_RULE], domain, RegexInterpreter())

    form_conversation = DialogueStateTracker.from_events(
        "in a form",
        evts=[
            # We are in an activate form
            ActionExecuted(form_name),
            ActiveLoop(form_name),
            SlotSet(REQUESTED_SLOT, "some value"),
            ActionExecuted(ACTION_LISTEN_NAME),
            # User sends message as response to a requested slot
            UserUttered("haha", {"name": GREET_INTENT_NAME}),
            # Form is happy and deactivates itself
            ActionExecuted(form_name),
            ActiveLoop(None),
            SlotSet(REQUESTED_SLOT, None),
            # User sends another message. Form is already done. Shouldn't get triggered
            # again
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered("haha", {"name": GREET_INTENT_NAME}),
        ],
        slots=domain.slots,
    )

    # RulePolicy triggers form again
    action_probabilities = policy.predict_action_probabilities(
        form_conversation, domain, RegexInterpreter())
    assert_predicted_action(action_probabilities, domain, UTTER_GREET_ACTION)
コード例 #5
0
 async def run(
     self,
     output_channel: "OutputChannel",
     nlg: "NaturalLanguageGenerator",
     tracker: "DialogueStateTracker",
     domain: "Domain",
 ) -> List[Event]:
     """Runs action. Please see parent class for the full docstring."""
     return [ActiveLoop(None), SlotSet(REQUESTED_SLOT, None)]
コード例 #6
0
ファイル: action.py プロジェクト: karen-white/rasa
    def _slot_set_events_from_tracker(
        tracker: "DialogueStateTracker", ) -> List["SlotSet"]:
        """Fetch SlotSet events from tracker and carry over key, value and metadata."""

        return [
            SlotSet(key=event.key, value=event.value, metadata=event.metadata)
            for event in tracker.applied_events()
            if isinstance(event, SlotSet)
        ]
コード例 #7
0
ファイル: forms.py プロジェクト: tomasmadeira/Rasa-x
 async def is_done(
     self,
     output_channel: "OutputChannel",
     nlg: "NaturalLanguageGenerator",
     tracker: "DialogueStateTracker",
     domain: "Domain",
     events_so_far: List[Event],
 ) -> bool:
     return SlotSet(REQUESTED_SLOT, None) in events_so_far
コード例 #8
0
async def test_predict_form_action_if_multiple_turns():
    form_name = "some_form"
    other_intent = "bye"
    domain = Domain.from_yaml(f"""
    intents:
    - {GREET_INTENT_NAME}
    - {other_intent}
    actions:
    - {UTTER_GREET_ACTION}
    - some-action
    slots:
      {REQUESTED_SLOT}:
        type: unfeaturized
    forms:
    - {form_name}
""")

    policy = RulePolicy()
    policy.train([GREET_RULE], domain, RegexInterpreter())

    form_conversation = DialogueStateTracker.from_events(
        "in a form",
        evts=[
            # We are in an active form
            ActionExecuted(form_name),
            ActiveLoop(form_name),
            SlotSet(REQUESTED_SLOT, "some value"),
            # User responds to slot request
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered("haha", {"name": GREET_INTENT_NAME}),
            # Form validates input and requests another slot
            ActionExecuted(form_name),
            SlotSet(REQUESTED_SLOT, "some other"),
            # User responds to 2nd slot request
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered("haha", {"name": other_intent}),
        ],
        slots=domain.slots,
    )

    # RulePolicy triggers form again
    action_probabilities = policy.predict_action_probabilities(
        form_conversation, domain, RegexInterpreter())
    assert_predicted_action(action_probabilities, domain, form_name)
コード例 #9
0
ファイル: test_processor.py プロジェクト: ysinjab/rasa
async def test_restart_triggers_session_start(
    default_channel: CollectingOutputChannel,
    default_processor: MessageProcessor,
    monkeypatch: MonkeyPatch,
):
    # The rule policy is trained and used so as to allow the default action ActionRestart to be predicted
    rule_policy = RulePolicy()
    rule_policy.train([], default_processor.domain, RegexInterpreter())
    monkeypatch.setattr(
        default_processor.policy_ensemble,
        "policies",
        [rule_policy, *default_processor.policy_ensemble.policies],
    )

    sender_id = uuid.uuid4().hex

    entity = "name"
    slot_1 = {entity: "name1"}
    await default_processor.handle_message(
        UserMessage(f"/greet{json.dumps(slot_1)}", default_channel, sender_id)
    )

    assert default_channel.latest_output() == {
        "recipient_id": sender_id,
        "text": "hey there name1!",
    }

    # This restarts the chat
    await default_processor.handle_message(
        UserMessage("/restart", default_channel, sender_id)
    )

    tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

    expected = [
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted(),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered(
            f"/greet{json.dumps(slot_1)}",
            {INTENT_NAME_KEY: "greet", "confidence": 1.0},
            [{"entity": entity, "start": 6, "end": 23, "value": "name1"}],
        ),
        SlotSet(entity, slot_1[entity]),
        ActionExecuted("utter_greet"),
        BotUttered("hey there name1!", metadata={"template_name": "utter_greet"}),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("/restart", {INTENT_NAME_KEY: "restart", "confidence": 1.0}),
        ActionExecuted(ACTION_RESTART_NAME),
        Restarted(),
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted(),
        # No previous slot is set due to restart.
        ActionExecuted(ACTION_LISTEN_NAME),
    ]
    assert list(tracker.events) == expected
コード例 #10
0
def test_operator_occur_never_applied_negated():
    events_expected = [
        (UserUttered(intent={INTENT_NAME_KEY: "1"}), False),
        (SlotSet("2", value=None), False),
        (UserUttered(intent={INTENT_NAME_KEY: "0"}), False),
        (SlotSet("1", value="test"), False),
    ]
    events, expected = zip(*events_expected)
    sub_marker = OrMarker([IntentDetectedMarker("1"),
                           SlotSetMarker("2")],
                          name="or marker",
                          negated=False)
    marker = OccurrenceMarker([sub_marker],
                              name="or never occurred",
                              negated=True)
    for event in events:
        marker.track(event)

    assert marker.relevant_events() == []
コード例 #11
0
def test_operator_and(negated: bool):
    events_expected = [
        (UserUttered(intent={INTENT_NAME_KEY: "1"}), False),
        (SlotSet("2", value="bla"), False),
        (UserUttered(intent={INTENT_NAME_KEY: "1"}), True),
        (SlotSet("2", value=None), False),
        (UserUttered(intent={INTENT_NAME_KEY: "1"}), False),
        (SlotSet("2", value="bla"), False),
        (UserUttered(intent={INTENT_NAME_KEY: "2"}), False),
    ]
    events, expected = zip(*events_expected)
    sub_markers = [IntentDetectedMarker("1"), SlotSetMarker("2")]
    marker = AndMarker(sub_markers, name="marker_name", negated=negated)
    for event in events:
        marker.track(event)
    expected = list(expected)
    if negated:
        expected = [not applies for applies in expected]
    assert marker.history == expected
コード例 #12
0
 def set_slot(tracker, message):
     if len(message.get("quick_replies", [])) < 2:
         return None  # abort if only deny_suggestions button would be shown
     try:
         tracker.update(SlotSet("disambiguation_message", value=message))
         result = message
     except Exception as e:
         logger.error("Could not set message slot: {}".format(e))
         result = None
     return result
コード例 #13
0
ファイル: domain.py プロジェクト: malhotra1432/rasa-1
 def slots_for_entities(self, entities: List[Dict[Text,
                                                  Any]]) -> List[SlotSet]:
     if self.store_entities_as_slots:
         slot_events = []
         for s in self.slots:
             if s.auto_fill:
                 matching_entities = [
                     e["value"] for e in entities if e["entity"] == s.name
                 ]
                 if matching_entities:
                     if s.type_name == "list":
                         slot_events.append(
                             SlotSet(s.name, matching_entities))
                     else:
                         slot_events.append(
                             SlotSet(s.name, matching_entities[-1]))
         return slot_events
     else:
         return []
コード例 #14
0
async def test_activate_and_immediate_deactivate():
    slot_name = "num_people"
    slot_value = 5

    tracker = DialogueStateTracker.from_events(
        sender_id="bla",
        evts=[
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered(
                "haha",
                {"name": "greet"},
                entities=[{
                    "entity": slot_name,
                    "value": slot_value
                }],
            ),
        ],
    )
    form_name = "my form"
    action = FormAction(form_name, None)
    domain = f"""
    forms:
      {form_name}:
        {slot_name}:
        - type: from_entity
          entity: {slot_name}
    slots:
      {slot_name}:
        type: unfeaturized
    """
    domain = Domain.from_yaml(domain)
    events = await action.run(
        CollectingOutputChannel(),
        TemplatedNaturalLanguageGenerator(domain.templates),
        tracker,
        domain,
    )
    assert events == [
        ActiveLoop(form_name),
        SlotSet(slot_name, slot_value),
        SlotSet(REQUESTED_SLOT, None),
        ActiveLoop(None),
    ]
コード例 #15
0
ファイル: test_slots.py プロジェクト: ChenHuaYou/rasa
    def test_apply_single_item_to_slot(
        self, value: Any, mappings: List[Dict[Text, Any]]
    ):
        slot = self.create_slot(mappings=mappings, influence_conversation=False)
        tracker = DialogueStateTracker.from_events("sender", evts=[], slots=[slot])

        slot_event = SlotSet(slot.name, value)
        tracker.update(slot_event)

        assert tracker.slots[slot.name].value == ["cat"]
コード例 #16
0
async def test_markers_cli_results_save_correctly(tmp_path: Path):
    domain = Domain.empty()
    store = InMemoryTrackerStore(domain)

    for i in range(5):
        tracker = DialogueStateTracker(str(i), None)
        tracker.update_with_events([SlotSet(str(j), "slot") for j in range(5)],
                                   domain)
        tracker.update(ActionExecuted(ACTION_SESSION_START_NAME))
        tracker.update(UserUttered("hello"))
        tracker.update_with_events(
            [SlotSet(str(5 + j), "slot") for j in range(5)], domain)
        await store.save(tracker)

    tracker_loader = MarkerTrackerLoader(store, "all")

    results_path = tmp_path / "results.csv"

    markers = OrMarker(markers=[
        SlotSetMarker("2", name="marker1"),
        SlotSetMarker("7", name="marker2")
    ])
    await markers.evaluate_trackers(tracker_loader.load(), results_path)

    with open(results_path, "r") as results:
        result_reader = csv.DictReader(results)
        senders = set()

        for row in result_reader:
            senders.add(row["sender_id"])
            if row["marker"] == "marker1":
                assert row["session_idx"] == "0"
                assert int(row["event_idx"]) >= 2
                assert row["num_preceding_user_turns"] == "0"

            if row["marker"] == "marker2":
                assert row["session_idx"] == "1"
                assert int(row["event_idx"]) >= 3
                assert row["num_preceding_user_turns"] == "1"

        assert len(senders) == 5
コード例 #17
0
def test_can_read_test_story_with_slots(domain: Domain):
    trackers = training.load_data(
        "data/test_yaml_stories/simple_story_with_only_end.yml",
        domain,
        use_story_concatenation=False,
        tracker_limit=1000,
        remove_duplicates=False,
    )
    assert len(trackers) == 1

    assert trackers[0].events[-2] == SlotSet(key="name", value="peter")
    assert trackers[0].events[-1] == ActionExecuted("action_listen")
コード例 #18
0
def test_yaml_slot_without_value_is_parsed(domain: Domain):
    yaml_file = "data/test_yaml_stories/story_with_slot_was_set.yml"

    tracker = training.load_data(
        yaml_file,
        domain,
        use_story_concatenation=False,
        tracker_limit=1000,
        remove_duplicates=False,
    )

    assert tracker[0].events[-2] == SlotSet(key="name", value=DEFAULT_VALUE_TEXT_SLOTS)
コード例 #19
0
ファイル: test_rule_policy.py プロジェクト: attgua/Geco
async def test_form_unhappy_path_from_general_rule():
    form_name = "some_form"

    domain = Domain.from_yaml(
        f"""
        intents:
        - {GREET_INTENT_NAME}
        actions:
        - {UTTER_GREET_ACTION}
        - some-action
        slots:
          {REQUESTED_SLOT}:
            type: unfeaturized
        forms:
        - {form_name}
    """
    )

    policy = RulePolicy()
    # RulePolicy should memorize that unhappy_rule overrides GREET_RULE
    policy.train([GREET_RULE], domain, RegexInterpreter())

    # Check that RulePolicy predicts action to handle unhappy path
    conversation_events = [
        ActionExecuted(form_name),
        ActiveLoop(form_name),
        SlotSet(REQUESTED_SLOT, "some value"),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("haha", {"name": GREET_INTENT_NAME}),
        ActionExecutionRejected(form_name),
    ]

    prediction = policy.predict_action_probabilities(
        DialogueStateTracker.from_events(
            "casd", evts=conversation_events, slots=domain.slots
        ),
        domain,
        RegexInterpreter(),
    )
    # check that general rule action is predicted
    assert_predicted_action(prediction, domain, UTTER_GREET_ACTION)

    # Check that RulePolicy triggers form again after handling unhappy path
    conversation_events.append(ActionExecuted(UTTER_GREET_ACTION))
    prediction = policy.predict_action_probabilities(
        DialogueStateTracker.from_events(
            "casd", evts=conversation_events, slots=domain.slots
        ),
        domain,
        RegexInterpreter(),
    )
    # check that action_listen from general rule is overwritten by form action
    assert_predicted_action(prediction, domain, form_name)
コード例 #20
0
ファイル: forms.py プロジェクト: praneethgb/rasa
    async def validate_slots(
        self,
        slot_candidates: Dict[Text, Any],
        tracker: "DialogueStateTracker",
        domain: Domain,
        output_channel: OutputChannel,
        nlg: NaturalLanguageGenerator,
    ) -> List[Union[SlotSet, Event]]:
        """Validate the extracted slots.

        If a custom action is available for validating the slots, we call it to validate
        them. Otherwise there is no validation.

        Args:
            slot_candidates: Extracted slots which are candidates to fill the slots
                required by the form.
            tracker: The current conversation tracker.
            domain: The current model domain.
            output_channel: The output channel which can be used to send messages
                to the user.
            nlg:  `NaturalLanguageGenerator` to use for response generation.

        Returns:
            The validation events including potential bot messages and `SlotSet` events
            for the validated slots.
        """
        logger.debug(f"Validating extracted slots: {slot_candidates}")
        events: List[Union[SlotSet, Event]] = [
            SlotSet(slot_name, value)
            for slot_name, value in slot_candidates.items()
        ]

        validate_name = f"validate_{self.name()}"

        if validate_name not in domain.action_names_or_texts:
            return events

        _tracker = self._temporary_tracker(tracker, events, domain)
        _action = RemoteAction(validate_name, self.action_endpoint)
        validate_events = await _action.run(output_channel, nlg, _tracker,
                                            domain)

        validated_slot_names = [
            event.key for event in validate_events
            if isinstance(event, SlotSet)
        ]

        # If the custom action doesn't return a SlotSet event for an extracted slot
        # candidate we assume that it was valid. The custom action has to return a
        # SlotSet(slot_name, None) event to mark a Slot as invalid.
        return validate_events + [
            event for event in events if event.key not in validated_slot_names
        ]
コード例 #21
0
def test_condition(condition_marker_type: Type[ConditionMarker],
                   negated: bool):
    """Each marker applies an exact number of times (slots are immediately un-set)."""
    marker = condition_marker_type(text="same-text",
                                   name="marker_name",
                                   negated=negated)
    events = [
        UserUttered(intent={"name": "1"}),
        UserUttered(intent={"name": "same-text"}),
        SlotSet("same-text", value="any"),
        SlotSet("same-text", value=None),
        ActionExecuted(action_name="same-text"),
    ]
    num_non_negated_condition_applies = 3
    events = events * num_non_negated_condition_applies
    for event in events:
        marker.track(event)
    assert len(marker.history) == len(events)
    expected = (num_non_negated_condition_applies if not negated else
                (len(events) - num_non_negated_condition_applies))
    assert sum(marker.history) == expected
コード例 #22
0
def test_sessions_evaluated_returns_event_indices_wrt_tracker_not_dialogue():
    events = [
        ActionExecuted(action_name=ACTION_SESSION_START_NAME),
        UserUttered(intent={INTENT_NAME_KEY: "ignored"}),
        UserUttered(intent={INTENT_NAME_KEY: "ignored"}),
        UserUttered(intent={INTENT_NAME_KEY: "ignored"}),
        SlotSet("same-text", value="any"),
        ActionExecuted(action_name=ACTION_SESSION_START_NAME),
        UserUttered(intent={INTENT_NAME_KEY: "no-slot-set-here"}),
        UserUttered(intent={INTENT_NAME_KEY: "no-slot-set-here"}),
        SlotSet("same-text", value="any"),
    ]
    marker = SlotSetMarker(text="same-text", name="my-marker")
    evaluation = marker.evaluate_events(events)
    assert len(evaluation) == 2
    assert len(evaluation[0]["my-marker"]) == 1
    assert evaluation[0]["my-marker"][0].preceding_user_turns == 3
    assert evaluation[0]["my-marker"][0].idx == 4
    assert len(evaluation[1]["my-marker"]) == 1
    assert evaluation[1]["my-marker"][0].preceding_user_turns == 2
    assert evaluation[1]["my-marker"][
        0].idx == 8  # i.e. NOT the index in the dialogue
コード例 #23
0
def _tracker_store_and_tracker_with_slot_set(
) -> Tuple[InMemoryTrackerStore, DialogueStateTracker]:
    # returns an InMemoryTrackerStore containing a tracker with a slot set

    slot_key = "cuisine"
    slot_val = "French"

    store = InMemoryTrackerStore(domain)
    tracker = store.get_or_create_tracker(DEFAULT_SENDER_ID)
    ev = SlotSet(slot_key, slot_val)
    tracker.update(ev)

    return store, tracker
コード例 #24
0
def get_or_create_tracker_store(store: TrackerStore) -> None:
    slot_key = "location"
    slot_val = "Easter Island"

    tracker = store.get_or_create_tracker(DEFAULT_SENDER_ID)
    ev = SlotSet(slot_key, slot_val)
    tracker.update(ev)
    assert tracker.get_slot(slot_key) == slot_val

    store.save(tracker)

    again = store.get_or_create_tracker(DEFAULT_SENDER_ID)
    assert again.get_slot(slot_key) == slot_val
コード例 #25
0
def test_rule_with_condition(rule_steps_without_stories: List[StoryStep]):
    rule = rule_steps_without_stories[0]
    assert rule.block_name == "Rule with condition"
    assert rule.events == [
        ActiveLoop("loop_q_form"),
        SlotSet("requested_slot", "some_slot"),
        ActionExecuted(RULE_SNIPPET_ACTION_NAME),
        UserUttered(
            intent={"name": "inform", "confidence": 1.0},
            entities=[{"entity": "some_slot", "value": "bla"}],
        ),
        ActionExecuted("loop_q_form"),
    ]
コード例 #26
0
    async def validate_slots(
        self,
        slot_candidates: Dict[Text, Any],
        tracker: "DialogueStateTracker",
        domain: Domain,
        output_channel: OutputChannel,
        nlg: NaturalLanguageGenerator,
    ) -> List[Union[SlotSet, Event]]:
        """Validate the extracted slots.

        If a custom action is available for validating the slots, we call it to validate
        them. Otherwise there is no validation.

        Args:
            slot_candidates: Extracted slots which are candidates to fill the slots
                required by the form.
            tracker: The current conversation tracker.
            domain: The current model domain.
            output_channel: The output channel which can be used to send messages
                to the user.
            nlg:  `NaturalLanguageGenerator` to use for response generation.

        Returns:
            The validation events including potential bot messages and `SlotSet` events
            for the validated slots, if the custom form validation action is present in
            domain actions.
            Otherwise, returns empty list since the extracted slots already have
            corresponding `SlotSet` events in the tracker.
        """
        logger.debug(f"Validating extracted slots: {slot_candidates}")
        events: List[Union[SlotSet, Event]] = [
            SlotSet(slot_name, value)
            for slot_name, value in slot_candidates.items()
        ]

        validate_name = f"validate_{self.name()}"

        if validate_name not in domain.action_names_or_texts:
            return []

        # create temporary tracker with only the SlotSet events added
        # since last user utterance
        _tracker = self._temporary_tracker(tracker, events, domain)

        _action = RemoteAction(validate_name, self.action_endpoint)
        validate_events = await _action.run(output_channel, nlg, _tracker,
                                            domain)

        # Only return the validated SlotSet events by the custom form validation action
        # to avoid adding duplicate SlotSet events for slots that are already valid.
        return validate_events
コード例 #27
0
    async def request_next_slot(
        self,
        tracker: "DialogueStateTracker",
        domain: Domain,
        output_channel: OutputChannel,
        nlg: NaturalLanguageGenerator,
        events_so_far: List[Event],
    ) -> List[Event]:
        """Request the next slot and response if needed, else return `None`."""
        request_slot_events = []

        if await self.is_done(output_channel, nlg, tracker, domain, events_so_far):
            # The custom action for slot validation decided to stop the form early
            return [SlotSet(REQUESTED_SLOT, None)]

        slot_to_request = next(
            (
                event.value
                for event in events_so_far
                if isinstance(event, SlotSet) and event.key == REQUESTED_SLOT
            ),
            None,
        )

        temp_tracker = self._temporary_tracker(tracker, events_so_far, domain)

        if not slot_to_request:
            slot_to_request = self._find_next_slot_to_request(temp_tracker, domain)
            request_slot_events.append(SlotSet(REQUESTED_SLOT, slot_to_request))

        if slot_to_request:
            bot_message_events = await self._ask_for_slot(
                domain, nlg, output_channel, slot_to_request, temp_tracker
            )
            return request_slot_events + bot_message_events

        # no more required slots to fill
        return [SlotSet(REQUESTED_SLOT, None)]
コード例 #28
0
async def test_required_slots(graph, age, authorization_req):
    """
        (start)
        |
        AGE --- fail age condition --- AUTHORIZATION --- fail authorization
        |                               |                       condition
        |                               pass authorization          |
        pass age condition              condition                   |
        |                               /                           |
        |                              /                           /
        COMMENTS ----------------------                           /
        |                                                        /
        (end) ---------------------------------------------------
    """

    spec = {"name": "default_form", "graph_elements": graph}

    form, tracker = new_form_and_tracker(spec, "age",
                                         ["authorization", "comments"])
    tracker.update(SlotSet("age", age))

    # first test with no authorization
    tracker.update(SlotSet("authorization", "false"))
    assert form.required_slots(tracker) == [
        "age",
        *(["authorization"] if authorization_req else []),
        # here comments is only required if authorization is not required
        *(["comments"] if not authorization_req else [])
    ]

    # then with authorization
    tracker.update(SlotSet("authorization", "true"))
    assert form.required_slots(tracker) == [
        "age",
        *(["authorization"] if authorization_req else []),
        # now comments is always required
        "comments"
    ]
コード例 #29
0
async def test_validate_slots(validate_return_events: List[Dict],
                              expected_events: List[Event]):
    form_name = "my form"
    slot_name = "num_people"
    slot_value = "hi"
    events = [
        ActiveLoop(form_name),
        SlotSet(REQUESTED_SLOT, slot_name),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered(slot_value,
                    entities=[{
                        "entity": "num_tables",
                        "value": 5
                    }]),
    ]
    tracker = DialogueStateTracker.from_events(sender_id="bla", evts=events)

    domain = f"""
    slots:
      {slot_name}:
        type: any
      num_tables:
        type: any
    forms:
      {form_name}:
        {slot_name}:
        - type: from_text
        num_tables:
        - type: from_entity
          entity: num_tables
    actions:
    - validate_{form_name}
    """
    domain = Domain.from_yaml(domain)
    action_server_url = "http:/my-action-server:5055/webhook"

    with aioresponses() as mocked:
        mocked.post(action_server_url,
                    payload={"events": validate_return_events})

        action_server = EndpointConfig(action_server_url)
        action = FormAction(form_name, action_server)

        events = await action.run(
            CollectingOutputChannel(),
            TemplatedNaturalLanguageGenerator(domain.templates),
            tracker,
            domain,
        )
        assert events == expected_events
コード例 #30
0
async def test_update_tracker_session_with_slots(
    default_channel: CollectingOutputChannel,
    default_processor: MessageProcessor,
    monkeypatch: MonkeyPatch,
):
    sender_id = uuid.uuid4().hex
    tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

    # apply a user uttered and five slots
    user_event = UserUttered("some utterance")
    tracker.update(user_event)

    slot_set_events = [
        SlotSet(f"slot key {i}", f"test value {i}") for i in range(5)
    ]

    for event in slot_set_events:
        tracker.update(event)

    # patch `_has_session_expired()` so the `_update_tracker_session()` call actually
    # does something
    monkeypatch.setattr(default_processor, "_has_session_expired",
                        lambda _: True)

    await default_processor._update_tracker_session(tracker, default_channel)

    # the save is not called in _update_tracker_session()
    default_processor._save_tracker(tracker)

    # inspect tracker and make sure all events are present
    tracker = default_processor.tracker_store.retrieve(sender_id)
    events = list(tracker.events)

    # the first three events should be up to the user utterance
    assert events[:2] == [ActionExecuted(ACTION_LISTEN_NAME), user_event]

    # next come the five slots
    assert events[2:7] == slot_set_events

    # the next two events are the session start sequence
    assert events[7:9] == [
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted()
    ]

    # the five slots should be reapplied
    assert events[9:14] == slot_set_events

    # finally an action listen, this should also be the last event
    assert events[14] == events[-1] == ActionExecuted(ACTION_LISTEN_NAME)