def test_swap_intent_with2(): swap_rules = InputValidator( InputValidator._load_yaml(VALIDATOR_RULES_YAML)['intent_substitution']) # make sure intent not swapped after another action than the one specified in the rule parse_data = {"intent": {"name": "whatever", "confidence": 1.0}} Rules._swap_intent(parse_data, None, swap_rules.rules[1]) assert parse_data["intent"]["name"] == "whatever"
def test_swap_intent_with1(): swap_rules = InputValidator( InputValidator._load_yaml(VALIDATOR_RULES_YAML)['intent_substitution']) # make sure intent swapped parse_data = {"intent": {"name": "chitchat.i_am_angry", "confidence": 1.0}} Rules._swap_intent(parse_data, None, swap_rules.rules[1]) assert parse_data["intent"]["name"] == "request.handover"
def test_swap_intent_after1(): swap_rules = InputValidator( InputValidator._load_yaml(VALIDATOR_RULES_YAML)['intent_substitution']) # make sure intent swapped parse_data = {"intent": {"name": "whatever", "confidence": 1.0}} Rules._swap_intent(parse_data, "utter_something", swap_rules.rules[0]) assert parse_data["intent"]["name"] == "intent_something"
def __init__(self, interpreter, # type: NaturalLanguageInterpreter policy_ensemble, # type: PolicyEnsemble domain, # type: Domain tracker_store, # type: TrackerStore generator, # type: NaturalLanguageGenerator max_number_of_predictions=10, # type: int message_preprocessor=None, # type: Optional[LambdaType] on_circuit_break=None, # type: Optional[LambdaType] create_dispatcher=None, # type: Optional[LambdaType] rules_file=None # type: Optional[str] ): self.rules = Rules(rules_file) if rules_file is not None else None super(SuperMessageProcessor, self).__init__( interpreter, policy_ensemble, domain, tracker_store, generator, max_number_of_predictions, message_preprocessor, on_circuit_break ) self.create_dispatcher = create_dispatcher if self.create_dispatcher is None: self.create_dispatcher = lambda sender_id, output_channel, nlg: Dispatcher(sender_id, output_channel, nlg)
def test_swap_intent_with3(): swap_rules = InputValidator( InputValidator._load_yaml(VALIDATOR_RULES_YAML)['intent_substitution']) # make sure intent is swapped and entity is added parse_data = {"intent": {"name": "chitchat.bye", "confidence": 1.0}} Rules._swap_intent(parse_data, None, swap_rules.rules[2]) assert parse_data["intent"]["name"] == "chitchat" assert parse_data["entities"][0]["entity"] == "intent" assert parse_data["entities"][0]["value"] == "chitchat.bye"
def test_swap_intent_with4(): swap_rules = InputValidator( InputValidator._load_yaml(VALIDATOR_RULES_YAML)['intent_substitution']) # just checking regex is ok parse_data = { "intent": { "name": "chitchat.this_is_frustrating", "confidence": 1.0 } } Rules._swap_intent(parse_data, None, swap_rules.rules[2]) assert parse_data["intent"]["name"] == "chitchat.this_is_frustrating"
def test_swap_intent_after2(): swap_rules = InputValidator( InputValidator._load_yaml(VALIDATOR_RULES_YAML)['intent_substitution']) # make sure intent is not swapped when in unless list parse_data = { "intent": { "name": "chitchat.this_is_frustrating", "confidence": 1.0 } } Rules._swap_intent(parse_data, "utter_something", swap_rules.rules[0]) assert parse_data["intent"]["name"] == "chitchat.this_is_frustrating"
class SuperMessageProcessor(MessageProcessor): def __init__( self, interpreter, # type: NaturalLanguageInterpreter policy_ensemble, # type: PolicyEnsemble domain, # type: Domain tracker_store, # type: TrackerStore generator, # type: NaturalLanguageGenerator max_number_of_predictions=10, # type: int message_preprocessor=None, # type: Optional[LambdaType] on_circuit_break=None, # type: Optional[LambdaType] create_dispatcher=None, # type: Optional[LambdaType] rules_file=None # type: Optional[str] ): self.rules = Rules(rules_file) if rules_file is not None else None super(SuperMessageProcessor, self).__init__(interpreter, policy_ensemble, domain, tracker_store, generator, max_number_of_predictions, message_preprocessor, on_circuit_break) self.create_dispatcher = create_dispatcher if self.create_dispatcher is None: self.create_dispatcher = lambda sender_id, output_channel, nlg: Dispatcher( sender_id, output_channel, nlg) def _handle_message_with_tracker(self, message, tracker): # type: (UserMessage, DialogueStateTracker) -> None parse_data = self._parse_message(message) # rules section # if self._rule_interrupts(parse_data, tracker, message): return # rules section - end # # don't ever directly mutate the tracker # - instead pass its events to log tracker.update( UserUttered(message.text, parse_data["intent"], parse_data["entities"], parse_data)) # store all entities as slots for e in self.domain.slots_for_entities(parse_data["entities"]): tracker.update(e) logger.debug("Logged UserUtterance - " "tracker now has {} events".format(len(tracker.events))) def _rule_interrupts(self, parse_data, tracker, message): if self.rules is not None: dispatcher = self.create_dispatcher(message.sender_id, message.output_channel, self.nlg) return self.rules.interrupts(dispatcher, parse_data, tracker, self._run_action) def _predict_and_execute_next_action(self, message, tracker): # this will actually send the response to the user dispatcher = self.create_dispatcher(message.sender_id, message.output_channel, self.nlg) # keep taking actions decided by the policy until it chooses to 'listen' should_predict_another_action = True num_predicted_actions = 0 self._log_slots(tracker) # action loop. predicts actions until we hit action listen while (should_predict_another_action and self._should_handle_message(tracker) and num_predicted_actions < self.max_number_of_predictions): # this actually just calls the policy's method by the same name action = self._get_next_action(tracker) should_predict_another_action = self._run_action( action, tracker, dispatcher) num_predicted_actions += 1 if (num_predicted_actions == self.max_number_of_predictions and should_predict_another_action): # circuit breaker was tripped logger.warn("Circuit breaker tripped. Stopped predicting " "more actions for sender '{}'".format( tracker.sender_id)) if self.on_circuit_break: # call a registered callback self.on_circuit_break(tracker, dispatcher)