async def run( self, dispatcher: "CollectingDispatcher", tracker: "Tracker", domain: "DomainDict", ) -> List[EventType]: """Execute the side effects of this form. Steps: - activate if needed - validate user input if needed - set validated slots - utter_ask_{slot} template with the next required slot - submit the form if all required slots are set - deactivate the form """ # activate the form events = await self._activate_if_required(dispatcher, tracker, domain) # validate user input events.extend(await self._validate_if_required(dispatcher, tracker, domain)) # check that the form wasn't deactivated in validation if ActiveLoop(None) not in events: # create temp tracker with populated slots from `validate` method temp_tracker = tracker.copy() for e in events: if e["event"] == "slot": temp_tracker.slots[e["name"]] = e["value"] next_slot_events = self.request_next_slot(dispatcher, temp_tracker, domain) if next_slot_events is not None: # request next slot events.extend(next_slot_events) else: # there is nothing more to request, so we can submit self._log_form_slots(temp_tracker) logger.debug(f"Submitting the form '{self.name()}'") if utils.is_coroutine_action(self.submit): events.extend(await self.submit(dispatcher, temp_tracker, domain)) else: # see https://github.com/python/mypy/issues/5206 events.extend( cast( List[EventType], self.submit(dispatcher, temp_tracker, domain), ) ) # deactivate the form after submission if utils.is_coroutine_action(self.deactivate): events.extend(await self.deactivate()) # type: ignore else: events.extend(self.deactivate()) return events
async def get_object(self, object_type: Text, object_identifier: Text) -> Optional[Dict[Text, Any]]: if object_type not in self.data: return None objects = self.data[object_type] if utils.is_coroutine_action(self.get_key_attribute_of_object): key_attribute = await self.get_key_attribute_of_object(object_type) else: # see https://github.com/python/mypy/issues/5206 key_attribute = cast(Text, self.get_key_attribute_of_object(object_type)) # filter the objects by its key attribute, for example, 'id' objects_of_interest = list( filter( lambda obj: str(obj[key_attribute]).lower() == str( object_identifier).lower(), objects, )) # if the object was referred to directly, we need to compare the representation # of each object with the given object identifier if not objects_of_interest: if utils.is_coroutine_action( self.get_representation_function_of_object): repr_function = await self.get_representation_function_of_object( object_type) else: # see https://github.com/python/mypy/issues/5206 repr_function = cast( Callable, self.get_representation_function_of_object(object_type)) objects_of_interest = list( filter( lambda obj: str(object_identifier).lower() in str( repr_function(obj)).lower(), objects, )) if not objects_of_interest or len(objects_of_interest) > 1: # TODO: # if multiple objects are found, the objects could be shown # to the user. the user then needs to clarify what object he meant. return None return objects_of_interest[0]
async def run(self, action_call: Dict[Text, Any]) -> Optional[Dict[Text, Any]]: from rasa_sdk.interfaces import Tracker action_name = action_call.get("next_action") if action_name: logger.debug(f"Received request to run '{action_name}'") action = self.actions.get(action_name) if not action: raise ActionNotFoundException(action_name) tracker_json = action_call.get("tracker") domain = action_call.get("domain", {}) tracker = Tracker.from_dict(tracker_json) dispatcher = CollectingDispatcher() if utils.is_coroutine_action(action): events = await action(dispatcher, tracker, domain) else: events = action(dispatcher, tracker, domain) if not events: # make sure the action did not just return `None`... events = [] validated_events = self.validate_events(events, action_name) logger.debug(f"Finished running '{action_name}'") return self._create_api_response(validated_events, dispatcher.messages) else: logger.warning("Received an action call without an action.")
async def utter_objects( self, dispatcher: CollectingDispatcher, object_type: Text, objects: List[Dict[Text, Any]], ) -> None: """ Utters a response to the user that lists all found objects. Args: dispatcher: the dispatcher object_type: the object type objects: the list of objects """ if objects: dispatcher.utter_message( text=f"Found the following objects of type '{object_type}':") if utils.is_coroutine_action( self.knowledge_base.get_representation_function_of_object): repr_function = ( await self.knowledge_base.get_representation_function_of_object( object_type)) else: repr_function = ( self.knowledge_base.get_representation_function_of_object( object_type)) for i, obj in enumerate(objects, 1): dispatcher.utter_message(text=f"{i}: {repr_function(obj)}") else: dispatcher.utter_message( text=f"I could not find any objects of type '{object_type}'.")
async def _validate_if_required( self, dispatcher: "CollectingDispatcher", tracker: "Tracker", domain: "DomainDict", ) -> List[EventType]: """Return a list of events from `self.validate(...)` if validation is required: - the form is active - the form is called after `action_listen` - form validation was not cancelled """ # no active_loop means that it is called during activation need_validation = not tracker.active_loop or ( tracker.latest_action_name == "action_listen" and not tracker.active_loop.get(LOOP_INTERRUPTED_KEY, False) ) if need_validation: logger.debug(f"Validating user input '{tracker.latest_message}'") if utils.is_coroutine_action(self.validate): return await self.validate(dispatcher, tracker, domain) else: # see https://github.com/python/mypy/issues/5206 return cast( List[Dict[Text, Any]], self.validate(dispatcher, tracker, domain) ) else: logger.debug("Skipping validation") return []
async def validate_slots( self, slot_dict: Dict[Text, Any], dispatcher: "CollectingDispatcher", tracker: "Tracker", domain: Dict[Text, Any], ) -> List[EventType]: """Validate slots using helper validation functions. Call validate_{slot} function for each slot, value pair to be validated. If this function is not implemented, set the slot to the value. """ for slot, value in list(slot_dict.items()): validate_func = getattr(self, f"validate_{slot}", lambda *x: {slot: value}) if utils.is_coroutine_action(validate_func): validation_output = await validate_func( value, dispatcher, tracker, domain ) else: validation_output = validate_func(value, dispatcher, tracker, domain) if not isinstance(validation_output, dict): logger.warning( "Returning values in helper validation methods is deprecated. " + f"Your `validate_{slot}()` method should return " + "a dict of {'slot_name': value} instead." ) validation_output = {slot: validation_output} slot_dict.update(validation_output) # validation succeed, set slots to extracted values return [SlotSet(slot, value) for slot, value in slot_dict.items()]
async def _validate_if_required( self, dispatcher: "CollectingDispatcher", tracker: "Tracker", domain: Dict[Text, Any], ) -> List[EventType]: """Return a list of events from `self.validate(...)` if validation is required: - the form is active - the form is called after `action_listen` - form validation was not cancelled """ # no active_loop means that it is called during activation need_validation = not tracker.active_loop or ( tracker.latest_action_name == "action_listen" and tracker.active_loop.get("validate", True)) if need_validation: logger.debug(f"Validating user input '{tracker.latest_message}'") if utils.is_coroutine_action(self.validate): return await self.validate(dispatcher, tracker, domain) else: return self.validate(dispatcher, tracker, domain) else: logger.debug("Skipping validation") return []
async def utter_objects( self, dispatcher: CollectingDispatcher, object_type: Text, objects: List[Dict[Text, Any]], ) -> None: """ Utters a response to the user that lists all found objects. Args: dispatcher: the dispatcher object_type: the object type objects: the list of objects """ if objects: dispatcher.utter_message( text="找到下列{}:".format(self.en_to_zh(object_type))) if utils.is_coroutine_action( self.knowledge_base.get_representation_function_of_object): repr_function = await self.knowledge_base.get_representation_function_of_object( object_type) else: repr_function = self.knowledge_base.get_representation_function_of_object( object_type) for i, obj in enumerate(objects, 1): dispatcher.utter_message(text=f"{i}: {repr_function(obj)}") else: dispatcher.utter_message( text="我没找到任何{}.".format(self.en_to_zh(object_type)))
async def extract_other_slots( self, dispatcher: "CollectingDispatcher", tracker: "Tracker", domain: Dict[Text, Any], ) -> Dict[Text, Any]: """Extract the values of the other slots if they are set by corresponding entities from the user input else return None """ slot_to_fill = tracker.get_slot(REQUESTED_SLOT) slot_values = {} if utils.is_coroutine_action(self.required_slots): required_slots = await self.required_slots(tracker) else: required_slots = self.required_slots(tracker) for slot in required_slots: # look for other slots if slot != slot_to_fill: # list is used to cover the case of list slot type other_slot_mappings = self.get_mappings_for_slot(slot) for other_slot_mapping in other_slot_mappings: # check whether the slot should be filled # by entity with the same name should_fill_entity_slot = ( other_slot_mapping["type"] == "from_entity" and other_slot_mapping.get("entity") == slot and self.intent_is_desired(other_slot_mapping, tracker) ) # check whether the slot should be # filled from trigger intent mapping should_fill_trigger_slot = ( tracker.active_form.get("name") != self.name() and other_slot_mapping["type"] == "from_trigger_intent" and self.intent_is_desired(other_slot_mapping, tracker) ) if should_fill_entity_slot: value = self.get_entity_value(slot, tracker) elif should_fill_trigger_slot: value = other_slot_mapping.get("value") else: value = None if value is not None: logger.debug(f"Extracted '{value}' for extra slot '{slot}'.") slot_values[slot] = value # this slot is done, check next break return slot_values
async def _log_form_slots(self, tracker: "Tracker") -> None: """Logs the values of all required slots before submitting the form.""" if utils.is_coroutine_action(self.required_slots): required_slots = await self.required_slots(tracker) else: required_slots = self.required_slots(tracker) slot_values = "\n".join( [f"\t{slot}: {tracker.get_slot(slot)}" for slot in required_slots] ) logger.debug( f"No slots left to request, all required slots are filled:\n{slot_values}" )
async def validate( self, dispatcher: "CollectingDispatcher", tracker: "Tracker", domain: "DomainDict", ) -> List[EventType]: """Validate slots by calling a validation function for each slot. Args: dispatcher: the dispatcher which is used to send messages back to the user. tracker: the conversation tracker for the current user. domain: the bot's domain. Returns: `SlotSet` events for every validated slot. """ slots: Dict[Text, Any] = tracker.slots_to_validate() for slot_name, slot_value in list(slots.items()): function_name = f"validate_{slot_name.replace('-','_')}" validate_func = getattr(self, function_name, None) if not validate_func: logger.debug( f"Skipping validation for `{slot_name}`: there is no validation function specified." ) continue if utils.is_coroutine_action(validate_func): validation_output = await validate_func( slot_value, dispatcher, tracker, domain ) else: validation_output = validate_func( slot_value, dispatcher, tracker, domain ) if validation_output: slots.update(validation_output) else: warnings.warn( f"Cannot validate `{slot_name}`: make sure the validation function returns the correct output." ) return [SlotSet(slot, value) for slot, value in slots.items()]
async def _activate_if_required( self, dispatcher: "CollectingDispatcher", tracker: "Tracker", domain: Dict[Text, Any], ) -> List[EventType]: """Activate form if the form is called for the first time. If activating, validate any required slots that were filled before form activation and return `Form` event with the name of the form, as well as any `SlotSet` events from validation of pre-filled slots. """ if tracker.active_form.get("name") is not None: logger.debug(f"The form '{tracker.active_form}' is active") else: logger.debug("There is no active form") if tracker.active_form.get("name") == self.name(): return [] else: logger.debug(f"Activated the form '{self.name()}'") events = [Form(self.name())] # collect values of required slots filled before activation prefilled_slots = {} if utils.is_coroutine_action(self.required_slots): required_slots = await self.required_slots(tracker) else: required_slots = self.required_slots(tracker) for slot_name in required_slots: if not self._should_request_slot(tracker, slot_name): prefilled_slots[slot_name] = tracker.get_slot(slot_name) if prefilled_slots: logger.debug(f"Validating pre-filled required slots: {prefilled_slots}") events.extend( await self.validate_slots( prefilled_slots, dispatcher, tracker, domain ) ) else: logger.debug("No pre-filled required slots to validate.") return events
async def request_next_slot( self, dispatcher: "CollectingDispatcher", tracker: "Tracker", domain: Dict[Text, Any], ) -> Optional[List[EventType]]: """Request the next slot and utter template if needed, else return None""" if utils.is_coroutine_action(self.required_slots): required_slots = await self.required_slots(tracker) else: required_slots = self.required_slots(tracker) for slot in required_slots: if self._should_request_slot(tracker, slot): logger.debug(f"Request next slot '{slot}'") dispatcher.utter_message(template=f"utter_ask_{slot}", **tracker.slots) return [SlotSet(REQUESTED_SLOT, slot)] # no more required slots to fill return None
async def utter_objects( self, dispatcher: CollectingDispatcher, object_type: Text, objects: List[Dict[Text, Any]], ) -> None: """ Utters a response to the user that lists all found objects. Args: dispatcher: the dispatcher object_type: the object type objects: the list of objects """ if objects: dispatcher.utter_message( text=f"Folgendes ist in der '{object_type}' versichert:" ) if utils.is_coroutine_action( self.knowledge_base.get_representation_function_of_object ): repr_function = await self.knowledge_base.get_representation_function_of_object( object_type ) else: repr_function = self.knowledge_base.get_representation_function_of_object( object_type ) for i, obj in enumerate(objects, 1): dispatcher.utter_message(text=f"{i}: {repr_function(obj)}") else: dispatcher.utter_message( text=f"Ich konnte leider nichts zu '{object_type}' finden." )
def test_action_async_check(): assert not is_coroutine_action(CustomAction.run) assert is_coroutine_action(CustomAsyncAction.run)
async def _query_attribute( self, dispatcher: CollectingDispatcher, tracker: Tracker ) -> List[Dict]: """ Queries the knowledge base for the value of the requested attribute of the mentioned object and outputs it to the user. Args: dispatcher: the dispatcher tracker: the tracker Returns: list of slots """ object_type = tracker.get_slot(SLOT_OBJECT_TYPE) attribute = tracker.get_slot(SLOT_ATTRIBUTE) object_name = get_object_name( tracker, self.knowledge_base.ordinal_mention_mapping, self.use_last_object_mention, ) if not object_name or not attribute: dispatcher.utter_message(template="utter_ask_rephrase") return [SlotSet(SLOT_MENTION, None)] if utils.is_coroutine_action(self.knowledge_base.get_object): object_of_interest = await self.knowledge_base.get_object( object_type, object_name # type: ignore ) else: object_of_interest = self.knowledge_base.get_object( object_type, object_name ) if not object_of_interest or attribute not in object_of_interest: dispatcher.utter_message(template="utter_ask_rephrase") return [SlotSet(SLOT_MENTION, None)] value = object_of_interest[attribute] if utils.is_coroutine_action( self.knowledge_base.get_representation_function_of_object ): repr_function = await self.knowledge_base.get_representation_function_of_object( object_type # type: ignore ) else: repr_function = self.knowledge_base.get_representation_function_of_object( object_type ) object_representation = repr_function(object_of_interest) if utils.is_coroutine_action(self.knowledge_base.get_key_attribute_of_object): key_attribute = await self.knowledge_base.get_key_attribute_of_object( object_type ) else: key_attribute = self.knowledge_base.get_key_attribute_of_object(object_type) object_identifier = object_of_interest[key_attribute] if utils.is_coroutine_action(self.utter_attribute_value): await self.utter_attribute_value( dispatcher, object_representation, attribute, value # type: ignore ) else: self.utter_attribute_value( dispatcher, object_representation, attribute, value ) slots = [ SlotSet(SLOT_OBJECT_TYPE, object_type), SlotSet(SLOT_ATTRIBUTE, None), SlotSet(SLOT_MENTION, None), SlotSet(SLOT_LAST_OBJECT, object_identifier), SlotSet(SLOT_LAST_OBJECT_TYPE, object_type), ] return slots
async def _query_objects(self, dispatcher: CollectingDispatcher, object_type: Text, tracker: Tracker) -> List[Dict]: """ Queries the knowledge base for objects of the requested object type and outputs those to the user. The objects are filtered by any attribute the user mentioned in the request. Args: dispatcher: the dispatcher tracker: the tracker Returns: list of slots """ if utils.is_coroutine_action( self.knowledge_base.get_attributes_of_object): object_attributes = await self.knowledge_base.get_attributes_of_object( object_type) else: # see https://github.com/python/mypy/issues/5206 object_attributes = cast( List[Text], self.knowledge_base.get_attributes_of_object(object_type)) # get all set attribute slots of the object type to be able to filter the # list of objects attributes = get_attribute_slots(tracker, object_attributes) # query the knowledge base if utils.is_coroutine_action(self.knowledge_base.get_objects): objects = await self.knowledge_base.get_objects( object_type, attributes) else: # see https://github.com/python/mypy/issues/5206 objects = cast( List[Dict[Text, Any]], self.knowledge_base.get_objects(object_type, attributes), ) if utils.is_coroutine_action(self.utter_objects): await self.utter_objects(dispatcher, object_type, objects) else: self.utter_objects(dispatcher, object_type, objects) if not objects: return reset_attribute_slots(tracker, object_attributes) if utils.is_coroutine_action( self.knowledge_base.get_key_attribute_of_object): key_attribute = await self.knowledge_base.get_key_attribute_of_object( object_type) else: # see https://github.com/python/mypy/issues/5206 key_attribute = cast( Text, self.knowledge_base.get_key_attribute_of_object(object_type)) last_object = None if len(objects) > 1 else objects[0][key_attribute] slots = [ SlotSet(SLOT_OBJECT_TYPE, object_type), SlotSet(SLOT_MENTION, None), SlotSet(SLOT_ATTRIBUTE, None), SlotSet(SLOT_LAST_OBJECT, last_object), SlotSet(SLOT_LAST_OBJECT_TYPE, object_type), SlotSet(SLOT_LISTED_OBJECTS, list(map(lambda e: e[key_attribute], objects))), ] return slots + reset_attribute_slots(tracker, object_attributes)
async def _query_attribute( self, dispatcher: CollectingDispatcher, object_type: Text, attribute: Text, tracker: Tracker, ) -> List[Dict]: """ Queries the knowledge base for the value of the requested attribute of the mentioned object and outputs it to the user. Args: dispatcher: the dispatcher tracker: the tracker Returns: list of slots """ object_name = get_object_name( tracker, self.knowledge_base.ordinal_mention_mapping, self.use_last_object_mention, ) if not object_name or not attribute: dispatcher.utter_message(template="utter_ask_rephrase") return [SlotSet(SLOT_MENTION, None)] if utils.is_coroutine_action(self.knowledge_base.get_object): object_of_interest = await self.knowledge_base.get_object( object_type, object_name) else: # see https://github.com/python/mypy/issues/5206 object_of_interest = cast( Optional[Dict[Text, Any]], self.knowledge_base.get_object(object_type, object_name), ) if not object_of_interest or attribute not in object_of_interest: dispatcher.utter_message(template="utter_ask_rephrase") return [SlotSet(SLOT_MENTION, None)] value = object_of_interest[attribute] if utils.is_coroutine_action( self.knowledge_base.get_representation_function_of_object): repr_function = ( await self.knowledge_base.get_representation_function_of_object( object_type)) else: # see https://github.com/python/mypy/issues/5206 repr_function = cast( Callable, self.knowledge_base.get_representation_function_of_object( object_type), ) object_representation = repr_function(object_of_interest) if utils.is_coroutine_action( self.knowledge_base.get_key_attribute_of_object): key_attribute = await self.knowledge_base.get_key_attribute_of_object( object_type) else: # see https://github.com/python/mypy/issues/5206 key_attribute = cast( Text, self.knowledge_base.get_key_attribute_of_object(object_type)) object_identifier = object_of_interest[key_attribute] if utils.is_coroutine_action(self.utter_attribute_value): await self.utter_attribute_value(dispatcher, object_representation, attribute, value) else: self.utter_attribute_value(dispatcher, object_representation, attribute, value) slots = [ SlotSet(SLOT_OBJECT_TYPE, object_type), SlotSet(SLOT_ATTRIBUTE, None), SlotSet(SLOT_MENTION, None), SlotSet(SLOT_LAST_OBJECT, object_identifier), SlotSet(SLOT_LAST_OBJECT_TYPE, object_type), ] return slots