Example #1
0
    def _parse_message(self, message):
        # for testing - you can short-cut the NLU part with a message
        # in the format /intent{"entity1": val1, "entity2": val2}
        # parse_data is a dict of intent & entities
        if message.text.startswith(INTENT_MESSAGE_PREFIX):
            parse_data = RegexInterpreter().parse(message.text)
        else:
            parse_data = self.interpreter.parse(message.text)

        logger.debug("Received user message '{}' with intent '{}' "
                     "and entities '{}'".format(message.text,
                                                parse_data["intent"],
                                                parse_data["entities"]))
        return parse_data
Example #2
0
def test_training_script_with_max_history_set(tmpdir):
    train_dialogue_model(DEFAULT_DOMAIN_PATH,
                         DEFAULT_STORIES_FILE,
                         tmpdir.strpath,
                         interpreter=RegexInterpreter(),
                         policy_config='data/test_config/max_hist_config.yml',
                         kwargs={})
    agent = Agent.load(tmpdir.strpath)
    for policy in agent.policy_ensemble.policies:
        if hasattr(policy.featurizer, 'max_history'):
            if type(policy) == FormPolicy:
                assert policy.featurizer.max_history == 2
            else:
                assert policy.featurizer.max_history == 5
Example #3
0
def test_random_seed(tmpdir, config_file):
    # set random seed in config file to
    # generate a reproducible training result
    agent_1 = train_dialogue_model(DEFAULT_DOMAIN_PATH,
                                   DEFAULT_STORIES_FILE,
                                   tmpdir.strpath + "1",
                                   interpreter=RegexInterpreter(),
                                   policy_config=config_file,
                                   kwargs={})

    agent_2 = train_dialogue_model(DEFAULT_DOMAIN_PATH,
                                   DEFAULT_STORIES_FILE,
                                   tmpdir.strpath + "2",
                                   interpreter=RegexInterpreter(),
                                   policy_config=config_file,
                                   kwargs={})

    processor_1 = agent_1.create_processor()
    processor_2 = agent_2.create_processor()

    probs_1 = processor_1.predict_next("1")
    probs_2 = processor_2.predict_next("2")
    assert probs_1["confidence"] == probs_2["confidence"]
Example #4
0
 def learnonline(self, msg, args):
     """Command to trigger learn_online on rasa agent"""
     token = config.BOT_IDENTITY['token']
     if token is None:
         raise Exception('No slack token')
     train_agent= Agent(self.domain_file,
               policies=[MemoizationPolicy(max_history=2), KerasPolicy()],
               interpreter=RegexInterpreter())
     training_data = train_agent.load_data(self.training_data_file)
     train_agent.train_online(training_data,
                              input_channel=self.backend_adapter,
                              batch_size=50,
                              epochs=200,
                              max_training_samples=300)
Example #5
0
    def _parse_message(self, message):
        # for testing - you can short-cut the NLU part with a message
        # in the format _intent[entity1=val1,entity=val2]
        # parse_data is a dict of intent & entities
        if message.text.startswith('_'):
            parse_data = RegexInterpreter().parse(message.text)
        else:
            parse_data = self.interpreter.parse(message.text)

        logger.debug("Received user message '{}' with intent '{}' "
                     "and entities  '{}'".format(message.text,
                                                 parse_data["intent"],
                                                 parse_data["entities"]))
        return parse_data
Example #6
0
    def read_from_file(file_name,
                       domain,
                       interpreter=RegexInterpreter(),
                       template_variables=None):
        """Given a json file reads the contained stories."""

        try:
            with io.open(file_name, "r") as f:
                lines = f.readlines()
            reader = StoryFileReader(domain, interpreter, template_variables)
            return reader.process_lines(lines)
        except Exception as e:
            raise Exception("Failed to parse '{}'. {}".format(
                os.path.abspath(file_name), e))
def run_fake_user(input_channel, max_training_samples=10, serve_forever=True):
    logger.info("Starting to train policy")
    agent = Agent(RASA_CORE_DOMAIN_PATH,
                  policies=[MemoizationPolicy(), KerasPolicy()],
                  interpreter=RegexInterpreter())

    agent.train_online(RASA_CORE_TRAINING_DATA_PATH,
                       input_channel=input_channel,
                       epochs=RASA_CORE_EPOCHS,
                       max_training_samples=max_training_samples)

    while serve_forever:
        agent.handle_message(UserMessage(back, ConsoleOutputChannel()))

    return agent
Example #8
0
def test_message_processor(default_domain, capsys):
    story_filename = "data/dsl_stories/stories_defaultdomain.md"
    ensemble = SimplePolicyEnsemble([ScoringPolicy()])
    interpreter = RegexInterpreter()

    PolicyTrainer(ensemble, default_domain,
                  BinaryFeaturizer()).train(story_filename, max_history=3)

    tracker_store = InMemoryTrackerStore(default_domain)
    processor = MessageProcessor(interpreter, ensemble, default_domain,
                                 tracker_store)

    processor.handle_message(UserMessage("_greet", ConsoleOutputChannel()))
    out, _ = capsys.readouterr()
    assert "hey there!" in out
Example #9
0
def test_channel_inheritance():
    from rasa_core.channels import RestInput
    from rasa_core.channels import RasaChatInput
    from rasa_core.agent import Agent
    from rasa_core.interpreter import RegexInterpreter

    # load your trained agent
    agent = Agent.load(MODEL_PATH, interpreter=RegexInterpreter())

    rasa_input = RasaChatInput("https://example.com")

    # set serve_forever=False if you want to keep the server running
    s = agent.handle_channels([RestInput(), rasa_input], 5004,
                              serve_forever=False)
    assert s.started
def train_model_online():
    agent = Agent(RASA_CORE_DOMAIN_PATH,
                  policies=[MemoizationPolicy(),
                            StatusPolicy()],
                  interpreter=RegexInterpreter())

    agent.train_online(RASA_CORE_TRAINING_DATA_PATH,
                       input_channel=FileInputChannel(
                           RASA_CORE_TRAINING_DATA_PATH,
                           message_line_pattern='^\s*\*\s(.*)$',
                           max_messages=10),
                       epochs=RASA_CORE_EPOCHS)

    agent.interpreter = RasaNLUInterpreter(RASA_NLU_MODEL_PATH)
    return agent
Example #11
0
def core_server(tmpdir_factory):
    model_path = tmpdir_factory.mktemp("model").strpath

    agent = Agent("data/test_domains/default.yml",
                  policies=[AugmentedMemoizationPolicy(max_history=3)])

    training_data = agent.load_data(DEFAULT_STORIES_FILE)
    agent.train(training_data)
    agent.persist(model_path)

    loaded_agent = Agent.load(model_path, interpreter=RegexInterpreter())

    app = server.create_app(loaded_agent)
    channel.register([RestInput()], app, agent.handle_message, "/webhooks/")
    return app
Example #12
0
def run_babi_online(max_messages=10):
    training_data = 'examples/babi/data/babi_task5_dev_rasa_even_smaller.md'
    logger.info("Starting to train policy")
    agent = Agent("examples/restaurant_domain.yml",
                  policies=[MemoizationPolicy(),
                            RestaurantPolicy()],
                  interpreter=RegexInterpreter())

    input_c = FileInputChannel(training_data,
                               message_line_pattern='^\s*\*\s(.*)$',
                               max_messages=max_messages)
    agent.train_online(training_data, input_channel=input_c, epochs=10)

    agent.interpreter = RasaNLUInterpreter(nlu_model_path)
    return agent
Example #13
0
def test_story_visualization(default_domain, tmpdir):
    story_steps = StoryFileReader.read_from_file(
        "data/test_stories/stories.md",
        default_domain,
        interpreter=RegexInterpreter())
    out_file = tmpdir.join("graph.png").strpath
    generated_graph = visualize_stories(story_steps,
                                        default_domain,
                                        output_file=out_file,
                                        max_history=3,
                                        should_merge_nodes=False)

    assert len(generated_graph.nodes()) == 51

    assert len(generated_graph.edges()) == 56
def run_babi_online(max_messages=10):
    training_data = 'stories.md'
    logger.info("Starting to train policy")
    agent = Agent("domain.yml",
                  policies=[MemoizationPolicy(),
                            MusicPlayerPolicy()],
                  interpreter=RegexInterpreter())

    input_c = FileInputChannel(training_data,
                               message_line_pattern='^\s*\*\s(.*)$',
                               max_messages=max_messages)
    agent.train_online(training_data, input_channel=input_c, epochs=10)

    agent.interpreter = RasaNLUInterpreter(nlu_model_path)
    return agent
Example #15
0
def extract_story_graph(
    resource_name,  # type: Text
    domain,  # type: Domain
    interpreter=None  # type: Optional[NaturalLanguageInterpreter]
):
    # type: (...) -> StoryGraph
    from rasa_core.interpreter import RegexInterpreter
    from rasa_core.training.dsl import StoryFileReader
    from rasa_core.training.structures import StoryGraph

    if not interpreter:
        interpreter = RegexInterpreter()
    story_steps = StoryFileReader.read_from_folder(resource_name, domain,
                                                   interpreter)
    return StoryGraph(story_steps)
Example #16
0
def test_message_processor(default_domain, capsys):
    story_filename = "data/dsl_stories/stories_defaultdomain.md"
    ensemble = SimplePolicyEnsemble([ScoringPolicy()])
    interpreter = RegexInterpreter()

    PolicyTrainer(ensemble, default_domain,
                  BinaryFeaturizer()).train(story_filename, max_history=3)

    tracker_store = InMemoryTrackerStore(default_domain)
    processor = MessageProcessor(interpreter, ensemble, default_domain,
                                 tracker_store)

    out = CollectingOutputChannel()
    processor.handle_message(UserMessage("_greet[name=Core]", out))
    assert ("default", "hey there Core!") == out.latest_output()
Example #17
0
    def run_online_training(self, ensemble, domain, interpreter=None,
                            input_channel=None):
        from rasa_core.agent import Agent
        if interpreter is None:
            interpreter = RegexInterpreter()

        bot = Agent(domain, ensemble,
                    featurizer=self.featurizer,
                    interpreter=interpreter)
        bot.toggle_memoization(False)

        try:
            bot.handle_channel(
                    input_channel if input_channel else ConsoleInputChannel())
        except TrainingFinishedException:
            pass    # training has finished
Example #18
0
def test_regex_interpreter():
    interp = RegexInterpreter()

    text = INTENT_MESSAGE_PREFIX + 'my_intent'
    result = interp.parse(text)
    assert result['text'] == text
    assert len(result['intent_ranking']) == 1
    assert result['intent']['name'] == \
        result['intent_ranking'][0]['name'] == \
        'my_intent'
    assert result['intent']['confidence'] == \
        result['intent_ranking'][0]['confidence'] == \
        pytest.approx(1.0)
    assert len(result['entities']) == 0

    text = INTENT_MESSAGE_PREFIX + 'my_intent{"foo":"bar"}'
    result = interp.parse(text)
    assert result['text'] == text
    assert len(result['intent_ranking']) == 1
    assert result['intent']['name'] == \
        result['intent_ranking'][0]['name'] == \
        'my_intent'
    assert result['intent']['confidence'] == \
        result['intent_ranking'][0]['confidence'] == \
        pytest.approx(1.0)
    assert len(result['entities']) == 1
    assert result["entities"][0]["entity"] == "foo"
    assert result["entities"][0]["value"] == "bar"

    text = INTENT_MESSAGE_PREFIX + '[email protected]'
    result = interp.parse(text)
    assert result['text'] == text
    assert len(result['intent_ranking']) == 1
    assert result['intent']['name'] == \
        result['intent_ranking'][0]['name'] == \
        'my_intent'
    assert result['intent']['confidence'] == \
        result['intent_ranking'][0]['confidence'] == \
        pytest.approx(0.5)
    assert len(result['entities']) == 0

    text = INTENT_MESSAGE_PREFIX + '[email protected]{"foo":"bar"}'
    result = interp.parse(text)
    assert result['text'] == text
    assert len(result['intent_ranking']) == 1
    assert result['intent']['name'] == \
        result['intent_ranking'][0]['name'] == \
        'my_intent'
    assert result['intent']['confidence'] == \
        result['intent_ranking'][0]['confidence'] == \
        pytest.approx(0.5)
    assert len(result['entities']) == 1
    assert result["entities"][0]["entity"] == "foo"
    assert result["entities"][0]["value"] == "bar"
Example #19
0
def train_dialogue_model(domain_file, stories_file, output_path,
                         use_online_learning, nlu_model_path, kwargs):
    agent = Agent(domain_file, policies=[MemoizationPolicy(), KerasPolicy()])

    if use_online_learning:
        if nlu_model_path:
            agent.interpreter = RasaNLUInterpreter(nlu_model_path)
        else:
            agent.interpreter = RegexInterpreter()
        agent.train_online(stories_file,
                           input_channel=ConsoleInputChannel(),
                           epochs=10,
                           model_path=output_path)
    else:
        agent.train(stories_file, validation_split=0.1, **kwargs)

    agent.persist(output_path)
def extract_training_data_from_file(
        filename,  # type: Text
        domain,  # type: Domain
        featurizer=None,  # type: Featurizer
        interpreter=RegexInterpreter(),  # type: NaturalLanguageInterpreter
        augmentation_factor=20,  # type: int
        max_history=1,  # type: int
        remove_duplicates=True,
        max_number_of_trackers=2000  # type: int
):
    # type: (...) -> DialogueTrainingData

    graph = extract_story_graph_from_file(filename, domain, interpreter)
    g = TrainingsDataGenerator(graph, domain, featurizer, remove_duplicates,
                               augmentation_factor, max_history,
                               max_number_of_trackers)
    return g.generate()
Example #21
0
    def read_from_file(filename, domain, interpreter=RegexInterpreter(),
                       template_variables=None):
        """Given a json file reads the contained stories."""

        try:
            with io.open(filename, "r") as f:
                lines = f.readlines()
            reader = StoryFileReader(domain, interpreter, template_variables)
            return reader.process_lines(lines)
        except ValueError as err:
            file_info = ("Invalid story file format. Failed to parse "
                         "'{}'".format(os.path.abspath(filename)))
            logger.exception(file_info)
            if not err.args:
                err.args = ('',)
            err.args = err.args + (file_info,)
            raise
Example #22
0
def run_babi_online():
    training_data = 'data/weather.md'
    logger.info("Starting to train policy")
    agent = Agent("../weather_domain.yml",
                  policies=[MemoizationPolicy(),
                            WeatherPolicy()],
                  interpreter=RegexInterpreter())

    input_c = FileInputChannel(training_data,
                               message_line_pattern='^\s*\*\s(.*)$',
                               max_messages=10)
    agent.train_online(training_data, input_channel=input_c, epochs=10)

    agent.interpreter = RasaNLUHttpInterpreter(
        model_name='model_20171013-084449',
        token=None,
        server='http://localhost:7000')
    return agent
Example #23
0
def run_fake_user(max_training_samples=10, serve_forever=True):
    training_data = 'examples/babi/data/babi_task5_fu_rasa_fewer_actions.md'

    logger.info("Starting to train policy")

    agent = Agent("examples/restaurant_domain.yml",
                  policies=[MemoizationPolicy(),
                            KerasPolicy()],
                  interpreter=RegexInterpreter())

    # Instead of generating the response messages ourselves, the fake user will
    # generate input messages based on the dialogue state
    input_channel = FakeUserInputChannel(agent.tracker_store)

    agent.train_online(training_data,
                       input_channel=input_channel,
                       epochs=1,
                       max_training_samples=max_training_samples)
    return agent
Example #24
0
def extract_story_graph(
    resource_name: Text,
    domain: 'Domain',
    interpreter: Optional['NaturalLanguageInterpreter'] = None,
    use_e2e: bool = False,
    exclusion_percentage: int = None
) -> 'StoryGraph':
    from rasa_core.interpreter import RegexInterpreter
    from rasa_core.training.dsl import StoryFileReader
    from rasa_core.training.structures import StoryGraph

    if not interpreter:
        interpreter = RegexInterpreter()
    story_steps = StoryFileReader.read_from_folder(
        resource_name,
        domain, interpreter,
        use_e2e=use_e2e,
        exclusion_percentage=exclusion_percentage)
    return StoryGraph(story_steps)
Example #25
0
    def run_online_training(
        self,
        domain,  # type: Domain
        interpreter,  # type: NaturalLanguageInterpreter
        input_channel=None  # type: Optional[InputChannel]
    ):
        # type: (...) -> None
        from rasa_core.agent import Agent
        if interpreter is None:
            interpreter = RegexInterpreter()

        bot = Agent(domain, self, interpreter=interpreter)
        bot.toggle_memoization(False)

        try:
            bot.handle_channel(
                input_channel if input_channel else ConsoleInputChannel())
        except TrainingFinishedException:
            pass  # training has finished
Example #26
0
def test_register_channel_without_route():
    """Check we properly connect the input channel blueprint if route is None"""
    from rasa_core.channels import RestInput
    from flask import Flask
    import rasa_core

    # load your trained agent
    agent = Agent.load(MODEL_PATH, interpreter=RegexInterpreter())
    input_channel = RestInput()

    app = Flask(__name__)
    rasa_core.channels.channel.register([input_channel],
                                        app,
                                        agent.handle_message,
                                        route=None)

    routes_list = utils.list_routes(app)
    assert routes_list.get("/webhook").startswith(
        "custom_webhook_RestInput.receive")
Example #27
0
def train_dialog(online=False, nlu=False, policies=[KerasPolicy()]):
    agent = Agent("domain.yml", policies=policies)
    stories = "data\stories.md"
    output_path = "models\dialog"
    if online:
        if nlu:
            agent.interpreter = RasaNLUInterpreter(
                "models/infosys_cui/current")
        else:
            agent.interpreter = RegexInterpreter()
        agent.train_online(stories,
                           input_channel=ConsoleInputChannel(),
                           epochs=10,
                           model_path=output_path)
    else:
        kwargs = {"epochs": 300}
        agent.train(stories, validation_split=0.1, **kwargs)

    agent.persist(output_path)
Example #28
0
def extract_trackers_from_file(
        filename,  # type: Text
        domain,  # type: Domain
        featurizer,  # type: Featurizer
        interpreter=RegexInterpreter(),  # type: NaturalLanguageInterpreter
        max_history=1,  # type: int
        max_number_of_trackers=2000  # type: int
):
    # type: (...) -> List[DialogueStateTracker]

    graph = extract_story_graph_from_file(filename, domain, interpreter)
    g = TrainingsDataGenerator(graph, domain, featurizer,
                               use_story_concatenation=False,
                               max_history=max_history,
                               tracker_limit=1000,
                               remove_duplicates=False,
                               max_number_of_trackers=max_number_of_trackers)
    data = g.generate()
    return data.metadata["trackers"]
Example #29
0
def test_concerts_online_example():
    sys.path.append("examples/concertbot/")
    from train_online import run_concertbot_online
    from rasa_core import utils

    # simulates cmdline input / detailed explanation above
    utils.input = lambda _=None: "2"

    input_channel = FileInputChannel('examples/concertbot/data/stories.md',
                                     message_line_pattern='^\s*\*\s(.*)$',
                                     max_messages=3)
    domain_file = os.path.join("examples", "concertbot", "concert_domain.yml")
    training_file = os.path.join("examples", "concertbot", "data",
                                 "stories.md")
    agent = run_concertbot_online(input_channel, RegexInterpreter(),
                                  domain_file, training_file)
    responses = agent.handle_message("_greet")
    assert responses[-1] in {
        "hey there!", "how can I help you?", "default message"
    }
Example #30
0
def _parse_message(self, message):
    # for testing - you can short-cut the NLU part with a message
    # in the format _intent[entity1=val1,entity=val2]
    # parse_data is a dict of intent & entities
    if message.text.startswith(INTENT_MESSAGE_PREFIX):
        parse_data = RegexInterpreter().parse(message.text)
    elif isinstance(self.interpreter, RasaMultiNLUHttpInterpreter):
        language = message.output_channel.language if hasattr(
            message.output_channel, 'language') else 'en'
        parse_data = self.interpreter.parse(
            message.text,
            message.output_channel.language,
        )
    else:
        parse_data = self.interpreter.parse(message.text)

    logger.debug("Received user message '{}' with intent '{}' "
                 "and entities '{}'".format(message.text, parse_data["intent"],
                                            parse_data["entities"]))
    return parse_data
Example #31
0
    def _parse_message(self, message):
        # for testing - you can short-cut the NLU part with a message
        # in the format _intent[entity1=val1,entity=val2]
        # parse_data is a dict of intent & entities
        if (message.text.startswith(INTENT_MESSAGE_PREFIX) or
                message.text.startswith("_")):
            if RegexInterpreter.is_using_deprecated_format(message.text):
                warnings.warn(
                        "Parsing messages with leading `_` is deprecated and "
                        "will be removed. Instead, prepend your intents with "
                        "`{0}`, e.g. `{0}mood_greet` "
                        "or `{0}restart`.".format(INTENT_MESSAGE_PREFIX))
            parse_data = RegexInterpreter().parse(message.text)
        else:
            parse_data = self.interpreter.parse(message.text)

        logger.debug("Received user message '{}' with intent '{}' "
                     "and entities '{}'".format(message.text,
                                                parse_data["intent"],
                                                parse_data["entities"]))
        return parse_data