Beispiel #1
0
def test_extract_requested_slot_mapping_does_not_apply(slot_mapping: Dict):
    form_name = "some_form"
    entity_name = "some_slot"
    form = FormAction(form_name, None)

    domain = Domain.from_dict({"forms": {form_name: {entity_name: [slot_mapping]}}})

    tracker = DialogueStateTracker.from_events(
        "default",
        [
            SlotSet(REQUESTED_SLOT, "some_slot"),
            UserUttered(
                "bla",
                intent={"name": "greet", "confidence": 1.0},
                entities=[{"entity": entity_name, "value": "some_value"}],
            ),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )

    slot_values = form.extract_requested_slot(tracker, domain, "some_slot")
    # check that the value was not extracted for incorrect intent
    assert slot_values == {}
Beispiel #2
0
def test_tracker_update_slots_with_entity(domain: Domain):
    tracker = DialogueStateTracker("default", domain.slots)

    test_entity = domain.entities[0]
    expected_slot_value = "test user"

    intent = {"name": "greet", PREDICTED_CONFIDENCE_KEY: 1.0}
    tracker.update(
        UserUttered(
            "/greet",
            intent,
            [{
                "start": 1,
                "end": 5,
                "value": expected_slot_value,
                "entity": test_entity,
                "extractor": "manual",
            }],
        ),
        domain,
    )

    assert tracker.get_slot(test_entity) == expected_slot_value
Beispiel #3
0
def test_tracker_store_deprecated_session_retrieval_kwarg():
    tracker_store = SQLTrackerStore(
        Domain.empty(), retrieve_events_from_previous_conversation_sessions=True
    )

    conversation_id = uuid.uuid4().hex
    tracker = DialogueStateTracker.from_events(
        conversation_id,
        [
            ActionExecuted(ACTION_SESSION_START_NAME),
            SessionStarted(),
            UserUttered("hi"),
        ],
    )

    mocked_retrieve_full_tracker = Mock()
    tracker_store.retrieve_full_tracker = mocked_retrieve_full_tracker

    tracker_store.save(tracker)

    _ = tracker_store.retrieve(conversation_id)

    mocked_retrieve_full_tracker.assert_called_once()
    def _parse_raw_user_utterance(
            self, step: Dict[Text, Any]) -> Optional[UserUttered]:
        from rasa.shared.nlu.interpreter import RegexInterpreter

        intent_name = self._user_intent_from_step(step)
        intent = {"name": intent_name, "confidence": 1.0}

        if KEY_USER_MESSAGE in step:
            user_message = step[KEY_USER_MESSAGE].strip()
            entities = entities_parser.find_entities_in_training_example(
                user_message)
            plain_text = entities_parser.replace_entities(user_message)

            if plain_text.startswith(INTENT_MESSAGE_PREFIX):
                entities = (
                    RegexInterpreter().synchronous_parse(plain_text).get(
                        ENTITIES, []))
        else:
            raw_entities = step.get(KEY_ENTITIES, [])
            entities = self._parse_raw_entities(raw_entities)
            # set plain_text to None because only intent was provided in the stories
            plain_text = None
        return UserUttered(plain_text, intent, entities)
Beispiel #5
0
async def test_reminder_lock(
    default_channel: CollectingOutputChannel,
    default_processor: MessageProcessor,
    caplog: LogCaptureFixture,
):
    caplog.clear()
    with caplog.at_level(logging.DEBUG):
        sender_id = uuid.uuid4().hex

        reminder = ReminderScheduled("remind", datetime.datetime.now())
        tracker = default_processor.tracker_store.get_or_create_tracker(
            sender_id)

        tracker.update(UserUttered("test"))
        tracker.update(ActionExecuted("action_schedule_reminder"))
        tracker.update(reminder)

        default_processor.tracker_store.save(tracker)

        await default_processor.handle_reminder(reminder, sender_id,
                                                default_channel)

        assert f"Deleted lock for conversation '{sender_id}'." in caplog.text
Beispiel #6
0
async def test_reminder_restart(default_channel: CollectingOutputChannel,
                                default_processor: MessageProcessor):
    sender_id = uuid.uuid4().hex

    reminder = ReminderScheduled("utter_greet",
                                 datetime.datetime.now(),
                                 kill_on_user_message=False)
    tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

    tracker.update(reminder)
    tracker.update(Restarted())  # cancels the reminder
    tracker.update(UserUttered("test"))

    default_processor.tracker_store.save(tracker)
    await default_processor.handle_reminder(
        reminder,
        sender_id,
        default_channel,
    )

    # retrieve the updated tracker
    t = default_processor.tracker_store.retrieve(sender_id)
    assert len(t.events) == 4  # nothing should have been executed
Beispiel #7
0
    def _parse_raw_user_utterance(self, step: Dict[Text, Any]) -> Optional[UserUttered]:
        intent_name, full_retrieval_intent = self._user_intent_from_step(step)
        intent = {
            INTENT_NAME_KEY: intent_name,
            FULL_RETRIEVAL_INTENT_NAME_KEY: full_retrieval_intent,
            PREDICTED_CONFIDENCE_KEY: 1.0,
        }

        if KEY_USER_MESSAGE in step:
            user_message = step[KEY_USER_MESSAGE].strip()
            entities = entities_parser.find_entities_in_training_example(user_message)
            plain_text = entities_parser.replace_entities(user_message)

            if plain_text.startswith(INTENT_MESSAGE_PREFIX):
                entities = self.unpack_regex_message(Message({TEXT: plain_text})).get(
                    ENTITIES, []
                )
        else:
            raw_entities = step.get(KEY_ENTITIES, [])
            entities = self._parse_raw_entities(raw_entities)
            # set plain_text to None because only intent was provided in the stories
            plain_text = None
        return UserUttered(plain_text, intent, entities)
Beispiel #8
0
async def test_predict_form_action_if_in_form():
    form_name = "some_form"

    domain = Domain.from_yaml(f"""
    intents:
    - {GREET_INTENT_NAME}
    actions:
    - {UTTER_GREET_ACTION}
    - some-action
    slots:
      {REQUESTED_SLOT}:
        type: unfeaturized
    forms:
    - {form_name}
""")

    policy = RulePolicy()
    policy.train([GREET_RULE], domain, RegexInterpreter())

    form_conversation = DialogueStateTracker.from_events(
        "in a form",
        evts=[
            # We are in an activate form
            ActionExecuted(form_name),
            ActiveLoop(form_name),
            SlotSet(REQUESTED_SLOT, "some value"),
            ActionExecuted(ACTION_LISTEN_NAME),
            # User sends message as response to a requested slot
            UserUttered("haha", {"name": GREET_INTENT_NAME}),
        ],
        slots=domain.slots,
    )

    # RulePolicy triggers form again
    action_probabilities = policy.predict_action_probabilities(
        form_conversation, domain, RegexInterpreter())
    assert_predicted_action(action_probabilities, domain, form_name)
async def test_can_read_test_story_with_entities_without_value(domain: Domain):
    trackers = await training.load_data(
        "data/test_yaml_stories/story_with_or_and_entities_with_no_value.yml",
        domain,
        use_story_concatenation=False,
        tracker_limit=1000,
        remove_duplicates=False,
    )
    assert len(trackers) == 1

    assert trackers[0].events[-4] == UserUttered(
        intent={
            "name": "greet",
            "confidence": 1.0
        },
        entities=[{
            "entity": "name",
            "value": ""
        }],
        parse_data={
            "text": "/greet",
            "intent_ranking": [{
                "confidence": 1.0,
                "name": "greet"
            }],
            "intent": {
                "confidence": 1.0,
                "name": "greet"
            },
            "entities": [{
                "entity": "name",
                "value": ""
            }],
        },
    )
    assert trackers[0].events[-2] == ActionExecuted("utter_greet")
    assert trackers[0].events[-1] == ActionExecuted("action_listen")
Beispiel #10
0
def test_extract_other_slots_with_entity(
    some_other_slot_mapping: List[Dict[Text, Any]],
    some_slot_mapping: List[Dict[Text, Any]],
    entities: List[Dict[Text, Any]],
    intent: Text,
    expected_slot_values: Dict[Text, Text],
):
    """Test extraction of other not requested slots values from entities."""

    form_name = "some_form"
    form = FormAction(form_name, None)

    domain = Domain.from_dict(
        {
            "forms": {
                form_name: {
                    "some_other_slot": some_other_slot_mapping,
                    "some_slot": some_slot_mapping,
                }
            }
        }
    )

    tracker = DialogueStateTracker.from_events(
        "default",
        [
            SlotSet(REQUESTED_SLOT, "some_slot"),
            UserUttered(
                "bla", intent={"name": intent, "confidence": 1.0}, entities=entities
            ),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )

    slot_values = form.extract_other_slots(tracker, domain)
    # check that the value was extracted for non requested slot
    assert slot_values == expected_slot_values
Beispiel #11
0
async def test_ask_rephrase_after_failed_affirmation():
    rephrase_text = "please rephrase"
    tracker = DialogueStateTracker.from_events(
        "some-sender",
        evts=[
            # User sends message with low NLU confidence
            *_message_requiring_fallback(),
            ActiveLoop(ACTION_TWO_STAGE_FALLBACK_NAME),
            # Action asks user to affirm
            *_two_stage_clarification_request(),
            ActionExecuted(ACTION_LISTEN_NAME),
            # User denies suggested intents
            UserUttered("hi", {"name": USER_INTENT_OUT_OF_SCOPE}),
        ],
    )

    domain = Domain.from_yaml(f"""
        version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
        responses:
            utter_ask_rephrase:
            - text: {rephrase_text}
        """)
    action = TwoStageFallbackAction()

    events = await action.run(
        CollectingOutputChannel(),
        TemplatedNaturalLanguageGenerator(domain.responses),
        tracker,
        domain,
    )

    assert len(events) == 1
    assert isinstance(events[0], BotUttered)

    bot_utterance = events[0]
    assert isinstance(bot_utterance, BotUttered)
    assert bot_utterance.text == rephrase_text
Beispiel #12
0
    async def test_finetune_after_load(
        self, trained_policy: MemoizationPolicy, default_domain: Domain, tmp_path: Path
    ):

        trained_policy.persist(tmp_path)

        loaded_policy = MemoizationPolicy.load(tmp_path, should_finetune=True)

        assert loaded_policy.finetune_mode

        new_story = TrackerWithCachedStates.from_events(
            "channel",
            domain=default_domain,
            slots=default_domain.slots,
            evts=[
                ActionExecuted(ACTION_LISTEN_NAME),
                UserUttered(intent={"name": "why"}),
                ActionExecuted("utter_channel"),
                ActionExecuted(ACTION_LISTEN_NAME),
            ],
        )
        original_train_data = await train_trackers(
            default_domain, augmentation_factor=20
        )
        loaded_policy.train(
            original_train_data + [new_story], default_domain, RegexInterpreter()
        )

        # Get the hash of the tracker state of new story
        new_story_states, _ = loaded_policy.featurizer.training_states_and_actions(
            [new_story], default_domain
        )

        # Feature keys for each new state should be present in the lookup
        for states in new_story_states:
            state_key = loaded_policy._create_feature_key(states)
            assert state_key in loaded_policy.lookup
Beispiel #13
0
def test_contradicting_rules():
    utter_anti_greet_action = "utter_anti_greet"
    domain = Domain.from_yaml(
        f"""
intents:
- {GREET_INTENT_NAME}
actions:
- {UTTER_GREET_ACTION}
- {utter_anti_greet_action}
    """
    )
    policy = RulePolicy()
    anti_greet_rule = TrackerWithCachedStates.from_events(
        "anti greet rule",
        domain=domain,
        slots=domain.slots,
        evts=[
            ActionExecuted(RULE_SNIPPET_ACTION_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(utter_anti_greet_action),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )
    anti_greet_rule.is_rule_tracker = True

    with pytest.raises(InvalidRule) as execinfo:
        policy.train([GREET_RULE, anti_greet_rule], domain, RegexInterpreter())
    assert all(
        name in execinfo.value.message
        for name in {
            UTTER_GREET_ACTION,
            GREET_RULE.sender_id,
            utter_anti_greet_action,
            anti_greet_rule.sender_id,
        }
    )
Beispiel #14
0
async def test_failing_form_activation_due_to_no_rule():
    form_name = "some_form"
    other_intent = "bye"
    domain = Domain.from_yaml(
        f"""
        intents:
        - {GREET_INTENT_NAME}
        - {other_intent}
        actions:
        - {UTTER_GREET_ACTION}
        - some-action
        slots:
          {REQUESTED_SLOT}:
            type: unfeaturized
        forms:
        - {form_name}
    """
    )

    policy = RulePolicy()
    policy.train([GREET_RULE], domain, RegexInterpreter())

    conversation_events = [
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("haha", {"name": other_intent}),
    ]

    # RulePolicy has no matching rule since no rule for form activation is given
    prediction = policy.predict_action_probabilities(
        DialogueStateTracker.from_events(
            "casd", evts=conversation_events, slots=domain.slots
        ),
        domain,
        RegexInterpreter(),
    )

    assert prediction.max_confidence == policy._core_fallback_threshold
Beispiel #15
0
def test_restrict_multiple_user_inputs_in_rules():
    domain = Domain.from_yaml(f"""
intents:
- {GREET_INTENT_NAME}
actions:
- {UTTER_GREET_ACTION}
    """)
    policy = RulePolicy()
    greet_events = [
        UserUttered(intent={"name": GREET_INTENT_NAME}),
        ActionExecuted(UTTER_GREET_ACTION),
        ActionExecuted(ACTION_LISTEN_NAME),
    ]

    forbidden_rule = DialogueStateTracker.from_events(
        "bla",
        evts=[
            ActionExecuted(RULE_SNIPPET_ACTION_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
        ] + greet_events * (policy.ALLOWED_NUMBER_OF_USER_INPUTS + 1),
    )
    forbidden_rule.is_rule_tracker = True
    with pytest.raises(InvalidRule):
        policy.train([forbidden_rule], domain, RegexInterpreter())
def test_extract_requested_slot_from_entity_no_intent():
    """Test extraction of a slot value from entity with the different name
        and any intent
    """

    spec = {
        "name": "default_form",
        "slots": [
            {
                "name": "some_slot",
                "filling": [{"type": "from_entity", "entity": ["some_entity"]}],
            }
        ],
    }

    form, tracker = new_form_and_tracker(spec, "some_slot")
    tracker.update(
        UserUttered(entities=[{"entity": "some_entity", "value": "some_value"}])
    )

    slot_values = form.extract_requested_slot(
        OutputChannel(), nlg, tracker, Domain.empty()
    )
    assert slot_values == {"some_slot": "some_value"}
Beispiel #17
0
def test_get_latest_entity_values(
    entities: List[Dict[Text, Any]], expected_values: List[Text], default_domain: Domain
):
    entity_type = entities[0].get("entity")
    entity_role = entities[0].get("role")
    entity_group = entities[0].get("group")

    tracker = DialogueStateTracker("default", default_domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0
    assert list(tracker.get_latest_entity_values(entity_type)) == []

    intent = {"name": "greet", PREDICTED_CONFIDENCE_KEY: 1.0}
    tracker.update(UserUttered("/greet", intent, entities))

    assert (
        list(
            tracker.get_latest_entity_values(
                entity_type, entity_role=entity_role, entity_group=entity_group
            )
        )
        == expected_values
    )
    assert list(tracker.get_latest_entity_values("unknown")) == []
Beispiel #18
0
def test_extract_requested_slot_from_entity(
    mapping_not_intent: Optional[Text],
    mapping_intent: Optional[Text],
    mapping_role: Optional[Text],
    mapping_group: Optional[Text],
    entities: List[Dict[Text, Any]],
    intent: Text,
    expected_slot_values: Dict[Text, Text],
):
    """Test extraction of a slot value from entity with the different restrictions."""

    form_name = "some form"
    form = FormAction(form_name, None)

    mapping = form.from_entity(
        entity="some_entity",
        role=mapping_role,
        group=mapping_group,
        intent=mapping_intent,
        not_intent=mapping_not_intent,
    )
    domain = Domain.from_dict({"forms": {form_name: {"some_slot": [mapping]}}})

    tracker = DialogueStateTracker.from_events(
        "default",
        [
            ActiveLoop(form_name),
            SlotSet(REQUESTED_SLOT, "some_slot"),
            UserUttered(
                "bla", intent={"name": intent, "confidence": 1.0}, entities=entities
            ),
        ],
    )

    slot_values = form.extract_requested_slot(tracker, domain, "some_slot")
    assert slot_values == expected_slot_values
Beispiel #19
0
    def _parse_message(self, message: Text, line_num: int) -> UserUttered:

        if self.use_e2e:
            parsed = self.parse_e2e_message(message,
                                            self._is_used_for_training)
            text = parsed.get("text")
            intent = {
                INTENT_NAME_KEY:
                parsed.get("intent_response_key", default=parsed.get("intent"))
            }
            entities = parsed.get("entities")
            parse_data = {
                "text": text,
                "intent": intent,
                "intent_ranking": [intent],
                "entities": entities,
            }
        else:
            parse_data = RegexInterpreter().synchronous_parse(message)
            text = None
            intent = parse_data.get("intent")

        utterance = UserUttered(text, intent, parse_data.get("entities"),
                                parse_data)

        intent_name = utterance.intent.get(INTENT_NAME_KEY)

        if self.domain and intent_name not in self.domain.intents:
            rasa.shared.utils.io.raise_warning(
                f"Found unknown intent '{intent_name}' on line {line_num}. "
                "Please, make sure that all intents are "
                "listed in your domain yaml.",
                UserWarning,
                docs=DOCS_URL_DOMAINS,
            )
        return utterance
Beispiel #20
0
def test_policy_priority():
    domain = Domain.load("data/test_domains/default.yml")
    tracker = DialogueStateTracker.from_events("test", [UserUttered("hi")], [])

    priority_1 = ConstantPolicy(priority=1, predict_index=0)
    priority_2 = ConstantPolicy(priority=2, predict_index=1)

    policy_ensemble_0 = SimplePolicyEnsemble([priority_1, priority_2])
    policy_ensemble_1 = SimplePolicyEnsemble([priority_2, priority_1])

    priority_2_result = priority_2.predict_action_probabilities(
        tracker, domain, RegexInterpreter())

    i = 1  # index of priority_2 in ensemble_0
    result, best_policy = policy_ensemble_0.probabilities_using_best_policy(
        tracker, domain, RegexInterpreter())
    assert best_policy == "policy_{}_{}".format(i, type(priority_2).__name__)
    assert result == priority_2_result

    i = 0  # index of priority_2 in ensemble_1
    result, best_policy = policy_ensemble_1.probabilities_using_best_policy(
        tracker, domain, RegexInterpreter())
    assert best_policy == "policy_{}_{}".format(i, type(priority_2).__name__)
    assert result == priority_2_result
Beispiel #21
0
def test_tracker_store_storage_and_retrieval(store):
    tracker = store.get_or_create_tracker("some-id")
    # the retrieved tracker should be empty
    assert tracker.sender_id == "some-id"

    # Action listen should be in there
    assert list(tracker.events) == [ActionExecuted(ACTION_LISTEN_NAME)]

    # lets log a test message
    intent = {"name": "greet", "confidence": 1.0}
    tracker.update(UserUttered("/greet", intent, []))
    assert tracker.latest_message.intent.get("name") == "greet"
    store.save(tracker)

    # retrieving the same tracker should result in the same tracker
    retrieved_tracker = store.get_or_create_tracker("some-id")
    assert retrieved_tracker.sender_id == "some-id"
    assert len(retrieved_tracker.events) == 2
    assert retrieved_tracker.latest_message.intent.get("name") == "greet"

    # getting another tracker should result in an empty tracker again
    other_tracker = store.get_or_create_tracker("some-other-id")
    assert other_tracker.sender_id == "some-other-id"
    assert len(other_tracker.events) == 1
Beispiel #22
0
async def test_set_slot_and_deactivate():
    form_name = "my form"
    slot_name = "num_people"
    slot_value = "dasdasdfasdf"
    events = [
        ActiveLoop(form_name),
        SlotSet(REQUESTED_SLOT, slot_name),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered(slot_value),
    ]
    tracker = DialogueStateTracker.from_events(sender_id="bla", evts=events)

    domain = f"""
    forms:
      {form_name}:
        {slot_name}:
        - type: from_text
    slots:
      {slot_name}:
        type: text
        influence_conversation: false
    """
    domain = Domain.from_yaml(domain)

    action = FormAction(form_name, None)
    events = await action.run(
        CollectingOutputChannel(),
        TemplatedNaturalLanguageGenerator(domain.templates),
        tracker,
        domain,
    )
    assert events == [
        SlotSet(slot_name, slot_value),
        SlotSet(REQUESTED_SLOT, None),
        ActiveLoop(None),
    ]
Beispiel #23
0
async def test_applied_events_after_action_session_start(
    default_channel: CollectingOutputChannel,
    template_nlg: TemplatedNaturalLanguageGenerator,
):
    slot_set = SlotSet("my_slot", "value")
    events = [
        slot_set,
        ActionExecuted(ACTION_LISTEN_NAME),
        # User triggers a restart manually by triggering the intent
        UserUttered(
            text=f"/{USER_INTENT_SESSION_START}",
            intent={"name": USER_INTENT_SESSION_START},
        ),
    ]
    tracker = DialogueStateTracker.from_events("🕵️‍♀️", events)

    # Mapping Policy kicks in and runs the session restart action
    events = await ActionSessionStart().run(
        default_channel, template_nlg, tracker, Domain.empty()
    )
    for event in events:
        tracker.update(event)

    assert tracker.applied_events() == [slot_set, ActionExecuted(ACTION_LISTEN_NAME)]
Beispiel #24
0
def test_yaml_writer_doesnt_dump_action_unlikely_intent():
    events = [
        UserUttered("Hello", {"name": "greet"}),
        ActionExecuted("utter_hello"),
        ActionExecuted(ACTION_UNLIKELY_INTENT_NAME,
                       metadata={"key1": "value1"}),
        ActionExecuted("utter_bye"),
    ]
    tracker = DialogueStateTracker.from_events("default", events)
    dump = YAMLStoryWriter().dumps(tracker.as_story().story_steps,
                                   is_test_story=True)

    assert (dump.strip() == textwrap.dedent("""
    version: "3.0"
    stories:
    - story: default
      steps:
      - intent: greet
        user: |-
          Hello
      - action: utter_hello
      - action: utter_bye

""").strip())
Beispiel #25
0
def test_operator_seq_does_not_allow_overlap(negated: bool):
    events_expected = [
        (UserUttered(intent={INTENT_NAME_KEY: "1"}), False),
        (UserUttered(intent={INTENT_NAME_KEY: "2"}), False),
        (UserUttered(intent={INTENT_NAME_KEY: "1"}), False),
        (UserUttered(intent={INTENT_NAME_KEY: "2"}), False),
        (UserUttered(intent={INTENT_NAME_KEY: "3"}), True),
        (UserUttered(intent={INTENT_NAME_KEY: "3"}), False),
    ]
    events, expected = zip(*events_expected)
    sub_markers = [
        IntentDetectedMarker("1"),
        IntentDetectedMarker("2"),
        IntentDetectedMarker("3"),
    ]
    marker = SequenceMarker(sub_markers, name="marker_name", negated=negated)
    for event in events:
        marker.track(event)
    expected = list(expected)
    if negated:
        expected = [not applies for applies in expected]
    assert marker.history == expected
Beispiel #26
0
def test_operators_nested_simple():
    events = [
        UserUttered(intent={"name": "1"}),
        UserUttered(intent={"name": "2"}),
        UserUttered(intent={"name": "3"}),
        SlotSet("s1", value="any"),
        UserUttered(intent={"name": "4"}),
        UserUttered(intent={"name": "5"}),
        UserUttered(intent={"name": "6"}),
    ]
    marker = AndMarker(
        markers=[
            SlotSetMarker("s1"),
            OrMarker([IntentDetectedMarker("4"),
                      IntentDetectedMarker("6")]),
        ],
        name="marker_name",
    )
    evaluation = marker.evaluate_events(events)

    assert len(evaluation[0]["marker_name"]) == 2
    assert evaluation[0]["marker_name"][0].preceding_user_turns == 3
    assert evaluation[0]["marker_name"][1].preceding_user_turns == 5
Beispiel #27
0
    def process_user_utterance(user_utterance: UserUttered,
                               is_test_story: bool = False) -> OrderedDict:
        """Converts a single user utterance into an ordered dict.

        Args:
            user_utterance: Original user utterance object.
            is_test_story: Identifies if the user utterance should be added
                           to the final YAML or not.

        Returns:
            Dict with a user utterance.
        """
        result = CommentedMap()
        if user_utterance.intent_name and not user_utterance.use_text_for_featurization:
            result[KEY_USER_INTENT] = user_utterance.intent_name

        if hasattr(user_utterance, "inline_comment"):
            comment = user_utterance.inline_comment()
            if comment:
                result.yaml_add_eol_comment(comment, KEY_USER_INTENT)

        if user_utterance.text and (
                # We only print the utterance text if it was an end-to-end prediction
                user_utterance.use_text_for_featurization
                # or if we want to print a conversation test story.
                or is_test_story):
            result[KEY_USER_MESSAGE] = LiteralScalarString(
                rasa.shared.core.events.format_message(
                    user_utterance.text,
                    user_utterance.intent_name,
                    user_utterance.entities,
                ))

        if len(user_utterance.entities) and not is_test_story:
            entities = []
            for entity in user_utterance.entities:
                if "value" in entity:
                    if hasattr(user_utterance, "inline_comment_for_entity"):
                        for predicted in user_utterance.predicted_entities:
                            if predicted["start"] == entity["start"]:
                                commented_entity = user_utterance.inline_comment_for_entity(  # noqa: E501
                                    predicted, entity)
                                if commented_entity:
                                    entity_map = CommentedMap([
                                        (entity["entity"], entity["value"])
                                    ])
                                    entity_map.yaml_add_eol_comment(
                                        commented_entity,
                                        entity["entity"],
                                    )
                                    entities.append(entity_map)
                                else:
                                    entities.append(
                                        OrderedDict([(entity["entity"],
                                                      entity["value"])]))
                    else:
                        entities.append(
                            OrderedDict([(entity["entity"], entity["value"])]))
                else:
                    entities.append(entity["entity"])
            result[KEY_ENTITIES] = entities

        return result
Beispiel #28
0
    BotUttered,
    FollowupAction,
    UserUtteranceReverted,
    AgentUttered,
    SessionStarted,
    format_message,
)
from rasa.shared.nlu.constants import INTENT_NAME_KEY
from tests.core.policies.test_rule_policy import GREET_INTENT_NAME, UTTER_GREET_ACTION


@pytest.mark.parametrize(
    "one_event,another_event",
    [
        (
            UserUttered("/greet", {"name": "greet", "confidence": 1.0}, []),
            UserUttered("/goodbye", {"name": "goodbye", "confidence": 1.0}, []),
        ),
        (SlotSet("my_slot", "value"), SlotSet("my__other_slot", "value")),
        (Restarted(), None),
        (AllSlotsReset(), None),
        (ConversationPaused(), None),
        (ConversationResumed(), None),
        (StoryExported(), None),
        (ActionReverted(), None),
        (UserUtteranceReverted(), None),
        (SessionStarted(), None),
        (ActionExecuted("my_action"), ActionExecuted("my_other_action")),
        (FollowupAction("my_action"), FollowupAction("my_other_action")),
        (
            BotUttered("my_text", {"my_data": 1}),
Beispiel #29
0
def random_user_uttered_event(timestamp: Optional[float] = None) -> UserUttered:
    return UserUttered(
        uuid.uuid4().hex,
        timestamp=timestamp if timestamp is not None else random.random(),
    )
Beispiel #30
0
async def test_handle_message_with_session_start(
    default_channel: CollectingOutputChannel,
    default_processor: MessageProcessor,
    monkeypatch: MonkeyPatch,
):
    sender_id = uuid.uuid4().hex

    entity = "name"
    slot_1 = {entity: "Core"}
    await default_processor.handle_message(
        UserMessage(f"/greet{json.dumps(slot_1)}", default_channel, sender_id))

    assert default_channel.latest_output() == {
        "recipient_id": sender_id,
        "text": "hey there Core!",
    }

    # patch processor so a session start is triggered
    monkeypatch.setattr(default_processor, "_has_session_expired",
                        lambda _: True)

    slot_2 = {entity: "post-session start hello"}
    # handle a new message
    await default_processor.handle_message(
        UserMessage(f"/greet{json.dumps(slot_2)}", default_channel, sender_id))

    tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

    # make sure the sequence of events is as expected
    expected = [
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted(),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered(
            f"/greet{json.dumps(slot_1)}",
            {
                INTENT_NAME_KEY: "greet",
                "confidence": 1.0
            },
            [{
                "entity": entity,
                "start": 6,
                "end": 22,
                "value": "Core"
            }],
        ),
        SlotSet(entity, slot_1[entity]),
        ActionExecuted("utter_greet"),
        BotUttered("hey there Core!",
                   metadata={"template_name": "utter_greet"}),
        ActionExecuted(ACTION_LISTEN_NAME),
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted(),
        # the initial SlotSet is reapplied after the SessionStarted sequence
        SlotSet(entity, slot_1[entity]),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered(
            f"/greet{json.dumps(slot_2)}",
            {
                INTENT_NAME_KEY: "greet",
                "confidence": 1.0
            },
            [{
                "entity": entity,
                "start": 6,
                "end": 42,
                "value": "post-session start hello",
            }],
        ),
        SlotSet(entity, slot_2[entity]),
        ActionExecuted("utter_greet"),
        BotUttered(
            "hey there post-session start hello!",
            metadata={"template_name": "utter_greet"},
        ),
        ActionExecuted(ACTION_LISTEN_NAME),
    ]

    assert list(tracker.events) == expected