Ejemplo n.º 1
0
def train_dialogue_model(domain_file,
                         stories_file,
                         output_path,
                         use_online_learning=False,
                         nlu_model_path=None,
                         max_history=None,
                         kwargs=None):
    if not kwargs:
        kwargs = {}

    agent = Agent(
        domain_file,
        policies=[MemoizationPolicy(max_history=max_history),
                  KerasPolicy()])

    data_load_args, kwargs = utils.extract_args(
        kwargs, {
            "use_story_concatenation", "unique_last_num_states",
            "augmentation_factor", "remove_duplicates", "debug_plots"
        })
    training_data = agent.load_data(stories_file, **data_load_args)

    if use_online_learning:
        if nlu_model_path:
            agent.interpreter = RasaNLUInterpreter(nlu_model_path)
        else:
            agent.interpreter = RegexInterpreter()
        agent.train_online(training_data,
                           input_channel=ConsoleInputChannel(),
                           model_path=output_path,
                           **kwargs)
    else:
        agent.train(training_data, **kwargs)

    agent.persist(output_path)
Ejemplo n.º 2
0
def train_dialogue_model(domain_file, stories_file, output_path,
                         use_online_learning=False, nlu_model_path=None,
                         kwargs=None):
    if not kwargs:
        kwargs = {}

    agent = Agent(domain_file, policies=[MemoizationPolicy(), KerasPolicy()])

    if use_online_learning:
        if nlu_model_path:
            agent.interpreter = RasaNLUInterpreter(nlu_model_path)
        else:
            agent.interpreter = RegexInterpreter()
        agent.train_online(
                stories_file,
                input_channel=ConsoleInputChannel(),
                epochs=10,
                model_path=output_path)
    else:
        agent.train(
                stories_file,
                validation_split=0.1,
                **kwargs
        )

    agent.persist(output_path)
Ejemplo n.º 3
0
def train_dialogue_model(domain_file, stories_file, output_path,
                         use_online_learning=False,
                         nlu_model_path=None,
                         max_history=None,
                         kwargs=None):
    if not kwargs:
        kwargs = {}

    agent = Agent(domain_file, policies=[
        MemoizationPolicy(max_history=max_history),
        KerasPolicy()])
    training_data = agent.load_data(stories_file)

    if use_online_learning:
        if nlu_model_path:
            agent.interpreter = RasaNLUInterpreter(nlu_model_path)
        else:
            agent.interpreter = RegexInterpreter()
        agent.train_online(
                training_data,
                input_channel=ConsoleInputChannel(),
                model_path=output_path,
                **kwargs)
    else:
        agent.train(training_data, **kwargs)

    agent.persist(output_path)
Ejemplo n.º 4
0
def train_dialog(online=False, nlu=False, policies=[KerasPolicy()]):
    agent = Agent("domain.yml", policies=policies)
    stories = "data\stories.md"
    output_path = "models\dialog"
    if online:
        if nlu:
            agent.interpreter = RasaNLUInterpreter(
                "models/infosys_cui/current")
        else:
            agent.interpreter = RegexInterpreter()
        agent.train_online(stories,
                           input_channel=ConsoleInputChannel(),
                           epochs=10,
                           model_path=output_path)
    else:
        kwargs = {"epochs": 300}
        agent.train(stories, validation_split=0.1, **kwargs)

    agent.persist(output_path)
def train_model_online():
    agent = Agent(RASA_CORE_DOMAIN_PATH,
                  policies=[MemoizationPolicy(),
                            StatusPolicy()],
                  interpreter=RegexInterpreter())

    agent.train_online(RASA_CORE_TRAINING_DATA_PATH,
                       input_channel=FileInputChannel(
                           RASA_CORE_TRAINING_DATA_PATH,
                           message_line_pattern='^\s*\*\s(.*)$',
                           max_messages=10),
                       epochs=RASA_CORE_EPOCHS)

    agent.interpreter = RasaNLUInterpreter(RASA_NLU_MODEL_PATH)
    return agent
Ejemplo n.º 6
0
def run_babi_online(max_messages=10):
    training_data = 'examples/babi/data/babi_task5_dev_rasa_even_smaller.md'
    logger.info("Starting to train policy")
    agent = Agent("examples/restaurant_domain.yml",
                  policies=[MemoizationPolicy(),
                            RestaurantPolicy()],
                  interpreter=RegexInterpreter())

    input_c = FileInputChannel(training_data,
                               message_line_pattern='^\s*\*\s(.*)$',
                               max_messages=max_messages)
    agent.train_online(training_data, input_channel=input_c, epochs=10)

    agent.interpreter = RasaNLUInterpreter(nlu_model_path)
    return agent
Ejemplo n.º 7
0
def run_babi_online(max_messages=10):
    training_data = 'stories.md'
    logger.info("Starting to train policy")
    agent = Agent("domain.yml",
                  policies=[MemoizationPolicy(),
                            MusicPlayerPolicy()],
                  interpreter=RegexInterpreter())

    input_c = FileInputChannel(training_data,
                               message_line_pattern='^\s*\*\s(.*)$',
                               max_messages=max_messages)
    agent.train_online(training_data, input_channel=input_c, epochs=10)

    agent.interpreter = RasaNLUInterpreter(nlu_model_path)
    return agent
Ejemplo n.º 8
0
def run_babi_online():
    training_data = 'data/weather.md'
    logger.info("Starting to train policy")
    agent = Agent("../weather_domain.yml",
                  policies=[MemoizationPolicy(),
                            WeatherPolicy()],
                  interpreter=RegexInterpreter())

    input_c = FileInputChannel(training_data,
                               message_line_pattern='^\s*\*\s(.*)$',
                               max_messages=10)
    agent.train_online(training_data, input_channel=input_c, epochs=10)

    agent.interpreter = RasaNLUHttpInterpreter(
        model_name='model_20171013-084449',
        token=None,
        server='http://localhost:7000')
    return agent