コード例 #1
0
ファイル: actions.py プロジェクト: KaczmarekWill/airo_chatbot
	def _query_objects(self, dispatcher: CollectingDispatcher, tracker: Tracker) -> List[Dict]:
		object_type = tracker.get_slot(SLOT_OBJECT_TYPE)
		object_attributes = self.knowledge_base.get_attributes_of_object(object_type)

		attributes = get_attribute_slots(tracker, object_attributes)
		
		objects = self.knowledge_base.get_objects(object_type, attributes)

		self.utter_objects(dispatcher, object_type, objects, attributes)

		if not objects:
			return reset_attribute_slots(tracker, object_attributes)

		key_attribute = self.knowledge_base.get_key_attribute_of_object(object_type)

		last_object = None if len(objects) > 1 else objects[0][key_attribute]

		slots = [
			SlotSet(SLOT_OBJECT_TYPE, object_type),
			SlotSet(SLOT_MENTION, None),
			SlotSet(SLOT_LAST_OBJECT, last_object),
			SlotSet(SLOT_LAST_OBJECT_TYPE, object_type),
			SlotSet(SLOT_LISTED_OBJECTS, list(map(lambda e: e[key_attribute], objects)))
		]

		return slots + reset_attribute_slots(tracker, object_attributes)
コード例 #2
0
ファイル: actions.py プロジェクト: xinru1414/rasa-sdk
    async def _query_objects(self, dispatcher: CollectingDispatcher,
                             object_type: Text,
                             tracker: Tracker) -> List[Dict]:
        """
        Queries the knowledge base for objects of the requested object type and
        outputs those to the user. The objects are filtered by any attribute the
        user mentioned in the request.

        Args:
            dispatcher: the dispatcher
            tracker: the tracker

        Returns: list of slots
        """
        object_attributes = await utils.call_potential_coroutine(
            self.knowledge_base.get_attributes_of_object(object_type))

        # get all set attribute slots of the object type to be able to filter the
        # list of objects
        attributes = get_attribute_slots(tracker, object_attributes)
        # query the knowledge base
        objects = await utils.call_potential_coroutine(
            self.knowledge_base.get_objects(object_type, attributes))

        await utils.call_potential_coroutine(
            self.utter_objects(dispatcher, object_type, objects))

        if not objects:
            return reset_attribute_slots(tracker, object_attributes)

        key_attribute = await utils.call_potential_coroutine(
            self.knowledge_base.get_key_attribute_of_object(object_type))

        last_object = None if len(objects) > 1 else objects[0][key_attribute]

        slots = [
            SlotSet(SLOT_OBJECT_TYPE, object_type),
            SlotSet(SLOT_MENTION, None),
            SlotSet(SLOT_ATTRIBUTE, None),
            SlotSet(SLOT_LAST_OBJECT, last_object),
            SlotSet(SLOT_LAST_OBJECT_TYPE, object_type),
            SlotSet(SLOT_LISTED_OBJECTS,
                    list(map(lambda e: e[key_attribute], objects))),
        ]

        return slots + reset_attribute_slots(tracker, object_attributes)
コード例 #3
0
ファイル: test_utils.py プロジェクト: yuri789/rasa-sdk
def test_get_attribute_slots():
    object_attributes = ["name", "cuisine", "price-range"]

    expected_attribute_slots = [
        {"name": "name", "value": "PastaBar"},
        {"name": "cuisine", "value": "Italian"},
    ]

    tracker = Tracker(
        "default",
        {"name": "PastaBar", "cuisine": "Italian"},
        {},
        [],
        False,
        None,
        {},
        "action_listen",
    )

    attribute_slots = get_attribute_slots(tracker, object_attributes)

    for a in attribute_slots:
        assert a in expected_attribute_slots
コード例 #4
0
ファイル: actions.py プロジェクト: nbeuchat/rasa-sdk
    async def _query_objects(self, dispatcher: CollectingDispatcher,
                             object_type: Text,
                             tracker: Tracker) -> List[Dict]:
        """
        Queries the knowledge base for objects of the requested object type and
        outputs those to the user. The objects are filtered by any attribute the
        user mentioned in the request.

        Args:
            dispatcher: the dispatcher
            tracker: the tracker

        Returns: list of slots
        """
        if utils.is_coroutine_action(
                self.knowledge_base.get_attributes_of_object):
            object_attributes = await self.knowledge_base.get_attributes_of_object(
                object_type)
        else:
            # see https://github.com/python/mypy/issues/5206
            object_attributes = cast(
                List[Text],
                self.knowledge_base.get_attributes_of_object(object_type))

        # get all set attribute slots of the object type to be able to filter the
        # list of objects
        attributes = get_attribute_slots(tracker, object_attributes)
        # query the knowledge base
        if utils.is_coroutine_action(self.knowledge_base.get_objects):
            objects = await self.knowledge_base.get_objects(
                object_type, attributes)
        else:
            # see https://github.com/python/mypy/issues/5206
            objects = cast(
                List[Dict[Text, Any]],
                self.knowledge_base.get_objects(object_type, attributes),
            )

        if utils.is_coroutine_action(self.utter_objects):
            await self.utter_objects(dispatcher, object_type, objects)
        else:
            self.utter_objects(dispatcher, object_type, objects)

        if not objects:
            return reset_attribute_slots(tracker, object_attributes)

        if utils.is_coroutine_action(
                self.knowledge_base.get_key_attribute_of_object):
            key_attribute = await self.knowledge_base.get_key_attribute_of_object(
                object_type)
        else:
            # see https://github.com/python/mypy/issues/5206
            key_attribute = cast(
                Text,
                self.knowledge_base.get_key_attribute_of_object(object_type))

        last_object = None if len(objects) > 1 else objects[0][key_attribute]

        slots = [
            SlotSet(SLOT_OBJECT_TYPE, object_type),
            SlotSet(SLOT_MENTION, None),
            SlotSet(SLOT_ATTRIBUTE, None),
            SlotSet(SLOT_LAST_OBJECT, last_object),
            SlotSet(SLOT_LAST_OBJECT_TYPE, object_type),
            SlotSet(SLOT_LISTED_OBJECTS,
                    list(map(lambda e: e[key_attribute], objects))),
        ]

        return slots + reset_attribute_slots(tracker, object_attributes)
コード例 #5
0
    def _query_attribute(self, dispatcher: CollectingDispatcher,
                         tracker: Tracker) -> List[Dict]:
        """
        Queries the knowledge base for the value of the requested attribute of the
        mentioned object and outputs it to the user.

        Args:
            dispatcher: the dispatcher
            tracker: the tracker

        Returns: list of slots
        """
        object_type = tracker.get_slot(SLOT_OBJECT_TYPE)
        attribute = tracker.get_slot(SLOT_ATTRIBUTE)

        object_name = self._get_object_name(
            tracker,
            self.knowledge_base.ordinal_mention_mapping,
            self.use_last_object_mention,
        )

        logger.info("_query_attribute [object_type]:" + str(object_type) +
                    " [attribute]:" + str(attribute) + " [object_name]:" +
                    str(object_name))

        if not object_name or not attribute:
            logger.info("object_name or attribute not available")
            dispatcher.utter_template("utter_ask_rephrase", tracker)
            return [SlotSet(SLOT_MENTION, None)]

        object_attributes = self.knowledge_base.get_attributes_of_object(
            object_type)

        # get all set attribute slots of the object type to be able to filter the
        # list of objects
        attributes = get_attribute_slots(tracker, object_attributes)
        # query the knowledge base

        objects = self.knowledge_base.get_objects(object_type, attributes,
                                                  object_name)
        key_attribute = self.knowledge_base.get_key_attribute_of_object(
            object_type)

        if not objects or attribute not in objects[0]:
            logger.info("object not found or attribute not in objects[0]")
            dispatcher.utter_template("utter_ask_rephrase", tracker)
            return [SlotSet(SLOT_MENTION, None)] + reset_attribute_slots(
                tracker, object_attributes)

        if len(objects) > 1:
            dispatcher.utter_message("Ho trovato più di un risultato:")

        for object_of_interest in objects:
            value = object_of_interest[attribute]
            repr_function = self.knowledge_base.get_representation_function_of_object(
                object_type)
            object_representation = repr_function(object_of_interest)
            key_attribute = self.knowledge_base.get_key_attribute_of_object(
                object_type)
            object_identifier = object_of_interest[key_attribute]

            self.utter_attribute_value(dispatcher, object_representation,
                                       attribute, value)

        slots = [
            SlotSet(SLOT_OBJECT_TYPE, object_type),
            SlotSet(SLOT_ATTRIBUTE, None),
            SlotSet(SLOT_MENTION, None),
            SlotSet(SLOT_LAST_OBJECT, object_identifier),
            SlotSet(SLOT_LAST_OBJECT_TYPE, object_type),
        ]

        return slots + reset_attribute_slots(tracker, object_attributes)