Exemplo n.º 1
0
def test_process_warns_if_intent_or_entities_not_in_domain(
    intent: Text,
    entities: Optional[Text],
    expected_intent: Text,
    domain_entities: List[Text],
):
    # construct text according to pattern
    text = INTENT_MESSAGE_PREFIX + intent  # do not add a confidence value
    if entities is not None:
        text += json.dumps(entities)
    message = Message(data={TEXT: text})

    # construct domain from expected intent/entities
    domain = Domain(
        intents=[expected_intent],
        entities=domain_entities,
        slots=[],
        responses={},
        action_names=[],
        forms={},
    )

    # expect a warning
    with pytest.warns(UserWarning):
        unpacked_message = YAMLStoryReader.unpack_regex_message(message, domain)

    if "wrong" not in intent:
        assert unpacked_message.data[INTENT][INTENT_NAME_KEY] == intent
        if "wrong" in entities:
            assert unpacked_message.data[ENTITIES] is not None
            assert len(unpacked_message.data[ENTITIES]) == 0
    else:
        assert unpacked_message == message
Exemplo n.º 2
0
async def test_unpack_regex_message_has_correct_entity_start_and_end():
    entity = "name"
    slot_1 = {entity: "Core"}
    text = f"/greet{json.dumps(slot_1)}"

    message = Message(data={TEXT: text},)

    domain = Domain(
        intents=["greet"],
        entities=[entity],
        slots=[],
        responses={},
        action_names=[],
        forms={},
    )

    message = YAMLStoryReader.unpack_regex_message(
        message, domain, entity_extractor_name="RegexMessageHandler"
    )

    assert message.data == {
        "text": '/greet{"name": "Core"}',
        "intent": {"name": "greet", "confidence": 1.0},
        "intent_ranking": [{"name": "greet", "confidence": 1.0}],
        "entities": [
            {
                "entity": "name",
                "value": "Core",
                "start": 6,
                "end": 22,
                EXTRACTOR: "RegexMessageHandler",
            }
        ],
    }
Exemplo n.º 3
0
    def process(self,
                messages: List[Message],
                domain: Optional[Domain] = None) -> List[Message]:
        """Adds hardcoded intents and entities for messages starting with '/'.

        Args:
            messages: The messages which should be handled.
            domain: If given the domain is used to check whether the intent, entities
                valid.

        Returns:
            The messages with potentially intent and entity prediction replaced
            in case the message started with a `/`.
        """
        return [
            YAMLStoryReader.unpack_regex_message(
                message, domain, entity_extractor_name=self.name)
            for message in messages
        ]
Exemplo n.º 4
0
def test_process_unpacks_attributes_from_single_message_and_fallsback_if_needed(
    confidence: Optional[Text],
    entities: Optional[Text],
    expected_confidence: float,
    expected_entities: Optional[List[Dict[Text, Any]]],
    should_warn: bool,
):
    # dummy intent
    expected_intent = "my-intent"

    # construct text according to pattern
    text = " \t  " + INTENT_MESSAGE_PREFIX + expected_intent
    if confidence is not None:
        text += f"@{confidence}"
    if entities is not None:
        text += entities
    text += " \t "

    # create a message with some dummy attributes and features
    message = Message(
        data={TEXT: text, INTENT: "extracted-from-the-pattern-text-via-nlu"},
        features=[
            Features(
                features=np.zeros((1, 1)),
                feature_type=FEATURE_TYPE_SENTENCE,
                attribute=TEXT,
                origin="nlu-pipeline",
            )
        ],
    )

    # construct domain from expected intent/entities
    domain_entities = [item[ENTITY_ATTRIBUTE_TYPE] for item in expected_entities]
    domain_intents = [expected_intent] if expected_intent is not None else []
    domain = Domain(
        intents=domain_intents,
        entities=domain_entities,
        slots=[],
        responses={},
        action_names=[],
        forms={},
    )

    # extract information
    if should_warn:
        with pytest.warns(UserWarning):
            unpacked_message = YAMLStoryReader.unpack_regex_message(message, domain)
    else:
        unpacked_message = YAMLStoryReader.unpack_regex_message(message, domain)

    assert not unpacked_message.features

    assert set(unpacked_message.data.keys()) == {
        TEXT,
        INTENT,
        INTENT_RANKING_KEY,
        ENTITIES,
    }

    assert unpacked_message.data[TEXT] == message.data[TEXT].strip()

    assert set(unpacked_message.data[INTENT].keys()) == {
        INTENT_NAME_KEY,
        PREDICTED_CONFIDENCE_KEY,
    }
    assert unpacked_message.data[INTENT][INTENT_NAME_KEY] == expected_intent
    assert (
        unpacked_message.data[INTENT][PREDICTED_CONFIDENCE_KEY] == expected_confidence
    )

    intent_ranking = unpacked_message.data[INTENT_RANKING_KEY]
    assert len(intent_ranking) == 1
    assert intent_ranking[0] == {
        INTENT_NAME_KEY: expected_intent,
        PREDICTED_CONFIDENCE_KEY: expected_confidence,
    }
    if expected_entities:
        entity_data: List[Dict[Text, Any]] = unpacked_message.data[ENTITIES]
        assert all(
            set(item.keys())
            == {
                ENTITY_ATTRIBUTE_VALUE,
                ENTITY_ATTRIBUTE_TYPE,
                ENTITY_ATTRIBUTE_START,
                ENTITY_ATTRIBUTE_END,
            }
            for item in entity_data
        )
        assert set(
            (item[ENTITY_ATTRIBUTE_TYPE], item[ENTITY_ATTRIBUTE_VALUE])
            for item in expected_entities
        ) == set(
            (item[ENTITY_ATTRIBUTE_TYPE], item[ENTITY_ATTRIBUTE_VALUE])
            for item in entity_data
        )
    else:
        assert unpacked_message.data[ENTITIES] is not None
        assert len(unpacked_message.data[ENTITIES]) == 0