コード例 #1
0
def test_temporary_tracker():
    extra_slot = "some_slot"
    sender_id = "test"
    domain = Domain.from_yaml(
        f"""
        version: "2.0"
        slots:
          {extra_slot}:
            type: unfeaturized
        """
    )

    previous_events = [ActionExecuted(ACTION_LISTEN_NAME)]
    old_tracker = DialogueStateTracker.from_events(
        sender_id, previous_events, slots=domain.slots
    )
    new_events = [Restarted()]
    form_action = FormAction("some name", None)
    temp_tracker = form_action._temporary_tracker(old_tracker, new_events, domain)

    assert extra_slot in temp_tracker.slots.keys()
    assert list(temp_tracker.events) == [
        *previous_events,
        SlotSet(REQUESTED_SLOT),
        ActionExecuted(form_action.name()),
        *new_events,
    ]
コード例 #2
0
def test_extract_requested_slot_mapping_does_not_apply(slot_mapping: Dict):
    form_name = "some_form"
    entity_name = "some_slot"
    form = FormAction(form_name, None)

    domain = Domain.from_dict(
        {"forms": {
            form_name: {
                entity_name: [slot_mapping]
            }
        }})

    tracker = DialogueStateTracker.from_events(
        "default",
        [
            SlotSet(REQUESTED_SLOT, "some_slot"),
            UserUttered(
                "bla",
                intent={
                    "name": "greet",
                    "confidence": 1.0
                },
                entities=[{
                    "entity": entity_name,
                    "value": "some_value"
                }],
            ),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )

    slot_values = form.extract_requested_slot(tracker, domain)
    # check that the value was not extracted for incorrect intent
    assert slot_values == {}
コード例 #3
0
ファイル: action.py プロジェクト: ChenHuaYou/rasa
    def _fails_unique_entity_mapping_check(
        self,
        slot_name: Text,
        mapping: Dict[Text, Any],
        tracker: "DialogueStateTracker",
        domain: "Domain",
    ) -> bool:
        from rasa.core.actions.forms import FormAction

        if mapping[MAPPING_TYPE] != str(SlotMappingType.FROM_ENTITY):
            return False

        form_name = tracker.active_loop_name

        if not form_name:
            return False

        if tracker.get_slot(REQUESTED_SLOT) == slot_name:
            return False

        form = FormAction(form_name, self._action_endpoint)

        if slot_name not in form.required_slots(domain):
            return False

        if form.entity_mapping_is_unique(mapping, domain):
            return False

        return True
コード例 #4
0
def test_name_of_utterance():
    form_name = "another_form"
    slot_name = "num_people"
    full_utterance_name = f"utter_ask_{form_name}_{slot_name}"

    domain = f"""
    forms:
    - {form_name}:
        {slot_name}:
        - type: from_text
    responses:
        {full_utterance_name}:
        - text: "How many people?"
    """
    domain = Domain.from_yaml(domain)

    action_server_url = "http:/my-action-server:5055/webhook"

    with aioresponses():
        action_server = EndpointConfig(action_server_url)
        action = FormAction(form_name, action_server)

        assert action._name_of_utterance(domain,
                                         slot_name) == full_utterance_name
        assert (action._name_of_utterance(
            domain, "another_slot") == "utter_ask_another_slot")
コード例 #5
0
def test_extract_requested_slot_from_entity(
    mapping_not_intent: Optional[Text],
    mapping_intent: Optional[Text],
    mapping_role: Optional[Text],
    mapping_group: Optional[Text],
    entities: List[Dict[Text, Any]],
    intent: Text,
    expected_slot_values: Dict[Text, Text],
):
    """Test extraction of a slot value from entity with the different restrictions."""

    form_name = "some form"
    form = FormAction(form_name, None)

    mapping = form.from_entity(
        entity="some_entity",
        role=mapping_role,
        group=mapping_group,
        intent=mapping_intent,
        not_intent=mapping_not_intent,
    )
    domain = Domain.from_dict({"forms": {form_name: {"some_slot": [mapping]}}})

    tracker = DialogueStateTracker.from_events(
        "default",
        [
            SlotSet(REQUESTED_SLOT, "some_slot"),
            UserUttered(
                "bla", intent={"name": intent, "confidence": 1.0}, entities=entities
            ),
        ],
    )

    slot_values = form.extract_requested_slot(tracker, domain)
    assert slot_values == expected_slot_values
コード例 #6
0
def test_invalid_slot_mapping():
    form_name = "my_form"
    form = FormAction(form_name, None)
    slot_name = "test"
    tracker = DialogueStateTracker.from_events(
        "sender", [SlotSet(REQUESTED_SLOT, slot_name)]
    )

    domain = Domain.from_dict(
        {"forms": [{form_name: {slot_name: [{"type": "invalid"}]}}]}
    )

    with pytest.raises(ValueError):
        form.extract_requested_slot(tracker, domain)
コード例 #7
0
ファイル: action.py プロジェクト: pu55yf3r/rasa
def action_from_name(
    name: Text,
    action_endpoint: Optional[EndpointConfig],
    user_actions: List[Text],
    should_use_form_action: bool = False,
    retrieval_intents: Optional[List[Text]] = None,
) -> "Action":
    """Return an action instance for the name."""

    defaults = {a.name(): a for a in default_actions(action_endpoint)}

    if name in defaults and name not in user_actions:
        return defaults[name]
    elif name.startswith(UTTER_PREFIX) and is_retrieval_action(
        name, retrieval_intents or []
    ):
        return ActionRetrieveResponse(name)
    elif name.startswith(UTTER_PREFIX):
        return ActionUtterTemplate(name)
    elif should_use_form_action:
        from rasa.core.actions.forms import FormAction

        return FormAction(name, action_endpoint)
    else:
        return RemoteAction(name, action_endpoint)
コード例 #8
0
async def test_action_rejection():
    form_name = "my form"
    slot_to_fill = "some slot"
    tracker = DialogueStateTracker.from_events(
        sender_id="bla",
        evts=[
            ActiveLoop(form_name),
            SlotSet(REQUESTED_SLOT, slot_to_fill),
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered("haha", {"name": "greet"}),
        ],
    )
    form_name = "my form"
    action = FormAction(form_name, None)
    domain = f"""
    forms:
      {form_name}:
        {slot_to_fill}:
        - type: from_entity
          entity: some_entity
    slots:
      {slot_to_fill}:
        type: unfeaturized
    """
    domain = Domain.from_yaml(domain)

    with pytest.raises(ActionExecutionRejection):
        await action.run(
            CollectingOutputChannel(),
            TemplatedNaturalLanguageGenerator(domain.templates),
            tracker,
            domain,
        )
コード例 #9
0
def action_from_name(name: Text, domain: Domain,
                     action_endpoint: Optional[EndpointConfig]) -> "Action":
    """Retrieves an action by its name.

    Args:
        name: The name of the action.
        domain: The current model domain.
        action_endpoint: The endpoint to execute custom actions.

    Returns:
        The instantiated action.
    """
    defaults = {a.name(): a for a in default_actions(action_endpoint)}

    if name in defaults and name not in domain.user_actions_and_forms:
        return defaults[name]

    if name.startswith(UTTER_PREFIX) and is_retrieval_action(
            name, domain.retrieval_intents):
        return ActionRetrieveResponse(name)

    if name.startswith(UTTER_PREFIX):
        return ActionUtterTemplate(name)

    is_form = name in domain.form_names
    # Users can override the form by defining an action with the same name as the form
    user_overrode_form_action = is_form and name in domain.user_actions
    if is_form and not user_overrode_form_action:
        from rasa.core.actions.forms import FormAction

        return FormAction(name, action_endpoint)

    return RemoteAction(name, action_endpoint)
コード例 #10
0
async def test_set_slot_and_deactivate():
    form_name = "my form"
    slot_name = "num_people"
    slot_value = "dasdasdfasdf"
    events = [
        Form(form_name),
        SlotSet(REQUESTED_SLOT, slot_name),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered(slot_value),
    ]
    tracker = DialogueStateTracker.from_events(sender_id="bla", evts=events)

    domain = f"""
    forms:
    - {form_name}:
        {slot_name}:
        - type: from_text
    slots:
      {slot_name}:
        type: unfeaturized
    """
    domain = Domain.from_yaml(domain)

    action = FormAction(form_name, None)
    events = await action.run(
        CollectingOutputChannel(),
        TemplatedNaturalLanguageGenerator(domain.templates),
        tracker,
        domain,
    )
    assert events == [
        SlotSet(slot_name, slot_value),
        SlotSet(REQUESTED_SLOT, None),
        Form(None),
    ]
コード例 #11
0
async def test_activate_with_prefilled_slot():
    slot_name = "num_people"
    slot_value = 5

    tracker = DialogueStateTracker.from_events(
        sender_id="bla", evts=[SlotSet(slot_name, slot_value)])
    form_name = "my form"
    action = FormAction(form_name, None)

    next_slot_to_request = "next slot to request"
    domain = f"""
    forms:
      {form_name}:
        {slot_name}:
        - type: from_entity
          entity: {slot_name}
        {next_slot_to_request}:
        - type: from_text
    slots:
      {slot_name}:
        type: unfeaturized
    """
    domain = Domain.from_yaml(domain)
    events = await action.run(
        CollectingOutputChannel(),
        TemplatedNaturalLanguageGenerator(domain.templates),
        tracker,
        domain,
    )
    assert events == [
        ActiveLoop(form_name),
        SlotSet(slot_name, slot_value),
        SlotSet(REQUESTED_SLOT, next_slot_to_request),
    ]
コード例 #12
0
ファイル: test_forms.py プロジェクト: takjub/rasa
async def test_ask_for_slot(
    domain: Dict,
    expected_action: Text,
    monkeypatch: MonkeyPatch,
    default_nlg: TemplatedNaturalLanguageGenerator,
):
    slot_name = "sun"

    action_from_name = Mock(return_value=action.ActionListen())
    endpoint_config = Mock()
    monkeypatch.setattr(action, action.action_for_name_or_text.__name__,
                        action_from_name)

    form = FormAction("my_form", endpoint_config)
    domain = Domain.from_dict(domain)
    await form._ask_for_slot(
        domain,
        default_nlg,
        CollectingOutputChannel(),
        slot_name,
        DialogueStateTracker.from_events("dasd", []),
    )

    action_from_name.assert_called_once_with(expected_action, domain,
                                             endpoint_config)
コード例 #13
0
async def test_activate():
    tracker = DialogueStateTracker.from_events(sender_id="bla", evts=[])
    form_name = "my form"
    action = FormAction(form_name, None)
    slot_name = "num_people"
    domain = f"""
forms:
  {form_name}:
    {slot_name}:
    - type: from_entity
      entity: number
responses:
    utter_ask_num_people:
    - text: "How many people?"
"""
    domain = Domain.from_yaml(domain)

    events = await action.run(
        CollectingOutputChannel(),
        TemplatedNaturalLanguageGenerator(domain.templates),
        tracker,
        domain,
    )
    assert events[:-1] == [
        ActiveLoop(form_name),
        SlotSet(REQUESTED_SLOT, slot_name)
    ]
    assert isinstance(events[-1], BotUttered)
コード例 #14
0
async def test_validate_slots_on_activation_with_other_action_after_user_utterance(
):
    form_name = "my form"
    slot_name = "num_people"
    slot_value = "hi"
    events = [
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered(slot_value,
                    entities=[{
                        "entity": "num_tables",
                        "value": 5
                    }]),
        ActionExecuted("action_in_between"),
    ]
    tracker = DialogueStateTracker.from_events(sender_id="bla", evts=events)

    domain = f"""
    slots:
      {slot_name}:
        type: unfeaturized
    forms:
      {form_name}:
        {slot_name}:
        - type: from_text
    actions:
    - validate_{form_name}
    """
    domain = Domain.from_yaml(domain)
    action_server_url = "http:/my-action-server:5055/webhook"

    expected_slot_value = "✅"
    with aioresponses() as mocked:
        mocked.post(
            action_server_url,
            payload={
                "events": [{
                    "event": "slot",
                    "name": slot_name,
                    "value": expected_slot_value
                }]
            },
        )

        action_server = EndpointConfig(action_server_url)
        action = FormAction(form_name, action_server)

        events = await action.run(
            CollectingOutputChannel(),
            TemplatedNaturalLanguageGenerator(domain.templates),
            tracker,
            domain,
        )

    assert events == [
        ActiveLoop(form_name),
        SlotSet(slot_name, expected_slot_value),
        SlotSet(REQUESTED_SLOT, None),
        ActiveLoop(None),
    ]
コード例 #15
0
def test_extract_requested_slot_default():
    """Test default extraction of a slot value from entity with the same name."""
    form = FormAction("some form", None)

    tracker = DialogueStateTracker.from_events(
        "default",
        [
            SlotSet(REQUESTED_SLOT, "some_slot"),
            UserUttered(
                "bla", entities=[{"entity": "some_slot", "value": "some_value"}]
            ),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )

    slot_values = form.extract_requested_slot(tracker, Domain.empty())
    assert slot_values == {"some_slot": "some_value"}
コード例 #16
0
async def test_trigger_slot_mapping_applies(
    trigger_slot_mapping: Dict, expected_value: Text
):
    form_name = "some_form"
    entity_name = "some_slot"
    slot_filled_by_trigger_mapping = "other_slot"
    form = FormAction(form_name, None)

    domain = Domain.from_dict(
        {
            "forms": [
                {
                    form_name: {
                        entity_name: [
                            {
                                "type": "from_entity",
                                "entity": entity_name,
                                "intent": "some_intent",
                            }
                        ],
                        slot_filled_by_trigger_mapping: [trigger_slot_mapping],
                    }
                }
            ]
        }
    )

    tracker = DialogueStateTracker.from_events(
        "default",
        [
            SlotSet(REQUESTED_SLOT, "some_slot"),
            UserUttered(
                "bla",
                intent={"name": "greet", "confidence": 1.0},
                entities=[{"entity": entity_name, "value": "some_value"}],
            ),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )

    slot_values = form.extract_other_slots(tracker, domain)
    assert slot_values == {slot_filled_by_trigger_mapping: expected_value}
コード例 #17
0
def action_for_name_or_text(
        action_name_or_text: Text, domain: Domain,
        action_endpoint: Optional[EndpointConfig]) -> "Action":
    """Retrieves an action by its name or by its text in case it's an end-to-end action.

    Args:
        action_name_or_text: The name of the action.
        domain: The current model domain.
        action_endpoint: The endpoint to execute custom actions.

    Raises:
        ActionNotFoundException: If action not in current domain.

    Returns:
        The instantiated action.
    """
    if action_name_or_text not in domain.action_names_or_texts:
        domain.raise_action_not_found_exception(action_name_or_text)

    defaults = {a.name(): a for a in default_actions(action_endpoint)}

    if (action_name_or_text in defaults
            and action_name_or_text not in domain.user_actions_and_forms):
        return defaults[action_name_or_text]

    if action_name_or_text.startswith(UTTER_PREFIX) and is_retrieval_action(
            action_name_or_text, domain.retrieval_intents):
        return ActionRetrieveResponse(action_name_or_text)

    if action_name_or_text in domain.action_texts:
        return ActionEndToEndResponse(action_name_or_text)

    if action_name_or_text.startswith(UTTER_PREFIX):
        return ActionUtterTemplate(action_name_or_text)

    # bf >
    elif domain.forms.get(action_name_or_text,
                          {}).get("graph_elements") is not None:
        return generate_bf_form_action(action_name_or_text)
    elif action_name_or_text in actions_bf:
        return actions_bf[action_name_or_text]
    # </ bf

    is_form = action_name_or_text in domain.form_names
    # Users can override the form by defining an action with the same name as the form
    user_overrode_form_action = is_form and action_name_or_text in domain.user_actions
    if is_form and not user_overrode_form_action:
        from rasa.core.actions.forms import FormAction

        return FormAction(action_name_or_text, action_endpoint)

    return RemoteAction(action_name_or_text, action_endpoint)
コード例 #18
0
def test_extract_other_slots_with_entity(
    some_other_slot_mapping: List[Dict[Text, Any]],
    some_slot_mapping: List[Dict[Text, Any]],
    entities: List[Dict[Text, Any]],
    intent: Text,
    expected_slot_values: Dict[Text, Text],
):
    """Test extraction of other not requested slots values from entities."""

    form_name = "some_form"
    form = FormAction(form_name, None)

    domain = Domain.from_dict(
        {
            "forms": [
                {
                    form_name: {
                        "some_other_slot": some_other_slot_mapping,
                        "some_slot": some_slot_mapping,
                    }
                }
            ]
        }
    )

    tracker = DialogueStateTracker.from_events(
        "default",
        [
            SlotSet(REQUESTED_SLOT, "some_slot"),
            UserUttered(
                "bla", intent={"name": intent, "confidence": 1.0}, entities=entities
            ),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )

    slot_values = form.extract_other_slots(tracker, domain)
    # check that the value was extracted for non requested slot
    assert slot_values == expected_slot_values
コード例 #19
0
async def test_validate_slots(validate_return_events: List[Dict],
                              expected_events: List[Event]):
    form_name = "my form"
    slot_name = "num_people"
    slot_value = "hi"
    events = [
        ActiveLoop(form_name),
        SlotSet(REQUESTED_SLOT, slot_name),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered(slot_value,
                    entities=[{
                        "entity": "num_tables",
                        "value": 5
                    }]),
    ]
    tracker = DialogueStateTracker.from_events(sender_id="bla", evts=events)

    domain = f"""
    slots:
      {slot_name}:
        type: any
      num_tables:
        type: any
    forms:
      {form_name}:
        {slot_name}:
        - type: from_text
        num_tables:
        - type: from_entity
          entity: num_tables
    actions:
    - validate_{form_name}
    """
    domain = Domain.from_yaml(domain)
    action_server_url = "http:/my-action-server:5055/webhook"

    with aioresponses() as mocked:
        mocked.post(action_server_url,
                    payload={"events": validate_return_events})

        action_server = EndpointConfig(action_server_url)
        action = FormAction(form_name, action_server)

        events = await action.run(
            CollectingOutputChannel(),
            TemplatedNaturalLanguageGenerator(domain.templates),
            tracker,
            domain,
        )
        assert events == expected_events
コード例 #20
0
async def test_ask_for_slot(
    domain: Dict, expected_action: Text, monkeypatch: MonkeyPatch
):
    slot_name = "sun"

    action_from_name = Mock(return_value=action.ActionListen())
    endpoint_config = Mock()
    monkeypatch.setattr(action, action.action_from_name.__name__, action_from_name)

    form = FormAction("my_form", endpoint_config)
    domain = Domain.from_dict(domain)
    await form._ask_for_slot(
        domain, None, None, slot_name, DialogueStateTracker.from_events("dasd", [])
    )

    action_from_name.assert_called_once_with(expected_action, domain, endpoint_config)
コード例 #21
0
ファイル: test_forms.py プロジェクト: takjub/rasa
def test_name_of_utterance(utterance_name: Text):
    form_name = "my_form"
    slot_name = "num_people"

    domain = f"""
    forms:
      {form_name}:
        {slot_name}:
        - type: from_text
    responses:
        {utterance_name}:
        - text: "How many people?"
    """
    domain = Domain.from_yaml(domain)

    action = FormAction(form_name, None)

    assert action._name_of_utterance(domain, slot_name) == utterance_name
コード例 #22
0
ファイル: test_forms.py プロジェクト: takjub/rasa
async def test_ask_for_slot_if_not_utter_ask(
        monkeypatch: MonkeyPatch,
        default_nlg: TemplatedNaturalLanguageGenerator):
    action_from_name = Mock(return_value=action.ActionListen())
    endpoint_config = Mock()
    monkeypatch.setattr(action, action.action_for_name_or_text.__name__,
                        action_from_name)

    form = FormAction("my_form", endpoint_config)
    events = await form._ask_for_slot(
        Domain.empty(),
        default_nlg,
        CollectingOutputChannel(),
        "some slot",
        DialogueStateTracker.from_events("dasd", []),
    )

    assert not events
    action_from_name.assert_not_called()
コード例 #23
0
async def test_activate_and_immediate_deactivate():
    slot_name = "num_people"
    slot_value = 5

    tracker = DialogueStateTracker.from_events(
        sender_id="bla",
        evts=[
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered(
                "haha",
                {"name": "greet"},
                entities=[{
                    "entity": slot_name,
                    "value": slot_value
                }],
            ),
        ],
    )
    form_name = "my form"
    action = FormAction(form_name, None)
    domain = f"""
    forms:
      {form_name}:
        {slot_name}:
        - type: from_entity
          entity: {slot_name}
    slots:
      {slot_name}:
        type: unfeaturized
    """
    domain = Domain.from_yaml(domain)
    events = await action.run(
        CollectingOutputChannel(),
        TemplatedNaturalLanguageGenerator(domain.templates),
        tracker,
        domain,
    )
    assert events == [
        ActiveLoop(form_name),
        SlotSet(slot_name, slot_value),
        SlotSet(REQUESTED_SLOT, None),
        ActiveLoop(None),
    ]
コード例 #24
0
ファイル: test_forms.py プロジェクト: takjub/rasa
async def test_no_slots_extracted_with_custom_slot_mappings(
        custom_events: List[Event]):
    form_name = "my form"
    events = [
        ActiveLoop(form_name),
        SlotSet(REQUESTED_SLOT, "num_tables"),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("off topic"),
    ]
    tracker = DialogueStateTracker.from_events(sender_id="bla", evts=events)

    domain = f"""
    slots:
      num_tables:
        type: any
    forms:
      {form_name}:
        num_tables:
        - type: from_entity
          entity: num_tables
    actions:
    - validate_{form_name}
    """
    domain = Domain.from_yaml(domain)
    action_server_url = "http:/my-action-server:5055/webhook"

    with aioresponses() as mocked:
        mocked.post(action_server_url, payload={"events": custom_events})

        action_server = EndpointConfig(action_server_url)
        action = FormAction(form_name, action_server)

        with pytest.raises(ActionExecutionRejection):
            await action.run(
                CollectingOutputChannel(),
                TemplatedNaturalLanguageGenerator(domain.templates),
                tracker,
                domain,
            )
コード例 #25
0
def action_from_name(
    name: Text,
    action_endpoint: Optional[EndpointConfig],
    user_actions: List[Text],
    should_use_form_action: bool = False,
    retrieval_intents: Optional[List[Text]] = None,
    domain: Optional[Domain] = None,
) -> "Action":
    """Return an action instance for the name."""

    # bf
    bf_forms = []
    if domain:
        for slot in domain.slots or []:
            if slot.name == "bf_forms": bf_forms = slot.initial_value
        bf_forms = [f.get("name") for f in bf_forms]
    # /bf

    defaults = {a.name(): a for a in default_actions(action_endpoint)}

    if name in defaults and name not in user_actions:
        return defaults[name]
    elif name.startswith(UTTER_PREFIX) and is_retrieval_action(
            name, retrieval_intents or []):
        return ActionRetrieveResponse(name)
    elif name.startswith(UTTER_PREFIX):
        return ActionUtterTemplate(name)
    # bf >
    elif name.endswith("_form") and any(form == name for form in bf_forms):
        return generate_bf_form_action(name)
    elif name in actions_bf:
        return actions_bf[name]
    # </ bf
    elif should_use_form_action:
        from rasa.core.actions.forms import FormAction

        return FormAction(name, action_endpoint)
    else:
        return RemoteAction(name, action_endpoint)
コード例 #26
0
ファイル: test_forms.py プロジェクト: takjub/rasa
async def test_request_correct_slots_after_unhappy_path_with_custom_required_slots(
):
    form_name = "some_form"
    slot_name_1 = "slot_1"
    slot_name_2 = "slot_2"

    domain = f"""
        slots:
          {slot_name_1}:
            type: any
          {slot_name_2}:
            type: any
        forms:
          {form_name}:
            {slot_name_1}:
            - type: from_intent
              intent: some_intent
              value: some_value
            {slot_name_2}:
            - type: from_intent
              intent: some_intent
              value: some_value
        actions:
        - validate_{form_name}
        """
    domain = Domain.from_yaml(domain)

    tracker = DialogueStateTracker.from_events(
        "default",
        [
            ActiveLoop(form_name),
            SlotSet(REQUESTED_SLOT, "slot_2"),
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered(
                "hello",
                intent={
                    "name": "greet",
                    "confidence": 1.0
                },
            ),
            ActionExecutionRejected(form_name),
            ActionExecuted("utter_greet"),
        ],
    )

    action_server_url = "http://my-action-server:5055/webhook"

    # Custom form validation action changes the order of the requested slots
    validate_return_events = [
        {
            "event": "slot",
            "name": REQUESTED_SLOT,
            "value": slot_name_2
        },
    ]

    # The form should ask the same slot again when coming back after unhappy path
    expected_events = [SlotSet(REQUESTED_SLOT, slot_name_2)]

    with aioresponses() as mocked:
        mocked.post(action_server_url,
                    payload={"events": validate_return_events})

        action_server = EndpointConfig(action_server_url)
        action = FormAction(form_name, action_server)

        events = await action.run(
            CollectingOutputChannel(),
            TemplatedNaturalLanguageGenerator(domain.templates),
            tracker,
            domain,
        )
        assert events == expected_events
コード例 #27
0
ファイル: test_forms.py プロジェクト: takjub/rasa
async def test_switch_forms_with_same_slot(default_agent: Agent):
    """Tests switching of forms, where the first slot is the same in both forms.

    Tests the fix for issue 7710"""

    # Define two forms in the domain, with same first slot
    slot_a = "my_slot_a"

    form_1 = "my_form_1"
    utter_ask_form_1 = f"Please provide the value for {slot_a} of form 1"

    form_2 = "my_form_2"
    utter_ask_form_2 = f"Please provide the value for {slot_a} of form 2"

    domain = f"""
version: "2.0"
nlu:
- intent: order_status
  examples: |
    - check status of my order
    - when are my shoes coming in
- intent: return
  examples: |
    - start a return
    - I don't want my shoes anymore
forms:
  {form_1}:
    {slot_a}:
    - type: from_entity
      entity: number
  {form_2}:
    {slot_a}:
    - type: from_entity
      entity: number
responses:
    utter_ask_{form_1}_{slot_a}:
    - text: {utter_ask_form_1}
    utter_ask_{form_2}_{slot_a}:
    - text: {utter_ask_form_2}
"""

    domain = Domain.from_yaml(domain)

    # Driving it like rasa/core/processor
    processor = MessageProcessor(
        default_agent.interpreter,
        default_agent.policy_ensemble,
        domain,
        InMemoryTrackerStore(domain),
        TemplatedNaturalLanguageGenerator(domain.templates),
    )

    # activate the first form
    tracker = DialogueStateTracker.from_events(
        "some-sender",
        evts=[
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered("order status", {
                "name": "form_1",
                "confidence": 1.0
            }),
            DefinePrevUserUtteredFeaturization(False),
        ],
    )
    # rasa/core/processor.predict_next_action
    prediction = PolicyPrediction([], "some_policy")
    action_1 = FormAction(form_1, None)

    await processor._run_action(
        action_1,
        tracker,
        CollectingOutputChannel(),
        TemplatedNaturalLanguageGenerator(domain.templates),
        prediction,
    )

    events_expected = [
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("order status", {
            "name": "form_1",
            "confidence": 1.0
        }),
        DefinePrevUserUtteredFeaturization(False),
        ActionExecuted(form_1),
        ActiveLoop(form_1),
        SlotSet(REQUESTED_SLOT, slot_a),
        BotUttered(
            text=utter_ask_form_1,
            metadata={"template_name": f"utter_ask_{form_1}_{slot_a}"},
        ),
    ]
    assert tracker.applied_events() == events_expected

    next_events = [
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("return my shoes", {
            "name": "form_2",
            "confidence": 1.0
        }),
        DefinePrevUserUtteredFeaturization(False),
    ]
    tracker.update_with_events(
        next_events,
        domain,
    )
    events_expected.extend(next_events)

    # form_1 is still active, and bot will first validate if the user utterance
    #  provides valid data for the requested slot, which is rejected
    await processor._run_action(
        action_1,
        tracker,
        CollectingOutputChannel(),
        TemplatedNaturalLanguageGenerator(domain.templates),
        prediction,
    )
    events_expected.extend([ActionExecutionRejected(action_name=form_1)])
    assert tracker.applied_events() == events_expected

    # Next, bot predicts form_2
    action_2 = FormAction(form_2, None)
    await processor._run_action(
        action_2,
        tracker,
        CollectingOutputChannel(),
        TemplatedNaturalLanguageGenerator(domain.templates),
        prediction,
    )
    events_expected.extend([
        ActionExecuted(form_2),
        ActiveLoop(form_2),
        SlotSet(REQUESTED_SLOT, slot_a),
        BotUttered(
            text=utter_ask_form_2,
            metadata={"template_name": f"utter_ask_{form_2}_{slot_a}"},
        ),
    ])
    assert tracker.applied_events() == events_expected