示例#1
0
    def on_message(self, client, userdata, msg):
        """Received message from MQTT broker."""
        try:
            _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload),
                          msg.topic)
            if msg.topic == NluQuery.topic():
                json_payload = json.loads(msg.payload)

                # Check siteId
                if not self._check_siteId(json_payload):
                    return

                try:
                    query = NluQuery(**json_payload)
                    _LOGGER.debug("<- %s", query)
                    self.handle_query(query)
                except Exception as e:
                    _LOGGER.exception("nlu query")
                    self.publish(
                        NluError(
                            siteId=query.siteId,
                            sessionId=json_payload.get("sessionId", ""),
                            error=str(e),
                            context="",
                        ))
        except Exception:
            _LOGGER.exception("on_message")
示例#2
0
    async def handle_train(
        self,
        train: NluTrain,
        site_id: str = "default"
    ) -> typing.AsyncIterable[typing.Union[typing.Tuple[NluTrainSuccess,
                                                        TopicArgs], NluError]]:
        """Transform sentences/slots into Snips NLU training dataset."""
        try:
            assert train.sentences, "No training sentences"

            start_time = time.perf_counter()

            new_engine = rhasspysnips_nlu.train(
                sentences_dict=train.sentences,
                language=self.snips_language,
                slots_dict=train.slots,
                engine_path=self.engine_path,
                dataset_path=self.dataset_path,
            )

            end_time = time.perf_counter()

            _LOGGER.debug("Trained Snips engine in %s second(s)",
                          end_time - start_time)
            self.engine = new_engine

            yield (NluTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_train")
            yield NluError(site_id=site_id,
                           session_id=train.id,
                           error=str(e),
                           context=train.id)
示例#3
0
    async def handle_train(
        self, train: NluTrain, site_id: str = "default"
    ) -> typing.AsyncIterable[
        typing.Union[typing.Tuple[NluTrainSuccess, TopicArgs], NluError]
    ]:
        """Transform sentences to intent examples"""
        try:
            _LOGGER.debug("Loading %s", train.graph_path)
            with open(train.graph_path, mode="rb") as graph_file:
                self.intent_graph = rhasspynlu.gzip_pickle_to_graph(graph_file)

            self.examples = rhasspyfuzzywuzzy.train(self.intent_graph)

            if self.examples_path:
                # Write examples to JSON file
                with open(self.examples_path, "w") as examples_file:
                    json.dump(self.examples, examples_file)

                _LOGGER.debug("Wrote %s", str(self.examples_path))

            yield (NluTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_train")
            yield NluError(
                site_id=site_id, session_id=train.id, error=str(e), context=train.id
            )
示例#4
0
    async def handle_nlu_train(
        self, train: NluTrain, site_id: str = "default"
    ) -> typing.AsyncIterable[
        typing.Union[typing.Tuple[NluTrainSuccess, TopicArgs], NluError]
    ]:
        """Re-trains NLU system"""
        try:
            # Load gzipped graph pickle
            _LOGGER.debug("Loading %s", train.graph_path)
            with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip:
                intent_graph = nx.readwrite.gpickle.read_gpickle(graph_gzip)

            # Get JSON intent graph
            json_graph = rhasspynlu.graph_to_json(intent_graph)

            if self.nlu_train_url:
                # Remote NLU server
                _LOGGER.debug(self.nlu_train_url)

                async with self.http_session.post(
                    self.nlu_train_url, json=json_graph, ssl=self.ssl_context
                ) as response:
                    # No data expected in response
                    response.raise_for_status()
            elif self.nlu_train_command:
                # Local NLU training command
                _LOGGER.debug(self.nlu_train_command)

                proc = await asyncio.create_subprocess_exec(
                    *self.nlu_train_command,
                    stdin=asyncio.subprocess.PIPE,
                    stderr=asyncio.subprocess.PIPE,
                )

                output, error = await proc.communicate(json.dumps(json_graph).encode())

                if output:
                    _LOGGER.debug(output.decode())

                if error:
                    _LOGGER.debug(error.decode())
            else:
                _LOGGER.warning("Can't train NLU system. No train URL or command.")

            # Report success
            yield (NluTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_nlu_train")
            yield NluError(
                error=str(e),
                context=f"url='{self.nlu_train_url}', command='{self.nlu_train_command}'",
                site_id=site_id,
                session_id=train.id,
            )
示例#5
0
    async def handle_train(
        self,
        train: NluTrain,
        site_id: str = "default"
    ) -> typing.AsyncIterable[typing.Union[typing.Tuple[NluTrainSuccess,
                                                        TopicArgs], NluError]]:
        """Transform sentences to intent examples"""
        try:
            _LOGGER.debug("Loading %s", train.graph_path)
            with open(train.graph_path, mode="rb") as graph_file:
                self.intent_graph = rhasspynlu.gzip_pickle_to_graph(graph_file)

            examples = rhasspyfuzzywuzzy.train(self.intent_graph)

            if self.examples_path:
                if self.examples_path.is_file():
                    # Delete existing file
                    self.examples_path.unlink()

                # Write examples to SQLite database
                conn = sqlite3.connect(str(self.examples_path))
                c = conn.cursor()
                c.execute("""DROP TABLE IF EXISTS intents""")
                c.execute(
                    """CREATE TABLE intents (sentence text, path text)""")

                for _, sentences in examples.items():
                    for sentence, path in sentences.items():
                        c.execute(
                            "INSERT INTO intents VALUES (?, ?)",
                            (sentence, json.dumps(path, ensure_ascii=False)),
                        )

                conn.commit()
                conn.close()

                _LOGGER.debug("Wrote %s", str(self.examples_path))
            yield (NluTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_train")
            yield NluError(site_id=site_id,
                           session_id=train.id,
                           error=str(e),
                           context=train.id)
示例#6
0
    async def handle_query(
        self, query: NluQuery
    ) -> typing.AsyncIterable[typing.Union[NluIntentParsed, typing.Tuple[
            NluIntent, TopicArgs], NluIntentNotRecognized, NluError, ]]:
        """Do intent recognition."""
        original_input = query.input

        try:
            if not self.intent_graph and self.graph_path and self.graph_path.is_file(
            ):
                # Load graph from file
                _LOGGER.debug("Loading %s", self.graph_path)
                with open(self.graph_path, mode="rb") as graph_file:
                    self.intent_graph = rhasspynlu.gzip_pickle_to_graph(
                        graph_file)

            if self.intent_graph:

                def intent_filter(intent_name: str) -> bool:
                    """Filter out intents."""
                    if query.intent_filter:
                        return intent_name in query.intent_filter
                    return True

                # Replace digits with words
                if self.replace_numbers:
                    # Have to assume whitespace tokenization
                    words = rhasspynlu.replace_numbers(query.input.split(),
                                                       self.language)
                    query.input = " ".join(words)

                input_text = query.input

                # Fix casing for output event
                if self.word_transform:
                    input_text = self.word_transform(input_text)

                if self.failure_token and (self.failure_token
                                           in query.input.split()):
                    # Failure token was found in input
                    recognitions = []
                else:
                    # Pass in raw query input so raw values will be correct
                    recognitions = recognize(
                        query.input,
                        self.intent_graph,
                        intent_filter=intent_filter,
                        word_transform=self.word_transform,
                        fuzzy=self.fuzzy,
                        extra_converters=self.extra_converters,
                    )
            else:
                _LOGGER.error("No intent graph loaded")
                recognitions = []

            if NluHermesMqtt.is_success(recognitions):
                # Use first recognition only.
                recognition = recognitions[0]
                assert recognition is not None
                assert recognition.intent is not None

                intent = Intent(
                    intent_name=recognition.intent.name,
                    confidence_score=recognition.intent.confidence,
                )
                slots = [
                    Slot(
                        entity=(e.source or e.entity),
                        slot_name=e.entity,
                        confidence=1.0,
                        value=e.value_dict,
                        raw_value=e.raw_value,
                        range=SlotRange(
                            start=e.start,
                            end=e.end,
                            raw_start=e.raw_start,
                            raw_end=e.raw_end,
                        ),
                    ) for e in recognition.entities
                ]

                if query.custom_entities:
                    # Copy user-defined entities
                    for entity_name, entity_value in query.custom_entities.items(
                    ):
                        slots.append(
                            Slot(
                                entity=entity_name,
                                confidence=1.0,
                                value={"value": entity_value},
                            ))

                # intentParsed
                yield NluIntentParsed(
                    input=recognition.text,
                    id=query.id,
                    site_id=query.site_id,
                    session_id=query.session_id,
                    intent=intent,
                    slots=slots,
                )

                # intent
                yield (
                    NluIntent(
                        input=recognition.text,
                        id=query.id,
                        site_id=query.site_id,
                        session_id=query.session_id,
                        intent=intent,
                        slots=slots,
                        asr_tokens=[
                            NluIntent.make_asr_tokens(recognition.tokens)
                        ],
                        asr_confidence=query.asr_confidence,
                        raw_input=original_input,
                        wakeword_id=query.wakeword_id,
                        lang=(query.lang or self.lang),
                        custom_data=query.custom_data,
                    ),
                    {
                        "intent_name": recognition.intent.name
                    },
                )
            else:
                # Not recognized
                yield NluIntentNotRecognized(
                    input=query.input,
                    id=query.id,
                    site_id=query.site_id,
                    session_id=query.session_id,
                    custom_data=query.custom_data,
                )
        except Exception as e:
            _LOGGER.exception("handle_query")
            yield NluError(
                site_id=query.site_id,
                session_id=query.session_id,
                error=str(e),
                context=original_input,
            )
示例#7
0
    async def handle_query(
        self, query: NluQuery
    ) -> typing.AsyncIterable[typing.Union[NluIntentParsed, typing.Tuple[
            NluIntent, TopicArgs], NluIntentNotRecognized, NluError, ]]:
        """Do intent recognition."""
        original_input = query.input

        try:
            self.maybe_load_engine()
            assert self.engine, "Snips engine not loaded. You may need to train."

            input_text = query.input

            # Fix casing for output event
            if self.word_transform:
                input_text = self.word_transform(input_text)

            # Do parsing
            result = self.engine.parse(input_text, query.intent_filter)
            intent_name = result.get("intent", {}).get("intentName")

            if intent_name:
                slots = [
                    Slot(
                        slot_name=s["slotName"],
                        entity=s["entity"],
                        value=s["value"],
                        raw_value=s["rawValue"],
                        range=SlotRange(start=s["range"]["start"],
                                        end=s["range"]["end"]),
                    ) for s in result.get("slots", [])
                ]

                # intentParsed
                yield NluIntentParsed(
                    input=query.input,
                    id=query.id,
                    site_id=query.site_id,
                    session_id=query.session_id,
                    intent=Intent(intent_name=intent_name,
                                  confidence_score=1.0),
                    slots=slots,
                )

                # intent
                yield (
                    NluIntent(
                        input=query.input,
                        id=query.id,
                        site_id=query.site_id,
                        session_id=query.session_id,
                        intent=Intent(intent_name=intent_name,
                                      confidence_score=1.0),
                        slots=slots,
                        asr_tokens=[
                            NluIntent.make_asr_tokens(query.input.split())
                        ],
                        raw_input=original_input,
                        wakeword_id=query.wakeword_id,
                        lang=query.lang,
                    ),
                    {
                        "intent_name": intent_name
                    },
                )
            else:
                # Not recognized
                yield NluIntentNotRecognized(
                    input=query.input,
                    id=query.id,
                    site_id=query.site_id,
                    session_id=query.session_id,
                )
        except Exception as e:
            _LOGGER.exception("handle_query")
            yield NluError(
                site_id=query.site_id,
                session_id=query.session_id,
                error=str(e),
                context=original_input,
            )
示例#8
0
    async def handle_query(
        self, query: NluQuery
    ) -> typing.AsyncIterable[typing.Union[NluIntentParsed, typing.Tuple[
            NluIntent, TopicArgs], NluIntentNotRecognized, NluError, ]]:
        """Do intent recognition."""
        # Check intent graph
        try:
            if (not self.intent_graph and self.intent_graph_path
                    and self.intent_graph_path.is_file()):
                _LOGGER.debug("Loading %s", self.intent_graph_path)
                with open(self.intent_graph_path, mode="rb") as graph_file:
                    self.intent_graph = rhasspynlu.gzip_pickle_to_graph(
                        graph_file)

            # Check examples
            if (self.intent_graph and self.examples_path
                    and self.examples_path.is_file()):

                def intent_filter(intent_name: str) -> bool:
                    """Filter out intents."""
                    if query.intent_filter:
                        return intent_name in query.intent_filter
                    return True

                original_text = query.input

                # Replace digits with words
                if self.replace_numbers:
                    # Have to assume whitespace tokenization
                    words = rhasspynlu.replace_numbers(query.input.split(),
                                                       self.language)
                    query.input = " ".join(words)

                input_text = query.input

                # Fix casing
                if self.word_transform:
                    input_text = self.word_transform(input_text)

                recognitions: typing.List[rhasspynlu.intent.Recognition] = []

                if input_text:
                    recognitions = rhasspyfuzzywuzzy.recognize(
                        input_text,
                        self.intent_graph,
                        str(self.examples_path),
                        intent_filter=intent_filter,
                        extra_converters=self.extra_converters,
                    )
            else:
                _LOGGER.error("No intent graph or examples loaded")
                recognitions = []

            # Use first recognition only if above threshold
            if (recognitions and recognitions[0] and recognitions[0].intent
                    and (recognitions[0].intent.confidence >=
                         self.confidence_threshold)):
                recognition = recognitions[0]
                assert recognition.intent
                intent = Intent(
                    intent_name=recognition.intent.name,
                    confidence_score=recognition.intent.confidence,
                )
                slots = [
                    Slot(
                        entity=(e.source or e.entity),
                        slot_name=e.entity,
                        confidence=1.0,
                        value=e.value_dict,
                        raw_value=e.raw_value,
                        range=SlotRange(
                            start=e.start,
                            end=e.end,
                            raw_start=e.raw_start,
                            raw_end=e.raw_end,
                        ),
                    ) for e in recognition.entities
                ]

                if query.custom_entities:
                    # Copy user-defined entities
                    for entity_name, entity_value in query.custom_entities.items(
                    ):
                        slots.append(
                            Slot(
                                entity=entity_name,
                                confidence=1.0,
                                value={"value": entity_value},
                            ))

                # intentParsed
                yield NluIntentParsed(
                    input=recognition.text,
                    id=query.id,
                    site_id=query.site_id,
                    session_id=query.session_id,
                    intent=intent,
                    slots=slots,
                )

                # intent
                yield (
                    NluIntent(
                        input=recognition.text,
                        id=query.id,
                        site_id=query.site_id,
                        session_id=query.session_id,
                        intent=intent,
                        slots=slots,
                        asr_tokens=[
                            NluIntent.make_asr_tokens(recognition.tokens)
                        ],
                        asr_confidence=query.asr_confidence,
                        raw_input=original_text,
                        wakeword_id=query.wakeword_id,
                        lang=(query.lang or self.lang),
                        custom_data=query.custom_data,
                    ),
                    {
                        "intent_name": recognition.intent.name
                    },
                )
            else:
                # Not recognized
                yield NluIntentNotRecognized(
                    input=query.input,
                    id=query.id,
                    site_id=query.site_id,
                    session_id=query.session_id,
                    custom_data=query.custom_data,
                )
        except Exception as e:
            _LOGGER.exception("handle_query")
            yield NluError(
                site_id=query.site_id,
                session_id=query.session_id,
                error=str(e),
                context=original_text,
            )
示例#9
0
def test_nlu_error():
    """Test NluError."""
    assert NluError.topic() == "hermes/error/nlu"
示例#10
0
    async def handle_query(
        self, query: NluQuery
    ) -> typing.AsyncIterable[
        typing.Union[
            typing.Tuple[NluIntent, TopicArgs],
            NluIntentParsed,
            NluIntentNotRecognized,
            NluError,
        ]
    ]:
        """Do intent recognition."""
        try:
            input_text = query.input

            # Fix casing
            if self.word_transform:
                input_text = self.word_transform(input_text)

            if self.nlu_url:
                # Use remote server
                _LOGGER.debug(self.nlu_url)

                params = {}

                # Add intent filter
                if query.intent_filter:
                    params["intentFilter"] = ",".join(query.intent_filter)

                async with self.http_session.post(
                    self.nlu_url, data=input_text, params=params, ssl=self.ssl_context
                ) as response:
                    response.raise_for_status()
                    intent_dict = await response.json()
            elif self.nlu_command:
                # Run external command
                _LOGGER.debug(self.nlu_command)
                proc = await asyncio.create_subprocess_exec(
                    *self.nlu_command,
                    stdin=asyncio.subprocess.PIPE,
                    stdout=asyncio.subprocess.PIPE,
                )

                input_bytes = (input_text.strip() + "\n").encode()
                output, error = await proc.communicate(input_bytes)
                if error:
                    _LOGGER.debug(error.decode())

                intent_dict = json.loads(output)
            else:
                _LOGGER.warning("Not handling NLU query (no URL or command)")
                return

            intent_name = intent_dict["intent"].get("name", "")

            if intent_name:
                # Recognized
                tokens = query.input.split()
                slots = [
                    Slot(
                        entity=e["entity"],
                        slot_name=e["entity"],
                        confidence=1,
                        value=e.get("value_details", {"value": ["value"]}),
                        raw_value=e.get("raw_value", e["value"]),
                        range=SlotRange(
                            start=e.get("start", 0),
                            end=e.get("end", 1),
                            raw_start=e.get("raw_start"),
                            raw_end=e.get("raw_end"),
                        ),
                    )
                    for e in intent_dict.get("entities", [])
                ]

                yield NluIntentParsed(
                    input=query.input,
                    id=query.id,
                    site_id=query.site_id,
                    session_id=query.session_id,
                    intent=Intent(
                        intent_name=intent_name,
                        confidence_score=intent_dict["intent"].get("confidence", 1.0),
                    ),
                    slots=slots,
                )

                yield (
                    NluIntent(
                        input=query.input,
                        id=query.id,
                        site_id=query.site_id,
                        session_id=query.session_id,
                        intent=Intent(
                            intent_name=intent_name,
                            confidence_score=intent_dict["intent"].get(
                                "confidence", 1.0
                            ),
                        ),
                        slots=slots,
                        asr_tokens=[NluIntent.make_asr_tokens(tokens)],
                        raw_input=query.input,
                        wakeword_id=query.wakeword_id,
                        lang=query.lang,
                    ),
                    {"intent_name": intent_name},
                )
            else:
                # Not recognized
                yield NluIntentNotRecognized(
                    input=query.input,
                    id=query.id,
                    site_id=query.site_id,
                    session_id=query.session_id,
                )
        except Exception as e:
            _LOGGER.exception("handle_query")
            yield NluError(
                error=repr(e),
                context=repr(query),
                site_id=query.site_id,
                session_id=query.session_id,
            )
    async def handle_query(
        self, query: NluQuery
    ) -> typing.AsyncIterable[typing.Union[
            NluIntentParsed, NluIntentNotRecognized, NluError, ]]:
        """Do intent recognition."""
        try:
            # Replace digits with words
            if self.replace_numbers:
                # Have to assume whitespace tokenization
                words = rhasspynlu.replace_numbers(query.input.split(),
                                                   self.number_language)
                query.input = " ".join(words)

            input_text = query.input

            # Fix casing for output event
            if self.word_transform:
                input_text = self.word_transform(input_text)

            parse_url = urljoin(self.rasa_url, "model/parse")
            _LOGGER.debug(parse_url)

            async with self.http_session.post(
                    parse_url,
                    json={
                        "text": input_text,
                        "project": self.rasa_project
                    },
                    ssl=self.ssl_context,
            ) as response:
                response.raise_for_status()
                intent_json = await response.json()
                intent = intent_json.get("intent", {})
                intent_name = intent.get("name", "")

                if intent_name and (query.intent_filter is None
                                    or intent_name in query.intent_filter):
                    confidence_score = float(intent.get("confidence", 0.0))
                    slots = [
                        Slot(
                            entity=e.get("entity", ""),
                            slot_name=e.get("entity", ""),
                            confidence=float(e.get("confidence", 0.0)),
                            value={
                                "kind": "Unknown",
                                "value": e.get("value", ""),
                                "additional_info":
                                e.get("additional_info", {}),
                                "extractor": e.get("extractor", None),
                            },
                            raw_value=e.get("value", ""),
                            range=SlotRange(
                                start=int(e.get("start", 0)),
                                end=int(e.get("end", 1)),
                                raw_start=int(e.get("start", 0)),
                                raw_end=int(e.get("end", 1)),
                            ),
                        ) for e in intent_json.get("entities", [])
                    ]

                    # intentParsed
                    yield NluIntentParsed(
                        input=input_text,
                        id=query.id,
                        site_id=query.site_id,
                        session_id=query.session_id,
                        intent=Intent(intent_name=intent_name,
                                      confidence_score=confidence_score),
                        slots=slots,
                    )
                else:
                    # Not recognized
                    yield NluIntentNotRecognized(
                        input=query.input,
                        id=query.id,
                        site_id=query.site_id,
                        session_id=query.session_id,
                    )
        except Exception as e:
            _LOGGER.exception("nlu query")
            yield NluError(
                site_id=query.site_id,
                session_id=query.session_id,
                error=str(e),
                context=query.input,
            )
    async def handle_train(
        self,
        train: NluTrain,
        site_id: str = "default"
    ) -> typing.AsyncIterable[typing.Union[typing.Tuple[NluTrainSuccess,
                                                        TopicArgs], NluError]]:
        """Transform sentences to intent graph"""
        try:
            _LOGGER.debug("Loading %s", train.graph_path)
            with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip:
                self.intent_graph = nx.readwrite.gpickle.read_gpickle(
                    graph_gzip)

            # Build Markdown sentences
            sentences_by_intent = NluHermesMqtt.make_sentences_by_intent(
                self.intent_graph)

            if self.examples_md_path is not None:
                # Use user-specified file
                examples_md_file = open(self.examples_md_path, "w+")
            else:
                # Use temporary file
                examples_md_file = typing.cast(
                    typing.TextIO, tempfile.TemporaryFile(mode="w+"))

            with examples_md_file:
                # Write to YAML/Markdown file
                for intent_name, intent_sents in sentences_by_intent.items():
                    # Rasa Markdown training format
                    print(f"## intent:{intent_name}", file=examples_md_file)
                    for intent_sent in intent_sents:
                        raw_index = 0
                        index_entity = {
                            e.raw_start: e
                            for e in intent_sent.entities
                        }
                        entity: typing.Optional[Entity] = None
                        sentence_tokens: typing.List[str] = []
                        entity_tokens: typing.List[str] = []
                        for raw_token in intent_sent.raw_tokens:
                            token = raw_token
                            if entity and (raw_index >= entity.raw_end):
                                # Finish current entity
                                last_token = entity_tokens[-1]
                                entity_tokens[
                                    -1] = f"{last_token}]({entity.entity})"
                                if entity.value != entity.raw_value:
                                    synonym = f":{entity.value}"
                                else:
                                    synonym = ""
                                entity_tokens[
                                    -1] = f"{last_token}]({entity.entity}{synonym})"
                                sentence_tokens.extend(entity_tokens)
                                entity = None
                                entity_tokens = []

                            new_entity = index_entity.get(raw_index)
                            if new_entity:
                                # Begin new entity
                                assert entity is None, "Unclosed entity"
                                entity = new_entity
                                entity_tokens = []
                                token = f"[{token}"

                            if entity:
                                # Add to current entity
                                entity_tokens.append(token)
                            else:
                                # Add directly to sentence
                                sentence_tokens.append(token)

                            raw_index += len(raw_token) + 1

                        if entity:
                            # Finish final entity
                            last_token = entity_tokens[-1]
                            entity_tokens[
                                -1] = f"{last_token}]({entity.entity})"
                            if entity.value != entity.raw_value:
                                synonym = f":{entity.value}"
                            else:
                                synonym = ""
                            entity_tokens[
                                -1] = f"{last_token}]({entity.entity}{synonym})"
                            sentence_tokens.extend(entity_tokens)

                        # Print single example
                        print("-",
                              " ".join(sentence_tokens),
                              file=examples_md_file)

                    # Newline between intents
                    print("", file=examples_md_file)

                # Create training YAML file
                with tempfile.NamedTemporaryFile(
                        suffix=".json", mode="w+",
                        delete=False) as training_file:

                    training_config = io.StringIO()

                    if self.config_path:
                        # Use provided config
                        with open(self.config_path, "r") as config_file:
                            # Copy verbatim
                            for line in config_file:
                                training_config.write(line)
                    else:
                        # Use default config
                        training_config.write(
                            f'language: "{self.rasa_language}"\n')
                        training_config.write(
                            'pipeline: "pretrained_embeddings_spacy"\n')

                    # Write markdown directly into YAML.
                    # Because reasons.
                    examples_md_file.seek(0)
                    blank_line = False
                    for line in examples_md_file:
                        line = line.strip()
                        if line:
                            if blank_line:
                                print("", file=training_file)
                                blank_line = False

                            print(f"  {line}", file=training_file)
                        else:
                            blank_line = True

                    # Do training via HTTP API
                    training_file.seek(0)
                    with open(training_file.name, "rb") as training_data:

                        training_body = {
                            "config": training_config.getvalue(),
                            "nlu": training_data.read().decode("utf-8"),
                        }
                        training_config.close()

                        # POST training data
                        training_response: typing.Optional[bytes] = None

                        try:
                            training_url = urljoin(self.rasa_url,
                                                   "model/train")
                            _LOGGER.debug(training_url)
                            async with self.http_session.post(
                                    training_url,
                                    json=training_body,
                                    params=json.dumps(
                                        {"project": self.rasa_project},
                                        ensure_ascii=False),
                                    ssl=self.ssl_context,
                            ) as response:
                                training_response = await response.read()
                                response.raise_for_status()

                                model_file = os.path.join(
                                    self.rasa_model_dir,
                                    response.headers["filename"])
                                _LOGGER.debug("Received model %s", model_file)

                                # Replace model with PUT.
                                # Do we really have to do this?
                                model_url = urljoin(self.rasa_url, "model")
                                _LOGGER.debug(model_url)
                                async with self.http_session.put(
                                        model_url,
                                        json={"model_file":
                                              model_file}) as response:
                                    response.raise_for_status()
                        except Exception as e:
                            if training_response:
                                _LOGGER.exception("rasa train")

                                # Rasa gives quite helpful error messages, so extract them from the response.
                                error_message = json.loads(
                                    training_response)["message"]
                                raise Exception(
                                    f"{response.reason}: {error_message}")

                            # Empty response; re-raise exception
                            raise e

            yield (NluTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_train")
            yield NluError(site_id=site_id,
                           error=str(e),
                           context=train.id,
                           session_id=train.id)
示例#13
0
    async def handle_query(
        self, query: NluQuery
    ) -> typing.AsyncIterable[typing.Union[NluIntentParsed, typing.Tuple[
            NluIntent, TopicArgs], NluIntentNotRecognized, NluError, ]]:
        """Do intent recognition."""
        try:
            original_input = query.input

            # Replace digits with words
            if self.replace_numbers:
                # Have to assume whitespace tokenization
                words = rhasspynlu.replace_numbers(query.input.split(),
                                                   self.number_language)
                query.input = " ".join(words)

            input_text = query.input

            # Fix casing for output event
            if self.word_transform:
                input_text = self.word_transform(input_text)

            parse_url = urljoin(self.rasa_url, "model/parse")
            _LOGGER.debug(parse_url)

            async with self.http_session.post(
                    parse_url,
                    json={
                        "text": input_text,
                        "project": self.rasa_project
                    },
                    ssl=self.ssl_context,
            ) as response:
                response.raise_for_status()
                intent_json = await response.json()
                intent = intent_json.get("intent", {})
                intent_name = intent.get("name", "")

                if intent_name and (query.intent_filter is None
                                    or intent_name in query.intent_filter):
                    confidence_score = float(intent.get("confidence", 0.0))
                    slots = [
                        Slot(
                            entity=e.get("entity", ""),
                            slot_name=e.get("entity", ""),
                            confidence=float(e.get("confidence", 0.0)),
                            value={
                                "kind": "Unknown",
                                "value": e.get("value", "")
                            },
                            raw_value=e.get("value", ""),
                            range=SlotRange(
                                start=int(e.get("start", 0)),
                                end=int(e.get("end", 1)),
                                raw_start=int(e.get("start", 0)),
                                raw_end=int(e.get("end", 1)),
                            ),
                        ) for e in intent_json.get("entities", [])
                    ]

                    if query.custom_entities:
                        # Copy user-defined entities
                        for entity_name, entity_value in query.custom_entities.items(
                        ):
                            slots.append(
                                Slot(
                                    entity=entity_name,
                                    confidence=1.0,
                                    value={"value": entity_value},
                                ))

                    # intentParsed
                    yield NluIntentParsed(
                        input=input_text,
                        id=query.id,
                        site_id=query.site_id,
                        session_id=query.session_id,
                        intent=Intent(intent_name=intent_name,
                                      confidence_score=confidence_score),
                        slots=slots,
                    )

                    # intent
                    yield (
                        NluIntent(
                            input=input_text,
                            id=query.id,
                            site_id=query.site_id,
                            session_id=query.session_id,
                            intent=Intent(
                                intent_name=intent_name,
                                confidence_score=confidence_score,
                            ),
                            slots=slots,
                            asr_tokens=[
                                NluIntent.make_asr_tokens(input_text.split())
                            ],
                            asr_confidence=query.asr_confidence,
                            raw_input=original_input,
                            lang=(query.lang or self.lang),
                            custom_data=query.custom_data,
                        ),
                        {
                            "intent_name": intent_name
                        },
                    )
                else:
                    # Not recognized
                    yield NluIntentNotRecognized(
                        input=query.input,
                        id=query.id,
                        site_id=query.site_id,
                        session_id=query.session_id,
                        custom_data=query.custom_data,
                    )
        except Exception as e:
            _LOGGER.exception("nlu query")
            yield NluError(
                site_id=query.site_id,
                session_id=query.session_id,
                error=str(e),
                context=query.input,
            )