示例#1
0
def test_people_form():
    domain = TemplateDomain.load("data/test_domains/people_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-people"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchPeople().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "person_name"
    tracker.update(events[0])

    # second user utterance
    name = "Rasa Due"
    tracker.update(
        UserUttered(name,
                    intent={"name": "inform"}))

    events = ActionSearchPeople().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)

    assert events[0].key == "person_name"
    assert events[0].value == name
示例#2
0
def test_inmemory_tracker_store(filename):
    domain = Domain.load("data/test_domains/default.yml")
    tracker = tracker_from_dialogue_file(filename, domain)
    tracker_store = InMemoryTrackerStore(domain)
    tracker_store.save(tracker)
    restored = tracker_store.retrieve(tracker.sender_id)
    assert restored == tracker
示例#3
0
def test_travel_form():
    domain = TemplateDomain.load("data/test_domains/travel_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-travel"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchTravel().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "GPE_origin"
    tracker.update(events[0])

    # second user utterance
    entities = [{"entity": "GPE", "value": "Berlin"}]
    tracker.update(UserUttered("",
                               intent={"name": "inform"},
                               entities=entities))
    events = ActionSearchTravel().run(dispatcher, tracker, domain)
    for e in events:
        print(e.as_story_string())
    assert len(events) == 2
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "GPE_origin"
    assert events[0].value == "Berlin"
    assert events[1].key == "requested_slot"
    assert events[1].value == "GPE_destination"
示例#4
0
def test_query_form_set_username_in_form():
    domain = TemplateDomain.load("data/test_domains/query_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-form"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchQuery().run(dispatcher, tracker, domain)
    last_message = dispatcher.latest_bot_messages[-1]
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "username"
    assert last_message.text == 'what is your name?'
    tracker.update(events[0])

    # second user utterance
    username = '******'
    tracker.update(UserUttered(username, intent={"name": "inform"}))
    events = ActionSearchQuery().run(dispatcher, tracker, domain)
    last_message = dispatcher.latest_bot_messages[-1]
    assert len(events) == 2
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "username"
    assert events[0].value == username
    assert events[1].key == "requested_slot"
    assert events[1].value == "query"
    assert username in last_message.text
示例#5
0
def test_restaurant_form_unhappy_1():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "cuisine"
    tracker.update(events[0])

    # second user utterance does not provide what's asked
    tracker.update(
        UserUttered("",
                    intent={"name": "inform"}))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    print([(e.key, e.value) for e in events])
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)

    # same slot requested again
    assert events[0].key == "requested_slot"
    assert events[0].value == "cuisine"
示例#6
0
def test_restaurant_form():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "cuisine"
    tracker.update(events[0])

    # second user utterance
    entities = [{"entity": "cuisine", "value": "chinese"}]
    tracker.update(
        UserUttered("",
                    intent={"name": "inform"},
                    entities=entities))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 2
    assert isinstance(events[0], SlotSet)
    assert isinstance(events[1], SlotSet)

    assert events[0].key == "cuisine"
    assert events[0].value == "chinese"

    assert events[1].key == "requested_slot"
    assert events[1].value == "people"
def test_get_or_create():
    slot_key = 'location'
    slot_val = 'Easter Island'
    store = InMemoryTrackerStore(domain)

    tracker = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID)
    ev = SlotSet(slot_key, slot_val)
    tracker.update(ev)
    assert tracker.get_slot(slot_key) == slot_val

    store.save(tracker)

    again = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID)
    assert again.get_slot(slot_key) == slot_val
def test_restart_after_retrieval_from_tracker_store(default_domain):
    store = InMemoryTrackerStore(default_domain)
    tr = store.get_or_create_tracker("myuser")
    synth = [ActionExecuted("action_listen") for _ in range(4)]

    for e in synth:
        tr.update(e)

    tr.update(Restarted())
    latest_restart = tr.idx_after_latest_restart()

    store.save(tr)
    tr2 = store.retrieve("myuser")
    latest_restart_after_loading = tr2.idx_after_latest_restart()
    assert latest_restart == latest_restart_after_loading
示例#9
0
def test_restaurant_form_unhappy_2():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    entities = [
        {"entity": "cuisine", "value": "chinese"},
        {"entity": "number", "value": 8}]

    tracker.update(
        UserUttered("",
                    intent={"name": "inform"},
                    entities=entities))

    # store all entities as slots
    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)

    for e in events:
        tracker.update(e)

    cuisine = tracker.get_slot("cuisine")
    people = tracker.get_slot("people")
    assert cuisine == "chinese"
    assert people == 8

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 3
    assert isinstance(events[0], SlotSet)
    assert events[2].key == "requested_slot"
    assert events[2].value == "vegetarian"
    tracker.update(events[2])

    # second user utterance does not provide what's asked
    tracker.update(
        UserUttered("",
                    intent={"name": "random"}))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    s = events[0].as_story_string()
    assert len(events) == 1
    assert events[0].key == "requested_slot"
    assert events[0].value == "vegetarian"
示例#10
0
def test_restaurant_form_skipahead():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    entities = [{"entity": "cuisine", "value": "chinese"},
                {"entity": "number", "value": 8}]
    tracker.update(UserUttered("",
                               intent={"name": "inform"},
                               entities=entities))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    s = events[0].as_story_string()
    print(events[0].as_story_string())
    print(events[1].as_story_string())
    assert len(events) == 3
    assert events[2].key == "requested_slot"
    assert events[2].value == "vegetarian"
示例#11
0
def run(serve_forever=True,port=5002, debug = False):
    domain = os.path.abspath("")
    interpreter = RasaNLUHttpInterpreter(server="http://rasanlu:5000",token = "",model_name = "",project = "")
    tracker_domain = TemplateDomain.load(os.path.abspath(""))
    tracker_store = InMemoryTrackerStore(tracker_domain)
    chat_endpoint = BotInput()
    if debug:
        input_channel = ConsoleInputChannel()
    else:
        input_channel = HttpInputChannel(port, "/ai", chat_endpoint)
    agent = Agent.load(domain,
                    interpreter=interpreter,
                    tracker_store=tracker_store)

    if serve_forever:
        agent.handle_channel(input_channel)
    return agent
示例#12
0
def create_tracker_store(core_model, endpoints):
    domain = get_domain(core_model)
    tracker_publish = utils.read_endpoint_config(endpoints, "tracker-publish")
    # Setup tracker store
    tracker_store = None
    store = utils.read_endpoint_config(endpoints, "tracker-store")
    if store.type == 'memory':
        tracker_store = InMemoryTrackerStore(domain=domain,
                                             publish_url=tracker_publish.url)
    elif store.type == 'redis':
        tracker_store = RedisTrackerStore(domain=domain,
                                          host=store.host,
                                          port=store.port,
                                          db=store.db,
                                          password=store.password,
                                          publish_url=tracker_publish.url)

    return tracker_store
示例#13
0
def test_get_or_create():
    slot_key = 'location'
    slot_val = 'Easter Island'
    store = InMemoryTrackerStore(domain)

    tracker = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID)
    ev = SlotSet(slot_key, slot_val)
    tracker.update(ev)
    assert tracker.get_slot(slot_key) == slot_val

    store.save(tracker)

    again = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID)
    assert again.get_slot(slot_key) == slot_val
示例#14
0
def test_restart_after_retrieval_from_tracker_store(default_domain):
    store = InMemoryTrackerStore(default_domain)
    tr = store.get_or_create_tracker("myuser")
    synth = [ActionExecuted("action_listen") for _ in range(4)]

    for e in synth:
        tr.update(e)

    tr.update(Restarted())
    latest_restart = tr.idx_after_latest_restart()

    store.save(tr)
    tr2 = store.retrieve("myuser")
    latest_restart_after_loading = tr2.idx_after_latest_restart()
    assert latest_restart == latest_restart_after_loading
示例#15
0
def train(max_training_samples=3,serve_forever=True):
    training_data = 'stories.md'

    from rasa_core.interpreter import RasaNLUInterpreter
    interpreter = RasaNLUInterpreter("../model/latest/")

    default_domain = TemplateDomain.load("domain.yml")

    if os.path.exists(agent_model_path):
        logger.info("加载已有的模型")
        # agent = Agent.load(
        #     agent_model_path)

        # interpreter=interpreter,
        # tracker_store=InMemoryTrackerStore(default_domain)

        ######### 各种配置 #########
        ## 意图理解模型 路径
        nlu_model_path = '../model/latest'
        ## 对话模型 路径
        model_directory = "../model/agent"

        # 启动agent
        agent = Agent.load(model_directory, nlu_model_path)

    else:
        logger.info("从头创建模型")
        agent = Agent(domain_conf_path,
            interpreter=interpreter,
            policies=[MemoizationPolicy(), KerasPolicy()],
            tracker_store=InMemoryTrackerStore(default_domain))
    #for debug: print(interpreter.parse(u"你好"))

    logger.info("开始在线训练...")

    agent.train_online(training_data,
                       input_channel=ConsoleInputChannel(),
                       epochs=1)#,                       max_training_samples=max_training_samples

    return agent
示例#16
0
def run(core_dir, nlu_dir):
    pika_broker = None

    if ENABLE_ANALYTICS:
        pika_broker = PikaProducer(url, username, password, queue=queue)

    configs = {
        "user": os.getenv("ROCKETCHAT_BOT_USERNAME"),
        "password": os.getenv("ROCKETCHAT_BOT_PASSWORD"),
        "server_url": os.getenv("ROCKETCHAT_URL"),
    }

    input_channel = RocketChatInput(
        user=configs["user"],
        password=configs["password"],
        server_url=configs["server_url"],
    )

    _tracker_store = InMemoryTrackerStore(domain=None,
                                          event_broker=pika_broker)

    _endpoints = AvailableEndpoints.read_endpoints(None)
    _interpreter = NaturalLanguageInterpreter.create(nlu_dir)

    _agent = load_agent(
        core_dir,
        interpreter=_interpreter,
        tracker_store=_tracker_store,
        endpoints=_endpoints,
    )

    http_server = start_server([input_channel], "", "", 5005, _agent)

    try:
        http_server.serve_forever()
    except Exception as exc:
        logger.exception(exc)
示例#17
0
def stores_to_be_tested():
    return [MockRedisTrackerStore(domain), InMemoryTrackerStore(domain)]
示例#18
0
def stores_to_be_tested():
    return [RedisTrackerStore(domain, mock=True),
            InMemoryTrackerStore(domain)]
示例#19
0
 def _create_tracker_store(cls, store, domain):
     return store if store is not None else InMemoryTrackerStore(domain)
示例#20
0
def load_agent(nlu_folder="models/nlu/default/current",
               agent_folder="models/dialogue",
               path_to_scenario_file="wismo/v1/data/test/dialogue",
               online=True,
               standalone=True,
               vui=True,
               nlu_config_file='ensemble.json',
               nlu_off=False,
               preprocessor_off=False,
               featurization="ents",
               feedback=False,
               server_bot_folder="/var/www/feedback/"):

    from rasa_core.dispatcher import JinjaDispatcher
    from rasa_core.domain import JinjaDomain
    import os

    agent_folder = choose_featurization(agent_folder, featurization)
    """

    :param path_to_scenario_file:
    :param nlu_folder: the relative path starting from the current file folder to the nlu model directory
    :param agent_folder: the relative path starting from the current file folder to the dialogue model directory
    :param online: the flag that tells if the script is connected online or not
    :param standalone: of the app runs standalone or we want to create only the wsgi app
    :return:
    """

    main_folder_path = os.path.dirname(os.sys.argv[0])
    domain_folder = os.path.join(main_folder_path, agent_folder)

    if nlu_off:
        interpreter = RegexInterpreter()
    else:
        if online and not vui:
            interpreter = RasaNLUInterpreter("models/nlu/default/")
            #interpreter = JsonInterpreter()
        else:
            try:
                # interpreter = EnsembleInterpreter(os.path.join(main_folder_path, agent_folder))
                nlu_path = os.path.join(main_folder_path, nlu_folder)
                #interpreter = EnsembleInterpreter( os.path.join(nlu_path, nlu_config_file))
                interpreter = RasaNLUInterpreter("models/nlu/default/current")
            except FileNotFoundError:
                # if it fails, try real path
                main_folder_path = os.path.dirname(os.path.realpath(__file__))
                domain_folder = os.path.join(main_folder_path, agent_folder)
                nlu_path = os.path.join(main_folder_path, nlu_folder)
                interpreter = EnsembleInterpreter(
                    os.path.join(nlu_path, nlu_config_file))

    print("Loading Agent")
    domain_file = os.path.join(domain_folder, 'domain.yml')
    #domain_object = JinjaDomain.load(domain_file)
    #redis_tracker = create_tracker_store(domain_object)
    inMemo_tracker = InMemoryTrackerStore(domain_file)
    message_preprocessor = MessagePreprocessor(w2n=True)

    print("The current domain file is {}".format(domain_file))

    agent = CookieCutterAgent.load(os.path.join(main_folder_path,
                                                agent_folder),
                                   tracker_store=inMemo_tracker,
                                   interpreter=interpreter)
    if not vui:
        http_input_channel = HttpInputChannel(5000, "", False, AlexaInput())
    else:
        http_input_channel = HttpInputChannel(
            5000, "", False,
            VUIInput(display_always=True,
                     ssml_enabled=True,
                     action_enabled=True))
    print("Server is running")
    random_sender_id = "TEST.COMMANDLINE." + generate_random_string(32)
    if online:
        http_input_channel.set_standalone(standalone)
        agent.handle_channel(http_input_channel,
                             message_preprocessor=message_preprocessor
                             if not preprocessor_off else None,
                             feedback=feedback,
                             server_bot_folder=server_bot_folder)
    else:
        agent.handle_channel(ExtendedConsoleInputChannel(
            random_sender_id, resource_path=path_to_scenario_file),
                             message_preprocessor=message_preprocessor
                             if not preprocessor_off else None,
                             feedback=feedback,
                             server_bot_folder=server_bot_folder)

    return agent, http_input_channel
示例#21
0
 def create_tracker_store(cls, store, domain):
     # type: (Optional[TrackerStore], Domain) -> TrackerStore
     return store if store is not None else InMemoryTrackerStore(domain)
示例#22
0
文件: agent.py 项目: jonDel/rasa_core
        elif isinstance(domain, Domain):
            return domain
        elif domain is not None:
            raise ValueError(
                "Invalid param `domain`. Expected a path to a domain "
                "specification or a domain instance. But got "
                "type '{}' with value '{}'".format(type(domain), domain))

    @staticmethod
    def create_tracker_store(store: Optional['TrackerStore'],
                             domain: Domain) -> 'TrackerStore':
        if store is not None:
            store.domain = domain
            return store
        else:
            return InMemoryTrackerStore(domain)

    @staticmethod
    def _create_ensemble(
        policies: Union[List[Policy], PolicyEnsemble, None]
    ) -> Optional[PolicyEnsemble]:
        if policies is None:
            return None
        if isinstance(policies, list):
            return SimplePolicyEnsemble(policies)
        elif isinstance(policies, PolicyEnsemble):
            return policies
        else:
            passed_type = type(policies).__name__
            raise ValueError(
                "Invalid param `policies`. Passed object is "
示例#23
0
def compile_and_finetune(
        nlu_folder="models/nlu/default/current",
        agent_folder="models/dialogue",
        summary="/home/doomdiskday/Desktop/experiments/error_analysis/summaries/summary.json",
        dump_f="/home/doomdiskday/Desktop/experiments/error_analysis/summaries/full_interaction.json",
        fine_tune_data_file="data/fine_tune_story.md",
        validation_data_file="data/validation.md",
        featurization="ents"):

    import os
    with open(summary) as f:
        dials = json.load(f)

    agent_folder = choose_featurization(agent_folder, featurization)

    main_folder_path = os.path.dirname(os.sys.argv[0])
    domain_folder = os.path.join(main_folder_path, agent_folder)
    interpreter = RasaNLUInterpreter("models/nlu/default/current")
    inMemo_tracker = InMemoryTrackerStore(domain_file)
    message_preprocessor = MessagePreprocessor(w2n=True)

    agent = CookieCutterAgent.load(os.path.join(main_folder_path,
                                                agent_folder),
                                   tracker_store=inMemo_tracker,
                                   interpreter=None)
    messages = {}
    for d in dials:
        intents = []
        if d['correct'] == False and d['retagged']:
            print(d['id'])
            for step in d['retagged_interaction']:
                if step[0] == "*":
                    intents.append(' '.join(step.split("*")[1:]))

            messages[d['id']] = intents

    with open(dump_f, "w+") as f:
        json.dump({}, f)

    for id in messages.keys():
        for m in messages[id]:

            with open(dump_f) as f:
                old_dump = json.load(f)

            if id in old_dump:
                old_dump[id].append("* " + m)
            else:
                old_dump[id] = ['* {}'.format(m)]

            with open(dump_f, "w+") as f:
                json.dump(old_dump, f)

            mes = "/{}".format(m.strip())
            print(mes)
            agent.compile_message(mes,
                                  sender_id=id,
                                  summary=summary,
                                  dump_interaction=dump_f)

    with open(dump_f) as f:
        interactions = json.load(f)

    story_id = 0

    with open(fine_tune_data_file, "w+") as f:

        for dial in interactions:
            f.write("## story_{}\n".format(story_id))
            for step in interactions[dial]:
                if step[0] == "*":
                    f.write("{}\n".format(step))
                else:
                    if not "action_listen" in step:
                        f.write(" {}\n".format(step))
            story_id += 1
            f.write("\n")

    agent = CookieCutterAgent.load(os.path.join(main_folder_path,
                                                agent_folder),
                                   tracker_store=inMemo_tracker,
                                   interpreter=None)

    agent.fine_tune(training_resource_name=fine_tune_data_file,
                    validation_resource_name=validation_data_file,
                    model_path=agent_folder,
                    remove_duplicates=True,
                    augmentation_factor=0,
                    max_history=1)