예제 #1
0
	def _query_objects(self, dispatcher: CollectingDispatcher, tracker: Tracker) -> List[Dict]:
		object_type = tracker.get_slot(SLOT_OBJECT_TYPE)
		object_attributes = self.knowledge_base.get_attributes_of_object(object_type)

		attributes = get_attribute_slots(tracker, object_attributes)
		
		objects = self.knowledge_base.get_objects(object_type, attributes)

		self.utter_objects(dispatcher, object_type, objects, attributes)

		if not objects:
			return reset_attribute_slots(tracker, object_attributes)

		key_attribute = self.knowledge_base.get_key_attribute_of_object(object_type)

		last_object = None if len(objects) > 1 else objects[0][key_attribute]

		slots = [
			SlotSet(SLOT_OBJECT_TYPE, object_type),
			SlotSet(SLOT_MENTION, None),
			SlotSet(SLOT_LAST_OBJECT, last_object),
			SlotSet(SLOT_LAST_OBJECT_TYPE, object_type),
			SlotSet(SLOT_LISTED_OBJECTS, list(map(lambda e: e[key_attribute], objects)))
		]

		return slots + reset_attribute_slots(tracker, object_attributes)
예제 #2
0
    async def _query_objects(self, dispatcher: CollectingDispatcher,
                             object_type: Text,
                             tracker: Tracker) -> List[Dict]:
        """
        Copied from ActionQueryKnowledgeBase and overridden
        in order to introduce LIMIT parameter

        Queries the knowledge base for objects of the requested object type and
        outputs those to the user. The objects are filtered by any attribute the
        user mentioned in the request.

        Args:
            dispatcher: the dispatcher
            tracker: the tracker

        Returns: list of slots
        """
        object_attributes = await utils.call_potential_coroutine(
            self.knowledge_base.get_attributes_of_object(object_type))

        # get all set attribute slots of the object type to be able to filter the
        # list of objects
        limit = tracker.get_slot(SLOT_LIMIT)
        logger.info(f"Limit is {limit}")
        attributes = get_attribute_slots(tracker, object_attributes)
        var_args = {}
        if limit:
            var_args['limit'] = int(limit)

        # query the knowledge base
        objects = await utils.call_potential_coroutine(
            self.knowledge_base.get_objects(object_type, attributes,
                                            **var_args))

        await utils.call_potential_coroutine(
            self.utter_objects(dispatcher, object_type, objects, attributes))

        if not objects:
            return reset_attribute_slots(tracker, object_attributes)

        key_attribute = await utils.call_potential_coroutine(
            self.knowledge_base.get_key_attribute_of_object(object_type))

        last_object = None if len(objects) > 1 else objects[0][key_attribute]

        slots = [
            SlotSet(SLOT_OBJECT_TYPE, object_type),
            SlotSet(SLOT_MENTION, None),
            SlotSet(SLOT_ATTRIBUTE, None),
            SlotSet(SLOT_LAST_OBJECT, last_object),
            SlotSet(SLOT_LAST_OBJECT_TYPE, object_type),
            SlotSet(SLOT_LISTED_OBJECTS,
                    list(map(lambda e: e[key_attribute], objects))),
            SlotSet(SLOT_LIMIT, None),
        ]

        return slots + reset_attribute_slots(tracker, object_attributes)
예제 #3
0
    async def _query_objects(self, dispatcher: CollectingDispatcher,
                             object_type: Text,
                             tracker: Tracker) -> List[Dict]:
        """
        Queries the knowledge base for objects of the requested object type and
        outputs those to the user. The objects are filtered by any attribute the
        user mentioned in the request.

        Args:
            dispatcher: the dispatcher
            tracker: the tracker

        Returns: list of slots
        """
        object_attributes = await utils.call_potential_coroutine(
            self.knowledge_base.get_attributes_of_object(object_type))

        # get all set attribute slots of the object type to be able to filter the
        # list of objects
        attributes = get_attribute_slots(tracker, object_attributes)
        # query the knowledge base
        objects = await utils.call_potential_coroutine(
            self.knowledge_base.get_objects(object_type, attributes))

        await utils.call_potential_coroutine(
            self.utter_objects(dispatcher, object_type, objects))

        if not objects:
            return reset_attribute_slots(tracker, object_attributes)

        key_attribute = await utils.call_potential_coroutine(
            self.knowledge_base.get_key_attribute_of_object(object_type))

        last_object = None if len(objects) > 1 else objects[0][key_attribute]

        slots = [
            SlotSet(SLOT_OBJECT_TYPE, object_type),
            SlotSet(SLOT_MENTION, None),
            SlotSet(SLOT_ATTRIBUTE, None),
            SlotSet(SLOT_LAST_OBJECT, last_object),
            SlotSet(SLOT_LAST_OBJECT_TYPE, object_type),
            SlotSet(SLOT_LISTED_OBJECTS,
                    list(map(lambda e: e[key_attribute], objects))),
        ]

        return slots + reset_attribute_slots(tracker, object_attributes)
예제 #4
0
def test_reset_attribute_slots():
    object_attributes = ["name", "cuisine", "price-range"]

    expected_reset_slots = [SlotSet("name", None), SlotSet("cuisine", None)]

    tracker = Tracker(
        "default",
        {"name": "PastaBar", "cuisine": "Italian"},
        {},
        [],
        False,
        None,
        {},
        "action_listen",
    )

    reset_slots = reset_attribute_slots(tracker, object_attributes)

    for s in reset_slots:
        assert s in expected_reset_slots
예제 #5
0
    async def _query_objects(self, dispatcher: CollectingDispatcher,
                             object_type: Text,
                             tracker: Tracker) -> List[Dict]:
        """
        Queries the knowledge base for objects of the requested object type and
        outputs those to the user. The objects are filtered by any attribute the
        user mentioned in the request.

        Args:
            dispatcher: the dispatcher
            tracker: the tracker

        Returns: list of slots
        """
        if utils.is_coroutine_action(
                self.knowledge_base.get_attributes_of_object):
            object_attributes = await self.knowledge_base.get_attributes_of_object(
                object_type)
        else:
            # see https://github.com/python/mypy/issues/5206
            object_attributes = cast(
                List[Text],
                self.knowledge_base.get_attributes_of_object(object_type))

        # get all set attribute slots of the object type to be able to filter the
        # list of objects
        attributes = get_attribute_slots(tracker, object_attributes)
        # query the knowledge base
        if utils.is_coroutine_action(self.knowledge_base.get_objects):
            objects = await self.knowledge_base.get_objects(
                object_type, attributes)
        else:
            # see https://github.com/python/mypy/issues/5206
            objects = cast(
                List[Dict[Text, Any]],
                self.knowledge_base.get_objects(object_type, attributes),
            )

        if utils.is_coroutine_action(self.utter_objects):
            await self.utter_objects(dispatcher, object_type, objects)
        else:
            self.utter_objects(dispatcher, object_type, objects)

        if not objects:
            return reset_attribute_slots(tracker, object_attributes)

        if utils.is_coroutine_action(
                self.knowledge_base.get_key_attribute_of_object):
            key_attribute = await self.knowledge_base.get_key_attribute_of_object(
                object_type)
        else:
            # see https://github.com/python/mypy/issues/5206
            key_attribute = cast(
                Text,
                self.knowledge_base.get_key_attribute_of_object(object_type))

        last_object = None if len(objects) > 1 else objects[0][key_attribute]

        slots = [
            SlotSet(SLOT_OBJECT_TYPE, object_type),
            SlotSet(SLOT_MENTION, None),
            SlotSet(SLOT_ATTRIBUTE, None),
            SlotSet(SLOT_LAST_OBJECT, last_object),
            SlotSet(SLOT_LAST_OBJECT_TYPE, object_type),
            SlotSet(SLOT_LISTED_OBJECTS,
                    list(map(lambda e: e[key_attribute], objects))),
        ]

        return slots + reset_attribute_slots(tracker, object_attributes)
예제 #6
0
    def _query_attribute(self, dispatcher: CollectingDispatcher,
                         tracker: Tracker) -> List[Dict]:
        """
        Queries the knowledge base for the value of the requested attribute of the
        mentioned object and outputs it to the user.

        Args:
            dispatcher: the dispatcher
            tracker: the tracker

        Returns: list of slots
        """
        object_type = tracker.get_slot(SLOT_OBJECT_TYPE)
        attribute = tracker.get_slot(SLOT_ATTRIBUTE)

        object_name = self._get_object_name(
            tracker,
            self.knowledge_base.ordinal_mention_mapping,
            self.use_last_object_mention,
        )

        logger.info("_query_attribute [object_type]:" + str(object_type) +
                    " [attribute]:" + str(attribute) + " [object_name]:" +
                    str(object_name))

        if not object_name or not attribute:
            logger.info("object_name or attribute not available")
            dispatcher.utter_template("utter_ask_rephrase", tracker)
            return [SlotSet(SLOT_MENTION, None)]

        object_attributes = self.knowledge_base.get_attributes_of_object(
            object_type)

        # get all set attribute slots of the object type to be able to filter the
        # list of objects
        attributes = get_attribute_slots(tracker, object_attributes)
        # query the knowledge base

        objects = self.knowledge_base.get_objects(object_type, attributes,
                                                  object_name)
        key_attribute = self.knowledge_base.get_key_attribute_of_object(
            object_type)

        if not objects or attribute not in objects[0]:
            logger.info("object not found or attribute not in objects[0]")
            dispatcher.utter_template("utter_ask_rephrase", tracker)
            return [SlotSet(SLOT_MENTION, None)] + reset_attribute_slots(
                tracker, object_attributes)

        if len(objects) > 1:
            dispatcher.utter_message("Ho trovato più di un risultato:")

        for object_of_interest in objects:
            value = object_of_interest[attribute]
            repr_function = self.knowledge_base.get_representation_function_of_object(
                object_type)
            object_representation = repr_function(object_of_interest)
            key_attribute = self.knowledge_base.get_key_attribute_of_object(
                object_type)
            object_identifier = object_of_interest[key_attribute]

            self.utter_attribute_value(dispatcher, object_representation,
                                       attribute, value)

        slots = [
            SlotSet(SLOT_OBJECT_TYPE, object_type),
            SlotSet(SLOT_ATTRIBUTE, None),
            SlotSet(SLOT_MENTION, None),
            SlotSet(SLOT_LAST_OBJECT, object_identifier),
            SlotSet(SLOT_LAST_OBJECT_TYPE, object_type),
        ]

        return slots + reset_attribute_slots(tracker, object_attributes)