def test_people_form(): domain = TemplateDomain.load("data/test_domains/people_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-people" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchPeople().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "person_name" tracker.update(events[0]) # second user utterance name = "Rasa Due" tracker.update( UserUttered(name, intent={"name": "inform"})) events = ActionSearchPeople().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "person_name" assert events[0].value == name
def test_inmemory_tracker_store(filename): domain = Domain.load("data/test_domains/default.yml") tracker = tracker_from_dialogue_file(filename, domain) tracker_store = InMemoryTrackerStore(domain) tracker_store.save(tracker) restored = tracker_store.retrieve(tracker.sender_id) assert restored == tracker
def test_travel_form(): domain = TemplateDomain.load("data/test_domains/travel_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-travel" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchTravel().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "GPE_origin" tracker.update(events[0]) # second user utterance entities = [{"entity": "GPE", "value": "Berlin"}] tracker.update(UserUttered("", intent={"name": "inform"}, entities=entities)) events = ActionSearchTravel().run(dispatcher, tracker, domain) for e in events: print(e.as_story_string()) assert len(events) == 2 assert isinstance(events[0], SlotSet) assert events[0].key == "GPE_origin" assert events[0].value == "Berlin" assert events[1].key == "requested_slot" assert events[1].value == "GPE_destination"
def test_query_form_set_username_in_form(): domain = TemplateDomain.load("data/test_domains/query_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-form" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchQuery().run(dispatcher, tracker, domain) last_message = dispatcher.latest_bot_messages[-1] assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "username" assert last_message.text == 'what is your name?' tracker.update(events[0]) # second user utterance username = '******' tracker.update(UserUttered(username, intent={"name": "inform"})) events = ActionSearchQuery().run(dispatcher, tracker, domain) last_message = dispatcher.latest_bot_messages[-1] assert len(events) == 2 assert isinstance(events[0], SlotSet) assert events[0].key == "username" assert events[0].value == username assert events[1].key == "requested_slot" assert events[1].value == "query" assert username in last_message.text
def test_restaurant_form_unhappy_1(): domain = TemplateDomain.load("data/test_domains/restaurant_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-restaurant" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "cuisine" tracker.update(events[0]) # second user utterance does not provide what's asked tracker.update( UserUttered("", intent={"name": "inform"})) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) print([(e.key, e.value) for e in events]) assert len(events) == 1 assert isinstance(events[0], SlotSet) # same slot requested again assert events[0].key == "requested_slot" assert events[0].value == "cuisine"
def test_restaurant_form(): domain = TemplateDomain.load("data/test_domains/restaurant_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-restaurant" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "cuisine" tracker.update(events[0]) # second user utterance entities = [{"entity": "cuisine", "value": "chinese"}] tracker.update( UserUttered("", intent={"name": "inform"}, entities=entities)) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) assert len(events) == 2 assert isinstance(events[0], SlotSet) assert isinstance(events[1], SlotSet) assert events[0].key == "cuisine" assert events[0].value == "chinese" assert events[1].key == "requested_slot" assert events[1].value == "people"
def test_get_or_create(): slot_key = 'location' slot_val = 'Easter Island' store = InMemoryTrackerStore(domain) tracker = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID) ev = SlotSet(slot_key, slot_val) tracker.update(ev) assert tracker.get_slot(slot_key) == slot_val store.save(tracker) again = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID) assert again.get_slot(slot_key) == slot_val
def test_restart_after_retrieval_from_tracker_store(default_domain): store = InMemoryTrackerStore(default_domain) tr = store.get_or_create_tracker("myuser") synth = [ActionExecuted("action_listen") for _ in range(4)] for e in synth: tr.update(e) tr.update(Restarted()) latest_restart = tr.idx_after_latest_restart() store.save(tr) tr2 = store.retrieve("myuser") latest_restart_after_loading = tr2.idx_after_latest_restart() assert latest_restart == latest_restart_after_loading
def test_restaurant_form_unhappy_2(): domain = TemplateDomain.load("data/test_domains/restaurant_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-restaurant" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance entities = [ {"entity": "cuisine", "value": "chinese"}, {"entity": "number", "value": 8}] tracker.update( UserUttered("", intent={"name": "inform"}, entities=entities)) # store all entities as slots events = ActionSearchRestaurants().run(dispatcher, tracker, domain) for e in events: tracker.update(e) cuisine = tracker.get_slot("cuisine") people = tracker.get_slot("people") assert cuisine == "chinese" assert people == 8 events = ActionSearchRestaurants().run(dispatcher, tracker, domain) assert len(events) == 3 assert isinstance(events[0], SlotSet) assert events[2].key == "requested_slot" assert events[2].value == "vegetarian" tracker.update(events[2]) # second user utterance does not provide what's asked tracker.update( UserUttered("", intent={"name": "random"})) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) s = events[0].as_story_string() assert len(events) == 1 assert events[0].key == "requested_slot" assert events[0].value == "vegetarian"
def test_restaurant_form_skipahead(): domain = TemplateDomain.load("data/test_domains/restaurant_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-restaurant" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance entities = [{"entity": "cuisine", "value": "chinese"}, {"entity": "number", "value": 8}] tracker.update(UserUttered("", intent={"name": "inform"}, entities=entities)) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) s = events[0].as_story_string() print(events[0].as_story_string()) print(events[1].as_story_string()) assert len(events) == 3 assert events[2].key == "requested_slot" assert events[2].value == "vegetarian"
def run(serve_forever=True,port=5002, debug = False): domain = os.path.abspath("") interpreter = RasaNLUHttpInterpreter(server="http://rasanlu:5000",token = "",model_name = "",project = "") tracker_domain = TemplateDomain.load(os.path.abspath("")) tracker_store = InMemoryTrackerStore(tracker_domain) chat_endpoint = BotInput() if debug: input_channel = ConsoleInputChannel() else: input_channel = HttpInputChannel(port, "/ai", chat_endpoint) agent = Agent.load(domain, interpreter=interpreter, tracker_store=tracker_store) if serve_forever: agent.handle_channel(input_channel) return agent
def create_tracker_store(core_model, endpoints): domain = get_domain(core_model) tracker_publish = utils.read_endpoint_config(endpoints, "tracker-publish") # Setup tracker store tracker_store = None store = utils.read_endpoint_config(endpoints, "tracker-store") if store.type == 'memory': tracker_store = InMemoryTrackerStore(domain=domain, publish_url=tracker_publish.url) elif store.type == 'redis': tracker_store = RedisTrackerStore(domain=domain, host=store.host, port=store.port, db=store.db, password=store.password, publish_url=tracker_publish.url) return tracker_store
def test_get_or_create(): slot_key = 'location' slot_val = 'Easter Island' store = InMemoryTrackerStore(domain) tracker = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID) ev = SlotSet(slot_key, slot_val) tracker.update(ev) assert tracker.get_slot(slot_key) == slot_val store.save(tracker) again = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID) assert again.get_slot(slot_key) == slot_val
def test_restart_after_retrieval_from_tracker_store(default_domain): store = InMemoryTrackerStore(default_domain) tr = store.get_or_create_tracker("myuser") synth = [ActionExecuted("action_listen") for _ in range(4)] for e in synth: tr.update(e) tr.update(Restarted()) latest_restart = tr.idx_after_latest_restart() store.save(tr) tr2 = store.retrieve("myuser") latest_restart_after_loading = tr2.idx_after_latest_restart() assert latest_restart == latest_restart_after_loading
def train(max_training_samples=3,serve_forever=True): training_data = 'stories.md' from rasa_core.interpreter import RasaNLUInterpreter interpreter = RasaNLUInterpreter("../model/latest/") default_domain = TemplateDomain.load("domain.yml") if os.path.exists(agent_model_path): logger.info("加载已有的模型") # agent = Agent.load( # agent_model_path) # interpreter=interpreter, # tracker_store=InMemoryTrackerStore(default_domain) ######### 各种配置 ######### ## 意图理解模型 路径 nlu_model_path = '../model/latest' ## 对话模型 路径 model_directory = "../model/agent" # 启动agent agent = Agent.load(model_directory, nlu_model_path) else: logger.info("从头创建模型") agent = Agent(domain_conf_path, interpreter=interpreter, policies=[MemoizationPolicy(), KerasPolicy()], tracker_store=InMemoryTrackerStore(default_domain)) #for debug: print(interpreter.parse(u"你好")) logger.info("开始在线训练...") agent.train_online(training_data, input_channel=ConsoleInputChannel(), epochs=1)#, max_training_samples=max_training_samples return agent
def run(core_dir, nlu_dir): pika_broker = None if ENABLE_ANALYTICS: pika_broker = PikaProducer(url, username, password, queue=queue) configs = { "user": os.getenv("ROCKETCHAT_BOT_USERNAME"), "password": os.getenv("ROCKETCHAT_BOT_PASSWORD"), "server_url": os.getenv("ROCKETCHAT_URL"), } input_channel = RocketChatInput( user=configs["user"], password=configs["password"], server_url=configs["server_url"], ) _tracker_store = InMemoryTrackerStore(domain=None, event_broker=pika_broker) _endpoints = AvailableEndpoints.read_endpoints(None) _interpreter = NaturalLanguageInterpreter.create(nlu_dir) _agent = load_agent( core_dir, interpreter=_interpreter, tracker_store=_tracker_store, endpoints=_endpoints, ) http_server = start_server([input_channel], "", "", 5005, _agent) try: http_server.serve_forever() except Exception as exc: logger.exception(exc)
def stores_to_be_tested(): return [MockRedisTrackerStore(domain), InMemoryTrackerStore(domain)]
def stores_to_be_tested(): return [RedisTrackerStore(domain, mock=True), InMemoryTrackerStore(domain)]
def _create_tracker_store(cls, store, domain): return store if store is not None else InMemoryTrackerStore(domain)
def load_agent(nlu_folder="models/nlu/default/current", agent_folder="models/dialogue", path_to_scenario_file="wismo/v1/data/test/dialogue", online=True, standalone=True, vui=True, nlu_config_file='ensemble.json', nlu_off=False, preprocessor_off=False, featurization="ents", feedback=False, server_bot_folder="/var/www/feedback/"): from rasa_core.dispatcher import JinjaDispatcher from rasa_core.domain import JinjaDomain import os agent_folder = choose_featurization(agent_folder, featurization) """ :param path_to_scenario_file: :param nlu_folder: the relative path starting from the current file folder to the nlu model directory :param agent_folder: the relative path starting from the current file folder to the dialogue model directory :param online: the flag that tells if the script is connected online or not :param standalone: of the app runs standalone or we want to create only the wsgi app :return: """ main_folder_path = os.path.dirname(os.sys.argv[0]) domain_folder = os.path.join(main_folder_path, agent_folder) if nlu_off: interpreter = RegexInterpreter() else: if online and not vui: interpreter = RasaNLUInterpreter("models/nlu/default/") #interpreter = JsonInterpreter() else: try: # interpreter = EnsembleInterpreter(os.path.join(main_folder_path, agent_folder)) nlu_path = os.path.join(main_folder_path, nlu_folder) #interpreter = EnsembleInterpreter( os.path.join(nlu_path, nlu_config_file)) interpreter = RasaNLUInterpreter("models/nlu/default/current") except FileNotFoundError: # if it fails, try real path main_folder_path = os.path.dirname(os.path.realpath(__file__)) domain_folder = os.path.join(main_folder_path, agent_folder) nlu_path = os.path.join(main_folder_path, nlu_folder) interpreter = EnsembleInterpreter( os.path.join(nlu_path, nlu_config_file)) print("Loading Agent") domain_file = os.path.join(domain_folder, 'domain.yml') #domain_object = JinjaDomain.load(domain_file) #redis_tracker = create_tracker_store(domain_object) inMemo_tracker = InMemoryTrackerStore(domain_file) message_preprocessor = MessagePreprocessor(w2n=True) print("The current domain file is {}".format(domain_file)) agent = CookieCutterAgent.load(os.path.join(main_folder_path, agent_folder), tracker_store=inMemo_tracker, interpreter=interpreter) if not vui: http_input_channel = HttpInputChannel(5000, "", False, AlexaInput()) else: http_input_channel = HttpInputChannel( 5000, "", False, VUIInput(display_always=True, ssml_enabled=True, action_enabled=True)) print("Server is running") random_sender_id = "TEST.COMMANDLINE." + generate_random_string(32) if online: http_input_channel.set_standalone(standalone) agent.handle_channel(http_input_channel, message_preprocessor=message_preprocessor if not preprocessor_off else None, feedback=feedback, server_bot_folder=server_bot_folder) else: agent.handle_channel(ExtendedConsoleInputChannel( random_sender_id, resource_path=path_to_scenario_file), message_preprocessor=message_preprocessor if not preprocessor_off else None, feedback=feedback, server_bot_folder=server_bot_folder) return agent, http_input_channel
def create_tracker_store(cls, store, domain): # type: (Optional[TrackerStore], Domain) -> TrackerStore return store if store is not None else InMemoryTrackerStore(domain)
elif isinstance(domain, Domain): return domain elif domain is not None: raise ValueError( "Invalid param `domain`. Expected a path to a domain " "specification or a domain instance. But got " "type '{}' with value '{}'".format(type(domain), domain)) @staticmethod def create_tracker_store(store: Optional['TrackerStore'], domain: Domain) -> 'TrackerStore': if store is not None: store.domain = domain return store else: return InMemoryTrackerStore(domain) @staticmethod def _create_ensemble( policies: Union[List[Policy], PolicyEnsemble, None] ) -> Optional[PolicyEnsemble]: if policies is None: return None if isinstance(policies, list): return SimplePolicyEnsemble(policies) elif isinstance(policies, PolicyEnsemble): return policies else: passed_type = type(policies).__name__ raise ValueError( "Invalid param `policies`. Passed object is "
def compile_and_finetune( nlu_folder="models/nlu/default/current", agent_folder="models/dialogue", summary="/home/doomdiskday/Desktop/experiments/error_analysis/summaries/summary.json", dump_f="/home/doomdiskday/Desktop/experiments/error_analysis/summaries/full_interaction.json", fine_tune_data_file="data/fine_tune_story.md", validation_data_file="data/validation.md", featurization="ents"): import os with open(summary) as f: dials = json.load(f) agent_folder = choose_featurization(agent_folder, featurization) main_folder_path = os.path.dirname(os.sys.argv[0]) domain_folder = os.path.join(main_folder_path, agent_folder) interpreter = RasaNLUInterpreter("models/nlu/default/current") inMemo_tracker = InMemoryTrackerStore(domain_file) message_preprocessor = MessagePreprocessor(w2n=True) agent = CookieCutterAgent.load(os.path.join(main_folder_path, agent_folder), tracker_store=inMemo_tracker, interpreter=None) messages = {} for d in dials: intents = [] if d['correct'] == False and d['retagged']: print(d['id']) for step in d['retagged_interaction']: if step[0] == "*": intents.append(' '.join(step.split("*")[1:])) messages[d['id']] = intents with open(dump_f, "w+") as f: json.dump({}, f) for id in messages.keys(): for m in messages[id]: with open(dump_f) as f: old_dump = json.load(f) if id in old_dump: old_dump[id].append("* " + m) else: old_dump[id] = ['* {}'.format(m)] with open(dump_f, "w+") as f: json.dump(old_dump, f) mes = "/{}".format(m.strip()) print(mes) agent.compile_message(mes, sender_id=id, summary=summary, dump_interaction=dump_f) with open(dump_f) as f: interactions = json.load(f) story_id = 0 with open(fine_tune_data_file, "w+") as f: for dial in interactions: f.write("## story_{}\n".format(story_id)) for step in interactions[dial]: if step[0] == "*": f.write("{}\n".format(step)) else: if not "action_listen" in step: f.write(" {}\n".format(step)) story_id += 1 f.write("\n") agent = CookieCutterAgent.load(os.path.join(main_folder_path, agent_folder), tracker_store=inMemo_tracker, interpreter=None) agent.fine_tune(training_resource_name=fine_tune_data_file, validation_resource_name=validation_data_file, model_path=agent_folder, remove_duplicates=True, augmentation_factor=0, max_history=1)