def _parse_message(self, message): # for testing - you can short-cut the NLU part with a message # in the format /intent{"entity1": val1, "entity2": val2} # parse_data is a dict of intent & entities if message.text.startswith(INTENT_MESSAGE_PREFIX): parse_data = RegexInterpreter().parse(message.text) else: parse_data = self.interpreter.parse(message.text) logger.debug("Received user message '{}' with intent '{}' " "and entities '{}'".format(message.text, parse_data["intent"], parse_data["entities"])) return parse_data
def test_training_script_with_max_history_set(tmpdir): train_dialogue_model(DEFAULT_DOMAIN_PATH, DEFAULT_STORIES_FILE, tmpdir.strpath, interpreter=RegexInterpreter(), policy_config='data/test_config/max_hist_config.yml', kwargs={}) agent = Agent.load(tmpdir.strpath) for policy in agent.policy_ensemble.policies: if hasattr(policy.featurizer, 'max_history'): if type(policy) == FormPolicy: assert policy.featurizer.max_history == 2 else: assert policy.featurizer.max_history == 5
def test_random_seed(tmpdir, config_file): # set random seed in config file to # generate a reproducible training result agent_1 = train_dialogue_model(DEFAULT_DOMAIN_PATH, DEFAULT_STORIES_FILE, tmpdir.strpath + "1", interpreter=RegexInterpreter(), policy_config=config_file, kwargs={}) agent_2 = train_dialogue_model(DEFAULT_DOMAIN_PATH, DEFAULT_STORIES_FILE, tmpdir.strpath + "2", interpreter=RegexInterpreter(), policy_config=config_file, kwargs={}) processor_1 = agent_1.create_processor() processor_2 = agent_2.create_processor() probs_1 = processor_1.predict_next("1") probs_2 = processor_2.predict_next("2") assert probs_1["confidence"] == probs_2["confidence"]
def learnonline(self, msg, args): """Command to trigger learn_online on rasa agent""" token = config.BOT_IDENTITY['token'] if token is None: raise Exception('No slack token') train_agent= Agent(self.domain_file, policies=[MemoizationPolicy(max_history=2), KerasPolicy()], interpreter=RegexInterpreter()) training_data = train_agent.load_data(self.training_data_file) train_agent.train_online(training_data, input_channel=self.backend_adapter, batch_size=50, epochs=200, max_training_samples=300)
def _parse_message(self, message): # for testing - you can short-cut the NLU part with a message # in the format _intent[entity1=val1,entity=val2] # parse_data is a dict of intent & entities if message.text.startswith('_'): parse_data = RegexInterpreter().parse(message.text) else: parse_data = self.interpreter.parse(message.text) logger.debug("Received user message '{}' with intent '{}' " "and entities '{}'".format(message.text, parse_data["intent"], parse_data["entities"])) return parse_data
def read_from_file(file_name, domain, interpreter=RegexInterpreter(), template_variables=None): """Given a json file reads the contained stories.""" try: with io.open(file_name, "r") as f: lines = f.readlines() reader = StoryFileReader(domain, interpreter, template_variables) return reader.process_lines(lines) except Exception as e: raise Exception("Failed to parse '{}'. {}".format( os.path.abspath(file_name), e))
def run_fake_user(input_channel, max_training_samples=10, serve_forever=True): logger.info("Starting to train policy") agent = Agent(RASA_CORE_DOMAIN_PATH, policies=[MemoizationPolicy(), KerasPolicy()], interpreter=RegexInterpreter()) agent.train_online(RASA_CORE_TRAINING_DATA_PATH, input_channel=input_channel, epochs=RASA_CORE_EPOCHS, max_training_samples=max_training_samples) while serve_forever: agent.handle_message(UserMessage(back, ConsoleOutputChannel())) return agent
def test_message_processor(default_domain, capsys): story_filename = "data/dsl_stories/stories_defaultdomain.md" ensemble = SimplePolicyEnsemble([ScoringPolicy()]) interpreter = RegexInterpreter() PolicyTrainer(ensemble, default_domain, BinaryFeaturizer()).train(story_filename, max_history=3) tracker_store = InMemoryTrackerStore(default_domain) processor = MessageProcessor(interpreter, ensemble, default_domain, tracker_store) processor.handle_message(UserMessage("_greet", ConsoleOutputChannel())) out, _ = capsys.readouterr() assert "hey there!" in out
def test_channel_inheritance(): from rasa_core.channels import RestInput from rasa_core.channels import RasaChatInput from rasa_core.agent import Agent from rasa_core.interpreter import RegexInterpreter # load your trained agent agent = Agent.load(MODEL_PATH, interpreter=RegexInterpreter()) rasa_input = RasaChatInput("https://example.com") # set serve_forever=False if you want to keep the server running s = agent.handle_channels([RestInput(), rasa_input], 5004, serve_forever=False) assert s.started
def train_model_online(): agent = Agent(RASA_CORE_DOMAIN_PATH, policies=[MemoizationPolicy(), StatusPolicy()], interpreter=RegexInterpreter()) agent.train_online(RASA_CORE_TRAINING_DATA_PATH, input_channel=FileInputChannel( RASA_CORE_TRAINING_DATA_PATH, message_line_pattern='^\s*\*\s(.*)$', max_messages=10), epochs=RASA_CORE_EPOCHS) agent.interpreter = RasaNLUInterpreter(RASA_NLU_MODEL_PATH) return agent
def core_server(tmpdir_factory): model_path = tmpdir_factory.mktemp("model").strpath agent = Agent("data/test_domains/default.yml", policies=[AugmentedMemoizationPolicy(max_history=3)]) training_data = agent.load_data(DEFAULT_STORIES_FILE) agent.train(training_data) agent.persist(model_path) loaded_agent = Agent.load(model_path, interpreter=RegexInterpreter()) app = server.create_app(loaded_agent) channel.register([RestInput()], app, agent.handle_message, "/webhooks/") return app
def run_babi_online(max_messages=10): training_data = 'examples/babi/data/babi_task5_dev_rasa_even_smaller.md' logger.info("Starting to train policy") agent = Agent("examples/restaurant_domain.yml", policies=[MemoizationPolicy(), RestaurantPolicy()], interpreter=RegexInterpreter()) input_c = FileInputChannel(training_data, message_line_pattern='^\s*\*\s(.*)$', max_messages=max_messages) agent.train_online(training_data, input_channel=input_c, epochs=10) agent.interpreter = RasaNLUInterpreter(nlu_model_path) return agent
def test_story_visualization(default_domain, tmpdir): story_steps = StoryFileReader.read_from_file( "data/test_stories/stories.md", default_domain, interpreter=RegexInterpreter()) out_file = tmpdir.join("graph.png").strpath generated_graph = visualize_stories(story_steps, default_domain, output_file=out_file, max_history=3, should_merge_nodes=False) assert len(generated_graph.nodes()) == 51 assert len(generated_graph.edges()) == 56
def run_babi_online(max_messages=10): training_data = 'stories.md' logger.info("Starting to train policy") agent = Agent("domain.yml", policies=[MemoizationPolicy(), MusicPlayerPolicy()], interpreter=RegexInterpreter()) input_c = FileInputChannel(training_data, message_line_pattern='^\s*\*\s(.*)$', max_messages=max_messages) agent.train_online(training_data, input_channel=input_c, epochs=10) agent.interpreter = RasaNLUInterpreter(nlu_model_path) return agent
def extract_story_graph( resource_name, # type: Text domain, # type: Domain interpreter=None # type: Optional[NaturalLanguageInterpreter] ): # type: (...) -> StoryGraph from rasa_core.interpreter import RegexInterpreter from rasa_core.training.dsl import StoryFileReader from rasa_core.training.structures import StoryGraph if not interpreter: interpreter = RegexInterpreter() story_steps = StoryFileReader.read_from_folder(resource_name, domain, interpreter) return StoryGraph(story_steps)
def test_message_processor(default_domain, capsys): story_filename = "data/dsl_stories/stories_defaultdomain.md" ensemble = SimplePolicyEnsemble([ScoringPolicy()]) interpreter = RegexInterpreter() PolicyTrainer(ensemble, default_domain, BinaryFeaturizer()).train(story_filename, max_history=3) tracker_store = InMemoryTrackerStore(default_domain) processor = MessageProcessor(interpreter, ensemble, default_domain, tracker_store) out = CollectingOutputChannel() processor.handle_message(UserMessage("_greet[name=Core]", out)) assert ("default", "hey there Core!") == out.latest_output()
def run_online_training(self, ensemble, domain, interpreter=None, input_channel=None): from rasa_core.agent import Agent if interpreter is None: interpreter = RegexInterpreter() bot = Agent(domain, ensemble, featurizer=self.featurizer, interpreter=interpreter) bot.toggle_memoization(False) try: bot.handle_channel( input_channel if input_channel else ConsoleInputChannel()) except TrainingFinishedException: pass # training has finished
def test_regex_interpreter(): interp = RegexInterpreter() text = INTENT_MESSAGE_PREFIX + 'my_intent' result = interp.parse(text) assert result['text'] == text assert len(result['intent_ranking']) == 1 assert result['intent']['name'] == \ result['intent_ranking'][0]['name'] == \ 'my_intent' assert result['intent']['confidence'] == \ result['intent_ranking'][0]['confidence'] == \ pytest.approx(1.0) assert len(result['entities']) == 0 text = INTENT_MESSAGE_PREFIX + 'my_intent{"foo":"bar"}' result = interp.parse(text) assert result['text'] == text assert len(result['intent_ranking']) == 1 assert result['intent']['name'] == \ result['intent_ranking'][0]['name'] == \ 'my_intent' assert result['intent']['confidence'] == \ result['intent_ranking'][0]['confidence'] == \ pytest.approx(1.0) assert len(result['entities']) == 1 assert result["entities"][0]["entity"] == "foo" assert result["entities"][0]["value"] == "bar" text = INTENT_MESSAGE_PREFIX + '[email protected]' result = interp.parse(text) assert result['text'] == text assert len(result['intent_ranking']) == 1 assert result['intent']['name'] == \ result['intent_ranking'][0]['name'] == \ 'my_intent' assert result['intent']['confidence'] == \ result['intent_ranking'][0]['confidence'] == \ pytest.approx(0.5) assert len(result['entities']) == 0 text = INTENT_MESSAGE_PREFIX + '[email protected]{"foo":"bar"}' result = interp.parse(text) assert result['text'] == text assert len(result['intent_ranking']) == 1 assert result['intent']['name'] == \ result['intent_ranking'][0]['name'] == \ 'my_intent' assert result['intent']['confidence'] == \ result['intent_ranking'][0]['confidence'] == \ pytest.approx(0.5) assert len(result['entities']) == 1 assert result["entities"][0]["entity"] == "foo" assert result["entities"][0]["value"] == "bar"
def train_dialogue_model(domain_file, stories_file, output_path, use_online_learning, nlu_model_path, kwargs): agent = Agent(domain_file, policies=[MemoizationPolicy(), KerasPolicy()]) if use_online_learning: if nlu_model_path: agent.interpreter = RasaNLUInterpreter(nlu_model_path) else: agent.interpreter = RegexInterpreter() agent.train_online(stories_file, input_channel=ConsoleInputChannel(), epochs=10, model_path=output_path) else: agent.train(stories_file, validation_split=0.1, **kwargs) agent.persist(output_path)
def extract_training_data_from_file( filename, # type: Text domain, # type: Domain featurizer=None, # type: Featurizer interpreter=RegexInterpreter(), # type: NaturalLanguageInterpreter augmentation_factor=20, # type: int max_history=1, # type: int remove_duplicates=True, max_number_of_trackers=2000 # type: int ): # type: (...) -> DialogueTrainingData graph = extract_story_graph_from_file(filename, domain, interpreter) g = TrainingsDataGenerator(graph, domain, featurizer, remove_duplicates, augmentation_factor, max_history, max_number_of_trackers) return g.generate()
def read_from_file(filename, domain, interpreter=RegexInterpreter(), template_variables=None): """Given a json file reads the contained stories.""" try: with io.open(filename, "r") as f: lines = f.readlines() reader = StoryFileReader(domain, interpreter, template_variables) return reader.process_lines(lines) except ValueError as err: file_info = ("Invalid story file format. Failed to parse " "'{}'".format(os.path.abspath(filename))) logger.exception(file_info) if not err.args: err.args = ('',) err.args = err.args + (file_info,) raise
def run_babi_online(): training_data = 'data/weather.md' logger.info("Starting to train policy") agent = Agent("../weather_domain.yml", policies=[MemoizationPolicy(), WeatherPolicy()], interpreter=RegexInterpreter()) input_c = FileInputChannel(training_data, message_line_pattern='^\s*\*\s(.*)$', max_messages=10) agent.train_online(training_data, input_channel=input_c, epochs=10) agent.interpreter = RasaNLUHttpInterpreter( model_name='model_20171013-084449', token=None, server='http://localhost:7000') return agent
def run_fake_user(max_training_samples=10, serve_forever=True): training_data = 'examples/babi/data/babi_task5_fu_rasa_fewer_actions.md' logger.info("Starting to train policy") agent = Agent("examples/restaurant_domain.yml", policies=[MemoizationPolicy(), KerasPolicy()], interpreter=RegexInterpreter()) # Instead of generating the response messages ourselves, the fake user will # generate input messages based on the dialogue state input_channel = FakeUserInputChannel(agent.tracker_store) agent.train_online(training_data, input_channel=input_channel, epochs=1, max_training_samples=max_training_samples) return agent
def extract_story_graph( resource_name: Text, domain: 'Domain', interpreter: Optional['NaturalLanguageInterpreter'] = None, use_e2e: bool = False, exclusion_percentage: int = None ) -> 'StoryGraph': from rasa_core.interpreter import RegexInterpreter from rasa_core.training.dsl import StoryFileReader from rasa_core.training.structures import StoryGraph if not interpreter: interpreter = RegexInterpreter() story_steps = StoryFileReader.read_from_folder( resource_name, domain, interpreter, use_e2e=use_e2e, exclusion_percentage=exclusion_percentage) return StoryGraph(story_steps)
def run_online_training( self, domain, # type: Domain interpreter, # type: NaturalLanguageInterpreter input_channel=None # type: Optional[InputChannel] ): # type: (...) -> None from rasa_core.agent import Agent if interpreter is None: interpreter = RegexInterpreter() bot = Agent(domain, self, interpreter=interpreter) bot.toggle_memoization(False) try: bot.handle_channel( input_channel if input_channel else ConsoleInputChannel()) except TrainingFinishedException: pass # training has finished
def test_register_channel_without_route(): """Check we properly connect the input channel blueprint if route is None""" from rasa_core.channels import RestInput from flask import Flask import rasa_core # load your trained agent agent = Agent.load(MODEL_PATH, interpreter=RegexInterpreter()) input_channel = RestInput() app = Flask(__name__) rasa_core.channels.channel.register([input_channel], app, agent.handle_message, route=None) routes_list = utils.list_routes(app) assert routes_list.get("/webhook").startswith( "custom_webhook_RestInput.receive")
def train_dialog(online=False, nlu=False, policies=[KerasPolicy()]): agent = Agent("domain.yml", policies=policies) stories = "data\stories.md" output_path = "models\dialog" if online: if nlu: agent.interpreter = RasaNLUInterpreter( "models/infosys_cui/current") else: agent.interpreter = RegexInterpreter() agent.train_online(stories, input_channel=ConsoleInputChannel(), epochs=10, model_path=output_path) else: kwargs = {"epochs": 300} agent.train(stories, validation_split=0.1, **kwargs) agent.persist(output_path)
def extract_trackers_from_file( filename, # type: Text domain, # type: Domain featurizer, # type: Featurizer interpreter=RegexInterpreter(), # type: NaturalLanguageInterpreter max_history=1, # type: int max_number_of_trackers=2000 # type: int ): # type: (...) -> List[DialogueStateTracker] graph = extract_story_graph_from_file(filename, domain, interpreter) g = TrainingsDataGenerator(graph, domain, featurizer, use_story_concatenation=False, max_history=max_history, tracker_limit=1000, remove_duplicates=False, max_number_of_trackers=max_number_of_trackers) data = g.generate() return data.metadata["trackers"]
def test_concerts_online_example(): sys.path.append("examples/concertbot/") from train_online import run_concertbot_online from rasa_core import utils # simulates cmdline input / detailed explanation above utils.input = lambda _=None: "2" input_channel = FileInputChannel('examples/concertbot/data/stories.md', message_line_pattern='^\s*\*\s(.*)$', max_messages=3) domain_file = os.path.join("examples", "concertbot", "concert_domain.yml") training_file = os.path.join("examples", "concertbot", "data", "stories.md") agent = run_concertbot_online(input_channel, RegexInterpreter(), domain_file, training_file) responses = agent.handle_message("_greet") assert responses[-1] in { "hey there!", "how can I help you?", "default message" }
def _parse_message(self, message): # for testing - you can short-cut the NLU part with a message # in the format _intent[entity1=val1,entity=val2] # parse_data is a dict of intent & entities if message.text.startswith(INTENT_MESSAGE_PREFIX): parse_data = RegexInterpreter().parse(message.text) elif isinstance(self.interpreter, RasaMultiNLUHttpInterpreter): language = message.output_channel.language if hasattr( message.output_channel, 'language') else 'en' parse_data = self.interpreter.parse( message.text, message.output_channel.language, ) else: parse_data = self.interpreter.parse(message.text) logger.debug("Received user message '{}' with intent '{}' " "and entities '{}'".format(message.text, parse_data["intent"], parse_data["entities"])) return parse_data
def _parse_message(self, message): # for testing - you can short-cut the NLU part with a message # in the format _intent[entity1=val1,entity=val2] # parse_data is a dict of intent & entities if (message.text.startswith(INTENT_MESSAGE_PREFIX) or message.text.startswith("_")): if RegexInterpreter.is_using_deprecated_format(message.text): warnings.warn( "Parsing messages with leading `_` is deprecated and " "will be removed. Instead, prepend your intents with " "`{0}`, e.g. `{0}mood_greet` " "or `{0}restart`.".format(INTENT_MESSAGE_PREFIX)) parse_data = RegexInterpreter().parse(message.text) else: parse_data = self.interpreter.parse(message.text) logger.debug("Received user message '{}' with intent '{}' " "and entities '{}'".format(message.text, parse_data["intent"], parse_data["entities"])) return parse_data