Beispiel #1
0
    async def run(
        self,
        dispatcher: "CollectingDispatcher",
        tracker: "Tracker",
        domain: "DomainDict",
    ) -> List[EventType]:
        """Execute the side effects of this form.

        Steps:
        - activate if needed
        - validate user input if needed
        - set validated slots
        - utter_ask_{slot} template with the next required slot
        - submit the form if all required slots are set
        - deactivate the form
        """

        # activate the form
        events = await self._activate_if_required(dispatcher, tracker, domain)
        # validate user input
        events.extend(await self._validate_if_required(dispatcher, tracker, domain))
        # check that the form wasn't deactivated in validation
        if ActiveLoop(None) not in events:

            # create temp tracker with populated slots from `validate` method
            temp_tracker = tracker.copy()
            for e in events:
                if e["event"] == "slot":
                    temp_tracker.slots[e["name"]] = e["value"]

            next_slot_events = self.request_next_slot(dispatcher, temp_tracker, domain)

            if next_slot_events is not None:
                # request next slot
                events.extend(next_slot_events)
            else:
                # there is nothing more to request, so we can submit
                self._log_form_slots(temp_tracker)
                logger.debug(f"Submitting the form '{self.name()}'")
                if utils.is_coroutine_action(self.submit):
                    events.extend(await self.submit(dispatcher, temp_tracker, domain))
                else:
                    # see https://github.com/python/mypy/issues/5206
                    events.extend(
                        cast(
                            List[EventType],
                            self.submit(dispatcher, temp_tracker, domain),
                        )
                    )
                # deactivate the form after submission
                if utils.is_coroutine_action(self.deactivate):
                    events.extend(await self.deactivate())  # type: ignore
                else:
                    events.extend(self.deactivate())

        return events
Beispiel #2
0
    async def get_object(self, object_type: Text,
                         object_identifier: Text) -> Optional[Dict[Text, Any]]:
        if object_type not in self.data:
            return None

        objects = self.data[object_type]

        if utils.is_coroutine_action(self.get_key_attribute_of_object):
            key_attribute = await self.get_key_attribute_of_object(object_type)
        else:
            # see https://github.com/python/mypy/issues/5206
            key_attribute = cast(Text,
                                 self.get_key_attribute_of_object(object_type))

        # filter the objects by its key attribute, for example, 'id'
        objects_of_interest = list(
            filter(
                lambda obj: str(obj[key_attribute]).lower() == str(
                    object_identifier).lower(),
                objects,
            ))

        # if the object was referred to directly, we need to compare the representation
        # of each object with the given object identifier
        if not objects_of_interest:
            if utils.is_coroutine_action(
                    self.get_representation_function_of_object):
                repr_function = await self.get_representation_function_of_object(
                    object_type)
            else:
                # see https://github.com/python/mypy/issues/5206
                repr_function = cast(
                    Callable,
                    self.get_representation_function_of_object(object_type))

            objects_of_interest = list(
                filter(
                    lambda obj: str(object_identifier).lower() in str(
                        repr_function(obj)).lower(),
                    objects,
                ))

        if not objects_of_interest or len(objects_of_interest) > 1:
            # TODO:
            #  if multiple objects are found, the objects could be shown
            #  to the user. the user then needs to clarify what object he meant.
            return None

        return objects_of_interest[0]
Beispiel #3
0
    async def run(self, action_call: Dict[Text,
                                          Any]) -> Optional[Dict[Text, Any]]:
        from rasa_sdk.interfaces import Tracker

        action_name = action_call.get("next_action")
        if action_name:
            logger.debug(f"Received request to run '{action_name}'")
            action = self.actions.get(action_name)
            if not action:
                raise ActionNotFoundException(action_name)

            tracker_json = action_call.get("tracker")
            domain = action_call.get("domain", {})
            tracker = Tracker.from_dict(tracker_json)
            dispatcher = CollectingDispatcher()

            if utils.is_coroutine_action(action):
                events = await action(dispatcher, tracker, domain)
            else:
                events = action(dispatcher, tracker, domain)

            if not events:
                # make sure the action did not just return `None`...
                events = []

            validated_events = self.validate_events(events, action_name)
            logger.debug(f"Finished running '{action_name}'")
            return self._create_api_response(validated_events,
                                             dispatcher.messages)
        else:
            logger.warning("Received an action call without an action.")
Beispiel #4
0
    async def utter_objects(
        self,
        dispatcher: CollectingDispatcher,
        object_type: Text,
        objects: List[Dict[Text, Any]],
    ) -> None:
        """
        Utters a response to the user that lists all found objects.

        Args:
            dispatcher: the dispatcher
            object_type: the object type
            objects: the list of objects
        """
        if objects:
            dispatcher.utter_message(
                text=f"Found the following objects of type '{object_type}':")

            if utils.is_coroutine_action(
                    self.knowledge_base.get_representation_function_of_object):
                repr_function = (
                    await
                    self.knowledge_base.get_representation_function_of_object(
                        object_type))
            else:
                repr_function = (
                    self.knowledge_base.get_representation_function_of_object(
                        object_type))

            for i, obj in enumerate(objects, 1):
                dispatcher.utter_message(text=f"{i}: {repr_function(obj)}")
        else:
            dispatcher.utter_message(
                text=f"I could not find any objects of type '{object_type}'.")
Beispiel #5
0
 async def _validate_if_required(
     self,
     dispatcher: "CollectingDispatcher",
     tracker: "Tracker",
     domain: "DomainDict",
 ) -> List[EventType]:
     """Return a list of events from `self.validate(...)`
     if validation is required:
     - the form is active
     - the form is called after `action_listen`
     - form validation was not cancelled
     """
     # no active_loop means that it is called during activation
     need_validation = not tracker.active_loop or (
         tracker.latest_action_name == "action_listen"
         and not tracker.active_loop.get(LOOP_INTERRUPTED_KEY, False)
     )
     if need_validation:
         logger.debug(f"Validating user input '{tracker.latest_message}'")
         if utils.is_coroutine_action(self.validate):
             return await self.validate(dispatcher, tracker, domain)
         else:
             # see https://github.com/python/mypy/issues/5206
             return cast(
                 List[Dict[Text, Any]], self.validate(dispatcher, tracker, domain)
             )
     else:
         logger.debug("Skipping validation")
         return []
Beispiel #6
0
    async def validate_slots(
        self,
        slot_dict: Dict[Text, Any],
        dispatcher: "CollectingDispatcher",
        tracker: "Tracker",
        domain: Dict[Text, Any],
    ) -> List[EventType]:
        """Validate slots using helper validation functions.

        Call validate_{slot} function for each slot, value pair to be validated.
        If this function is not implemented, set the slot to the value.
        """

        for slot, value in list(slot_dict.items()):
            validate_func = getattr(self, f"validate_{slot}", lambda *x: {slot: value})
            if utils.is_coroutine_action(validate_func):
                validation_output = await validate_func(
                    value, dispatcher, tracker, domain
                )
            else:
                validation_output = validate_func(value, dispatcher, tracker, domain)
            if not isinstance(validation_output, dict):
                logger.warning(
                    "Returning values in helper validation methods is deprecated. "
                    + f"Your `validate_{slot}()` method should return "
                    + "a dict of {'slot_name': value} instead."
                )
                validation_output = {slot: validation_output}
            slot_dict.update(validation_output)

        # validation succeed, set slots to extracted values
        return [SlotSet(slot, value) for slot, value in slot_dict.items()]
Beispiel #7
0
 async def _validate_if_required(
     self,
     dispatcher: "CollectingDispatcher",
     tracker: "Tracker",
     domain: Dict[Text, Any],
 ) -> List[EventType]:
     """Return a list of events from `self.validate(...)`
     if validation is required:
     - the form is active
     - the form is called after `action_listen`
     - form validation was not cancelled
     """
     # no active_loop means that it is called during activation
     need_validation = not tracker.active_loop or (
         tracker.latest_action_name == "action_listen"
         and tracker.active_loop.get("validate", True))
     if need_validation:
         logger.debug(f"Validating user input '{tracker.latest_message}'")
         if utils.is_coroutine_action(self.validate):
             return await self.validate(dispatcher, tracker, domain)
         else:
             return self.validate(dispatcher, tracker, domain)
     else:
         logger.debug("Skipping validation")
         return []
    async def utter_objects(
        self,
        dispatcher: CollectingDispatcher,
        object_type: Text,
        objects: List[Dict[Text, Any]],
    ) -> None:
        """
        Utters a response to the user that lists all found objects.
        Args:
            dispatcher: the dispatcher
            object_type: the object type
            objects: the list of objects
        """
        if objects:
            dispatcher.utter_message(
                text="找到下列{}:".format(self.en_to_zh(object_type)))

            if utils.is_coroutine_action(
                    self.knowledge_base.get_representation_function_of_object):
                repr_function = await self.knowledge_base.get_representation_function_of_object(
                    object_type)
            else:
                repr_function = self.knowledge_base.get_representation_function_of_object(
                    object_type)

            for i, obj in enumerate(objects, 1):
                dispatcher.utter_message(text=f"{i}: {repr_function(obj)}")
        else:
            dispatcher.utter_message(
                text="我没找到任何{}.".format(self.en_to_zh(object_type)))
Beispiel #9
0
    async def extract_other_slots(
        self,
        dispatcher: "CollectingDispatcher",
        tracker: "Tracker",
        domain: Dict[Text, Any],
    ) -> Dict[Text, Any]:
        """Extract the values of the other slots
            if they are set by corresponding entities from the user input
            else return None
        """
        slot_to_fill = tracker.get_slot(REQUESTED_SLOT)

        slot_values = {}
        if utils.is_coroutine_action(self.required_slots):
            required_slots = await self.required_slots(tracker)
        else:
            required_slots = self.required_slots(tracker)
        for slot in required_slots:
            # look for other slots
            if slot != slot_to_fill:
                # list is used to cover the case of list slot type
                other_slot_mappings = self.get_mappings_for_slot(slot)

                for other_slot_mapping in other_slot_mappings:
                    # check whether the slot should be filled
                    # by entity with the same name
                    should_fill_entity_slot = (
                        other_slot_mapping["type"] == "from_entity"
                        and other_slot_mapping.get("entity") == slot
                        and self.intent_is_desired(other_slot_mapping, tracker)
                    )
                    # check whether the slot should be
                    # filled from trigger intent mapping
                    should_fill_trigger_slot = (
                        tracker.active_form.get("name") != self.name()
                        and other_slot_mapping["type"] == "from_trigger_intent"
                        and self.intent_is_desired(other_slot_mapping, tracker)
                    )
                    if should_fill_entity_slot:
                        value = self.get_entity_value(slot, tracker)
                    elif should_fill_trigger_slot:
                        value = other_slot_mapping.get("value")
                    else:
                        value = None

                    if value is not None:
                        logger.debug(f"Extracted '{value}' for extra slot '{slot}'.")
                        slot_values[slot] = value
                        # this slot is done, check  next
                        break

        return slot_values
Beispiel #10
0
    async def _log_form_slots(self, tracker: "Tracker") -> None:
        """Logs the values of all required slots before submitting the form."""

        if utils.is_coroutine_action(self.required_slots):
            required_slots = await self.required_slots(tracker)
        else:
            required_slots = self.required_slots(tracker)
        slot_values = "\n".join(
            [f"\t{slot}: {tracker.get_slot(slot)}" for slot in required_slots]
        )
        logger.debug(
            f"No slots left to request, all required slots are filled:\n{slot_values}"
        )
Beispiel #11
0
    async def validate(
        self,
        dispatcher: "CollectingDispatcher",
        tracker: "Tracker",
        domain: "DomainDict",
    ) -> List[EventType]:
        """Validate slots by calling a validation function for each slot.

        Args:
            dispatcher: the dispatcher which is used to
                send messages back to the user.
            tracker: the conversation tracker for the current user.
            domain: the bot's domain.
        Returns:
            `SlotSet` events for every validated slot.
        """
        slots: Dict[Text, Any] = tracker.slots_to_validate()

        for slot_name, slot_value in list(slots.items()):
            function_name = f"validate_{slot_name.replace('-','_')}"
            validate_func = getattr(self, function_name, None)

            if not validate_func:
                logger.debug(
                    f"Skipping validation for `{slot_name}`: there is no validation function specified."
                )
                continue

            if utils.is_coroutine_action(validate_func):
                validation_output = await validate_func(
                    slot_value, dispatcher, tracker, domain
                )
            else:
                validation_output = validate_func(
                    slot_value, dispatcher, tracker, domain
                )

            if validation_output:
                slots.update(validation_output)
            else:
                warnings.warn(
                    f"Cannot validate `{slot_name}`: make sure the validation function returns the correct output."
                )

        return [SlotSet(slot, value) for slot, value in slots.items()]
Beispiel #12
0
    async def _activate_if_required(
        self,
        dispatcher: "CollectingDispatcher",
        tracker: "Tracker",
        domain: Dict[Text, Any],
    ) -> List[EventType]:
        """Activate form if the form is called for the first time.

        If activating, validate any required slots that were filled before
        form activation and return `Form` event with the name of the form, as well
        as any `SlotSet` events from validation of pre-filled slots.
        """

        if tracker.active_form.get("name") is not None:
            logger.debug(f"The form '{tracker.active_form}' is active")
        else:
            logger.debug("There is no active form")

        if tracker.active_form.get("name") == self.name():
            return []
        else:
            logger.debug(f"Activated the form '{self.name()}'")
            events = [Form(self.name())]

            # collect values of required slots filled before activation
            prefilled_slots = {}
            if utils.is_coroutine_action(self.required_slots):
                required_slots = await self.required_slots(tracker)
            else:
                required_slots = self.required_slots(tracker)
            for slot_name in required_slots:
                if not self._should_request_slot(tracker, slot_name):
                    prefilled_slots[slot_name] = tracker.get_slot(slot_name)

            if prefilled_slots:
                logger.debug(f"Validating pre-filled required slots: {prefilled_slots}")
                events.extend(
                    await self.validate_slots(
                        prefilled_slots, dispatcher, tracker, domain
                    )
                )
            else:
                logger.debug("No pre-filled required slots to validate.")

            return events
Beispiel #13
0
    async def request_next_slot(
        self,
        dispatcher: "CollectingDispatcher",
        tracker: "Tracker",
        domain: Dict[Text, Any],
    ) -> Optional[List[EventType]]:
        """Request the next slot and utter template if needed,
            else return None"""

        if utils.is_coroutine_action(self.required_slots):
            required_slots = await self.required_slots(tracker)
        else:
            required_slots = self.required_slots(tracker)
        for slot in required_slots:
            if self._should_request_slot(tracker, slot):
                logger.debug(f"Request next slot '{slot}'")
                dispatcher.utter_message(template=f"utter_ask_{slot}", **tracker.slots)
                return [SlotSet(REQUESTED_SLOT, slot)]

        # no more required slots to fill
        return None
Beispiel #14
0
    async def utter_objects(
        self,
        dispatcher: CollectingDispatcher,
        object_type: Text,
        objects: List[Dict[Text, Any]],
    ) -> None:
        """
        Utters a response to the user that lists all found objects.

        Args:
            dispatcher: the dispatcher
            object_type: the object type
            objects: the list of objects
        """
        if objects:
            dispatcher.utter_message(
                text=f"Folgendes ist in der '{object_type}' versichert:"
            )

            if utils.is_coroutine_action(
                self.knowledge_base.get_representation_function_of_object
            ):
                repr_function = await self.knowledge_base.get_representation_function_of_object(
                    object_type
                )
            else:
                repr_function = self.knowledge_base.get_representation_function_of_object(
                    object_type
                )

            for i, obj in enumerate(objects, 1):
                dispatcher.utter_message(text=f"{i}: {repr_function(obj)}")
        else:
            dispatcher.utter_message(
                text=f"Ich konnte leider nichts zu  '{object_type}' finden."
            )
Beispiel #15
0
def test_action_async_check():
    assert not is_coroutine_action(CustomAction.run)
    assert is_coroutine_action(CustomAsyncAction.run)
Beispiel #16
0
    async def _query_attribute(
        self, dispatcher: CollectingDispatcher, tracker: Tracker
    ) -> List[Dict]:
        """
        Queries the knowledge base for the value of the requested attribute of the
        mentioned object and outputs it to the user.

        Args:
            dispatcher: the dispatcher
            tracker: the tracker

        Returns: list of slots
        """
        object_type = tracker.get_slot(SLOT_OBJECT_TYPE)
        attribute = tracker.get_slot(SLOT_ATTRIBUTE)

        object_name = get_object_name(
            tracker,
            self.knowledge_base.ordinal_mention_mapping,
            self.use_last_object_mention,
        )

        if not object_name or not attribute:
            dispatcher.utter_message(template="utter_ask_rephrase")
            return [SlotSet(SLOT_MENTION, None)]

        if utils.is_coroutine_action(self.knowledge_base.get_object):
            object_of_interest = await self.knowledge_base.get_object(
                object_type, object_name  # type: ignore
            )
        else:
            object_of_interest = self.knowledge_base.get_object(
                object_type, object_name
            )

        if not object_of_interest or attribute not in object_of_interest:
            dispatcher.utter_message(template="utter_ask_rephrase")
            return [SlotSet(SLOT_MENTION, None)]

        value = object_of_interest[attribute]
        if utils.is_coroutine_action(
            self.knowledge_base.get_representation_function_of_object
        ):
            repr_function = await self.knowledge_base.get_representation_function_of_object(
                object_type  # type: ignore
            )
        else:
            repr_function = self.knowledge_base.get_representation_function_of_object(
                object_type
            )
        object_representation = repr_function(object_of_interest)
        if utils.is_coroutine_action(self.knowledge_base.get_key_attribute_of_object):
            key_attribute = await self.knowledge_base.get_key_attribute_of_object(
                object_type
            )
        else:
            key_attribute = self.knowledge_base.get_key_attribute_of_object(object_type)
        object_identifier = object_of_interest[key_attribute]

        if utils.is_coroutine_action(self.utter_attribute_value):
            await self.utter_attribute_value(
                dispatcher, object_representation, attribute, value  # type: ignore
            )
        else:
            self.utter_attribute_value(
                dispatcher, object_representation, attribute, value
            )

        slots = [
            SlotSet(SLOT_OBJECT_TYPE, object_type),
            SlotSet(SLOT_ATTRIBUTE, None),
            SlotSet(SLOT_MENTION, None),
            SlotSet(SLOT_LAST_OBJECT, object_identifier),
            SlotSet(SLOT_LAST_OBJECT_TYPE, object_type),
        ]

        return slots
Beispiel #17
0
    async def _query_objects(self, dispatcher: CollectingDispatcher,
                             object_type: Text,
                             tracker: Tracker) -> List[Dict]:
        """
        Queries the knowledge base for objects of the requested object type and
        outputs those to the user. The objects are filtered by any attribute the
        user mentioned in the request.

        Args:
            dispatcher: the dispatcher
            tracker: the tracker

        Returns: list of slots
        """
        if utils.is_coroutine_action(
                self.knowledge_base.get_attributes_of_object):
            object_attributes = await self.knowledge_base.get_attributes_of_object(
                object_type)
        else:
            # see https://github.com/python/mypy/issues/5206
            object_attributes = cast(
                List[Text],
                self.knowledge_base.get_attributes_of_object(object_type))

        # get all set attribute slots of the object type to be able to filter the
        # list of objects
        attributes = get_attribute_slots(tracker, object_attributes)
        # query the knowledge base
        if utils.is_coroutine_action(self.knowledge_base.get_objects):
            objects = await self.knowledge_base.get_objects(
                object_type, attributes)
        else:
            # see https://github.com/python/mypy/issues/5206
            objects = cast(
                List[Dict[Text, Any]],
                self.knowledge_base.get_objects(object_type, attributes),
            )

        if utils.is_coroutine_action(self.utter_objects):
            await self.utter_objects(dispatcher, object_type, objects)
        else:
            self.utter_objects(dispatcher, object_type, objects)

        if not objects:
            return reset_attribute_slots(tracker, object_attributes)

        if utils.is_coroutine_action(
                self.knowledge_base.get_key_attribute_of_object):
            key_attribute = await self.knowledge_base.get_key_attribute_of_object(
                object_type)
        else:
            # see https://github.com/python/mypy/issues/5206
            key_attribute = cast(
                Text,
                self.knowledge_base.get_key_attribute_of_object(object_type))

        last_object = None if len(objects) > 1 else objects[0][key_attribute]

        slots = [
            SlotSet(SLOT_OBJECT_TYPE, object_type),
            SlotSet(SLOT_MENTION, None),
            SlotSet(SLOT_ATTRIBUTE, None),
            SlotSet(SLOT_LAST_OBJECT, last_object),
            SlotSet(SLOT_LAST_OBJECT_TYPE, object_type),
            SlotSet(SLOT_LISTED_OBJECTS,
                    list(map(lambda e: e[key_attribute], objects))),
        ]

        return slots + reset_attribute_slots(tracker, object_attributes)
Beispiel #18
0
    async def _query_attribute(
        self,
        dispatcher: CollectingDispatcher,
        object_type: Text,
        attribute: Text,
        tracker: Tracker,
    ) -> List[Dict]:
        """
        Queries the knowledge base for the value of the requested attribute of the
        mentioned object and outputs it to the user.

        Args:
            dispatcher: the dispatcher
            tracker: the tracker

        Returns: list of slots
        """

        object_name = get_object_name(
            tracker,
            self.knowledge_base.ordinal_mention_mapping,
            self.use_last_object_mention,
        )

        if not object_name or not attribute:
            dispatcher.utter_message(template="utter_ask_rephrase")
            return [SlotSet(SLOT_MENTION, None)]

        if utils.is_coroutine_action(self.knowledge_base.get_object):
            object_of_interest = await self.knowledge_base.get_object(
                object_type, object_name)
        else:
            # see https://github.com/python/mypy/issues/5206
            object_of_interest = cast(
                Optional[Dict[Text, Any]],
                self.knowledge_base.get_object(object_type, object_name),
            )

        if not object_of_interest or attribute not in object_of_interest:
            dispatcher.utter_message(template="utter_ask_rephrase")
            return [SlotSet(SLOT_MENTION, None)]

        value = object_of_interest[attribute]
        if utils.is_coroutine_action(
                self.knowledge_base.get_representation_function_of_object):
            repr_function = (
                await
                self.knowledge_base.get_representation_function_of_object(
                    object_type))
        else:
            # see https://github.com/python/mypy/issues/5206
            repr_function = cast(
                Callable,
                self.knowledge_base.get_representation_function_of_object(
                    object_type),
            )
        object_representation = repr_function(object_of_interest)
        if utils.is_coroutine_action(
                self.knowledge_base.get_key_attribute_of_object):
            key_attribute = await self.knowledge_base.get_key_attribute_of_object(
                object_type)
        else:
            # see https://github.com/python/mypy/issues/5206
            key_attribute = cast(
                Text,
                self.knowledge_base.get_key_attribute_of_object(object_type))
        object_identifier = object_of_interest[key_attribute]

        if utils.is_coroutine_action(self.utter_attribute_value):
            await self.utter_attribute_value(dispatcher, object_representation,
                                             attribute, value)
        else:
            self.utter_attribute_value(dispatcher, object_representation,
                                       attribute, value)

        slots = [
            SlotSet(SLOT_OBJECT_TYPE, object_type),
            SlotSet(SLOT_ATTRIBUTE, None),
            SlotSet(SLOT_MENTION, None),
            SlotSet(SLOT_LAST_OBJECT, object_identifier),
            SlotSet(SLOT_LAST_OBJECT_TYPE, object_type),
        ]

        return slots