Ejemplo n.º 1
0
def create_agent(model: Text, endpoints: Text = None) -> 'Agent':
    from rasa_core.interpreter import RasaNLUInterpreter
    from rasa_core.tracker_store import TrackerStore
    from rasa_core import broker
    from rasa_core.utils import AvailableEndpoints

    core_path, nlu_path = get_model_subdirectories(model)
    _endpoints = AvailableEndpoints.read_endpoints(endpoints)

    _interpreter = None
    if os.path.exists(nlu_path):
        _interpreter = RasaNLUInterpreter(model_directory=nlu_path)
    else:
        _interpreter = None
        logging.info("No NLU model found. Running without NLU.")

    _broker = broker.from_endpoint_config(_endpoints.event_broker)

    _tracker_store = TrackerStore.find_tracker_store(None,
                                                     _endpoints.tracker_store,
                                                     _broker)

    return Agent.load(core_path,
                      generator=_endpoints.nlg,
                      tracker_store=_tracker_store,
                      action_endpoint=_endpoints.action)
Ejemplo n.º 2
0
async def load_agent_on_start(core_model, endpoints, nlu_model, app, loop):
    """Load an agent.

    Used to be scheduled on server start
    (hence the `app` and `loop` arguments)."""
    from rasa_core import broker
    from rasa_core.agent import Agent

    _interpreter = NaturalLanguageInterpreter.create(nlu_model, endpoints.nlu)
    _broker = broker.from_endpoint_config(endpoints.event_broker)

    _tracker_store = TrackerStore.find_tracker_store(None,
                                                     endpoints.tracker_store,
                                                     _broker)

    if endpoints and endpoints.model:
        from rasa_core import agent

        app.agent = Agent(interpreter=_interpreter,
                          generator=endpoints.nlg,
                          tracker_store=_tracker_store,
                          action_endpoint=endpoints.action)

        await agent.load_from_server(app.agent, model_server=endpoints.model)
    else:
        app.agent = Agent.load(core_model,
                               interpreter=_interpreter,
                               generator=endpoints.nlg,
                               tracker_store=_tracker_store,
                               action_endpoint=endpoints.action)

    return app.agent
Ejemplo n.º 3
0
def test_tracker_store_from_string(default_domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = utils.read_endpoint_config(endpoints_path, "tracker_store")

    tracker_store = TrackerStore.find_tracker_store(default_domain,
                                                    store_config)

    assert isinstance(tracker_store, ExampleTrackerStore)
Ejemplo n.º 4
0
def test_tracker_store_from_invalid_string(default_domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = utils.read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "any string"

    tracker_store = TrackerStore.find_tracker_store(default_domain,
                                                    store_config)

    assert isinstance(tracker_store, InMemoryTrackerStore)
Ejemplo n.º 5
0
def test_tracker_store_from_invalid_module(default_domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = utils.read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "a.module.which.cannot.be.found"

    tracker_store = TrackerStore.find_tracker_store(default_domain,
                                                    store_config)

    assert isinstance(tracker_store, InMemoryTrackerStore)
Ejemplo n.º 6
0
def test_find_tracker_store(default_domain):
    store = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "tracker_store")
    tracker_store = RedisTrackerStore(domain=default_domain,
                                      host="localhost",
                                      port=6379,
                                      db=0,
                                      password="******",
                                      record_exp=3000)

    assert isinstance(
        tracker_store,
        type(TrackerStore.find_tracker_store(default_domain, store)))
Ejemplo n.º 7
0
def do_interactive_learning(cmdline_args, stories, additional_arguments):
    _endpoints = AvailableEndpoints.read_endpoints(cmdline_args.endpoints)
    _interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu,
                                                     _endpoints.nlu)

    if cmdline_args.core:
        if cmdline_args.finetune:
            raise ValueError("--core can only be used without "
                             "--finetune flag.")

        logger.info("Loading a pre-trained model. This means that "
                    "all training-related parameters will be ignored.")

        _broker = PikaProducer.from_endpoint_config(_endpoints.event_broker)
        _tracker_store = TrackerStore.find_tracker_store(
            None, _endpoints.tracker_store, _broker)

        _agent = Agent.load(cmdline_args.core,
                            interpreter=_interpreter,
                            generator=_endpoints.nlg,
                            tracker_store=_tracker_store,
                            action_endpoint=_endpoints.action)
    else:
        if cmdline_args.out:
            model_directory = cmdline_args.out
        else:
            model_directory = tempfile.mkdtemp(suffix="_core_model")

        _agent = train_dialogue_model(cmdline_args.domain, stories,
                                      model_directory, _interpreter,
                                      _endpoints, cmdline_args.dump_stories,
                                      cmdline_args.config[0], None,
                                      additional_arguments)

    interactive.run_interactive_learning(
        _agent,
        stories,
        finetune=cmdline_args.finetune,
        skip_visualization=cmdline_args.skip_visualization)
Ejemplo n.º 8
0
def do_interactive_learning(cmdline_args, stories, additional_arguments):
    _endpoints = AvailableEndpoints.read_endpoints(cmdline_args.endpoints)
    _interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu,
                                                     _endpoints.nlu)

    if (isinstance(cmdline_args.config, list)
            and len(cmdline_args.config) > 1):
        raise ValueError("You can only pass one config file at a time")
    if cmdline_args.core and cmdline_args.finetune:
        raise ValueError("--core can only be used without --finetune flag.")
    elif cmdline_args.core:
        logger.info("loading a pre-trained model. "
                    "all training-related parameters will be ignored")

        _broker = PikaProducer.from_endpoint_config(_endpoints.event_broker)
        _tracker_store = TrackerStore.find_tracker_store(
            None, _endpoints.tracker_store, _broker)

        _agent = Agent.load(cmdline_args.core,
                            interpreter=_interpreter,
                            generator=_endpoints.nlg,
                            tracker_store=_tracker_store,
                            action_endpoint=_endpoints.action)
    else:
        if not cmdline_args.out:
            raise ValueError("you must provide a path where the model "
                             "will be saved using -o / --out")

        _agent = train_dialogue_model(cmdline_args.domain, stories,
                                      cmdline_args.out, _interpreter,
                                      _endpoints, cmdline_args.dump_stories,
                                      cmdline_args.config[0], None,
                                      additional_arguments)
    interactive.run_interactive_learning(
        _agent,
        stories,
        finetune=cmdline_args.finetune,
        skip_visualization=cmdline_args.skip_visualization)
Ejemplo n.º 9
0
def start_core(platform_token):
    from rasa_core.utils import AvailableEndpoints
    _endpoints = AvailableEndpoints(
        # TODO: make endpoints more configurable, esp ports
        model=EndpointConfig(
            "http://localhost:5002"
            "/api/projects/default/models/tags/production",
            token=platform_token,
            wait_time_between_pulls=1),
        event_broker=EndpointConfig(**{"type": "file"}),
        nlg=EndpointConfig("http://localhost:5002"
                           "/api/nlg",
                           token=platform_token))

    from rasa_core import broker
    _broker = broker.from_endpoint_config(_endpoints.event_broker)

    from rasa_core.tracker_store import TrackerStore
    _tracker_store = TrackerStore.find_tracker_store(None,
                                                     _endpoints.tracker_store,
                                                     _broker)

    from rasa_core.run import load_agent
    _agent = load_agent("models",
                        interpreter=None,
                        tracker_store=_tracker_store,
                        endpoints=_endpoints)
    from rasa_core.run import serve_application
    print_success("About to start core")

    serve_application(
        _agent,
        "rasa",
        5005,
        "credentials.yml",
        "*",
        None,  # TODO: configure auth token
        True)
Ejemplo n.º 10
0
    logging.getLogger('werkzeug').setLevel(logging.WARN)
    logging.getLogger('matplotlib').setLevel(logging.WARN)

    utils.configure_colored_logging(cmdline_args.loglevel)
    utils.configure_file_logging(cmdline_args.loglevel,
                                 cmdline_args.log_file)

    logger.info("Rasa process starting")

    _endpoints = AvailableEndpoints.read_endpoints(cmdline_args.endpoints)
    _interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu,
                                                     _endpoints.nlu)
    _broker = PikaProducer.from_endpoint_config(_endpoints.event_broker)

    _tracker_store = TrackerStore.find_tracker_store(
        None, _endpoints.tracker_store, _broker)
    _agent = load_agent(cmdline_args.core,
                        interpreter=_interpreter,
                        tracker_store=_tracker_store,
                        endpoints=_endpoints)
    serve_application(_agent,
                      cmdline_args.connector,
                      cmdline_args.port,
                      cmdline_args.credentials,
                      cmdline_args.cors,
                      cmdline_args.auth_token,
                      cmdline_args.enable_api,
                      cmdline_args.jwt_secret,
                      cmdline_args.jwt_method)