def on_message(self, client, userdata, msg): """Received message from MQTT broker.""" try: _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload), msg.topic) if msg.topic == NluQuery.topic(): json_payload = json.loads(msg.payload) # Check siteId if not self._check_siteId(json_payload): return try: query = NluQuery(**json_payload) _LOGGER.debug("<- %s", query) self.handle_query(query) except Exception as e: _LOGGER.exception("nlu query") self.publish( NluError( siteId=query.siteId, sessionId=json_payload.get("sessionId", ""), error=str(e), context="", )) except Exception: _LOGGER.exception("on_message")
async def handle_train( self, train: NluTrain, site_id: str = "default" ) -> typing.AsyncIterable[typing.Union[typing.Tuple[NluTrainSuccess, TopicArgs], NluError]]: """Transform sentences/slots into Snips NLU training dataset.""" try: assert train.sentences, "No training sentences" start_time = time.perf_counter() new_engine = rhasspysnips_nlu.train( sentences_dict=train.sentences, language=self.snips_language, slots_dict=train.slots, engine_path=self.engine_path, dataset_path=self.dataset_path, ) end_time = time.perf_counter() _LOGGER.debug("Trained Snips engine in %s second(s)", end_time - start_time) self.engine = new_engine yield (NluTrainSuccess(id=train.id), {"site_id": site_id}) except Exception as e: _LOGGER.exception("handle_train") yield NluError(site_id=site_id, session_id=train.id, error=str(e), context=train.id)
async def handle_train( self, train: NluTrain, site_id: str = "default" ) -> typing.AsyncIterable[ typing.Union[typing.Tuple[NluTrainSuccess, TopicArgs], NluError] ]: """Transform sentences to intent examples""" try: _LOGGER.debug("Loading %s", train.graph_path) with open(train.graph_path, mode="rb") as graph_file: self.intent_graph = rhasspynlu.gzip_pickle_to_graph(graph_file) self.examples = rhasspyfuzzywuzzy.train(self.intent_graph) if self.examples_path: # Write examples to JSON file with open(self.examples_path, "w") as examples_file: json.dump(self.examples, examples_file) _LOGGER.debug("Wrote %s", str(self.examples_path)) yield (NluTrainSuccess(id=train.id), {"site_id": site_id}) except Exception as e: _LOGGER.exception("handle_train") yield NluError( site_id=site_id, session_id=train.id, error=str(e), context=train.id )
async def handle_nlu_train( self, train: NluTrain, site_id: str = "default" ) -> typing.AsyncIterable[ typing.Union[typing.Tuple[NluTrainSuccess, TopicArgs], NluError] ]: """Re-trains NLU system""" try: # Load gzipped graph pickle _LOGGER.debug("Loading %s", train.graph_path) with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip: intent_graph = nx.readwrite.gpickle.read_gpickle(graph_gzip) # Get JSON intent graph json_graph = rhasspynlu.graph_to_json(intent_graph) if self.nlu_train_url: # Remote NLU server _LOGGER.debug(self.nlu_train_url) async with self.http_session.post( self.nlu_train_url, json=json_graph, ssl=self.ssl_context ) as response: # No data expected in response response.raise_for_status() elif self.nlu_train_command: # Local NLU training command _LOGGER.debug(self.nlu_train_command) proc = await asyncio.create_subprocess_exec( *self.nlu_train_command, stdin=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) output, error = await proc.communicate(json.dumps(json_graph).encode()) if output: _LOGGER.debug(output.decode()) if error: _LOGGER.debug(error.decode()) else: _LOGGER.warning("Can't train NLU system. No train URL or command.") # Report success yield (NluTrainSuccess(id=train.id), {"site_id": site_id}) except Exception as e: _LOGGER.exception("handle_nlu_train") yield NluError( error=str(e), context=f"url='{self.nlu_train_url}', command='{self.nlu_train_command}'", site_id=site_id, session_id=train.id, )
async def handle_train( self, train: NluTrain, site_id: str = "default" ) -> typing.AsyncIterable[typing.Union[typing.Tuple[NluTrainSuccess, TopicArgs], NluError]]: """Transform sentences to intent examples""" try: _LOGGER.debug("Loading %s", train.graph_path) with open(train.graph_path, mode="rb") as graph_file: self.intent_graph = rhasspynlu.gzip_pickle_to_graph(graph_file) examples = rhasspyfuzzywuzzy.train(self.intent_graph) if self.examples_path: if self.examples_path.is_file(): # Delete existing file self.examples_path.unlink() # Write examples to SQLite database conn = sqlite3.connect(str(self.examples_path)) c = conn.cursor() c.execute("""DROP TABLE IF EXISTS intents""") c.execute( """CREATE TABLE intents (sentence text, path text)""") for _, sentences in examples.items(): for sentence, path in sentences.items(): c.execute( "INSERT INTO intents VALUES (?, ?)", (sentence, json.dumps(path, ensure_ascii=False)), ) conn.commit() conn.close() _LOGGER.debug("Wrote %s", str(self.examples_path)) yield (NluTrainSuccess(id=train.id), {"site_id": site_id}) except Exception as e: _LOGGER.exception("handle_train") yield NluError(site_id=site_id, session_id=train.id, error=str(e), context=train.id)
async def handle_query( self, query: NluQuery ) -> typing.AsyncIterable[typing.Union[NluIntentParsed, typing.Tuple[ NluIntent, TopicArgs], NluIntentNotRecognized, NluError, ]]: """Do intent recognition.""" original_input = query.input try: if not self.intent_graph and self.graph_path and self.graph_path.is_file( ): # Load graph from file _LOGGER.debug("Loading %s", self.graph_path) with open(self.graph_path, mode="rb") as graph_file: self.intent_graph = rhasspynlu.gzip_pickle_to_graph( graph_file) if self.intent_graph: def intent_filter(intent_name: str) -> bool: """Filter out intents.""" if query.intent_filter: return intent_name in query.intent_filter return True # Replace digits with words if self.replace_numbers: # Have to assume whitespace tokenization words = rhasspynlu.replace_numbers(query.input.split(), self.language) query.input = " ".join(words) input_text = query.input # Fix casing for output event if self.word_transform: input_text = self.word_transform(input_text) if self.failure_token and (self.failure_token in query.input.split()): # Failure token was found in input recognitions = [] else: # Pass in raw query input so raw values will be correct recognitions = recognize( query.input, self.intent_graph, intent_filter=intent_filter, word_transform=self.word_transform, fuzzy=self.fuzzy, extra_converters=self.extra_converters, ) else: _LOGGER.error("No intent graph loaded") recognitions = [] if NluHermesMqtt.is_success(recognitions): # Use first recognition only. recognition = recognitions[0] assert recognition is not None assert recognition.intent is not None intent = Intent( intent_name=recognition.intent.name, confidence_score=recognition.intent.confidence, ) slots = [ Slot( entity=(e.source or e.entity), slot_name=e.entity, confidence=1.0, value=e.value_dict, raw_value=e.raw_value, range=SlotRange( start=e.start, end=e.end, raw_start=e.raw_start, raw_end=e.raw_end, ), ) for e in recognition.entities ] if query.custom_entities: # Copy user-defined entities for entity_name, entity_value in query.custom_entities.items( ): slots.append( Slot( entity=entity_name, confidence=1.0, value={"value": entity_value}, )) # intentParsed yield NluIntentParsed( input=recognition.text, id=query.id, site_id=query.site_id, session_id=query.session_id, intent=intent, slots=slots, ) # intent yield ( NluIntent( input=recognition.text, id=query.id, site_id=query.site_id, session_id=query.session_id, intent=intent, slots=slots, asr_tokens=[ NluIntent.make_asr_tokens(recognition.tokens) ], asr_confidence=query.asr_confidence, raw_input=original_input, wakeword_id=query.wakeword_id, lang=(query.lang or self.lang), custom_data=query.custom_data, ), { "intent_name": recognition.intent.name }, ) else: # Not recognized yield NluIntentNotRecognized( input=query.input, id=query.id, site_id=query.site_id, session_id=query.session_id, custom_data=query.custom_data, ) except Exception as e: _LOGGER.exception("handle_query") yield NluError( site_id=query.site_id, session_id=query.session_id, error=str(e), context=original_input, )
async def handle_query( self, query: NluQuery ) -> typing.AsyncIterable[typing.Union[NluIntentParsed, typing.Tuple[ NluIntent, TopicArgs], NluIntentNotRecognized, NluError, ]]: """Do intent recognition.""" original_input = query.input try: self.maybe_load_engine() assert self.engine, "Snips engine not loaded. You may need to train." input_text = query.input # Fix casing for output event if self.word_transform: input_text = self.word_transform(input_text) # Do parsing result = self.engine.parse(input_text, query.intent_filter) intent_name = result.get("intent", {}).get("intentName") if intent_name: slots = [ Slot( slot_name=s["slotName"], entity=s["entity"], value=s["value"], raw_value=s["rawValue"], range=SlotRange(start=s["range"]["start"], end=s["range"]["end"]), ) for s in result.get("slots", []) ] # intentParsed yield NluIntentParsed( input=query.input, id=query.id, site_id=query.site_id, session_id=query.session_id, intent=Intent(intent_name=intent_name, confidence_score=1.0), slots=slots, ) # intent yield ( NluIntent( input=query.input, id=query.id, site_id=query.site_id, session_id=query.session_id, intent=Intent(intent_name=intent_name, confidence_score=1.0), slots=slots, asr_tokens=[ NluIntent.make_asr_tokens(query.input.split()) ], raw_input=original_input, wakeword_id=query.wakeword_id, lang=query.lang, ), { "intent_name": intent_name }, ) else: # Not recognized yield NluIntentNotRecognized( input=query.input, id=query.id, site_id=query.site_id, session_id=query.session_id, ) except Exception as e: _LOGGER.exception("handle_query") yield NluError( site_id=query.site_id, session_id=query.session_id, error=str(e), context=original_input, )
async def handle_query( self, query: NluQuery ) -> typing.AsyncIterable[typing.Union[NluIntentParsed, typing.Tuple[ NluIntent, TopicArgs], NluIntentNotRecognized, NluError, ]]: """Do intent recognition.""" # Check intent graph try: if (not self.intent_graph and self.intent_graph_path and self.intent_graph_path.is_file()): _LOGGER.debug("Loading %s", self.intent_graph_path) with open(self.intent_graph_path, mode="rb") as graph_file: self.intent_graph = rhasspynlu.gzip_pickle_to_graph( graph_file) # Check examples if (self.intent_graph and self.examples_path and self.examples_path.is_file()): def intent_filter(intent_name: str) -> bool: """Filter out intents.""" if query.intent_filter: return intent_name in query.intent_filter return True original_text = query.input # Replace digits with words if self.replace_numbers: # Have to assume whitespace tokenization words = rhasspynlu.replace_numbers(query.input.split(), self.language) query.input = " ".join(words) input_text = query.input # Fix casing if self.word_transform: input_text = self.word_transform(input_text) recognitions: typing.List[rhasspynlu.intent.Recognition] = [] if input_text: recognitions = rhasspyfuzzywuzzy.recognize( input_text, self.intent_graph, str(self.examples_path), intent_filter=intent_filter, extra_converters=self.extra_converters, ) else: _LOGGER.error("No intent graph or examples loaded") recognitions = [] # Use first recognition only if above threshold if (recognitions and recognitions[0] and recognitions[0].intent and (recognitions[0].intent.confidence >= self.confidence_threshold)): recognition = recognitions[0] assert recognition.intent intent = Intent( intent_name=recognition.intent.name, confidence_score=recognition.intent.confidence, ) slots = [ Slot( entity=(e.source or e.entity), slot_name=e.entity, confidence=1.0, value=e.value_dict, raw_value=e.raw_value, range=SlotRange( start=e.start, end=e.end, raw_start=e.raw_start, raw_end=e.raw_end, ), ) for e in recognition.entities ] if query.custom_entities: # Copy user-defined entities for entity_name, entity_value in query.custom_entities.items( ): slots.append( Slot( entity=entity_name, confidence=1.0, value={"value": entity_value}, )) # intentParsed yield NluIntentParsed( input=recognition.text, id=query.id, site_id=query.site_id, session_id=query.session_id, intent=intent, slots=slots, ) # intent yield ( NluIntent( input=recognition.text, id=query.id, site_id=query.site_id, session_id=query.session_id, intent=intent, slots=slots, asr_tokens=[ NluIntent.make_asr_tokens(recognition.tokens) ], asr_confidence=query.asr_confidence, raw_input=original_text, wakeword_id=query.wakeword_id, lang=(query.lang or self.lang), custom_data=query.custom_data, ), { "intent_name": recognition.intent.name }, ) else: # Not recognized yield NluIntentNotRecognized( input=query.input, id=query.id, site_id=query.site_id, session_id=query.session_id, custom_data=query.custom_data, ) except Exception as e: _LOGGER.exception("handle_query") yield NluError( site_id=query.site_id, session_id=query.session_id, error=str(e), context=original_text, )
def test_nlu_error(): """Test NluError.""" assert NluError.topic() == "hermes/error/nlu"
async def handle_query( self, query: NluQuery ) -> typing.AsyncIterable[ typing.Union[ typing.Tuple[NluIntent, TopicArgs], NluIntentParsed, NluIntentNotRecognized, NluError, ] ]: """Do intent recognition.""" try: input_text = query.input # Fix casing if self.word_transform: input_text = self.word_transform(input_text) if self.nlu_url: # Use remote server _LOGGER.debug(self.nlu_url) params = {} # Add intent filter if query.intent_filter: params["intentFilter"] = ",".join(query.intent_filter) async with self.http_session.post( self.nlu_url, data=input_text, params=params, ssl=self.ssl_context ) as response: response.raise_for_status() intent_dict = await response.json() elif self.nlu_command: # Run external command _LOGGER.debug(self.nlu_command) proc = await asyncio.create_subprocess_exec( *self.nlu_command, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, ) input_bytes = (input_text.strip() + "\n").encode() output, error = await proc.communicate(input_bytes) if error: _LOGGER.debug(error.decode()) intent_dict = json.loads(output) else: _LOGGER.warning("Not handling NLU query (no URL or command)") return intent_name = intent_dict["intent"].get("name", "") if intent_name: # Recognized tokens = query.input.split() slots = [ Slot( entity=e["entity"], slot_name=e["entity"], confidence=1, value=e.get("value_details", {"value": ["value"]}), raw_value=e.get("raw_value", e["value"]), range=SlotRange( start=e.get("start", 0), end=e.get("end", 1), raw_start=e.get("raw_start"), raw_end=e.get("raw_end"), ), ) for e in intent_dict.get("entities", []) ] yield NluIntentParsed( input=query.input, id=query.id, site_id=query.site_id, session_id=query.session_id, intent=Intent( intent_name=intent_name, confidence_score=intent_dict["intent"].get("confidence", 1.0), ), slots=slots, ) yield ( NluIntent( input=query.input, id=query.id, site_id=query.site_id, session_id=query.session_id, intent=Intent( intent_name=intent_name, confidence_score=intent_dict["intent"].get( "confidence", 1.0 ), ), slots=slots, asr_tokens=[NluIntent.make_asr_tokens(tokens)], raw_input=query.input, wakeword_id=query.wakeword_id, lang=query.lang, ), {"intent_name": intent_name}, ) else: # Not recognized yield NluIntentNotRecognized( input=query.input, id=query.id, site_id=query.site_id, session_id=query.session_id, ) except Exception as e: _LOGGER.exception("handle_query") yield NluError( error=repr(e), context=repr(query), site_id=query.site_id, session_id=query.session_id, )
async def handle_query( self, query: NluQuery ) -> typing.AsyncIterable[typing.Union[ NluIntentParsed, NluIntentNotRecognized, NluError, ]]: """Do intent recognition.""" try: # Replace digits with words if self.replace_numbers: # Have to assume whitespace tokenization words = rhasspynlu.replace_numbers(query.input.split(), self.number_language) query.input = " ".join(words) input_text = query.input # Fix casing for output event if self.word_transform: input_text = self.word_transform(input_text) parse_url = urljoin(self.rasa_url, "model/parse") _LOGGER.debug(parse_url) async with self.http_session.post( parse_url, json={ "text": input_text, "project": self.rasa_project }, ssl=self.ssl_context, ) as response: response.raise_for_status() intent_json = await response.json() intent = intent_json.get("intent", {}) intent_name = intent.get("name", "") if intent_name and (query.intent_filter is None or intent_name in query.intent_filter): confidence_score = float(intent.get("confidence", 0.0)) slots = [ Slot( entity=e.get("entity", ""), slot_name=e.get("entity", ""), confidence=float(e.get("confidence", 0.0)), value={ "kind": "Unknown", "value": e.get("value", ""), "additional_info": e.get("additional_info", {}), "extractor": e.get("extractor", None), }, raw_value=e.get("value", ""), range=SlotRange( start=int(e.get("start", 0)), end=int(e.get("end", 1)), raw_start=int(e.get("start", 0)), raw_end=int(e.get("end", 1)), ), ) for e in intent_json.get("entities", []) ] # intentParsed yield NluIntentParsed( input=input_text, id=query.id, site_id=query.site_id, session_id=query.session_id, intent=Intent(intent_name=intent_name, confidence_score=confidence_score), slots=slots, ) else: # Not recognized yield NluIntentNotRecognized( input=query.input, id=query.id, site_id=query.site_id, session_id=query.session_id, ) except Exception as e: _LOGGER.exception("nlu query") yield NluError( site_id=query.site_id, session_id=query.session_id, error=str(e), context=query.input, )
async def handle_train( self, train: NluTrain, site_id: str = "default" ) -> typing.AsyncIterable[typing.Union[typing.Tuple[NluTrainSuccess, TopicArgs], NluError]]: """Transform sentences to intent graph""" try: _LOGGER.debug("Loading %s", train.graph_path) with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip: self.intent_graph = nx.readwrite.gpickle.read_gpickle( graph_gzip) # Build Markdown sentences sentences_by_intent = NluHermesMqtt.make_sentences_by_intent( self.intent_graph) if self.examples_md_path is not None: # Use user-specified file examples_md_file = open(self.examples_md_path, "w+") else: # Use temporary file examples_md_file = typing.cast( typing.TextIO, tempfile.TemporaryFile(mode="w+")) with examples_md_file: # Write to YAML/Markdown file for intent_name, intent_sents in sentences_by_intent.items(): # Rasa Markdown training format print(f"## intent:{intent_name}", file=examples_md_file) for intent_sent in intent_sents: raw_index = 0 index_entity = { e.raw_start: e for e in intent_sent.entities } entity: typing.Optional[Entity] = None sentence_tokens: typing.List[str] = [] entity_tokens: typing.List[str] = [] for raw_token in intent_sent.raw_tokens: token = raw_token if entity and (raw_index >= entity.raw_end): # Finish current entity last_token = entity_tokens[-1] entity_tokens[ -1] = f"{last_token}]({entity.entity})" if entity.value != entity.raw_value: synonym = f":{entity.value}" else: synonym = "" entity_tokens[ -1] = f"{last_token}]({entity.entity}{synonym})" sentence_tokens.extend(entity_tokens) entity = None entity_tokens = [] new_entity = index_entity.get(raw_index) if new_entity: # Begin new entity assert entity is None, "Unclosed entity" entity = new_entity entity_tokens = [] token = f"[{token}" if entity: # Add to current entity entity_tokens.append(token) else: # Add directly to sentence sentence_tokens.append(token) raw_index += len(raw_token) + 1 if entity: # Finish final entity last_token = entity_tokens[-1] entity_tokens[ -1] = f"{last_token}]({entity.entity})" if entity.value != entity.raw_value: synonym = f":{entity.value}" else: synonym = "" entity_tokens[ -1] = f"{last_token}]({entity.entity}{synonym})" sentence_tokens.extend(entity_tokens) # Print single example print("-", " ".join(sentence_tokens), file=examples_md_file) # Newline between intents print("", file=examples_md_file) # Create training YAML file with tempfile.NamedTemporaryFile( suffix=".json", mode="w+", delete=False) as training_file: training_config = io.StringIO() if self.config_path: # Use provided config with open(self.config_path, "r") as config_file: # Copy verbatim for line in config_file: training_config.write(line) else: # Use default config training_config.write( f'language: "{self.rasa_language}"\n') training_config.write( 'pipeline: "pretrained_embeddings_spacy"\n') # Write markdown directly into YAML. # Because reasons. examples_md_file.seek(0) blank_line = False for line in examples_md_file: line = line.strip() if line: if blank_line: print("", file=training_file) blank_line = False print(f" {line}", file=training_file) else: blank_line = True # Do training via HTTP API training_file.seek(0) with open(training_file.name, "rb") as training_data: training_body = { "config": training_config.getvalue(), "nlu": training_data.read().decode("utf-8"), } training_config.close() # POST training data training_response: typing.Optional[bytes] = None try: training_url = urljoin(self.rasa_url, "model/train") _LOGGER.debug(training_url) async with self.http_session.post( training_url, json=training_body, params=json.dumps( {"project": self.rasa_project}, ensure_ascii=False), ssl=self.ssl_context, ) as response: training_response = await response.read() response.raise_for_status() model_file = os.path.join( self.rasa_model_dir, response.headers["filename"]) _LOGGER.debug("Received model %s", model_file) # Replace model with PUT. # Do we really have to do this? model_url = urljoin(self.rasa_url, "model") _LOGGER.debug(model_url) async with self.http_session.put( model_url, json={"model_file": model_file}) as response: response.raise_for_status() except Exception as e: if training_response: _LOGGER.exception("rasa train") # Rasa gives quite helpful error messages, so extract them from the response. error_message = json.loads( training_response)["message"] raise Exception( f"{response.reason}: {error_message}") # Empty response; re-raise exception raise e yield (NluTrainSuccess(id=train.id), {"site_id": site_id}) except Exception as e: _LOGGER.exception("handle_train") yield NluError(site_id=site_id, error=str(e), context=train.id, session_id=train.id)
async def handle_query( self, query: NluQuery ) -> typing.AsyncIterable[typing.Union[NluIntentParsed, typing.Tuple[ NluIntent, TopicArgs], NluIntentNotRecognized, NluError, ]]: """Do intent recognition.""" try: original_input = query.input # Replace digits with words if self.replace_numbers: # Have to assume whitespace tokenization words = rhasspynlu.replace_numbers(query.input.split(), self.number_language) query.input = " ".join(words) input_text = query.input # Fix casing for output event if self.word_transform: input_text = self.word_transform(input_text) parse_url = urljoin(self.rasa_url, "model/parse") _LOGGER.debug(parse_url) async with self.http_session.post( parse_url, json={ "text": input_text, "project": self.rasa_project }, ssl=self.ssl_context, ) as response: response.raise_for_status() intent_json = await response.json() intent = intent_json.get("intent", {}) intent_name = intent.get("name", "") if intent_name and (query.intent_filter is None or intent_name in query.intent_filter): confidence_score = float(intent.get("confidence", 0.0)) slots = [ Slot( entity=e.get("entity", ""), slot_name=e.get("entity", ""), confidence=float(e.get("confidence", 0.0)), value={ "kind": "Unknown", "value": e.get("value", "") }, raw_value=e.get("value", ""), range=SlotRange( start=int(e.get("start", 0)), end=int(e.get("end", 1)), raw_start=int(e.get("start", 0)), raw_end=int(e.get("end", 1)), ), ) for e in intent_json.get("entities", []) ] if query.custom_entities: # Copy user-defined entities for entity_name, entity_value in query.custom_entities.items( ): slots.append( Slot( entity=entity_name, confidence=1.0, value={"value": entity_value}, )) # intentParsed yield NluIntentParsed( input=input_text, id=query.id, site_id=query.site_id, session_id=query.session_id, intent=Intent(intent_name=intent_name, confidence_score=confidence_score), slots=slots, ) # intent yield ( NluIntent( input=input_text, id=query.id, site_id=query.site_id, session_id=query.session_id, intent=Intent( intent_name=intent_name, confidence_score=confidence_score, ), slots=slots, asr_tokens=[ NluIntent.make_asr_tokens(input_text.split()) ], asr_confidence=query.asr_confidence, raw_input=original_input, lang=(query.lang or self.lang), custom_data=query.custom_data, ), { "intent_name": intent_name }, ) else: # Not recognized yield NluIntentNotRecognized( input=query.input, id=query.id, site_id=query.site_id, session_id=query.session_id, custom_data=query.custom_data, ) except Exception as e: _LOGGER.exception("nlu query") yield NluError( site_id=query.site_id, session_id=query.session_id, error=str(e), context=query.input, )