예제 #1
0
def test_query_form_set_username_in_form():
    domain = TemplateDomain.load("data/test_domains/query_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-form"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchQuery().run(dispatcher, tracker, domain)
    last_message = dispatcher.latest_bot_messages[-1]
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "username"
    assert last_message.text == 'what is your name?'
    tracker.update(events[0])

    # second user utterance
    username = '******'
    tracker.update(UserUttered(username, intent={"name": "inform"}))
    events = ActionSearchQuery().run(dispatcher, tracker, domain)
    last_message = dispatcher.latest_bot_messages[-1]
    assert len(events) == 2
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "username"
    assert events[0].value == username
    assert events[1].key == "requested_slot"
    assert events[1].value == "query"
    assert username in last_message.text
예제 #2
0
def test_restaurant_form():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "cuisine"
    tracker.update(events[0])

    # second user utterance
    entities = [{"entity": "cuisine", "value": "chinese"}]
    tracker.update(
        UserUttered("", intent={"name": "inform"}, entities=entities))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 2
    assert isinstance(events[0], SlotSet)
    assert isinstance(events[1], SlotSet)

    assert events[0].key == "cuisine"
    assert events[0].value == "chinese"

    assert events[1].key == "requested_slot"
    assert events[1].value == "people"
예제 #3
0
def test_people_form():
    domain = TemplateDomain.load("data/test_domains/people_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-people"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchPeople().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "person_name"
    tracker.update(events[0])

    # second user utterance
    name = "Rasa Due"
    tracker.update(UserUttered(name, intent={"name": "inform"}))

    events = ActionSearchPeople().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)

    assert events[0].key == "person_name"
    assert events[0].value == name
예제 #4
0
def test_travel_form():
    domain = TemplateDomain.load("data/test_domains/travel_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-travel"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchTravel().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "GPE_origin"
    tracker.update(events[0])

    # second user utterance
    entities = [{"entity": "GPE", "value": "Berlin"}]
    tracker.update(
        UserUttered("", intent={"name": "inform"}, entities=entities))
    events = ActionSearchTravel().run(dispatcher, tracker, domain)
    for e in events:
        print(e.as_story_string())
    assert len(events) == 2
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "GPE_origin"
    assert events[0].value == "Berlin"
    assert events[1].key == "requested_slot"
    assert events[1].value == "GPE_destination"
예제 #5
0
def test_restaurant_form_unhappy_1():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "cuisine"
    tracker.update(events[0])

    # second user utterance does not provide what's asked
    tracker.update(UserUttered("", intent={"name": "inform"}))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    print([(e.key, e.value) for e in events])
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)

    # same slot requested again
    assert events[0].key == "requested_slot"
    assert events[0].value == "cuisine"
예제 #6
0
def test_restaurant_form_skipahead():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    entities = [{
        "entity": "cuisine",
        "value": "chinese"
    }, {
        "entity": "number",
        "value": 8
    }]
    tracker.update(
        UserUttered("", intent={"name": "inform"}, entities=entities))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    s = events[0].as_story_string()
    print(events[0].as_story_string())
    print(events[1].as_story_string())
    assert len(events) == 3
    assert events[2].key == "requested_slot"
    assert events[2].value == "vegetarian"
예제 #7
0
def test_action():
    domain = Domain.load('domain.yml')
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    dispatcher = Dispatcher("my-sender", CollectingOutputChannel(), nlg)
    uid = str(uuid.uuid1())
    tracker = DialogueStateTracker(uid, domain.slots)
    # print ("dispatcher,uid,tracker ===", dispatcher, uid, tracker)
    action = QuoraSearch()
    action.run(dispatcher, tracker, domain)
예제 #8
0
def generate_response(nlg_call, domain):
    kwargs = nlg_call.get("arguments", {})
    template = nlg_call.get("template")
    sender_id = nlg_call.get("tracker", {}).get("sender_id")
    events = nlg_call.get("tracker", {}).get("events")
    tracker = DialogueStateTracker.from_dict(sender_id, events, domain.slots)
    channel_name = nlg_call.get("channel")

    return TemplatedNaturalLanguageGenerator(domain.templates).generate(
        template, tracker, channel_name, **kwargs)
def test_action():
    domain = Domain.load('domain.yml')
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    dispatcher = Dispatcher("my-sender", CollectingOutputChannel(), nlg)
    uid = str(uuid.uuid1())
    tracker = DialogueStateTracker(uid, domain.slots)

    action = ActionJoke()
    action.run(dispatcher, tracker, domain)

    assert 'norris' in dispatcher.output_channel.latest_output()['text'].lower()
예제 #10
0
def test_dispatcher_template_invalid_vars():
    templates = {
        "my_made_up_template": [{
            "text": "a template referencing an invalid {variable}."}]}
    bot = CollectingOutputChannel()
    nlg = TemplatedNaturalLanguageGenerator(templates)
    dispatcher = Dispatcher("my-sender", bot, nlg)
    tracker = DialogueStateTracker("my-sender", slots=[])
    dispatcher.utter_template("my_made_up_template", tracker)
    collected = dispatcher.output_channel.latest_output()
    assert collected['text'].startswith(
            "a template referencing an invalid {variable}.")
예제 #11
0
def test_dispatcher_utter_buttons_from_domain_templ(default_tracker):
    domain_file = "examples/moodbot/domain.yml"
    domain = Domain.load(domain_file)
    bot = CollectingOutputChannel()
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    dispatcher = Dispatcher("my-sender", bot, nlg)
    dispatcher.utter_template("utter_greet", default_tracker)
    assert len(bot.messages) == 1
    assert bot.messages[0]['text'] == "Hey! How are you?"
    assert bot.messages[0]['data'] == [
        {'payload': 'great', 'title': 'great'},
        {'payload': 'super sad', 'title': 'super sad'}
    ]
예제 #12
0
파일: generator.py 프로젝트: mann2107/NLP
    def create(obj, domain):
        """Factory to create a generator."""

        if isinstance(obj, NaturalLanguageGenerator):
            return obj
        elif isinstance(obj, EndpointConfig):
            from rasa_core.nlg import CallbackNaturalLanguageGenerator
            return CallbackNaturalLanguageGenerator(obj)
        elif obj is None:
            from rasa_core.nlg import TemplatedNaturalLanguageGenerator
            return TemplatedNaturalLanguageGenerator(domain.templates)
        else:
            raise Exception("Cannot create a NaturalLanguageGenerator "
                            "based on the passed object. Type: `{}`"
                            "".format(type(obj)))
예제 #13
0
def test_restaurant_form_unhappy_2():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    entities = [{
        "entity": "cuisine",
        "value": "chinese"
    }, {
        "entity": "number",
        "value": 8
    }]

    tracker.update(
        UserUttered("", intent={"name": "inform"}, entities=entities))

    # store all entities as slots
    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)

    for e in events:
        tracker.update(e)

    cuisine = tracker.get_slot("cuisine")
    people = tracker.get_slot("people")
    assert cuisine == "chinese"
    assert people == 8

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 3
    assert isinstance(events[0], SlotSet)
    assert events[2].key == "requested_slot"
    assert events[2].value == "vegetarian"
    tracker.update(events[2])

    # second user utterance does not provide what's asked
    tracker.update(UserUttered("", intent={"name": "random"}))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    s = events[0].as_story_string()
    assert len(events) == 1
    assert events[0].key == "requested_slot"
    assert events[0].value == "vegetarian"
예제 #14
0
    def create(
        obj,  # type: Union[NaturalLanguageGenerator, EndpointConfig, None]
        domain  # type: Optional[Domain]
    ):
        # type: (...) -> NaturalLanguageGenerator
        """Factory to create a generator."""

        if isinstance(obj, NaturalLanguageGenerator):
            return obj
        elif isinstance(obj, EndpointConfig):
            from rasa_core.nlg import CallbackNaturalLanguageGenerator
            return CallbackNaturalLanguageGenerator(obj)
        elif obj is None:
            from rasa_core.nlg import TemplatedNaturalLanguageGenerator
            templates = domain.templates if domain else []
            return TemplatedNaturalLanguageGenerator(templates)
        else:
            raise Exception("Cannot create a NaturalLanguageGenerator "
                            "based on the passed object. Type: `{}`"
                            "".format(type(obj)))
예제 #15
0
def default_nlg(default_domain):
    return TemplatedNaturalLanguageGenerator(default_domain.templates)