コード例 #1
0
    def on_message(self, client, userdata, msg):
        """Received message from MQTT broker."""
        try:
            _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload),
                          msg.topic)

            if NluIntent.is_topic(msg.topic):
                # Intent from query
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                intent = NluIntent(**json_payload)

                # Report to websockets
                for queue in self.message_queues:
                    self.loop.call_soon_threadsafe(queue.put_nowait,
                                                   ("intent", intent))

                self.handle_message(msg.topic, intent)
            elif msg.topic == NluIntentNotRecognized.topic():
                # Intent not recognized
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                not_recognized = NluIntentNotRecognized(**json_payload)
                # Report to websockets
                for queue in self.message_queues:
                    self.loop.call_soon_threadsafe(queue.put_nowait,
                                                   ("intent", not_recognized))

                self.handle_message(msg.topic, not_recognized)
            elif msg.topic == TtsSayFinished.topic():
                # Text to speech finished
                json_payload = json.loads(msg.payload)
                self.handle_message(msg.topic, TtsSayFinished(**json_payload))
            elif msg.topic == AsrTextCaptured.topic():
                # Speech to text result
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                self.handle_message(msg.topic, AsrTextCaptured(**json_payload))

            # Forward to external message queues
            for queue in self.message_queues:
                self.loop.call_soon_threadsafe(
                    queue.put_nowait, ("mqtt", msg.topic, msg.payload))
        except Exception:
            _LOGGER.exception("on_message")
コード例 #2
0
    async def recognize_intent(self,
                               text: str) -> typing.Dict[str, typing.Any]:
        """Send an NLU query and wait for intent or not recognized"""
        nlu_id = str(uuid4())
        query = NluQuery(id=nlu_id, input=text, siteId=self.siteId)

        def handle_intent():
            while True:
                _, message = yield

                if isinstance(
                        message,
                    (NluIntent, NluIntentNotRecognized)) and (message.id
                                                              == nlu_id):
                    if isinstance(message, NluIntent):
                        # Finish parsing
                        message.intent = rhasspyhermes.intent.Intent(
                            **message.intent)
                        message.slots = [
                            rhasspyhermes.intent.Slot(**s)
                            for s in message.slots
                        ]

                    return True, message

        messages = [query]
        topics = [
            NluIntent.topic(intentName="#"),
            NluIntentNotRecognized.topic()
        ]

        # Expecting only a single result
        async for result in self.publish_wait(handle_intent(), messages,
                                              topics):
            return result
コード例 #3
0
    async def async_test_not_recognized(self):
        """Verify invalid input leads to recognition failure."""
        query_id = str(uuid.uuid4())
        text = "not a valid sentence at all"

        query = NluQuery(input=text,
                         id=query_id,
                         site_id=self.site_id,
                         session_id=self.session_id)

        results = []
        async for result in self.hermes.on_message(query):
            results.append(result)

        # Check results
        self.assertEqual(
            results,
            [
                NluIntentNotRecognized(
                    input=text,
                    id=query_id,
                    site_id=self.site_id,
                    session_id=self.session_id,
                )
            ],
        )
コード例 #4
0
    def test_not_recognized(self):
        """Verify invalid input leads to a not recognized error."""
        text = "set the garage light to red"
        self.hermes.publish = MagicMock()
        self.hermes.handle_query(
            NluQuery(input=text, siteId=self.siteId, sessionId=self.sessionId))

        self.hermes.publish.assert_called_with(
            NluIntentNotRecognized(input=text,
                                   siteId=self.siteId,
                                   sessionId=self.sessionId))
コード例 #5
0
    def handle_query(self, query: NluQuery):
        """Do intent recognition."""
        def intent_filter(intent_name: str) -> bool:
            """Filter out intents."""
            if query.intentFilter:
                return intent_name in query.intentFilter
            return True

        recognitions = recognize(query.input,
                                 self.graph,
                                 intent_filter=intent_filter)
        if recognitions:
            # Use first recognition only.
            recognition = recognitions[0]
            assert recognition is not None
            assert recognition.intent is not None

            self.publish(
                NluIntent(
                    input=query.input,
                    id=query.id,
                    siteId=query.siteId,
                    sessionId=query.sessionId,
                    intent=Intent(
                        intentName=recognition.intent.name,
                        confidenceScore=recognition.intent.confidence,
                    ),
                    slots=[
                        Slot(
                            entity=e.entity,
                            slotName=e.entity,
                            confidence=1,
                            value=e.value,
                            raw_value=e.raw_value,
                            range=SlotRange(start=e.raw_start, end=e.raw_end),
                        ) for e in recognition.entities
                    ],
                ),
                intentName=recognition.intent.name,
            )
        else:
            # Not recognized
            self.publish(
                NluIntentNotRecognized(
                    input=query.input,
                    id=query.id,
                    siteId=query.siteId,
                    sessionId=query.sessionId,
                ))
コード例 #6
0
ファイル: test_nlu.py プロジェクト: rhasspy/rhasspy-test
    def test_http_text_to_intent_hermes_failure(self):
        """Test recognition failure with text-to-intent HTTP endpoint (Hermes format)"""
        response = requests.post(
            self.api_url("text-to-intent"),
            data="not a valid sentence",
            params={"outputFormat": "hermes"},
        )
        self.check_status(response)

        result = response.json()
        self.assertEqual(result["type"], "intentNotRecognized")

        # Different type
        not_recognized = NluIntentNotRecognized.from_dict(result["value"])

        # Input carried forward
        self.assertEqual(not_recognized.input, "not a valid sentence")
コード例 #7
0
 def on_connect(self, client, userdata, flags, rc):
     """Connected to MQTT broker."""
     try:
         topics = [
             DialogueStartSession.topic(),
             DialogueContinueSession.topic(),
             DialogueEndSession.topic(),
             TtsSayFinished.topic(),
             NluIntent.topic(intent_name="#"),
             NluIntentNotRecognized.topic(),
             AsrTextCaptured.topic(),
         ] + list(self.wakeword_topics.keys())
         for topic in topics:
             self.client.subscribe(topic)
             _LOGGER.debug("Subscribed to %s", topic)
     except Exception:
         _LOGGER.exception("on_connect")
コード例 #8
0
    def _subscribe_callbacks(self):
        # Remove duplicate intent names
        intent_names = list(set(self._callbacks_intent.keys()))
        topics = [
            NluIntent.topic(intent_name=intent_name)
            for intent_name in intent_names
        ]

        if self._callbacks_hotword:
            topics.append(HotwordDetected.topic())

        if self._callbacks_intent_not_recognized:
            topics.append(NluIntentNotRecognized.topic())

        topic_names = list(set(self._callbacks_topic.keys()))
        topics.extend(topic_names)
        topics.extend(self._additional_topic)

        self.subscribe_topics(*topics)
コード例 #9
0
ファイル: test_dialogue.py プロジェクト: rhasspy/rhasspy-test
    async def on_message_test_not_recognized(
        self,
        message: Message,
        site_id: typing.Optional[str] = None,
        session_id: typing.Optional[str] = None,
        topic: typing.Optional[str] = None,
    ):
        """Receive messages for test_not_recognized"""
        _LOGGER.debug(message)
        if isinstance(message, DialogueSessionStarted):
            self.events["started"].set()

            # Verify session was started on the base
            self.assertEqual(message.site_id, self.base_id)
            self.assertEqual(message.custom_data,
                             self.custom_data[message.site_id])
            self.session_ids[message.site_id] = message.session_id

            self.custom_data[message.site_id] = str(uuid4())

            # Publish an intent not recognized message to abort the session
            yield NluIntentNotRecognized(
                input="test intent",
                site_id=message.site_id,
                session_id=message.session_id,
                custom_data=self.custom_data[message.site_id],
            )
        elif isinstance(message, DialogueSessionEnded):
            self.events["ended"].set()

            # Verify session was aborted on the base
            self.assertEqual(message.site_id, self.base_id)
            self.assertEqual(
                message.termination.reason,
                DialogueSessionTerminationReason.INTENT_NOT_RECOGNIZED,
            )
            self.assertEqual(message.custom_data,
                             self.custom_data[message.site_id])

            self.assertEqual(self.session_ids[message.site_id],
                             message.session_id)

        yield None
コード例 #10
0
    async def on_raw_message(self, topic: str, payload: bytes):
        """This method handles messages from the MQTT broker.

        Arguments:
            topic: The topic of the received MQTT message.

            payload: The payload of the received MQTT message.

        .. warning:: Don't override this method in your app. This is where all the magic happens in Rhasspy Hermes App.
        """
        try:
            if HotwordDetected.is_topic(topic):
                # hermes/hotword/<wakeword_id>/detected
                try:
                    hotword_detected = HotwordDetected.from_json(payload)
                    for function_h in self._callbacks_hotword:
                        await function_h(hotword_detected)
                except KeyError as key:
                    _LOGGER.error("Missing key %s in JSON payload for %s: %s",
                                  key, topic, payload)
            elif NluIntent.is_topic(topic):
                # hermes/intent/<intent_name>
                try:
                    nlu_intent = NluIntent.from_json(payload)
                    intent_name = nlu_intent.intent.intent_name
                    if intent_name in self._callbacks_intent:
                        for function_i in self._callbacks_intent[intent_name]:
                            await function_i(nlu_intent)
                except KeyError as key:
                    _LOGGER.error("Missing key %s in JSON payload for %s: %s",
                                  key, topic, payload)
            elif NluIntentNotRecognized.is_topic(topic):
                # hermes/nlu/intentNotRecognized
                try:
                    nlu_intent_not_recognized = NluIntentNotRecognized.from_json(
                        payload)
                    for function_inr in self._callbacks_intent_not_recognized:
                        await function_inr(nlu_intent_not_recognized)
                except KeyError as key:
                    _LOGGER.error("Missing key %s in JSON payload for %s: %s",
                                  key, topic, payload)
            elif DialogueIntentNotRecognized.is_topic(topic):
                # hermes/dialogueManager/intentNotRecognized
                try:
                    dialogue_intent_not_recognized = DialogueIntentNotRecognized.from_json(
                        payload)
                    for function_dinr in self._callbacks_dialogue_intent_not_recognized:
                        await function_dinr(dialogue_intent_not_recognized)
                except KeyError as key:
                    _LOGGER.error("Missing key %s in JSON payload for %s: %s",
                                  key, topic, payload)
            else:
                unexpected_topic = True
                if topic in self._callbacks_topic:
                    for function_1 in self._callbacks_topic[topic]:
                        await function_1(TopicData(topic, {}), payload)
                        unexpected_topic = False
                else:
                    for function_2 in self._callbacks_topic_regex:
                        if hasattr(function_2, "topic_extras"):
                            topic_extras = getattr(function_2, "topic_extras")
                            for pattern, named_positions in topic_extras:
                                if re.match(pattern, topic) is not None:
                                    data = TopicData(topic, {})
                                    parts = topic.split(sep="/")
                                    if named_positions is not None:
                                        for name, position in named_positions.items(
                                        ):
                                            data.data[name] = parts[position]

                                    function_2(data, payload)
                                    unexpected_topic = False

                if unexpected_topic:
                    _LOGGER.warning("Unexpected topic: %s", topic)

        except Exception:
            _LOGGER.exception("on_raw_message")
コード例 #11
0
    async def handle_query(
        self, query: NluQuery
    ) -> typing.AsyncIterable[
        typing.Union[
            typing.Tuple[NluIntent, TopicArgs],
            NluIntentParsed,
            NluIntentNotRecognized,
            NluError,
        ]
    ]:
        """Do intent recognition."""
        try:
            input_text = query.input

            # Fix casing
            if self.word_transform:
                input_text = self.word_transform(input_text)

            if self.nlu_url:
                # Use remote server
                _LOGGER.debug(self.nlu_url)

                params = {}

                # Add intent filter
                if query.intent_filter:
                    params["intentFilter"] = ",".join(query.intent_filter)

                async with self.http_session.post(
                    self.nlu_url, data=input_text, params=params, ssl=self.ssl_context
                ) as response:
                    response.raise_for_status()
                    intent_dict = await response.json()
            elif self.nlu_command:
                # Run external command
                _LOGGER.debug(self.nlu_command)
                proc = await asyncio.create_subprocess_exec(
                    *self.nlu_command,
                    stdin=asyncio.subprocess.PIPE,
                    stdout=asyncio.subprocess.PIPE,
                )

                input_bytes = (input_text.strip() + "\n").encode()
                output, error = await proc.communicate(input_bytes)
                if error:
                    _LOGGER.debug(error.decode())

                intent_dict = json.loads(output)
            else:
                _LOGGER.warning("Not handling NLU query (no URL or command)")
                return

            intent_name = intent_dict["intent"].get("name", "")

            if intent_name:
                # Recognized
                tokens = query.input.split()
                slots = [
                    Slot(
                        entity=e["entity"],
                        slot_name=e["entity"],
                        confidence=1,
                        value=e.get("value_details", {"value": ["value"]}),
                        raw_value=e.get("raw_value", e["value"]),
                        range=SlotRange(
                            start=e.get("start", 0),
                            end=e.get("end", 1),
                            raw_start=e.get("raw_start"),
                            raw_end=e.get("raw_end"),
                        ),
                    )
                    for e in intent_dict.get("entities", [])
                ]

                yield NluIntentParsed(
                    input=query.input,
                    id=query.id,
                    site_id=query.site_id,
                    session_id=query.session_id,
                    intent=Intent(
                        intent_name=intent_name,
                        confidence_score=intent_dict["intent"].get("confidence", 1.0),
                    ),
                    slots=slots,
                )

                yield (
                    NluIntent(
                        input=query.input,
                        id=query.id,
                        site_id=query.site_id,
                        session_id=query.session_id,
                        intent=Intent(
                            intent_name=intent_name,
                            confidence_score=intent_dict["intent"].get(
                                "confidence", 1.0
                            ),
                        ),
                        slots=slots,
                        asr_tokens=[NluIntent.make_asr_tokens(tokens)],
                        raw_input=query.input,
                        wakeword_id=query.wakeword_id,
                        lang=query.lang,
                    ),
                    {"intent_name": intent_name},
                )
            else:
                # Not recognized
                yield NluIntentNotRecognized(
                    input=query.input,
                    id=query.id,
                    site_id=query.site_id,
                    session_id=query.session_id,
                )
        except Exception as e:
            _LOGGER.exception("handle_query")
            yield NluError(
                error=repr(e),
                context=repr(query),
                site_id=query.site_id,
                session_id=query.session_id,
            )
コード例 #12
0
def test_nlu_intent_not_Recognized():
    """Test NluIntentNotRecognized."""
    assert NluIntentNotRecognized.topic() == "hermes/nlu/intentNotRecognized"
コード例 #13
0
"""Tests for rhasspyhermes_app NLU."""
# pylint: disable=protected-access,too-many-function-args
import asyncio

import pytest
from rhasspyhermes.nlu import NluIntentNotRecognized

from rhasspyhermes_app import HermesApp

INR_TOPIC = "hermes/nlu/intentNotRecognized"
INR = NluIntentNotRecognized(input="covfefe")

_LOOP = asyncio.get_event_loop()


@pytest.mark.asyncio
async def test_callbacks_intent_not_recognized(mocker):
    """Test intent not recognized callbacks."""
    app = HermesApp("Test intentNotRecognized", mqtt_client=mocker.MagicMock())

    # Mock callback and apply on_intent_not_recognized decorator.
    inr = mocker.MagicMock()
    app.on_intent_not_recognized(inr)

    # Simulate app.run() without the MQTT client.
    app._subscribe_callbacks()

    # Simulate intent not recognized message.
    await app.on_raw_message(INR_TOPIC, INR.to_json())

    # Check whether callback has been called with the right Rhasspy Hermes object.
コード例 #14
0
    async def handle_query(
        self, query: NluQuery
    ) -> typing.AsyncIterable[typing.Union[NluIntentParsed, typing.Tuple[
            NluIntent, TopicArgs], NluIntentNotRecognized, NluError, ]]:
        """Do intent recognition."""
        original_input = query.input

        try:
            self.maybe_load_engine()
            assert self.engine, "Snips engine not loaded. You may need to train."

            input_text = query.input

            # Fix casing for output event
            if self.word_transform:
                input_text = self.word_transform(input_text)

            # Do parsing
            result = self.engine.parse(input_text, query.intent_filter)
            intent_name = result.get("intent", {}).get("intentName")

            if intent_name:
                slots = [
                    Slot(
                        slot_name=s["slotName"],
                        entity=s["entity"],
                        value=s["value"],
                        raw_value=s["rawValue"],
                        range=SlotRange(start=s["range"]["start"],
                                        end=s["range"]["end"]),
                    ) for s in result.get("slots", [])
                ]

                # intentParsed
                yield NluIntentParsed(
                    input=query.input,
                    id=query.id,
                    site_id=query.site_id,
                    session_id=query.session_id,
                    intent=Intent(intent_name=intent_name,
                                  confidence_score=1.0),
                    slots=slots,
                )

                # intent
                yield (
                    NluIntent(
                        input=query.input,
                        id=query.id,
                        site_id=query.site_id,
                        session_id=query.session_id,
                        intent=Intent(intent_name=intent_name,
                                      confidence_score=1.0),
                        slots=slots,
                        asr_tokens=[
                            NluIntent.make_asr_tokens(query.input.split())
                        ],
                        raw_input=original_input,
                        wakeword_id=query.wakeword_id,
                        lang=query.lang,
                    ),
                    {
                        "intent_name": intent_name
                    },
                )
            else:
                # Not recognized
                yield NluIntentNotRecognized(
                    input=query.input,
                    id=query.id,
                    site_id=query.site_id,
                    session_id=query.session_id,
                )
        except Exception as e:
            _LOGGER.exception("handle_query")
            yield NluError(
                site_id=query.site_id,
                session_id=query.session_id,
                error=str(e),
                context=original_input,
            )
コード例 #15
0
    async def handle_query(
        self, query: NluQuery
    ) -> typing.AsyncIterable[typing.Union[
            NluIntentParsed, NluIntentNotRecognized, NluError, ]]:
        """Do intent recognition."""
        try:
            # Replace digits with words
            if self.replace_numbers:
                # Have to assume whitespace tokenization
                words = rhasspynlu.replace_numbers(query.input.split(),
                                                   self.number_language)
                query.input = " ".join(words)

            input_text = query.input

            # Fix casing for output event
            if self.word_transform:
                input_text = self.word_transform(input_text)

            parse_url = urljoin(self.rasa_url, "model/parse")
            _LOGGER.debug(parse_url)

            async with self.http_session.post(
                    parse_url,
                    json={
                        "text": input_text,
                        "project": self.rasa_project
                    },
                    ssl=self.ssl_context,
            ) as response:
                response.raise_for_status()
                intent_json = await response.json()
                intent = intent_json.get("intent", {})
                intent_name = intent.get("name", "")

                if intent_name and (query.intent_filter is None
                                    or intent_name in query.intent_filter):
                    confidence_score = float(intent.get("confidence", 0.0))
                    slots = [
                        Slot(
                            entity=e.get("entity", ""),
                            slot_name=e.get("entity", ""),
                            confidence=float(e.get("confidence", 0.0)),
                            value={
                                "kind": "Unknown",
                                "value": e.get("value", ""),
                                "additional_info":
                                e.get("additional_info", {}),
                                "extractor": e.get("extractor", None),
                            },
                            raw_value=e.get("value", ""),
                            range=SlotRange(
                                start=int(e.get("start", 0)),
                                end=int(e.get("end", 1)),
                                raw_start=int(e.get("start", 0)),
                                raw_end=int(e.get("end", 1)),
                            ),
                        ) for e in intent_json.get("entities", [])
                    ]

                    # intentParsed
                    yield NluIntentParsed(
                        input=input_text,
                        id=query.id,
                        site_id=query.site_id,
                        session_id=query.session_id,
                        intent=Intent(intent_name=intent_name,
                                      confidence_score=confidence_score),
                        slots=slots,
                    )
                else:
                    # Not recognized
                    yield NluIntentNotRecognized(
                        input=query.input,
                        id=query.id,
                        site_id=query.site_id,
                        session_id=query.session_id,
                    )
        except Exception as e:
            _LOGGER.exception("nlu query")
            yield NluError(
                site_id=query.site_id,
                session_id=query.session_id,
                error=str(e),
                context=query.input,
            )
コード例 #16
0
    async def handle_query(
        self, query: NluQuery
    ) -> typing.AsyncIterable[typing.Union[NluIntentParsed, typing.Tuple[
            NluIntent, TopicArgs], NluIntentNotRecognized, NluError, ]]:
        """Do intent recognition."""
        original_input = query.input

        try:
            if not self.intent_graph and self.graph_path and self.graph_path.is_file(
            ):
                # Load graph from file
                _LOGGER.debug("Loading %s", self.graph_path)
                with open(self.graph_path, mode="rb") as graph_file:
                    self.intent_graph = rhasspynlu.gzip_pickle_to_graph(
                        graph_file)

            if self.intent_graph:

                def intent_filter(intent_name: str) -> bool:
                    """Filter out intents."""
                    if query.intent_filter:
                        return intent_name in query.intent_filter
                    return True

                # Replace digits with words
                if self.replace_numbers:
                    # Have to assume whitespace tokenization
                    words = rhasspynlu.replace_numbers(query.input.split(),
                                                       self.language)
                    query.input = " ".join(words)

                input_text = query.input

                # Fix casing for output event
                if self.word_transform:
                    input_text = self.word_transform(input_text)

                if self.failure_token and (self.failure_token
                                           in query.input.split()):
                    # Failure token was found in input
                    recognitions = []
                else:
                    # Pass in raw query input so raw values will be correct
                    recognitions = recognize(
                        query.input,
                        self.intent_graph,
                        intent_filter=intent_filter,
                        word_transform=self.word_transform,
                        fuzzy=self.fuzzy,
                        extra_converters=self.extra_converters,
                    )
            else:
                _LOGGER.error("No intent graph loaded")
                recognitions = []

            if NluHermesMqtt.is_success(recognitions):
                # Use first recognition only.
                recognition = recognitions[0]
                assert recognition is not None
                assert recognition.intent is not None

                intent = Intent(
                    intent_name=recognition.intent.name,
                    confidence_score=recognition.intent.confidence,
                )
                slots = [
                    Slot(
                        entity=(e.source or e.entity),
                        slot_name=e.entity,
                        confidence=1.0,
                        value=e.value_dict,
                        raw_value=e.raw_value,
                        range=SlotRange(
                            start=e.start,
                            end=e.end,
                            raw_start=e.raw_start,
                            raw_end=e.raw_end,
                        ),
                    ) for e in recognition.entities
                ]

                if query.custom_entities:
                    # Copy user-defined entities
                    for entity_name, entity_value in query.custom_entities.items(
                    ):
                        slots.append(
                            Slot(
                                entity=entity_name,
                                confidence=1.0,
                                value={"value": entity_value},
                            ))

                # intentParsed
                yield NluIntentParsed(
                    input=recognition.text,
                    id=query.id,
                    site_id=query.site_id,
                    session_id=query.session_id,
                    intent=intent,
                    slots=slots,
                )

                # intent
                yield (
                    NluIntent(
                        input=recognition.text,
                        id=query.id,
                        site_id=query.site_id,
                        session_id=query.session_id,
                        intent=intent,
                        slots=slots,
                        asr_tokens=[
                            NluIntent.make_asr_tokens(recognition.tokens)
                        ],
                        asr_confidence=query.asr_confidence,
                        raw_input=original_input,
                        wakeword_id=query.wakeword_id,
                        lang=(query.lang or self.lang),
                        custom_data=query.custom_data,
                    ),
                    {
                        "intent_name": recognition.intent.name
                    },
                )
            else:
                # Not recognized
                yield NluIntentNotRecognized(
                    input=query.input,
                    id=query.id,
                    site_id=query.site_id,
                    session_id=query.session_id,
                    custom_data=query.custom_data,
                )
        except Exception as e:
            _LOGGER.exception("handle_query")
            yield NluError(
                site_id=query.site_id,
                session_id=query.session_id,
                error=str(e),
                context=original_input,
            )
コード例 #17
0
    def on_message(self, client, userdata, msg):
        """Received message from MQTT broker."""
        try:
            _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload),
                          msg.topic)
            if msg.topic == DialogueStartSession.topic():
                # Start session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_start(DialogueStartSession(**json_payload)),
                    self.loop)
            elif msg.topic == DialogueContinueSession.topic():
                # Continue session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_continue(
                        DialogueContinueSession(**json_payload)),
                    self.loop,
                )
            elif msg.topic == DialogueEndSession.topic():
                # End session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run outside event loop
                self.handle_end(DialogueEndSession(**json_payload))
            elif msg.topic == TtsSayFinished.topic():
                # TTS finished
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Signal event loop
                self.loop.call_soon_threadsafe(self.say_finished_event.set)
            elif msg.topic == AsrTextCaptured.topic():
                # Text captured
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Run outside event loop
                self.handle_text_captured(AsrTextCaptured(**json_payload))
            elif msg.topic.startswith(NluIntent.topic(intent_name="")):
                # Intent recognized
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                self.handle_recognized(NluIntent(**json_payload))
            elif msg.topic.startswith(NluIntentNotRecognized.topic()):
                # Intent recognized
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_not_recognized(
                        NluIntentNotRecognized(**json_payload)),
                    self.loop,
                )
            elif msg.topic in self.wakeword_topics:
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                wakeword_id = self.wakeword_topics[msg.topic]
                asyncio.run_coroutine_threadsafe(
                    self.handle_wake(wakeword_id,
                                     HotwordDetected(**json_payload)),
                    self.loop,
                )
        except Exception:
            _LOGGER.exception("on_message")
コード例 #18
0
ファイル: test_dialogue.py プロジェクト: rhasspy/rhasspy-test
    async def on_message_test_multi_session(
        self,
        message: Message,
        site_id: typing.Optional[str] = None,
        session_id: typing.Optional[str] = None,
        topic: typing.Optional[str] = None,
    ):
        """Receive messages for test_multi_session"""
        if not isinstance(message, AudioPlayBytes):
            _LOGGER.debug(message)

        if isinstance(message, DialogueSessionStarted):
            # Verify session was started on the base or a satellite
            self.assertIn(message.site_id, [self.base_id] + self.satellite_ids)
            self.assertEqual(message.custom_data,
                             self.custom_data[message.site_id])

            self.events[f"{message.site_id}_started"].set()

            self.session_ids[message.site_id] = message.session_id

            if message.site_id == self.continue_site_id:
                # Make session continue one more step
                self.custom_data[message.site_id] = "done"

                yield DialogueContinueSession(session_id=message.session_id,
                                              custom_data="done")
            else:
                self.custom_data[message.site_id] = str(uuid4())

                # Publish an intent not recognized message to abort the session
                yield NluIntentNotRecognized(
                    input=f"test intent (site={message.site_id})",
                    site_id=message.site_id,
                    session_id=message.session_id,
                    custom_data=self.custom_data[message.site_id],
                )
        elif isinstance(message, DialogueSessionEnded):

            # Verify session was aborted on the base or a satellite
            self.assertIn(message.site_id, [self.base_id] + self.satellite_ids)
            self.assertEqual(message.custom_data,
                             self.custom_data[message.site_id])

            self.events[f"{message.site_id}_ended"].set()

            self.assertEqual(
                message.termination.reason,
                DialogueSessionTerminationReason.INTENT_NOT_RECOGNIZED,
            )

            self.assertEqual(self.session_ids[message.site_id],
                             message.session_id)
        elif isinstance(message, AsrStartListening):
            # Follow on from continue session
            if (message.site_id == self.continue_site_id) and (
                    self.custom_data[message.site_id] == "done"):
                # Publish an intent not recognized message to abort the session
                yield NluIntentNotRecognized(
                    input=f"test intent (site={message.site_id})",
                    site_id=message.site_id,
                    session_id=message.session_id,
                )
        elif isinstance(message, AudioPlayBytes):
            yield (AudioPlayFinished(id=session_id), {"site_id": site_id})

        yield None
コード例 #19
0
    async def handle_query(
        self, query: NluQuery
    ) -> typing.AsyncIterable[typing.Union[NluIntentParsed, typing.Tuple[
            NluIntent, TopicArgs], NluIntentNotRecognized, NluError, ]]:
        """Do intent recognition."""
        # Check intent graph
        try:
            if (not self.intent_graph and self.intent_graph_path
                    and self.intent_graph_path.is_file()):
                _LOGGER.debug("Loading %s", self.intent_graph_path)
                with open(self.intent_graph_path, mode="rb") as graph_file:
                    self.intent_graph = rhasspynlu.gzip_pickle_to_graph(
                        graph_file)

            # Check examples
            if (self.intent_graph and self.examples_path
                    and self.examples_path.is_file()):

                def intent_filter(intent_name: str) -> bool:
                    """Filter out intents."""
                    if query.intent_filter:
                        return intent_name in query.intent_filter
                    return True

                original_text = query.input

                # Replace digits with words
                if self.replace_numbers:
                    # Have to assume whitespace tokenization
                    words = rhasspynlu.replace_numbers(query.input.split(),
                                                       self.language)
                    query.input = " ".join(words)

                input_text = query.input

                # Fix casing
                if self.word_transform:
                    input_text = self.word_transform(input_text)

                recognitions: typing.List[rhasspynlu.intent.Recognition] = []

                if input_text:
                    recognitions = rhasspyfuzzywuzzy.recognize(
                        input_text,
                        self.intent_graph,
                        str(self.examples_path),
                        intent_filter=intent_filter,
                        extra_converters=self.extra_converters,
                    )
            else:
                _LOGGER.error("No intent graph or examples loaded")
                recognitions = []

            # Use first recognition only if above threshold
            if (recognitions and recognitions[0] and recognitions[0].intent
                    and (recognitions[0].intent.confidence >=
                         self.confidence_threshold)):
                recognition = recognitions[0]
                assert recognition.intent
                intent = Intent(
                    intent_name=recognition.intent.name,
                    confidence_score=recognition.intent.confidence,
                )
                slots = [
                    Slot(
                        entity=(e.source or e.entity),
                        slot_name=e.entity,
                        confidence=1.0,
                        value=e.value_dict,
                        raw_value=e.raw_value,
                        range=SlotRange(
                            start=e.start,
                            end=e.end,
                            raw_start=e.raw_start,
                            raw_end=e.raw_end,
                        ),
                    ) for e in recognition.entities
                ]

                if query.custom_entities:
                    # Copy user-defined entities
                    for entity_name, entity_value in query.custom_entities.items(
                    ):
                        slots.append(
                            Slot(
                                entity=entity_name,
                                confidence=1.0,
                                value={"value": entity_value},
                            ))

                # intentParsed
                yield NluIntentParsed(
                    input=recognition.text,
                    id=query.id,
                    site_id=query.site_id,
                    session_id=query.session_id,
                    intent=intent,
                    slots=slots,
                )

                # intent
                yield (
                    NluIntent(
                        input=recognition.text,
                        id=query.id,
                        site_id=query.site_id,
                        session_id=query.session_id,
                        intent=intent,
                        slots=slots,
                        asr_tokens=[
                            NluIntent.make_asr_tokens(recognition.tokens)
                        ],
                        asr_confidence=query.asr_confidence,
                        raw_input=original_text,
                        wakeword_id=query.wakeword_id,
                        lang=(query.lang or self.lang),
                        custom_data=query.custom_data,
                    ),
                    {
                        "intent_name": recognition.intent.name
                    },
                )
            else:
                # Not recognized
                yield NluIntentNotRecognized(
                    input=query.input,
                    id=query.id,
                    site_id=query.site_id,
                    session_id=query.session_id,
                    custom_data=query.custom_data,
                )
        except Exception as e:
            _LOGGER.exception("handle_query")
            yield NluError(
                site_id=query.site_id,
                session_id=query.session_id,
                error=str(e),
                context=original_text,
            )
コード例 #20
0
    async def handle_text_captured(
        self, text_captured: AsrTextCaptured
    ) -> typing.AsyncIterable[
        typing.Union[
            AsrStopListening, HotwordToggleOn, NluQuery, NluIntentNotRecognized
        ]
    ]:
        """Handle ASR text captured for session."""
        try:
            if not text_captured.session_id:
                _LOGGER.warning("Missing session id on text captured message.")
                return

            site_session = self.all_sessions.get(text_captured.session_id)
            if site_session is None:
                _LOGGER.warning(
                    "No session for id %s. Dropping captured text from ASR.",
                    text_captured.session_id,
                )
                return

            _LOGGER.debug("Received text: %s", text_captured.text)

            # Record result
            site_session.text_captured = text_captured

            # Stop listening
            yield AsrStopListening(
                site_id=text_captured.site_id, session_id=site_session.session_id
            )

            # Enable hotword
            yield HotwordToggleOn(
                site_id=text_captured.site_id,
                reason=HotwordToggleReason.DIALOGUE_SESSION,
            )

            if (self.min_asr_confidence is not None) and (
                text_captured.likelihood < self.min_asr_confidence
            ):
                # Transcription is below thresold.
                # Don't actually do an NLU query, just reject as "not recognized".
                _LOGGER.debug(
                    "Transcription is below confidence threshold (%s < %s): %s",
                    text_captured.likelihood,
                    self.min_asr_confidence,
                    text_captured.text,
                )

                yield NluIntentNotRecognized(
                    input=text_captured.text,
                    site_id=site_session.site_id,
                    session_id=site_session.session_id,
                )
            else:
                # Perform query
                custom_entities: typing.Optional[typing.Dict[str, typing.Any]] = None

                # Copy custom entities from hotword detected
                if site_session.detected:
                    custom_entities = site_session.detected.custom_entities

                yield NluQuery(
                    input=text_captured.text,
                    intent_filter=site_session.intent_filter
                    or self.default_intent_filter,
                    session_id=site_session.session_id,
                    site_id=site_session.site_id,
                    wakeword_id=text_captured.wakeword_id or site_session.wakeword_id,
                    lang=text_captured.lang or site_session.lang,
                    custom_data=site_session.custom_data,
                    asr_confidence=text_captured.likelihood,
                    custom_entities=custom_entities,
                )
        except Exception:
            _LOGGER.exception("handle_text_captured")
コード例 #21
0
    async def handle_query(
        self, query: NluQuery
    ) -> typing.AsyncIterable[typing.Union[NluIntentParsed, typing.Tuple[
            NluIntent, TopicArgs], NluIntentNotRecognized, NluError, ]]:
        """Do intent recognition."""
        try:
            original_input = query.input

            # Replace digits with words
            if self.replace_numbers:
                # Have to assume whitespace tokenization
                words = rhasspynlu.replace_numbers(query.input.split(),
                                                   self.number_language)
                query.input = " ".join(words)

            input_text = query.input

            # Fix casing for output event
            if self.word_transform:
                input_text = self.word_transform(input_text)

            parse_url = urljoin(self.rasa_url, "model/parse")
            _LOGGER.debug(parse_url)

            async with self.http_session.post(
                    parse_url,
                    json={
                        "text": input_text,
                        "project": self.rasa_project
                    },
                    ssl=self.ssl_context,
            ) as response:
                response.raise_for_status()
                intent_json = await response.json()
                intent = intent_json.get("intent", {})
                intent_name = intent.get("name", "")

                if intent_name and (query.intent_filter is None
                                    or intent_name in query.intent_filter):
                    confidence_score = float(intent.get("confidence", 0.0))
                    slots = [
                        Slot(
                            entity=e.get("entity", ""),
                            slot_name=e.get("entity", ""),
                            confidence=float(e.get("confidence", 0.0)),
                            value={
                                "kind": "Unknown",
                                "value": e.get("value", "")
                            },
                            raw_value=e.get("value", ""),
                            range=SlotRange(
                                start=int(e.get("start", 0)),
                                end=int(e.get("end", 1)),
                                raw_start=int(e.get("start", 0)),
                                raw_end=int(e.get("end", 1)),
                            ),
                        ) for e in intent_json.get("entities", [])
                    ]

                    if query.custom_entities:
                        # Copy user-defined entities
                        for entity_name, entity_value in query.custom_entities.items(
                        ):
                            slots.append(
                                Slot(
                                    entity=entity_name,
                                    confidence=1.0,
                                    value={"value": entity_value},
                                ))

                    # intentParsed
                    yield NluIntentParsed(
                        input=input_text,
                        id=query.id,
                        site_id=query.site_id,
                        session_id=query.session_id,
                        intent=Intent(intent_name=intent_name,
                                      confidence_score=confidence_score),
                        slots=slots,
                    )

                    # intent
                    yield (
                        NluIntent(
                            input=input_text,
                            id=query.id,
                            site_id=query.site_id,
                            session_id=query.session_id,
                            intent=Intent(
                                intent_name=intent_name,
                                confidence_score=confidence_score,
                            ),
                            slots=slots,
                            asr_tokens=[
                                NluIntent.make_asr_tokens(input_text.split())
                            ],
                            asr_confidence=query.asr_confidence,
                            raw_input=original_input,
                            lang=(query.lang or self.lang),
                            custom_data=query.custom_data,
                        ),
                        {
                            "intent_name": intent_name
                        },
                    )
                else:
                    # Not recognized
                    yield NluIntentNotRecognized(
                        input=query.input,
                        id=query.id,
                        site_id=query.site_id,
                        session_id=query.session_id,
                        custom_data=query.custom_data,
                    )
        except Exception as e:
            _LOGGER.exception("nlu query")
            yield NluError(
                site_id=query.site_id,
                session_id=query.session_id,
                error=str(e),
                context=query.input,
            )