コード例 #1
0
def test_get_or_create():
    slot_key = 'location'
    slot_val = 'Easter Island'
    store = InMemoryTrackerStore(domain)

    tracker = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID)
    ev = SlotSet(slot_key, slot_val)
    tracker.update(ev)
    assert tracker.get_slot(slot_key) == slot_val

    store.save(tracker)

    again = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID)
    assert again.get_slot(slot_key) == slot_val
コード例 #2
0
def test_tracker_store_remembers_max_history(default_domain: Domain):
    store = InMemoryTrackerStore(default_domain)
    tr = store.get_or_create_tracker("myuser", max_event_history=42)
    tr.update(Restarted())

    store.save(tr)
    tr2 = store.retrieve("myuser")
    assert tr._max_event_history == tr2._max_event_history == 42
コード例 #3
0
def _tracker_store_and_tracker_with_slot_set(
) -> Tuple[InMemoryTrackerStore, DialogueStateTracker]:
    # returns an InMemoryTrackerStore containing a tracker with a slot set

    slot_key = "cuisine"
    slot_val = "French"

    store = InMemoryTrackerStore(domain)
    tracker = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID)
    ev = SlotSet(slot_key, slot_val)
    tracker.update(ev)

    return store, tracker
コード例 #4
0
def test_tracker_serialisation():
    slot_key = "location"
    slot_val = "Easter Island"

    store = InMemoryTrackerStore(domain)

    tracker = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID)
    ev = SlotSet(slot_key, slot_val)
    tracker.update(ev)

    serialised = store.serialise_tracker(tracker)

    assert tracker == store.deserialise_tracker(UserMessage.DEFAULT_SENDER_ID,
                                                serialised)
コード例 #5
0
def test_restart_after_retrieval_from_tracker_store(default_domain: Domain):
    store = InMemoryTrackerStore(default_domain)
    tr = store.get_or_create_tracker("myuser")
    synth = [ActionExecuted("action_listen") for _ in range(4)]

    for e in synth:
        tr.update(e)

    tr.update(Restarted())
    latest_restart = tr.idx_after_latest_restart()

    store.save(tr)
    tr2 = store.retrieve("myuser")
    latest_restart_after_loading = tr2.idx_after_latest_restart()
    assert latest_restart == latest_restart_after_loading
コード例 #6
0
class RasaServiceLocal(MqttService):
    """ Load RASA model and tracker directly and use to handle intent and routing messages"""
    def __init__(self, config, loop):
        """constructor"""
        super(RasaServiceLocal, self).__init__(config, loop)
        self.config = config
        self.subscribe_to = 'hermod/+/rasa/get_domain,hermod/+/rasa/set_slots' \
        + ',hermod/+/dialog/ended,hermod/+/dialog/init,hermod/+/nlu/externalparse,' \
        + 'hermod/+/nlu/parse,hermod/+/intent,hermod/+/intent,hermod/+/dialog/started'
        model_path = get_model(
            config['services']['RasaServiceLocal'].get('model_path'))
        endpoint = EndpointConfig(
            config['services']['RasaServiceLocal'].get('rasa_actions_url'))
        domain = 'domain.yml'
        self.tracker_store = InMemoryTrackerStore(domain)
        regex_interpreter = RegexInterpreter()
        self.text_interpreter = RasaNLUInterpreter(model_path + '/nlu')
        self.agent = Agent.load(model_path,
                                action_endpoint=endpoint,
                                tracker_store=self.tracker_store,
                                interpreter=regex_interpreter)

    async def connect_hook(self):
        """mqtt connected callback"""
        # SUBSCRIBE
        for sub in self.subscribe_to.split(","):
            await self.client.subscribe(sub)
        await self.client.publish('hermod/rasa/ready', json.dumps({}))

    async def on_message(self, message):
        """handle mqtt message"""
        topic = "{}".format(message.topic)
        parts = topic.split("/")
        site = parts[1]
        payload_string = str(message.payload, encoding='utf-8')
        payload = {}
        text = ''
        try:
            payload = json.loads(payload_string)
        except json.JSONDecodeError:
            pass
        if topic == 'hermod/' + site + '/rasa/set_slots':
            if payload:
                await self.set_slots(site, payload)

        elif topic == 'hermod/' + site + '/nlu/parse':
            if payload:
                await self.client.publish('hermod/' + site \
                + '/display/startwaiting', json.dumps({}))
                text = payload.get('query')
                await self.nlu_parse_request(site, text, payload)
                await self.client.publish(
                    'hermod/' + site + '/display/stopwaiting', json.dumps({}))

        elif topic == 'hermod/' + site + '/nlu/externalparse':
            if payload:
                text = payload.get('query')
                await self.nlu_external_parse_request(site, text, payload)

        elif topic == 'hermod/' + site + '/intent':

            if payload:
                await self.client.publish('hermod/' + site \
                + '/display/startwaiting', json.dumps({}))
                await self.handle_intent(site, payload)
                await self.client.publish('hermod/' + \
                site + '/display/stopwaiting', json.dumps({}))

        elif topic == 'hermod/' + site + '/tts/finished':
            await self.client.unsubscribe('hermod/' + site + '/tts/finished')
            await self.finish(site, payload)

        elif topic == 'hermod/' + site + '/dialog/started':
            await self.reset_tracker(site)

        elif topic == 'hermod/' + site + '/ ':
            # save dialog init data to slots for custom actions
            tracker = self.tracker_store.get_or_create_tracker(site)
            tracker.update(SlotSet("hermod_client", json.dumps(payload)))
            self.tracker_store.save(tracker)

        elif topic == 'hermod/' + site + '/rasa/get_domain':
            await self.send_domain(site)

        elif topic == 'hermod/' + site + '/core/ended':
            await self.send_story(site, payload)

    async def send_story(self, site, payload):
        """send conversation history for a site"""
        # text = payload.get('text', '')
        tracker = self.tracker_store.get_or_create_tracker(site)
        response = tracker.export_stories()
        await self.client.publish('hermod/' + site + \
        '/rasa/story', json.dumps({'id': payload.get('id', ''), 'story': response}))

    async def send_domain(self, site):
        """send domain for a site"""
        await self.client.publish('hermod/' + site + \
        '/rasa/domain', json.dumps(self.agent.domain.as_dict()))

    async def reset_tracker(self, site):
        """reset conversation history for a site"""
        pass
        # self.log('RESSET tracker '+site)
        # tracker = self.tracker_store.get_or_create_tracker(site)
        # tracker._reset()

    async def handle_intent(self, site, payload):
        """handle intent message"""
        await self.client.publish('hermod/' + site + '/core/started',
                                  json.dumps(payload))
        if payload:
            intent_name = payload.get('intent', {}).get('name', '')
            entities_json = {}
            entities = payload.get('entities', [])
            for entity in entities:
                entities_json[entity.get('entity')] = entity.get('value')
            intent_json = "/" + intent_name + json.dumps(entities_json)
            messages = []
            responses = await self.agent.handle_text(intent_json, sender_id=site, \
            output_channel=None)
            for response in responses:
                messages.append(response.get("text"))
            if messages:
                message = '. '.join(messages)
                await self.client.subscribe('hermod/' + site + '/tts/finished')
                await self.client.publish(
                    'hermod/' + site + '/tts/say',
                    json.dumps({
                        "text": message,
                        "id": payload.get('id', '')
                    }))
            else:
                await self.finish(site, payload)
        else:
            await self.finish(site, payload)

    async def set_slots(self, site, payload):
        """set tracker slots"""
        tracker = self.tracker_store.get_or_create_tracker(site)
        if payload:
            for slot in payload.get('slots', []):
                tracker.update(SlotSet(slot.get('slot'), slot.get('value')))
            self.tracker_store.save(tracker)
            await self.client.publish('hermod/' + site + '/dialog/slots', \
            json.dumps(tracker.current_slot_values()))

    async def send_slots(self, site):
        """send a message with current tracker slots for site"""
        tracker = self.tracker_store.get_or_create_tracker(site)
        slots = tracker.current_slot_values()
        await self.client.publish('hermod/' + site + '/dialog/slots',
                                  json.dumps(slots))

    async def finish(self, site, payload):
        """ finish intent callback """
        tracker = self.tracker_store.get_or_create_tracker(site)
        slots = tracker.current_slot_values()
        if slots.get('hermod_force_continue', False) == 'true':
            tracker.update(SlotSet("hermod_force_continue", None))
            tracker.update(SlotSet("hermod_force_end", None))
            self.tracker_store.save(tracker)
            await self.client.publish(
                'hermod/' + site + '/dialog/continue',
                json.dumps({"id": payload.get("id", "")}))
        elif slots.get('hermod_force_end', False) == 'true':
            tracker.update(SlotSet("hermod_force_continue", None))
            tracker.update(SlotSet("hermod_force_end", None))
            self.tracker_store.save(tracker)
            await self.client.publish(
                'hermod/' + site + '/dialog/end',
                json.dumps({"id": payload.get("id", "")}))
        else:
            if self.config.get('keep_listening') == "true":
                await self.client.publish(
                    'hermod/' + site + '/dialog/continue',
                    json.dumps({"id": payload.get("id", "")}))
            else:
                await self.client.publish(
                    'hermod/' + site + '/dialog/end',
                    json.dumps({"id": payload.get("id", "")}))
        await self.send_slots(site)
        await self.client.publish('hermod/' + site + '/core/ended',
                                  json.dumps(payload))

    async def nlu_parse_request(self, site, text, payload):
        """ parse text into NLU json and send as message"""
        response = await self.text_interpreter.parse(text)
        response['id'] = payload.get('id', '')
        await self.client.publish('hermod/' + site + '/nlu/intent',
                                  json.dumps(response))

    async def nlu_external_parse_request(self, site, text, payload):
        """ parse text into NLU json and send as message without invoking hermod flow"""
        response = await self.text_interpreter.parse(text)
        response['id'] = payload.get('id', '')
        await self.client.publish('hermod/' + site + '/nlu/externalintent',
                                  json.dumps(response))