def run_fake_user(input_channel, max_training_samples=10, serve_forever=True): customer = Customer() training_data = 'examples/babi/data/babi_task5_fu_rasa_fewer_actions.md' logger.info("Starting to train policy") agent = Agent("examples/restaurant_domain.yml", policies=[MemoizationPolicy(), KerasPolicy()], interpreter=RegexInterpreter()) agent.train_online(training_data, input_channel=input_channel, epochs=1, max_training_samples=max_training_samples) while serve_forever: tracker = agent.tracker_store.retrieve('default') back = customer.respond_to_action(tracker) if back == 'reset': agent.handle_message("_greet", output_channel=ConsoleOutputChannel()) else: agent.handle_message(back, output_channel=ConsoleOutputChannel()) return agent
def replay_events(tracker, agent): # type: (DialogueStateTracker, Agent) -> None """Take a tracker and replay the logged user utterances against an agent. During replaying of the user utterances, the executed actions and events created by the agent are compared to the logged ones of the tracker that is getting replayed. If they differ, a warning is logged. At the end, the tracker stored in the agent's tracker store for the same sender id will have quite the same state as the one that got replayed.""" actions_between_utterances = [] last_prediction = [ACTION_LISTEN_NAME] for i, event in enumerate(tracker.events_after_latest_restart()): if isinstance(event, UserUttered): _check_prediction_aligns_with_story(last_prediction, actions_between_utterances) actions_between_utterances = [] print(utils.wrap_with_color(event.text, utils.bcolors.OKGREEN)) agent.handle_message(event.text, sender_id=tracker.sender_id, output_channel=ConsoleOutputChannel()) tracker = agent.tracker_store.retrieve(tracker.sender_id) last_prediction = evaluate.actions_since_last_utterance(tracker) elif isinstance(event, ActionExecuted): actions_between_utterances.append(event.action_name) _check_prediction_aligns_with_story(last_prediction, actions_between_utterances)
def _record_messages(self, on_message, max_message_limit=None): logger.info("Bot loaded. Fake user will automatically respond!") num_messages = 0 while max_message_limit is None or num_messages < max_message_limit: tracker = self.tracker_store.retrieve('nlu') text = self.customer.respond_to_action(tracker) on_message(UserMessage(text, ConsoleOutputChannel())) num_messages += 1
def setUp(self): self.not_undestood = ActionNotUnderstood() # set Interpreter (NLU) to Rasa NLU self.interpreter = 'rasa-nlu/models/rasa-nlu/default/socialcompanionnlu' # load the trained agent model self.agent = Agent.load('./models/dialogue', self.interpreter) self.agent.handle_channel(ConsoleInputChannel()) # TODO mock dispatcher, tracker and domain self.dispatcher = Dispatcher(output_channel=ConsoleOutputChannel()) self.tracker = DialogueStateTracker() self.domain = Domain()
def __init__(self, file_name, output_channel=None, message_line_pattern=".*", max_messages=None): from rasa_core.channels.console import ConsoleOutputChannel self.message_filter = re.compile(message_line_pattern) self.file_name = file_name self.max_messages = max_messages if output_channel: self.output_channel = output_channel else: self.output_channel = ConsoleOutputChannel()
def test_message_processor(default_domain, capsys): story_filename = "data/dsl_stories/stories_defaultdomain.md" ensemble = SimplePolicyEnsemble([ScoringPolicy()]) interpreter = RegexInterpreter() PolicyTrainer(ensemble, default_domain, BinaryFeaturizer()).train(story_filename, max_history=3) tracker_store = InMemoryTrackerStore(default_domain) processor = MessageProcessor(interpreter, ensemble, default_domain, tracker_store) processor.handle_message(UserMessage("_greet", ConsoleOutputChannel())) out, _ = capsys.readouterr() assert "hey there!" in out
def run_fake_user(input_channel, max_training_samples=10, serve_forever=True): logger.info("Starting to train policy") agent = Agent(RASA_CORE_DOMAIN_PATH, policies=[MemoizationPolicy(), KerasPolicy()], interpreter=RegexInterpreter()) agent.train_online(RASA_CORE_TRAINING_DATA_PATH, input_channel=input_channel, epochs=RASA_CORE_EPOCHS, max_training_samples=max_training_samples) while serve_forever: agent.handle_message(UserMessage(back, ConsoleOutputChannel())) return agent
def __init__(self, filename, output_channel=None, message_line_pattern=".*", max_messages=None): # type: (Text, OutputChannel, Text, Optional[int]) -> None from rasa_core.channels.console import ConsoleOutputChannel self.message_filter = re.compile(message_line_pattern) self.filename = filename self.max_messages = max_messages if output_channel: self.output_channel = output_channel else: self.output_channel = ConsoleOutputChannel()
def default_dispatcher_cmd(default_domain): bot = ConsoleOutputChannel() return Dispatcher("my-sender", bot, default_domain)