Пример #1
0
def test_custom_slot_type(tmpdir):
    domain_path = utilities.write_text_to_file(tmpdir, "domain.yml", """
       slots:
         custom:
           type: tests.conftest.CustomSlot

       templates:
         utter_greet:
           - hey there!

       actions:
         - utter_greet """)
    TemplateDomain.load(domain_path)
Пример #2
0
def test_custom_slot_type(tmpdir):
    domain_path = utilities.write_text_to_file(tmpdir, "domain.yml", """
       slots:
         custom:
           type: tests.conftest.CustomSlot

       templates:
         utter_greet:
           - hey there!

       actions:
         - utter_greet """)
    TemplateDomain.load(domain_path)
Пример #3
0
def test_domain_fails_on_unknown_custom_slot_type(tmpdir):
    domain_path = utilities.write_text_to_file(tmpdir, "domain.yml", """
        slots:
            custom:
             type: tests.conftest.Unknown
        
        templates:
            utter_greet:
             - hey there!
        
        actions:
            - utter_greet""")
    with pytest.raises(ValueError):
        TemplateDomain.load(domain_path)
Пример #4
0
def test_custom_slot_type(tmpdir):
    domain_path = write_domain_yml(
        tmpdir, """
       slots:
         custom:
           type: tests.conftest.CustomSlot

       templates:
         utter_greet:
           - hey there!

       actions:
         - utter_greet """)
    TemplateDomain.load(domain_path)
Пример #5
0
def test_domain_fails_on_unknown_custom_slot_type(tmpdir):
    domain_path = utilities.write_text_to_file(tmpdir, "domain.yml", """
        slots:
            custom:
             type: tests.conftest.Unknown

        templates:
            utter_greet:
             - hey there!

        actions:
            - utter_greet""")
    with pytest.raises(ValueError):
        TemplateDomain.load(domain_path)
Пример #6
0
def test_restaurant_form_unhappy_1():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "cuisine"
    tracker.update(events[0])

    # second user utterance does not provide what's asked
    tracker.update(UserUttered("", intent={"name": "inform"}))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    print([(e.key, e.value) for e in events])
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)

    # same slot requested again
    assert events[0].key == "requested_slot"
    assert events[0].value == "cuisine"
Пример #7
0
def main():
    arg_parser = create_argument_parser()
    cmdline_args = arg_parser.parse_args()

    logging.basicConfig(level=cmdline_args.loglevel)

    # Set HTTP log path
    http_logs = os.path.join(dir_path, "bot/logs/http.logs")

    # Load Domain
    domain = TemplateDomain.load(os.path.join(MODEL_PATH, "domain.yml"))

    # Create tracker store
    tracker_store = ExtendedRedisTrackerStore(domain, db=2, timeout=900, mock=False)

    # Initialize Rasa Core Server
    model_directory = MODEL_PATH
    interpreter = INTEPRETER_PATH
    rasa = ExtendedRasaCoreServer(model_directory,
                                  interpreter,
                                  cmdline_args.loglevel,
                                  http_logs,
                                  cmdline_args.cors,
                                  auth_token = cmdline_args.auth_token,
                                  tracker_store=tracker_store)

    logger.info("Started http server on port %s" % cmdline_args.port)

    # Run Rasa Core Server
    rasa.app.run("0.0.0.0", cmdline_args.port)
Пример #8
0
def main():
    arg_parser = create_argument_parser()
    cmdline_args = arg_parser.parse_args()

    logging.basicConfig(level=cmdline_args.loglevel)

    # Set HTTP log path
    http_logs = os.path.join(dir_path, "bot/logs/http.logs")

    # Load Domain
    domain = TemplateDomain.load(os.path.join(MODEL_PATH, "domain.yml"))

    # Create tracker store
    tracker_store = ExtendedRedisTrackerStore(domain,
                                              db=2,
                                              timeout=900,
                                              mock=False)

    # Initialize Rasa Core Server
    model_directory = MODEL_PATH
    interpreter = INTEPRETER_PATH
    rasa = ExtendedRasaCoreServer(model_directory,
                                  interpreter,
                                  cmdline_args.loglevel,
                                  http_logs,
                                  cmdline_args.cors,
                                  auth_token=cmdline_args.auth_token,
                                  tracker_store=tracker_store)

    logger.info("Started http server on port %s" % cmdline_args.port)

    # Run Rasa Core Server
    rasa.app.run("0.0.0.0", cmdline_args.port)
Пример #9
0
    def load(cls,
             path,
             interpreter=None,
             tracker_store=None,
             action_factory=None,
             rules_file=None,
             generator=None,
             create_dispatcher=None):
        # type: (Text, Any, Optional[TrackerStore]) -> Agent

        if path is None:
            raise ValueError("No domain path specified.")
        domain = TemplateDomain.load(os.path.join(path, "domain.yml"),
                                     action_factory)
        # ensures the domain hasn't changed between test and train
        domain.compare_with_specification(path)
        ensemble = PolicyEnsemble.load(path)
        _interpreter = NaturalLanguageInterpreter.create(interpreter)
        _tracker_store = cls.create_tracker_store(tracker_store, domain)
        return cls(domain=domain,
                   policies=ensemble,
                   interpreter=_interpreter,
                   tracker_store=_tracker_store,
                   rules_file=rules_file,
                   generator=generator,
                   create_dispatcher=create_dispatcher)
Пример #10
0
def test_restaurant_form_skipahead():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    entities = [{
        "entity": "cuisine",
        "value": "chinese"
    }, {
        "entity": "number",
        "value": 8
    }]
    tracker.update(
        UserUttered("", intent={"name": "inform"}, entities=entities))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    s = events[0].as_story_string()
    print(events[0].as_story_string())
    print(events[1].as_story_string())
    assert len(events) == 3
    assert events[2].key == "requested_slot"
    assert events[2].value == "vegetarian"
Пример #11
0
def test_restaurant_form():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "cuisine"
    tracker.update(events[0])

    # second user utterance
    entities = [{"entity": "cuisine", "value": "chinese"}]
    tracker.update(
        UserUttered("", intent={"name": "inform"}, entities=entities))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 2
    assert isinstance(events[0], SlotSet)
    assert isinstance(events[1], SlotSet)

    assert events[0].key == "cuisine"
    assert events[0].value == "chinese"

    assert events[1].key == "requested_slot"
    assert events[1].value == "people"
Пример #12
0
def run_hello_world(max_training_samples=10, serve_forever=True):
    training_data = '../mom/data/stories.md'

    default_domain = TemplateDomain.load("../mom/domain.yml")
    agent = Agent(
        default_domain,
        # policies=[SimplePolicy()],
        policies=[MemoizationPolicy(), KerasPolicy()],
        interpreter=HelloInterpreter(),
        tracker_store=InMemoryTrackerStore(default_domain))
    logger.info("Starting to train policy")
    # agent = Agent(default_domain,
    #               policies=[SimplePolicy()],
    #               interpreter=HelloInterpreter(),
    #               tracker_store=InMemoryTrackerStore(default_domain))

    # if serve_forever:
    #     # Attach the commandline input to the controller to handle all
    #     # incoming messages from that channel
    #     agent.handle_channel(ConsoleInputChannel())

    agent.train_online(training_data,
                       input_channel=ConsoleInputChannel(),
                       epochs=1,
                       max_training_samples=max_training_samples)

    return agent
Пример #13
0
def test_travel_form():
    domain = TemplateDomain.load("data/test_domains/travel_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-travel"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchTravel().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "GPE_origin"
    tracker.update(events[0])

    # second user utterance
    entities = [{"entity": "GPE", "value": "Berlin"}]
    tracker.update(
        UserUttered("", intent={"name": "inform"}, entities=entities))
    events = ActionSearchTravel().run(dispatcher, tracker, domain)
    for e in events:
        print(e.as_story_string())
    assert len(events) == 2
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "GPE_origin"
    assert events[0].value == "Berlin"
    assert events[1].key == "requested_slot"
    assert events[1].value == "GPE_destination"
Пример #14
0
    def load(
        cls,
        path,  # type: Text
        interpreter=None,  # type: Union[NLI, Text, None]
        tracker_store=None,  # type: Optional[TrackerStore]
        action_factory=None  # type: Optional[Text]
    ):
        # type: (Text, Any, Optional[TrackerStore]) -> Agent
        """Load a persisted model from the passed path."""

        if path is None:
            raise ValueError("No domain path specified.")

        if os.path.isfile(path):
            raise ValueError("You are trying to load a MODEL from a file "
                             "('{}'), which is not possible. \n"
                             "The persisted path should be a directory "
                             "containing the various model files. \n\n"
                             "If you want to load training data instead of "
                             "a model, use `agent.load_data(...)` "
                             "instead.".format(path))

        ensemble = PolicyEnsemble.load(path)
        domain = TemplateDomain.load(os.path.join(path, "domain.yml"),
                                     action_factory)
        # ensures the domain hasn't changed between test and train
        domain.compare_with_specification(path)
        _interpreter = NaturalLanguageInterpreter.create(interpreter)
        _tracker_store = cls.create_tracker_store(tracker_store, domain)

        return cls(domain, ensemble, _interpreter, _tracker_store)
Пример #15
0
    def __init__(
        self,
        model=None,
        domain=None,
        test_cases=None,
        shuffle=False,
        distinct=True,
        rules=None,
        interpreter=RegexInterpreter(),
        create_output_channel=lambda on_response, domain, processor:
        TestOutputChannel(on_response, domain, processor),
    ):
        self.model = model
        self.distinct = distinct
        self.interpreter = interpreter
        self.domain = TemplateDomain.load(domain)
        self.input_channel = TestInputChannel()
        self.create_output_channel = create_output_channel
        self.agent = self._create_agent(self.input_channel, self.interpreter,
                                        model, rules)

        self.stories = self._build_stories_from_path(test_cases)
        if shuffle:
            random.shuffle(stories)
        self._run_test_cases()
        self.failed = False
Пример #16
0
def test_people_form():
    domain = TemplateDomain.load("data/test_domains/people_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-people"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchPeople().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "person_name"
    tracker.update(events[0])

    # second user utterance
    name = "Rasa Due"
    tracker.update(UserUttered(name, intent={"name": "inform"}))

    events = ActionSearchPeople().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)

    assert events[0].key == "person_name"
    assert events[0].value == name
Пример #17
0
def train(max_training_samples=3, serve_forever=True):
    story = 'stories.md'

    from rasa_core.interpreter import RasaNLUInterpreter
    interpreter = RasaNLUInterpreter(nlu_model_path)

    #domain配置文件

    default_domain = TemplateDomain.load("./domain.yml")

    if os.path.exists(agent_model_path):
        agent = Agent.load(agent_model_path,
                           interpreter=interpreter,
                           tracker_store=InMemoryTrackerStore(default_domain))
    else:
        agent = Agent(domain_conf_path,
                      policies=[MemoizationPolicy(),
                                KerasPolicy()],
                      interpreter=interpreter,
                      tracker_store=InMemoryTrackerStore(default_domain))
    #for debug: print(interpreter.parse(u"你好"))

    logger.info("开始训练...")

    training_data = agent.load_data(story)

    agent.train(training_data, epochs=50)
    return agent
Пример #18
0
def test_query_form_set_username_in_form():
    domain = TemplateDomain.load("data/test_domains/query_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-form"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchQuery().run(dispatcher, tracker, domain)
    last_message = dispatcher.latest_bot_messages[-1]
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "username"
    assert last_message.text == 'what is your name?'
    tracker.update(events[0])

    # second user utterance
    username = '******'
    tracker.update(UserUttered(username, intent={"name": "inform"}))
    events = ActionSearchQuery().run(dispatcher, tracker, domain)
    last_message = dispatcher.latest_bot_messages[-1]
    assert len(events) == 2
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "username"
    assert events[0].value == username
    assert events[1].key == "requested_slot"
    assert events[1].value == "query"
    assert username in last_message.text
Пример #19
0
def test_travel_form():
    domain = TemplateDomain.load("data/test_domains/travel_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-travel"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchTravel().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "GPE_origin"
    tracker.update(events[0])

    # second user utterance
    entities = [{"entity": "GPE", "value": "Berlin"}]
    tracker.update(UserUttered("",
                               intent={"name": "inform"},
                               entities=entities))
    events = ActionSearchTravel().run(dispatcher, tracker, domain)
    for e in events:
        print(e.as_story_string())
    assert len(events) == 2
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "GPE_origin"
    assert events[0].value == "Berlin"
    assert events[1].key == "requested_slot"
    assert events[1].value == "GPE_destination"
Пример #20
0
def test_query_form_set_username_in_form():
    domain = TemplateDomain.load("data/test_domains/query_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-form"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchQuery().run(dispatcher, tracker, domain)
    last_message = dispatcher.latest_bot_messages[-1]
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "username"
    assert last_message.text == 'what is your name?'
    tracker.update(events[0])

    # second user utterance
    username = '******'
    tracker.update(UserUttered(username, intent={"name": "inform"}))
    events = ActionSearchQuery().run(dispatcher, tracker, domain)
    last_message = dispatcher.latest_bot_messages[-1]
    assert len(events) == 2
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "username"
    assert events[0].value == username
    assert events[1].key == "requested_slot"
    assert events[1].value == "query"
    assert username in last_message.text
Пример #21
0
def run_my_world(serve_forever=True):
    default_domain = TemplateDomain.load('common_domain.yml')
    agent = Agent.load("models/policy/current",
                       interpreter=RasaNLUInterpreter(nlu_model_path))
    if serve_forever:
        agent.handle_channel(ConsoleInputChannel())
    return agent
Пример #22
0
def test_people_form():
    domain = TemplateDomain.load("data/test_domains/people_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-people"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchPeople().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "person_name"
    tracker.update(events[0])

    # second user utterance
    name = "Rasa Due"
    tracker.update(
        UserUttered(name,
                    intent={"name": "inform"}))

    events = ActionSearchPeople().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)

    assert events[0].key == "person_name"
    assert events[0].value == name
Пример #23
0
def test_restaurant_form_unhappy_1():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "cuisine"
    tracker.update(events[0])

    # second user utterance does not provide what's asked
    tracker.update(
        UserUttered("",
                    intent={"name": "inform"}))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    print([(e.key, e.value) for e in events])
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)

    # same slot requested again
    assert events[0].key == "requested_slot"
    assert events[0].value == "cuisine"
Пример #24
0
def test_restaurant_form():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "cuisine"
    tracker.update(events[0])

    # second user utterance
    entities = [{"entity": "cuisine", "value": "chinese"}]
    tracker.update(
        UserUttered("",
                    intent={"name": "inform"},
                    entities=entities))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 2
    assert isinstance(events[0], SlotSet)
    assert isinstance(events[1], SlotSet)

    assert events[0].key == "cuisine"
    assert events[0].value == "chinese"

    assert events[1].key == "requested_slot"
    assert events[1].value == "people"
Пример #25
0
    def load(cls,
             path,  # type: Text
             interpreter=None,  # type: Union[NLI, Text, None]
             tracker_store=None,  # type: Optional[TrackerStore]
             action_factory=None,  # type: Optional[Text]
             generator=None  # type: Union[EndpointConfig, NLG]
             ):
        # type: (Text, Any, Optional[TrackerStore]) -> Agent
        """Load a persisted model from the passed path."""

        if path is None:
            raise ValueError("No domain path specified.")

        if os.path.isfile(path):
            raise ValueError("You are trying to load a MODEL from a file "
                             "('{}'), which is not possible. \n"
                             "The persisted path should be a directory "
                             "containing the various model files. \n\n"
                             "If you want to load training data instead of "
                             "a model, use `agent.load_data(...)` "
                             "instead.".format(path))

        ensemble = PolicyEnsemble.load(path)
        domain = TemplateDomain.load(os.path.join(path, "domain.yml"),
                                     action_factory)
        # ensures the domain hasn't changed between test and train
        domain.compare_with_specification(path)
        _tracker_store = cls.create_tracker_store(tracker_store, domain)

        return cls(domain, ensemble, interpreter, generator, _tracker_store)
Пример #26
0
def test_inmemory_tracker_store(filename):
    domain = TemplateDomain.load("data/test_domains/default_with_topic.yml")
    tracker = tracker_from_dialogue_file(filename, domain)
    tracker_store = InMemoryTrackerStore(domain)
    tracker_store.save(tracker)
    restored = tracker_store.retrieve(tracker.sender_id)
    assert restored == tracker
Пример #27
0
def train_dialogue(domain_file=DOMAIN_FILE,
                   model_path=MODEL_PATH,
                   training_data_file=STORIES):
    agent = ExtendedAgent(domain=TemplateDomain.load(domain_file),
                          policies=[SklearnPolicy()])
    logger.info("Begin dialogue training")
    agent.train(filename=training_data_file, max_history=3)
    agent.persist(model_path)
Пример #28
0
async def get_no_of_stories(story_file, domain):
    """Get number of stories in a file."""
    from rasa_core.domain import TemplateDomain
    from rasa_core.training.dsl import StoryFileReader

    stories = await StoryFileReader.read_from_folder(
        story_file, TemplateDomain.load(domain))
    return len(stories)
Пример #29
0
 def trained_policy(self):
     default_domain = TemplateDomain.load("examples/default_domain.yml")
     policy = self.create_policy()
     X, y = train_data(self.max_history, default_domain)
     policy.max_history = self.max_history
     policy.featurizer = BinaryFeaturizer()
     policy.train(X, y, default_domain)
     return policy
 def trained_policy(self):
     default_domain = TemplateDomain.load(DEFAULT_DOMAIN_PATH)
     policy = self.create_policy()
     X, y = train_data(self.max_history, default_domain)
     policy.max_history = self.max_history
     policy.featurizer = BinaryFeaturizer()
     policy.train(X, y, default_domain)
     return policy
Пример #31
0
def test_utter_templates():
    domain_file = "examples/moodbot/domain.yml"
    domain = TemplateDomain.load(domain_file)
    expected_template = {
        "text": "Hey! How are you?",
        "buttons": [{"title": "great", "payload": "great"},
                    {"title": "super sad", "payload": "super sad"}]
    }
    assert domain.random_template_for("utter_greet") == expected_template
Пример #32
0
def test_utter_templates():
    domain_file = "examples/restaurant_domain.yml"
    domain = TemplateDomain.load(domain_file)
    expected_template = {
        "text": "in which price range?",
        "buttons": [{"title": "cheap", "payload": "cheap"},
                    {"title": "expensive", "payload": "expensive"}]
    }
    assert domain.random_template_for("utter_ask_price") == expected_template
Пример #33
0
    def train_dialogue(self):
        domain_file = os.path.join(self.TRAINING_DIR, 'domain.yml')
        stories_file = os.path.abspath(os.path.join(self.TRAINING_DIR, 'story.md'))

        domain = TemplateDomain.load(domain_file)
        # domain.compare_with_specification(os.path.join(self.TRAINING_DIR, 'dialogue'))
        agent = Agent(domain, policies=[MemoizationPolicy(), KerasPolicy()])
        agent.train(stories_file,validation_split=0.1)
        agent.persist(os.path.join(self.TRAINING_DIR, 'dialogue'))
def test_dispatcher_utter_buttons_from_domain_templ(capsys):
    domain_file = "examples/moodbot/domain.yml"
    domain = TemplateDomain.load(domain_file)
    bot = CollectingOutputChannel()
    dispatcher = Dispatcher("my-sender", bot, domain)
    dispatcher.utter_template("utter_greet")
    assert bot.messages[0][1] == "Hey! How are you?"
    assert bot.messages[1][1] == "1: great (great)"
    assert bot.messages[2][1] == "2: super sad (super sad)"
Пример #35
0
def test_utter_templates():
    domain_file = "examples/moodbot/domain.yml"
    domain = TemplateDomain.load(domain_file)
    expected_template = {
        "text": "Hey! How are you?",
        "buttons": [{"title": "great", "payload": "great"},
                    {"title": "super sad", "payload": "super sad"}]
    }
    assert domain.random_template_for("utter_greet") == expected_template
Пример #36
0
def simple():
    from rasa_core.tracker_store import InMemoryTrackerStore
    from rasa_core.domain import TemplateDomain
    default_domain = TemplateDomain.load("../mom/domain.yml")
    agent = Agent(default_domain,
                  policies=[SimplePolicy()],
                  interpreter=HelloInterpreter(),
                  tracker_store=InMemoryTrackerStore(default_domain))
    return agent
Пример #37
0
def test_dispatcher_utter_buttons_from_domain_templ(capsys):
    domain_file = "examples/restaurant_domain.yml"
    domain = TemplateDomain.load(domain_file)
    bot = CollectingOutputChannel()
    dispatcher = Dispatcher("my-sender", bot, domain)
    dispatcher.utter_template("utter_ask_price")
    assert bot.messages[0][1] == "in which price range?"
    assert bot.messages[1][1] == "1: cheap (cheap)"
    assert bot.messages[2][1] == "2: expensive (expensive)"
Пример #38
0
 def _create_domain(cls, domain):
     if isinstance(domain, str):
         return TemplateDomain.load(domain)
     elif isinstance(domain, Domain):
         return domain
     else:
         raise ValueError(
             "Invalid param `domain`. Expected a path to a domain "
             "specification or a domain instance. But got "
             "type '{}' with value '{}'".format(type(domain), domain))
Пример #39
0
def tracker_from_dialogue_file(filename, domain=None):
    dialogue = read_dialogue_file(filename)

    if domain is not None:
        domain = domain
    else:
        domain = TemplateDomain.load(DEFAULT_DOMAIN_PATH)
    tracker = DialogueStateTracker(dialogue.name, domain.slots)
    tracker.recreate_from_dialogue(dialogue)
    return tracker
Пример #40
0
 def load(cls, path, interpreter=None, tracker_store=None):
     # type: (Text, Any, Optional[TrackerStore]) -> Agent
     domain = TemplateDomain.load(os.path.join(path, "domain.yml"))
     # ensures the domain hasn't changed between test and train
     domain.compare_with_specification(path)
     featurizer = Featurizer.load(path)
     ensemble = PolicyEnsemble.load(path, featurizer)
     _interpreter = NaturalLanguageInterpreter.create(interpreter)
     _tracker_store = cls._create_tracker_store(tracker_store, domain)
     return cls(domain, ensemble, featurizer, _interpreter, _tracker_store)
Пример #41
0
    def _create_domain(domain):
        # type: (Union[Domain, Text]) -> Domain

        if isinstance(domain, string_types):
            return TemplateDomain.load(domain)
        elif isinstance(domain, Domain):
            return domain
        else:
            raise ValueError(
                    "Invalid param `domain`. Expected a path to a domain "
                    "specification or a domain instance. But got "
                    "type '{}' with value '{}'".format(type(domain), domain))
Пример #42
0
def test_dispatcher_utter_buttons_from_domain_templ(default_tracker):
    domain_file = "examples/moodbot/domain.yml"
    domain = TemplateDomain.load(domain_file)
    bot = CollectingOutputChannel()
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    dispatcher = Dispatcher("my-sender", bot, nlg)
    dispatcher.utter_template("utter_greet", default_tracker)
    assert len(bot.messages) == 1
    assert bot.messages[0]['text'] == "Hey! How are you?"
    assert bot.messages[0]['data'] == [
        {'payload': 'great', 'title': 'great'},
        {'payload': 'super sad', 'title': 'super sad'}
    ]
Пример #43
0
def test_restaurant_form_unhappy_2():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    entities = [
        {"entity": "cuisine", "value": "chinese"},
        {"entity": "number", "value": 8}]

    tracker.update(
        UserUttered("",
                    intent={"name": "inform"},
                    entities=entities))

    # store all entities as slots
    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)

    for e in events:
        tracker.update(e)

    cuisine = tracker.get_slot("cuisine")
    people = tracker.get_slot("people")
    assert cuisine == "chinese"
    assert people == 8

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 3
    assert isinstance(events[0], SlotSet)
    assert events[2].key == "requested_slot"
    assert events[2].value == "vegetarian"
    tracker.update(events[2])

    # second user utterance does not provide what's asked
    tracker.update(
        UserUttered("",
                    intent={"name": "random"}))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    s = events[0].as_story_string()
    assert len(events) == 1
    assert events[0].key == "requested_slot"
    assert events[0].value == "vegetarian"
Пример #44
0
    def load(cls,
             path,  # type: Text
             core_endpoint,  # type: EndpointConfig
             nlg_endpoint=None,  # type: EndpointConfig
             action_factory=None  # type: Optional[Text]
             ):
        # type: (...) -> RemoteAgent

        if isinstance(core_endpoint, string_types):
            raise Exception("This API has changed. Instead of passing in a url "
                            "for Rasa Core, you now need to pass in an "
                            "instance of 'EndpointConfig'. "
                            "(from rasa_core.utils import EndpointConfig )")

        domain = TemplateDomain.load(os.path.join(path, "domain.yml"),
                                     action_factory)

        core_client = RasaCoreClient(core_endpoint)
        core_client.upload_model(path, max_retries=5)

        return RemoteAgent(domain, core_client, nlg_endpoint)
Пример #45
0
def test_restaurant_form_skipahead():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    entities = [{"entity": "cuisine", "value": "chinese"},
                {"entity": "number", "value": 8}]
    tracker.update(UserUttered("",
                               intent={"name": "inform"},
                               entities=entities))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    s = events[0].as_story_string()
    print(events[0].as_story_string())
    print(events[1].as_story_string())
    assert len(events) == 3
    assert events[2].key == "requested_slot"
    assert events[2].value == "vegetarian"
Пример #46
0
 def default_domain(self):
     return TemplateDomain.load(DEFAULT_DOMAIN_PATH)
Пример #47
0
 def trained_policy(self, featurizer):
     default_domain = TemplateDomain.load(DEFAULT_DOMAIN_PATH)
     policy = self.create_policy(featurizer)
     training_trackers = train_trackers(default_domain)
     policy.train(training_trackers, default_domain)
     return policy
Пример #48
0
def train_dialogue(domain_file=DOMAIN_FILE, model_path=MODEL_PATH, training_data_file=STORIES):
    agent = ExtendedAgent(domain=TemplateDomain.load(domain_file), policies=[SklearnPolicy()])
    logger.info("Begin dialogue training")    
    agent.train(filename=training_data_file, max_history=3)
    agent.persist(model_path)
Пример #49
0
from rasa_core import training, restore
from rasa_core import utils
from rasa_core.actions.action import ActionListen, ACTION_LISTEN_NAME
from rasa_core.channels import UserMessage
from rasa_core.domain import TemplateDomain
from rasa_core.events import (
    UserUttered, ActionExecuted, Restarted, ActionReverted,
    UserUtteranceReverted)
from rasa_core.tracker_store import InMemoryTrackerStore, RedisTrackerStore
from rasa_core.tracker_store import (
    TrackerStore)
from rasa_core.trackers import DialogueStateTracker
from tests.conftest import DEFAULT_STORIES_FILE
from tests.utilities import tracker_from_dialogue_file, read_dialogue_file

domain = TemplateDomain.load("data/test_domains/default.yml")


class MockRedisTrackerStore(RedisTrackerStore):
    def __init__(self, domain):
        self.red = fakeredis.FakeStrictRedis()
        TrackerStore.__init__(self, domain)


def stores_to_be_tested():
    return [MockRedisTrackerStore(domain),
            InMemoryTrackerStore(domain)]


def stores_to_be_tested_ids():
    return ["redis-tracker",
Пример #50
0
def test_domain_from_template():
    domain_file = DEFAULT_DOMAIN_PATH
    domain = TemplateDomain.load(domain_file)
    assert len(domain.intents) == 10
    assert len(domain.actions) == 6