示例#1
0
 def __init__(self, domain_path='agents/prototypes/concertbot/domain.yml'):
     self.executor = ActionExecutor()
     action_package_name = 'actions.procs'
     self.executor.register_package(action_package_name)
     endpoint = EndpointConfig("http://localhost:5000")
     self.interpreter = RasaNLUHttpInterpreter(model_name="nlu",
                                               project_name='chinese',
                                               endpoint=endpoint)
     self.domain = Domain.load(domain_path)
示例#2
0
async def nlu_parse(url, message):
    tracker = DialogueStateTracker.from_dict("1", [],
                                             [Slot("requested_language")])
    # we'll expect this value 'en' to be part of the result from the interpreter
    tracker._set_slot("requested_language", "en")
    inte = RasaNLUHttpInterpreter(EndpointConfig(url))
    result = await inte.parse(message, tracker=tracker)
    return result
示例#3
0
class ActionRunner(object):
    def __init__(self, domain_path='agents/prototypes/concertbot/domain.yml'):
        self.executor = ActionExecutor()
        action_package_name = 'actions.procs'
        self.executor.register_package(action_package_name)
        endpoint = EndpointConfig("http://localhost:5000")
        self.interpreter = RasaNLUHttpInterpreter(model_name="nlu",
                                                  project_name='chinese',
                                                  endpoint=endpoint)
        self.domain = Domain.load(domain_path)

    def create_api_response(self, events, messages):
        return {"events": events if events else [], "responses": messages}

    def prepare(self, text):
        tracker = DialogueStateTracker("default", self.domain.slots)
        parse_data = self.interpreter.parse(text)
        # print(parse_data)
        tracker.update(
            UserUttered(text, parse_data["intent"], parse_data["entities"],
                        parse_data))
        # store all entities as slots
        for e in self.domain.slots_for_entities(parse_data["entities"]):
            tracker.update(e)

        print("Logged UserUtterance - "
              "tracker now has {} events".format(len(tracker.events)))
        # print(tracker.latest_message)
        return tracker

    def execute(self, action_name, text, get_tracker=False):
        """
        $ python -m sagas.bots.action_runner execute action_about_date '找音乐会'
        $ python -m sagas.bots.action_runner execute action_about_date '找音乐会' True
        $ python -m sagas.bots.action_runner execute action_joke '找音乐会'
        :param action_name:
        :param text:
        :return:
        """
        # tracker = DialogueStateTracker("default", domain.slots)
        tracker = self.prepare(text)

        dispatcher = CollectingDispatcher()
        action = self.executor.actions.get(action_name)
        events = action(dispatcher, tracker, self.domain)
        resp = self.create_api_response(events, dispatcher.messages)
        if get_tracker:
            evs = deserialise_events(events)
            for ev in evs:
                tracker.update(ev)
            return resp, tracker
        else:
            return resp
示例#4
0
async def test_http_interpreter():
    with aioresponses() as mocked:
        mocked.post("https://example.com/model/parse")

        endpoint = EndpointConfig("https://example.com")
        interpreter = RasaNLUHttpInterpreter(endpoint=endpoint)
        await interpreter.parse(text="message_text")

        r = latest_request(mocked, "POST", "https://example.com/model/parse")

        query = json_of_latest_request(r)
        response = {"text": "message_text", "token": None}

        assert query == response
示例#5
0
async def test_http_parsing():
    message = UserMessage("lunch?")

    endpoint = EndpointConfig("https://interpreter.com")
    with aioresponses() as mocked:
        mocked.post("https://interpreter.com/model/parse", repeat=True, status=200)

        inter = RasaNLUHttpInterpreter(endpoint_config=endpoint)
        try:
            await MessageProcessor(inter, None, None, None, None).parse_message(message)
        except KeyError:
            pass  # logger looks for intent and entities, so we except

        r = latest_request(mocked, "POST", "https://interpreter.com/model/parse")

        assert r
示例#6
0
async def test_parsing_with_tracker():
    tracker = DialogueStateTracker.from_dict("1", [], [Slot("requested_language")])

    # we'll expect this value 'en' to be part of the result from the interpreter
    tracker._set_slot("requested_language", "en")

    endpoint = EndpointConfig("https://interpreter.com")
    with aioresponses() as mocked:
        mocked.post("https://interpreter.com/parse", repeat=True, status=200)

        # mock the parse function with the one defined for this test
        with patch.object(RasaNLUHttpInterpreter, "parse", mocked_parse):
            interpreter = RasaNLUHttpInterpreter(endpoint_config=endpoint)
            agent = Agent(None, None, interpreter)
            result = await agent.parse_message_using_nlu_interpreter("lunch?", tracker)

            assert result["requested_language"] == "en"
示例#7
0
async def test_http_parsing():
    message = UserMessage('lunch?')

    endpoint = EndpointConfig('https://interpreter.com')
    with aioresponses() as mocked:
        mocked.post('https://interpreter.com/parse', repeat=True, status=200)

        inter = RasaNLUHttpInterpreter(endpoint=endpoint)
        try:
            await MessageProcessor(inter, None, None, None,
                                   None)._parse_message(message)
        except KeyError:
            pass  # logger looks for intent and entities, so we except

        r = latest_request(mocked, 'POST', "https://interpreter.com/parse")

        assert r
        assert json_of_latest_request(r)['message_id'] == message.message_id
示例#8
0
async def test_http_interpreter(endpoint_url, joined_url):
    with aioresponses() as mocked:
        mocked.post(joined_url)

        endpoint = EndpointConfig(endpoint_url)
        interpreter = RasaNLUHttpInterpreter(endpoint_config=endpoint)
        await interpreter.parse(text="message_text", message_id="message_id")

        r = latest_request(mocked, "POST", joined_url)

        query = json_of_latest_request(r)
        response = {
            "text": "message_text",
            "token": None,
            "message_id": "message_id"
        }

        assert query == response
示例#9
0
async def test_http_interpreter():
    with aioresponses() as mocked:
        mocked.post("https://example.com/parse")

        endpoint = EndpointConfig('https://example.com')
        interpreter = RasaNLUHttpInterpreter(endpoint=endpoint)
        await interpreter.parse(text='message_text', message_id='1134')

        r = latest_request(
            mocked, "POST", "https://example.com/parse")

        query = json_of_latest_request(r)
        response = {'project': 'default',
                    'q': 'message_text',
                    'message_id': '1134',
                    'model': None,
                    'token': None}

        assert query == response