예제 #1
0
    async def handle_train(
        self,
        train: NluTrain,
        site_id: str = "default"
    ) -> typing.AsyncIterable[typing.Union[typing.Tuple[NluTrainSuccess,
                                                        TopicArgs], NluError]]:
        """Transform sentences/slots into Snips NLU training dataset."""
        try:
            assert train.sentences, "No training sentences"

            start_time = time.perf_counter()

            new_engine = rhasspysnips_nlu.train(
                sentences_dict=train.sentences,
                language=self.snips_language,
                slots_dict=train.slots,
                engine_path=self.engine_path,
                dataset_path=self.dataset_path,
            )

            end_time = time.perf_counter()

            _LOGGER.debug("Trained Snips engine in %s second(s)",
                          end_time - start_time)
            self.engine = new_engine

            yield (NluTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_train")
            yield NluError(site_id=site_id,
                           session_id=train.id,
                           error=str(e),
                           context=train.id)
예제 #2
0
    async def async_test_train_success(self):
        """Verify successful training."""
        train_id = str(uuid.uuid4())

        def fake_read_graph(*args, **kwargs):
            return MagicMock()

        def fake_train(*args, **kwargs):
            return MagicMock()

        # Create temporary file for "open"
        with tempfile.NamedTemporaryFile(mode="wb+",
                                         suffix=".gz") as graph_file:
            train = NluTrain(id=train_id, graph_path=graph_file.name)

            # Ensure fake graph "loads" and training goes through
            with patch("rhasspynlu.gzip_pickle_to_graph", new=fake_read_graph):
                with patch("rhasspyfuzzywuzzy.train", new=fake_train):
                    results = []
                    async for result in self.hermes.on_message(
                            train, site_id=self.site_id):
                        results.append(result)

            self.assertEqual(results, [(NluTrainSuccess(id=train_id), {
                "site_id": self.site_id
            })])
예제 #3
0
    async def handle_train(
        self, train: NluTrain, site_id: str = "default"
    ) -> typing.AsyncIterable[
        typing.Union[typing.Tuple[NluTrainSuccess, TopicArgs], NluError]
    ]:
        """Transform sentences to intent examples"""
        try:
            _LOGGER.debug("Loading %s", train.graph_path)
            with open(train.graph_path, mode="rb") as graph_file:
                self.intent_graph = rhasspynlu.gzip_pickle_to_graph(graph_file)

            self.examples = rhasspyfuzzywuzzy.train(self.intent_graph)

            if self.examples_path:
                # Write examples to JSON file
                with open(self.examples_path, "w") as examples_file:
                    json.dump(self.examples, examples_file)

                _LOGGER.debug("Wrote %s", str(self.examples_path))

            yield (NluTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_train")
            yield NluError(
                site_id=site_id, session_id=train.id, error=str(e), context=train.id
            )
예제 #4
0
    async def handle_nlu_train(
        self, train: NluTrain, site_id: str = "default"
    ) -> typing.AsyncIterable[
        typing.Union[typing.Tuple[NluTrainSuccess, TopicArgs], NluError]
    ]:
        """Re-trains NLU system"""
        try:
            # Load gzipped graph pickle
            _LOGGER.debug("Loading %s", train.graph_path)
            with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip:
                intent_graph = nx.readwrite.gpickle.read_gpickle(graph_gzip)

            # Get JSON intent graph
            json_graph = rhasspynlu.graph_to_json(intent_graph)

            if self.nlu_train_url:
                # Remote NLU server
                _LOGGER.debug(self.nlu_train_url)

                async with self.http_session.post(
                    self.nlu_train_url, json=json_graph, ssl=self.ssl_context
                ) as response:
                    # No data expected in response
                    response.raise_for_status()
            elif self.nlu_train_command:
                # Local NLU training command
                _LOGGER.debug(self.nlu_train_command)

                proc = await asyncio.create_subprocess_exec(
                    *self.nlu_train_command,
                    stdin=asyncio.subprocess.PIPE,
                    stderr=asyncio.subprocess.PIPE,
                )

                output, error = await proc.communicate(json.dumps(json_graph).encode())

                if output:
                    _LOGGER.debug(output.decode())

                if error:
                    _LOGGER.debug(error.decode())
            else:
                _LOGGER.warning("Can't train NLU system. No train URL or command.")

            # Report success
            yield (NluTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_nlu_train")
            yield NluError(
                error=str(e),
                context=f"url='{self.nlu_train_url}', command='{self.nlu_train_command}'",
                site_id=site_id,
                session_id=train.id,
            )
예제 #5
0
    async def handle_train(
        self,
        train: NluTrain,
        site_id: str = "default"
    ) -> typing.AsyncIterable[typing.Union[typing.Tuple[NluTrainSuccess,
                                                        TopicArgs], NluError]]:
        """Transform sentences to intent examples"""
        try:
            _LOGGER.debug("Loading %s", train.graph_path)
            with open(train.graph_path, mode="rb") as graph_file:
                self.intent_graph = rhasspynlu.gzip_pickle_to_graph(graph_file)

            examples = rhasspyfuzzywuzzy.train(self.intent_graph)

            if self.examples_path:
                if self.examples_path.is_file():
                    # Delete existing file
                    self.examples_path.unlink()

                # Write examples to SQLite database
                conn = sqlite3.connect(str(self.examples_path))
                c = conn.cursor()
                c.execute("""DROP TABLE IF EXISTS intents""")
                c.execute(
                    """CREATE TABLE intents (sentence text, path text)""")

                for _, sentences in examples.items():
                    for sentence, path in sentences.items():
                        c.execute(
                            "INSERT INTO intents VALUES (?, ?)",
                            (sentence, json.dumps(path, ensure_ascii=False)),
                        )

                conn.commit()
                conn.close()

                _LOGGER.debug("Wrote %s", str(self.examples_path))
            yield (NluTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_train")
            yield NluError(site_id=site_id,
                           session_id=train.id,
                           error=str(e),
                           context=train.id)
    async def handle_train(
        self,
        train: NluTrain,
        site_id: str = "default"
    ) -> typing.AsyncIterable[typing.Union[typing.Tuple[NluTrainSuccess,
                                                        TopicArgs], NluError]]:
        """Transform sentences to intent graph"""
        try:
            _LOGGER.debug("Loading %s", train.graph_path)
            with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip:
                self.intent_graph = nx.readwrite.gpickle.read_gpickle(
                    graph_gzip)

            # Build Markdown sentences
            sentences_by_intent = NluHermesMqtt.make_sentences_by_intent(
                self.intent_graph)

            if self.examples_md_path is not None:
                # Use user-specified file
                examples_md_file = open(self.examples_md_path, "w+")
            else:
                # Use temporary file
                examples_md_file = typing.cast(
                    typing.TextIO, tempfile.TemporaryFile(mode="w+"))

            with examples_md_file:
                # Write to YAML/Markdown file
                for intent_name, intent_sents in sentences_by_intent.items():
                    # Rasa Markdown training format
                    print(f"## intent:{intent_name}", file=examples_md_file)
                    for intent_sent in intent_sents:
                        raw_index = 0
                        index_entity = {
                            e.raw_start: e
                            for e in intent_sent.entities
                        }
                        entity: typing.Optional[Entity] = None
                        sentence_tokens: typing.List[str] = []
                        entity_tokens: typing.List[str] = []
                        for raw_token in intent_sent.raw_tokens:
                            token = raw_token
                            if entity and (raw_index >= entity.raw_end):
                                # Finish current entity
                                last_token = entity_tokens[-1]
                                entity_tokens[
                                    -1] = f"{last_token}]({entity.entity})"
                                if entity.value != entity.raw_value:
                                    synonym = f":{entity.value}"
                                else:
                                    synonym = ""
                                entity_tokens[
                                    -1] = f"{last_token}]({entity.entity}{synonym})"
                                sentence_tokens.extend(entity_tokens)
                                entity = None
                                entity_tokens = []

                            new_entity = index_entity.get(raw_index)
                            if new_entity:
                                # Begin new entity
                                assert entity is None, "Unclosed entity"
                                entity = new_entity
                                entity_tokens = []
                                token = f"[{token}"

                            if entity:
                                # Add to current entity
                                entity_tokens.append(token)
                            else:
                                # Add directly to sentence
                                sentence_tokens.append(token)

                            raw_index += len(raw_token) + 1

                        if entity:
                            # Finish final entity
                            last_token = entity_tokens[-1]
                            entity_tokens[
                                -1] = f"{last_token}]({entity.entity})"
                            if entity.value != entity.raw_value:
                                synonym = f":{entity.value}"
                            else:
                                synonym = ""
                            entity_tokens[
                                -1] = f"{last_token}]({entity.entity}{synonym})"
                            sentence_tokens.extend(entity_tokens)

                        # Print single example
                        print("-",
                              " ".join(sentence_tokens),
                              file=examples_md_file)

                    # Newline between intents
                    print("", file=examples_md_file)

                # Create training YAML file
                with tempfile.NamedTemporaryFile(
                        suffix=".json", mode="w+",
                        delete=False) as training_file:

                    training_config = io.StringIO()

                    if self.config_path:
                        # Use provided config
                        with open(self.config_path, "r") as config_file:
                            # Copy verbatim
                            for line in config_file:
                                training_config.write(line)
                    else:
                        # Use default config
                        training_config.write(
                            f'language: "{self.rasa_language}"\n')
                        training_config.write(
                            'pipeline: "pretrained_embeddings_spacy"\n')

                    # Write markdown directly into YAML.
                    # Because reasons.
                    examples_md_file.seek(0)
                    blank_line = False
                    for line in examples_md_file:
                        line = line.strip()
                        if line:
                            if blank_line:
                                print("", file=training_file)
                                blank_line = False

                            print(f"  {line}", file=training_file)
                        else:
                            blank_line = True

                    # Do training via HTTP API
                    training_file.seek(0)
                    with open(training_file.name, "rb") as training_data:

                        training_body = {
                            "config": training_config.getvalue(),
                            "nlu": training_data.read().decode("utf-8"),
                        }
                        training_config.close()

                        # POST training data
                        training_response: typing.Optional[bytes] = None

                        try:
                            training_url = urljoin(self.rasa_url,
                                                   "model/train")
                            _LOGGER.debug(training_url)
                            async with self.http_session.post(
                                    training_url,
                                    json=training_body,
                                    params=json.dumps(
                                        {"project": self.rasa_project},
                                        ensure_ascii=False),
                                    ssl=self.ssl_context,
                            ) as response:
                                training_response = await response.read()
                                response.raise_for_status()

                                model_file = os.path.join(
                                    self.rasa_model_dir,
                                    response.headers["filename"])
                                _LOGGER.debug("Received model %s", model_file)

                                # Replace model with PUT.
                                # Do we really have to do this?
                                model_url = urljoin(self.rasa_url, "model")
                                _LOGGER.debug(model_url)
                                async with self.http_session.put(
                                        model_url,
                                        json={"model_file":
                                              model_file}) as response:
                                    response.raise_for_status()
                        except Exception as e:
                            if training_response:
                                _LOGGER.exception("rasa train")

                                # Rasa gives quite helpful error messages, so extract them from the response.
                                error_message = json.loads(
                                    training_response)["message"]
                                raise Exception(
                                    f"{response.reason}: {error_message}")

                            # Empty response; re-raise exception
                            raise e

            yield (NluTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_train")
            yield NluError(site_id=site_id,
                           error=str(e),
                           context=train.id,
                           session_id=train.id)