def test_get_next_action_probabilities_passes_interpreter_to_policies( monkeypatch: MonkeyPatch, ): policy = TEDPolicy() test_interpreter = Mock() def predict_action_probabilities( tracker: DialogueStateTracker, domain: Domain, interpreter: NaturalLanguageInterpreter, **kwargs, ) -> List[float]: assert interpreter == test_interpreter return [1, 0] policy.predict_action_probabilities = predict_action_probabilities ensemble = SimplePolicyEnsemble(policies=[policy]) domain = Domain.empty() processor = MessageProcessor(test_interpreter, ensemble, domain, InMemoryTrackerStore(domain), Mock()) # This should not raise processor._get_next_action_probabilities( DialogueStateTracker.from_events("lala", [ActionExecuted(ACTION_LISTEN_NAME)]))
def test_get_next_action_probabilities_pass_policy_predictions_without_interpreter_arg( predict_function: Callable, ): policy = TEDPolicy() policy.predict_action_probabilities = predict_function ensemble = SimplePolicyEnsemble(policies=[policy]) interpreter = Mock() domain = Domain.empty() processor = MessageProcessor( interpreter, ensemble, domain, InMemoryTrackerStore(domain), InMemoryLockStore(), Mock(), ) with pytest.warns(DeprecationWarning): processor._get_next_action_probabilities( DialogueStateTracker.from_events( "lala", [ActionExecuted(ACTION_LISTEN_NAME)] ) )
def test_predict_next_action_raises_limit_reached_exception(domain: Domain): interpreter = RegexInterpreter() ensemble = SimplePolicyEnsemble( policies=[RulePolicy(), MemoizationPolicy()]) tracker_store = InMemoryTrackerStore(domain) lock_store = InMemoryLockStore() processor = MessageProcessor( interpreter, ensemble, domain, tracker_store, lock_store, TemplatedNaturalLanguageGenerator(domain.responses), max_number_of_predictions=1, ) tracker = DialogueStateTracker.from_events( "test", evts=[ ActionExecuted(ACTION_LISTEN_NAME), UserUttered("Hi!"), ActionExecuted("test_action"), ], ) tracker.set_latest_action({"action_name": "test_action"}) with pytest.raises(ActionLimitReached): processor.predict_next_action(tracker)
def _in_training_data_fraction(action_list): """Given a list of action items, returns the fraction of actions that were predicted using one of the Memoization policies.""" from rasa.core.policies import SimplePolicyEnsemble in_training_data = [ a["action"] for a in action_list if not SimplePolicyEnsemble.is_not_memo_policy(a["policy"]) ] return len(in_training_data) / len(action_list)
async def train(self): """Train the engine. """ nltk.download('punkt') lang = self.config['language'] if not os.path.exists('data/' + self.config['skill-id']): _LOGGER.info("Starting Skill training.") _LOGGER.info("Generating stories.") data, domain_data, stories = await GenerateStories.run( self.config['skill-id'], self.config['language'], self.asm) training_data = TrainingData(training_examples=data) nlu_config = RasaNLUModelConfig({ "language": lang, "pipeline": self.config['pipeline'], "data": None }) trainer = Trainer(nlu_config, None, True) _LOGGER.info("Training Arcus NLU") trainer.train(training_data) trainer.persist("data/" + self.config['skill-id'], None, 'nlu') # Rasa core domain = Domain.from_dict(domain_data) reader = StoryFileReader(domain, RegexInterpreter(), None, False) story_steps = await reader.process_lines(stories) graph = StoryGraph(story_steps) g = TrainingDataGenerator( graph, domain, remove_duplicates=True, unique_last_num_states=None, augmentation_factor=20, tracker_limit=None, use_story_concatenation=True, debug_plots=False, ) training_trackers = g.generate() policy_list = SimplePolicyEnsemble.from_dict( {"policies": self.config['policies']}) policy_ensemble = SimplePolicyEnsemble(policy_list) _LOGGER.info("Training Arcus Core") policy_ensemble.train(training_trackers, domain) policy_ensemble.persist( "data/" + self.config['skill-id'] + "/core", False) domain.persist("data/" + self.config['skill-id'] + "/core/model") domain.persist_specification("data/" + self.config['skill-id'] + "/core")
def is_predicted_event_in_training_data(policy: Optional[Text]): """Determine whether event predicted by `policy` was in the training data. A predicted event is considered to be in training data if it was predicted by the `MemoizationPolicy` or the `AugmentedMemoizationPolicy`. Args: policy: Policy of the predicted event. Returns: `True` if the event was not predicted, otherwise `True` if it was not predicted by a memo policy, else `False`. """ if not policy: # event was not predicted by a policy return True return not SimplePolicyEnsemble.is_not_memo_policy(policy)
async def get_response(self, request): """Train the engine. """ if self.config.get('domain') is None: self.config.setdefault( 'domain', Domain.from_file("data/" + self.config['skill-id'] + "/core/model")) self.config.setdefault( 'tracker_store', ArcusTrackerStore(self.config.get('domain'), self.asm)) domain = self.config.get('domain') tracker_store = self.config.get('tracker_store') nlg = NaturalLanguageGenerator.create(None, domain) policy_ensemble = SimplePolicyEnsemble.load("data/" + self.config['skill-id'] + "/core") interpreter = LocalNLUInterpreter(request) url = 'http://localhost:8080/api/v1/skill/generic_action' processor = MessageProcessor(interpreter, policy_ensemble, domain, tracker_store, nlg, action_endpoint=EndpointConfig(url), message_preprocessor=None) message_nlu = UserMessage(request['text'], None, request['user'], input_channel=request['channel']) result = await processor.handle_message(message_nlu) if result is not None and len(result) > 0: return {"text": result[0]['text']} else: _LOGGER.info(result) return {"text": "error"}
def test_predict_next_action_with_hidden_rules(): rule_intent = "rule_intent" rule_action = "rule_action" story_intent = "story_intent" story_action = "story_action" rule_slot = "rule_slot" story_slot = "story_slot" domain = Domain.from_yaml(f""" version: "2.0" intents: - {rule_intent} - {story_intent} actions: - {rule_action} - {story_action} slots: {rule_slot}: type: text {story_slot}: type: text """) rule = TrackerWithCachedStates.from_events( "rule", domain=domain, slots=domain.slots, evts=[ ActionExecuted(RULE_SNIPPET_ACTION_NAME), ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": rule_intent}), ActionExecuted(rule_action), SlotSet(rule_slot, rule_slot), ActionExecuted(ACTION_LISTEN_NAME), ], is_rule_tracker=True, ) story = TrackerWithCachedStates.from_events( "story", domain=domain, slots=domain.slots, evts=[ ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": story_intent}), ActionExecuted(story_action), SlotSet(story_slot, story_slot), ActionExecuted(ACTION_LISTEN_NAME), ], ) interpreter = RegexInterpreter() ensemble = SimplePolicyEnsemble( policies=[RulePolicy(), MemoizationPolicy()]) ensemble.train([rule, story], domain, interpreter) tracker_store = InMemoryTrackerStore(domain) lock_store = InMemoryLockStore() processor = MessageProcessor( interpreter, ensemble, domain, tracker_store, lock_store, TemplatedNaturalLanguageGenerator(domain.responses), ) tracker = DialogueStateTracker.from_events( "casd", evts=[ ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": rule_intent}), ], slots=domain.slots, ) action, prediction = processor.predict_next_action(tracker) assert action._name == rule_action assert prediction.hide_rule_turn processor._log_action_on_tracker(tracker, action, [SlotSet(rule_slot, rule_slot)], prediction) action, prediction = processor.predict_next_action(tracker) assert isinstance(action, ActionListen) assert prediction.hide_rule_turn processor._log_action_on_tracker(tracker, action, None, prediction) tracker.events.append(UserUttered(intent={"name": story_intent})) # rules are hidden correctly if memo policy predicts next actions correctly action, prediction = processor.predict_next_action(tracker) assert action._name == story_action assert not prediction.hide_rule_turn processor._log_action_on_tracker(tracker, action, [SlotSet(story_slot, story_slot)], prediction) action, prediction = processor.predict_next_action(tracker) assert isinstance(action, ActionListen) assert not prediction.hide_rule_turn