def test_process_warns_if_intent_or_entities_not_in_domain( intent: Text, entities: Optional[Text], expected_intent: Text, domain_entities: List[Text], ): # construct text according to pattern text = INTENT_MESSAGE_PREFIX + intent # do not add a confidence value if entities is not None: text += json.dumps(entities) message = Message(data={TEXT: text}) # construct domain from expected intent/entities domain = Domain( intents=[expected_intent], entities=domain_entities, slots=[], responses={}, action_names=[], forms={}, ) # expect a warning with pytest.warns(UserWarning): unpacked_message = YAMLStoryReader.unpack_regex_message(message, domain) if "wrong" not in intent: assert unpacked_message.data[INTENT][INTENT_NAME_KEY] == intent if "wrong" in entities: assert unpacked_message.data[ENTITIES] is not None assert len(unpacked_message.data[ENTITIES]) == 0 else: assert unpacked_message == message
async def test_unpack_regex_message_has_correct_entity_start_and_end(): entity = "name" slot_1 = {entity: "Core"} text = f"/greet{json.dumps(slot_1)}" message = Message(data={TEXT: text},) domain = Domain( intents=["greet"], entities=[entity], slots=[], responses={}, action_names=[], forms={}, ) message = YAMLStoryReader.unpack_regex_message( message, domain, entity_extractor_name="RegexMessageHandler" ) assert message.data == { "text": '/greet{"name": "Core"}', "intent": {"name": "greet", "confidence": 1.0}, "intent_ranking": [{"name": "greet", "confidence": 1.0}], "entities": [ { "entity": "name", "value": "Core", "start": 6, "end": 22, EXTRACTOR: "RegexMessageHandler", } ], }
def process(self, messages: List[Message], domain: Optional[Domain] = None) -> List[Message]: """Adds hardcoded intents and entities for messages starting with '/'. Args: messages: The messages which should be handled. domain: If given the domain is used to check whether the intent, entities valid. Returns: The messages with potentially intent and entity prediction replaced in case the message started with a `/`. """ return [ YAMLStoryReader.unpack_regex_message( message, domain, entity_extractor_name=self.name) for message in messages ]
def test_process_unpacks_attributes_from_single_message_and_fallsback_if_needed( confidence: Optional[Text], entities: Optional[Text], expected_confidence: float, expected_entities: Optional[List[Dict[Text, Any]]], should_warn: bool, ): # dummy intent expected_intent = "my-intent" # construct text according to pattern text = " \t " + INTENT_MESSAGE_PREFIX + expected_intent if confidence is not None: text += f"@{confidence}" if entities is not None: text += entities text += " \t " # create a message with some dummy attributes and features message = Message( data={TEXT: text, INTENT: "extracted-from-the-pattern-text-via-nlu"}, features=[ Features( features=np.zeros((1, 1)), feature_type=FEATURE_TYPE_SENTENCE, attribute=TEXT, origin="nlu-pipeline", ) ], ) # construct domain from expected intent/entities domain_entities = [item[ENTITY_ATTRIBUTE_TYPE] for item in expected_entities] domain_intents = [expected_intent] if expected_intent is not None else [] domain = Domain( intents=domain_intents, entities=domain_entities, slots=[], responses={}, action_names=[], forms={}, ) # extract information if should_warn: with pytest.warns(UserWarning): unpacked_message = YAMLStoryReader.unpack_regex_message(message, domain) else: unpacked_message = YAMLStoryReader.unpack_regex_message(message, domain) assert not unpacked_message.features assert set(unpacked_message.data.keys()) == { TEXT, INTENT, INTENT_RANKING_KEY, ENTITIES, } assert unpacked_message.data[TEXT] == message.data[TEXT].strip() assert set(unpacked_message.data[INTENT].keys()) == { INTENT_NAME_KEY, PREDICTED_CONFIDENCE_KEY, } assert unpacked_message.data[INTENT][INTENT_NAME_KEY] == expected_intent assert ( unpacked_message.data[INTENT][PREDICTED_CONFIDENCE_KEY] == expected_confidence ) intent_ranking = unpacked_message.data[INTENT_RANKING_KEY] assert len(intent_ranking) == 1 assert intent_ranking[0] == { INTENT_NAME_KEY: expected_intent, PREDICTED_CONFIDENCE_KEY: expected_confidence, } if expected_entities: entity_data: List[Dict[Text, Any]] = unpacked_message.data[ENTITIES] assert all( set(item.keys()) == { ENTITY_ATTRIBUTE_VALUE, ENTITY_ATTRIBUTE_TYPE, ENTITY_ATTRIBUTE_START, ENTITY_ATTRIBUTE_END, } for item in entity_data ) assert set( (item[ENTITY_ATTRIBUTE_TYPE], item[ENTITY_ATTRIBUTE_VALUE]) for item in expected_entities ) == set( (item[ENTITY_ATTRIBUTE_TYPE], item[ENTITY_ATTRIBUTE_VALUE]) for item in entity_data ) else: assert unpacked_message.data[ENTITIES] is not None assert len(unpacked_message.data[ENTITIES]) == 0