def test_local_action_factory_fails_on_duplicated_actions(): actions = ["action_listen", "rasa_core.actions.action.ActionListen", "utter_test"] with pytest.raises(ValueError): TemplateDomain.instantiate_actions( "local", actions, None, ["utter_test"])
def test_local_action_factory_fails_on_duplicated_actions(): actions = [ "action_listen", "rasa_core.actions.action.ActionListen", "utter_test" ] with pytest.raises(ValueError): TemplateDomain.instantiate_actions("local", actions, None, ["utter_test"])
def test_custom_slot_type(tmpdir): domain_path = utilities.write_text_to_file(tmpdir, "domain.yml", """ slots: custom: type: tests.conftest.CustomSlot templates: utter_greet: - hey there! actions: - utter_greet """) TemplateDomain.load(domain_path)
def test_domain_fails_on_unknown_custom_slot_type(tmpdir): domain_path = utilities.write_text_to_file(tmpdir, "domain.yml", """ slots: custom: type: tests.conftest.Unknown templates: utter_greet: - hey there! actions: - utter_greet""") with pytest.raises(ValueError): TemplateDomain.load(domain_path)
def test_custom_slot_type(tmpdir): domain_path = write_domain_yml( tmpdir, """ slots: custom: type: tests.conftest.CustomSlot templates: utter_greet: - hey there! actions: - utter_greet """) TemplateDomain.load(domain_path)
def main(): arg_parser = create_argument_parser() cmdline_args = arg_parser.parse_args() logging.basicConfig(level=cmdline_args.loglevel) # Set HTTP log path http_logs = os.path.join(dir_path, "bot/logs/http.logs") # Load Domain domain = TemplateDomain.load(os.path.join(MODEL_PATH, "domain.yml")) # Create tracker store tracker_store = ExtendedRedisTrackerStore(domain, db=2, timeout=900, mock=False) # Initialize Rasa Core Server model_directory = MODEL_PATH interpreter = INTEPRETER_PATH rasa = ExtendedRasaCoreServer(model_directory, interpreter, cmdline_args.loglevel, http_logs, cmdline_args.cors, auth_token = cmdline_args.auth_token, tracker_store=tracker_store) logger.info("Started http server on port %s" % cmdline_args.port) # Run Rasa Core Server rasa.app.run("0.0.0.0", cmdline_args.port)
def test_inmemory_tracker_store(filename): domain = TemplateDomain.load("data/test_domains/default_with_topic.yml") tracker = tracker_from_dialogue_file(filename, domain) tracker_store = InMemoryTrackerStore(domain) tracker_store.save(tracker) restored = tracker_store.retrieve(tracker.sender_id) assert restored == tracker
def main(): arg_parser = create_argument_parser() cmdline_args = arg_parser.parse_args() logging.basicConfig(level=cmdline_args.loglevel) # Set HTTP log path http_logs = os.path.join(dir_path, "bot/logs/http.logs") # Load Domain domain = TemplateDomain.load(os.path.join(MODEL_PATH, "domain.yml")) # Create tracker store tracker_store = ExtendedRedisTrackerStore(domain, db=2, timeout=900, mock=False) # Initialize Rasa Core Server model_directory = MODEL_PATH interpreter = INTEPRETER_PATH rasa = ExtendedRasaCoreServer(model_directory, interpreter, cmdline_args.loglevel, http_logs, cmdline_args.cors, auth_token=cmdline_args.auth_token, tracker_store=tracker_store) logger.info("Started http server on port %s" % cmdline_args.port) # Run Rasa Core Server rasa.app.run("0.0.0.0", cmdline_args.port)
def load(cls, path, interpreter=None, tracker_store=None, action_factory=None, rules_file=None, generator=None, create_dispatcher=None): # type: (Text, Any, Optional[TrackerStore]) -> Agent if path is None: raise ValueError("No domain path specified.") domain = TemplateDomain.load(os.path.join(path, "domain.yml"), action_factory) # ensures the domain hasn't changed between test and train domain.compare_with_specification(path) ensemble = PolicyEnsemble.load(path) _interpreter = NaturalLanguageInterpreter.create(interpreter) _tracker_store = cls.create_tracker_store(tracker_store, domain) return cls(domain=domain, policies=ensemble, interpreter=_interpreter, tracker_store=_tracker_store, rules_file=rules_file, generator=generator, create_dispatcher=create_dispatcher)
def run_hello_world(max_training_samples=10, serve_forever=True): training_data = '../mom/data/stories.md' default_domain = TemplateDomain.load("../mom/domain.yml") agent = Agent( default_domain, # policies=[SimplePolicy()], policies=[MemoizationPolicy(), KerasPolicy()], interpreter=HelloInterpreter(), tracker_store=InMemoryTrackerStore(default_domain)) logger.info("Starting to train policy") # agent = Agent(default_domain, # policies=[SimplePolicy()], # interpreter=HelloInterpreter(), # tracker_store=InMemoryTrackerStore(default_domain)) # if serve_forever: # # Attach the commandline input to the controller to handle all # # incoming messages from that channel # agent.handle_channel(ConsoleInputChannel()) agent.train_online(training_data, input_channel=ConsoleInputChannel(), epochs=1, max_training_samples=max_training_samples) return agent
def train(max_training_samples=3, serve_forever=True): story = 'stories.md' from rasa_core.interpreter import RasaNLUInterpreter interpreter = RasaNLUInterpreter(nlu_model_path) #domain配置文件 default_domain = TemplateDomain.load("./domain.yml") if os.path.exists(agent_model_path): agent = Agent.load(agent_model_path, interpreter=interpreter, tracker_store=InMemoryTrackerStore(default_domain)) else: agent = Agent(domain_conf_path, policies=[MemoizationPolicy(), KerasPolicy()], interpreter=interpreter, tracker_store=InMemoryTrackerStore(default_domain)) #for debug: print(interpreter.parse(u"你好")) logger.info("开始训练...") training_data = agent.load_data(story) agent.train(training_data, epochs=50) return agent
def load( cls, path, # type: Text interpreter=None, # type: Union[NLI, Text, None] tracker_store=None, # type: Optional[TrackerStore] action_factory=None # type: Optional[Text] ): # type: (Text, Any, Optional[TrackerStore]) -> Agent """Load a persisted model from the passed path.""" if path is None: raise ValueError("No domain path specified.") if os.path.isfile(path): raise ValueError("You are trying to load a MODEL from a file " "('{}'), which is not possible. \n" "The persisted path should be a directory " "containing the various model files. \n\n" "If you want to load training data instead of " "a model, use `agent.load_data(...)` " "instead.".format(path)) ensemble = PolicyEnsemble.load(path) domain = TemplateDomain.load(os.path.join(path, "domain.yml"), action_factory) # ensures the domain hasn't changed between test and train domain.compare_with_specification(path) _interpreter = NaturalLanguageInterpreter.create(interpreter) _tracker_store = cls.create_tracker_store(tracker_store, domain) return cls(domain, ensemble, _interpreter, _tracker_store)
def __init__( self, model=None, domain=None, test_cases=None, shuffle=False, distinct=True, rules=None, interpreter=RegexInterpreter(), create_output_channel=lambda on_response, domain, processor: TestOutputChannel(on_response, domain, processor), ): self.model = model self.distinct = distinct self.interpreter = interpreter self.domain = TemplateDomain.load(domain) self.input_channel = TestInputChannel() self.create_output_channel = create_output_channel self.agent = self._create_agent(self.input_channel, self.interpreter, model, rules) self.stories = self._build_stories_from_path(test_cases) if shuffle: random.shuffle(stories) self._run_test_cases() self.failed = False
def test_query_form_set_username_in_form(): domain = TemplateDomain.load("data/test_domains/query_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-form" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchQuery().run(dispatcher, tracker, domain) last_message = dispatcher.latest_bot_messages[-1] assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "username" assert last_message.text == 'what is your name?' tracker.update(events[0]) # second user utterance username = '******' tracker.update(UserUttered(username, intent={"name": "inform"})) events = ActionSearchQuery().run(dispatcher, tracker, domain) last_message = dispatcher.latest_bot_messages[-1] assert len(events) == 2 assert isinstance(events[0], SlotSet) assert events[0].key == "username" assert events[0].value == username assert events[1].key == "requested_slot" assert events[1].value == "query" assert username in last_message.text
def run_my_world(serve_forever=True): default_domain = TemplateDomain.load('common_domain.yml') agent = Agent.load("models/policy/current", interpreter=RasaNLUInterpreter(nlu_model_path)) if serve_forever: agent.handle_channel(ConsoleInputChannel()) return agent
def test_people_form(): domain = TemplateDomain.load("data/test_domains/people_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-people" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchPeople().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "person_name" tracker.update(events[0]) # second user utterance name = "Rasa Due" tracker.update( UserUttered(name, intent={"name": "inform"})) events = ActionSearchPeople().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "person_name" assert events[0].value == name
def test_travel_form(): domain = TemplateDomain.load("data/test_domains/travel_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-travel" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchTravel().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "GPE_origin" tracker.update(events[0]) # second user utterance entities = [{"entity": "GPE", "value": "Berlin"}] tracker.update(UserUttered("", intent={"name": "inform"}, entities=entities)) events = ActionSearchTravel().run(dispatcher, tracker, domain) for e in events: print(e.as_story_string()) assert len(events) == 2 assert isinstance(events[0], SlotSet) assert events[0].key == "GPE_origin" assert events[0].value == "Berlin" assert events[1].key == "requested_slot" assert events[1].value == "GPE_destination"
def test_restaurant_form_unhappy_1(): domain = TemplateDomain.load("data/test_domains/restaurant_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-restaurant" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "cuisine" tracker.update(events[0]) # second user utterance does not provide what's asked tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) print([(e.key, e.value) for e in events]) assert len(events) == 1 assert isinstance(events[0], SlotSet) # same slot requested again assert events[0].key == "requested_slot" assert events[0].value == "cuisine"
def test_restaurant_form(): domain = TemplateDomain.load("data/test_domains/restaurant_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-restaurant" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "cuisine" tracker.update(events[0]) # second user utterance entities = [{"entity": "cuisine", "value": "chinese"}] tracker.update( UserUttered("", intent={"name": "inform"}, entities=entities)) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) assert len(events) == 2 assert isinstance(events[0], SlotSet) assert isinstance(events[1], SlotSet) assert events[0].key == "cuisine" assert events[0].value == "chinese" assert events[1].key == "requested_slot" assert events[1].value == "people"
def test_restaurant_form_unhappy_1(): domain = TemplateDomain.load("data/test_domains/restaurant_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-restaurant" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "cuisine" tracker.update(events[0]) # second user utterance does not provide what's asked tracker.update( UserUttered("", intent={"name": "inform"})) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) print([(e.key, e.value) for e in events]) assert len(events) == 1 assert isinstance(events[0], SlotSet) # same slot requested again assert events[0].key == "requested_slot" assert events[0].value == "cuisine"
def test_restaurant_form_skipahead(): domain = TemplateDomain.load("data/test_domains/restaurant_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-restaurant" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance entities = [{ "entity": "cuisine", "value": "chinese" }, { "entity": "number", "value": 8 }] tracker.update( UserUttered("", intent={"name": "inform"}, entities=entities)) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) s = events[0].as_story_string() print(events[0].as_story_string()) print(events[1].as_story_string()) assert len(events) == 3 assert events[2].key == "requested_slot" assert events[2].value == "vegetarian"
def load(cls, path, # type: Text interpreter=None, # type: Union[NLI, Text, None] tracker_store=None, # type: Optional[TrackerStore] action_factory=None, # type: Optional[Text] generator=None # type: Union[EndpointConfig, NLG] ): # type: (Text, Any, Optional[TrackerStore]) -> Agent """Load a persisted model from the passed path.""" if path is None: raise ValueError("No domain path specified.") if os.path.isfile(path): raise ValueError("You are trying to load a MODEL from a file " "('{}'), which is not possible. \n" "The persisted path should be a directory " "containing the various model files. \n\n" "If you want to load training data instead of " "a model, use `agent.load_data(...)` " "instead.".format(path)) ensemble = PolicyEnsemble.load(path) domain = TemplateDomain.load(os.path.join(path, "domain.yml"), action_factory) # ensures the domain hasn't changed between test and train domain.compare_with_specification(path) _tracker_store = cls.create_tracker_store(tracker_store, domain) return cls(domain, ensemble, interpreter, generator, _tracker_store)
def test_travel_form(): domain = TemplateDomain.load("data/test_domains/travel_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-travel" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchTravel().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "GPE_origin" tracker.update(events[0]) # second user utterance entities = [{"entity": "GPE", "value": "Berlin"}] tracker.update( UserUttered("", intent={"name": "inform"}, entities=entities)) events = ActionSearchTravel().run(dispatcher, tracker, domain) for e in events: print(e.as_story_string()) assert len(events) == 2 assert isinstance(events[0], SlotSet) assert events[0].key == "GPE_origin" assert events[0].value == "Berlin" assert events[1].key == "requested_slot" assert events[1].value == "GPE_destination"
def test_people_form(): domain = TemplateDomain.load("data/test_domains/people_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-people" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchPeople().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "person_name" tracker.update(events[0]) # second user utterance name = "Rasa Due" tracker.update(UserUttered(name, intent={"name": "inform"})) events = ActionSearchPeople().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "person_name" assert events[0].value == name
def train_dialogue(domain_file=DOMAIN_FILE, model_path=MODEL_PATH, training_data_file=STORIES): agent = ExtendedAgent(domain=TemplateDomain.load(domain_file), policies=[SklearnPolicy()]) logger.info("Begin dialogue training") agent.train(filename=training_data_file, max_history=3) agent.persist(model_path)
def test_domain_to_yaml(): test_yaml = """action_factory: null action_names: - utter_greet actions: - utter_greet config: store_entities_as_slots: true entities: [] intents: [] slots: {} templates: utter_greet: - text: hey there!""" domain = TemplateDomain.load_from_yaml(test_yaml) assert test_yaml.strip() == domain.as_yaml().strip() domain = TemplateDomain.load_from_yaml(domain.as_yaml())
def trained_policy(self): default_domain = TemplateDomain.load("examples/default_domain.yml") policy = self.create_policy() X, y = train_data(self.max_history, default_domain) policy.max_history = self.max_history policy.featurizer = BinaryFeaturizer() policy.train(X, y, default_domain) return policy
def trained_policy(self): default_domain = TemplateDomain.load(DEFAULT_DOMAIN_PATH) policy = self.create_policy() X, y = train_data(self.max_history, default_domain) policy.max_history = self.max_history policy.featurizer = BinaryFeaturizer() policy.train(X, y, default_domain) return policy
async def get_no_of_stories(story_file, domain): """Get number of stories in a file.""" from rasa_core.domain import TemplateDomain from rasa_core.training.dsl import StoryFileReader stories = await StoryFileReader.read_from_folder( story_file, TemplateDomain.load(domain)) return len(stories)
def train_dialogue(self): domain_file = os.path.join(self.TRAINING_DIR, 'domain.yml') stories_file = os.path.abspath(os.path.join(self.TRAINING_DIR, 'story.md')) domain = TemplateDomain.load(domain_file) # domain.compare_with_specification(os.path.join(self.TRAINING_DIR, 'dialogue')) agent = Agent(domain, policies=[MemoizationPolicy(), KerasPolicy()]) agent.train(stories_file,validation_split=0.1) agent.persist(os.path.join(self.TRAINING_DIR, 'dialogue'))
def test_utter_templates(): domain_file = "examples/moodbot/domain.yml" domain = TemplateDomain.load(domain_file) expected_template = { "text": "Hey! How are you?", "buttons": [{"title": "great", "payload": "great"}, {"title": "super sad", "payload": "super sad"}] } assert domain.random_template_for("utter_greet") == expected_template
def simple(): from rasa_core.tracker_store import InMemoryTrackerStore from rasa_core.domain import TemplateDomain default_domain = TemplateDomain.load("../mom/domain.yml") agent = Agent(default_domain, policies=[SimplePolicy()], interpreter=HelloInterpreter(), tracker_store=InMemoryTrackerStore(default_domain)) return agent
def test_domain_action_instantiation(): instantiated_actions = TemplateDomain.instantiate_actions( "remote", ["my_module.ActionTest", "utter_test"], ["action_test", "utter_test"], ["utter_test"]) assert len(instantiated_actions) == 4 assert instantiated_actions[0].name() == "action_listen" assert instantiated_actions[1].name() == "action_restart" assert instantiated_actions[2].name() == "action_test" assert instantiated_actions[3].name() == "utter_test"
def test_dispatcher_utter_buttons_from_domain_templ(capsys): domain_file = "examples/restaurant_domain.yml" domain = TemplateDomain.load(domain_file) bot = CollectingOutputChannel() dispatcher = Dispatcher("my-sender", bot, domain) dispatcher.utter_template("utter_ask_price") assert bot.messages[0][1] == "in which price range?" assert bot.messages[1][1] == "1: cheap (cheap)" assert bot.messages[2][1] == "2: expensive (expensive)"
def test_utter_templates(): domain_file = "examples/restaurant_domain.yml" domain = TemplateDomain.load(domain_file) expected_template = { "text": "in which price range?", "buttons": [{"title": "cheap", "payload": "cheap"}, {"title": "expensive", "payload": "expensive"}] } assert domain.random_template_for("utter_ask_price") == expected_template
def test_dispatcher_utter_buttons_from_domain_templ(capsys): domain_file = "examples/moodbot/domain.yml" domain = TemplateDomain.load(domain_file) bot = CollectingOutputChannel() dispatcher = Dispatcher("my-sender", bot, domain) dispatcher.utter_template("utter_greet") assert bot.messages[0][1] == "Hey! How are you?" assert bot.messages[1][1] == "1: great (great)" assert bot.messages[2][1] == "2: super sad (super sad)"
def test_domain_to_yaml(): test_yaml = """action_factory: null action_names: - utter_greet actions: - utter_greet config: store_entities_as_slots: true entities: [] intents: [] slots: {} templates: utter_greet: - text: hey there!""" domain = TemplateDomain.load_from_yaml(test_yaml) # python 3 and 2 are different here, python 3 will have a leading set # of --- at the begining of the yml assert domain.as_yaml().strip().endswith(test_yaml.strip()) domain = TemplateDomain.load_from_yaml(domain.as_yaml())
def tracker_from_dialogue_file(filename, domain=None): dialogue = read_dialogue_file(filename) if domain is not None: domain = domain else: domain = TemplateDomain.load(DEFAULT_DOMAIN_PATH) tracker = DialogueStateTracker(dialogue.name, domain.slots) tracker.recreate_from_dialogue(dialogue) return tracker
def test_domain_action_instantiation(): instantiated_actions = TemplateDomain.instantiate_actions( "remote", ["my_module.ActionTest", "utter_test"], ["action_test", "utter_test"], ["utter_test"]) assert len(instantiated_actions) == 5 assert instantiated_actions[0].name() == "action_listen" assert instantiated_actions[1].name() == "action_restart" assert instantiated_actions[2].name() == "action_default_fallback" assert instantiated_actions[3].name() == "action_test" assert instantiated_actions[4].name() == "utter_test"
def _create_domain(domain): # type: (Union[Domain, Text]) -> Domain if isinstance(domain, string_types): return TemplateDomain.load(domain) elif isinstance(domain, Domain): return domain else: raise ValueError( "Invalid param `domain`. Expected a path to a domain " "specification or a domain instance. But got " "type '{}' with value '{}'".format(type(domain), domain))
def test_dispatcher_utter_buttons_from_domain_templ(default_tracker): domain_file = "examples/moodbot/domain.yml" domain = TemplateDomain.load(domain_file) bot = CollectingOutputChannel() nlg = TemplatedNaturalLanguageGenerator(domain.templates) dispatcher = Dispatcher("my-sender", bot, nlg) dispatcher.utter_template("utter_greet", default_tracker) assert len(bot.messages) == 1 assert bot.messages[0]['text'] == "Hey! How are you?" assert bot.messages[0]['data'] == [ {'payload': 'great', 'title': 'great'}, {'payload': 'super sad', 'title': 'super sad'} ]
def test_restaurant_form_unhappy_2(): domain = TemplateDomain.load("data/test_domains/restaurant_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-restaurant" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance entities = [ {"entity": "cuisine", "value": "chinese"}, {"entity": "number", "value": 8}] tracker.update( UserUttered("", intent={"name": "inform"}, entities=entities)) # store all entities as slots events = ActionSearchRestaurants().run(dispatcher, tracker, domain) for e in events: tracker.update(e) cuisine = tracker.get_slot("cuisine") people = tracker.get_slot("people") assert cuisine == "chinese" assert people == 8 events = ActionSearchRestaurants().run(dispatcher, tracker, domain) assert len(events) == 3 assert isinstance(events[0], SlotSet) assert events[2].key == "requested_slot" assert events[2].value == "vegetarian" tracker.update(events[2]) # second user utterance does not provide what's asked tracker.update( UserUttered("", intent={"name": "random"})) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) s = events[0].as_story_string() assert len(events) == 1 assert events[0].key == "requested_slot" assert events[0].value == "vegetarian"
def load(cls, path, # type: Text core_endpoint, # type: EndpointConfig nlg_endpoint=None, # type: EndpointConfig action_factory=None # type: Optional[Text] ): # type: (...) -> RemoteAgent if isinstance(core_endpoint, string_types): raise Exception("This API has changed. Instead of passing in a url " "for Rasa Core, you now need to pass in an " "instance of 'EndpointConfig'. " "(from rasa_core.utils import EndpointConfig )") domain = TemplateDomain.load(os.path.join(path, "domain.yml"), action_factory) core_client = RasaCoreClient(core_endpoint) core_client.upload_model(path, max_retries=5) return RemoteAgent(domain, core_client, nlg_endpoint)
def test_restaurant_form_skipahead(): domain = TemplateDomain.load("data/test_domains/restaurant_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-restaurant" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance entities = [{"entity": "cuisine", "value": "chinese"}, {"entity": "number", "value": 8}] tracker.update(UserUttered("", intent={"name": "inform"}, entities=entities)) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) s = events[0].as_story_string() print(events[0].as_story_string()) print(events[1].as_story_string()) assert len(events) == 3 assert events[2].key == "requested_slot" assert events[2].value == "vegetarian"
def test_remote_action_factory_fails_on_duplicated_actions(): with pytest.raises(ValueError): TemplateDomain.instantiate_actions( "remote", ["action_listen", "random_name", "random_name"], None, ["utter_test"])
def test_domain_from_template(): domain_file = DEFAULT_DOMAIN_PATH domain = TemplateDomain.load(domain_file) assert len(domain.intents) == 10 assert len(domain.actions) == 6
from rasa_core import training, restore from rasa_core import utils from rasa_core.actions.action import ActionListen, ACTION_LISTEN_NAME from rasa_core.channels import UserMessage from rasa_core.domain import TemplateDomain from rasa_core.events import ( UserUttered, ActionExecuted, Restarted, ActionReverted, UserUtteranceReverted) from rasa_core.tracker_store import InMemoryTrackerStore, RedisTrackerStore from rasa_core.tracker_store import ( TrackerStore) from rasa_core.trackers import DialogueStateTracker from tests.conftest import DEFAULT_STORIES_FILE from tests.utilities import tracker_from_dialogue_file, read_dialogue_file domain = TemplateDomain.load("data/test_domains/default.yml") class MockRedisTrackerStore(RedisTrackerStore): def __init__(self, domain): self.red = fakeredis.FakeStrictRedis() TrackerStore.__init__(self, domain) def stores_to_be_tested(): return [MockRedisTrackerStore(domain), InMemoryTrackerStore(domain)] def stores_to_be_tested_ids(): return ["redis-tracker",
def default_domain(self): return TemplateDomain.load(DEFAULT_DOMAIN_PATH)
def trained_policy(self, featurizer): default_domain = TemplateDomain.load(DEFAULT_DOMAIN_PATH) policy = self.create_policy(featurizer) training_trackers = train_trackers(default_domain) policy.train(training_trackers, default_domain) return policy
def test_restaurant_domain_is_valid(): # should raise no exception TemplateDomain.validate_domain_yaml(read_file( 'examples/restaurantbot/restaurant_domain.yml'))