コード例 #1
0
async def load_agent_on_start(core_model, endpoints, nlu_model, app, loop):
    """Load an agent.

    Used to be scheduled on server start
    (hence the `app` and `loop` arguments)."""
    from rasa_core import broker
    from rasa_core.agent import Agent

    _interpreter = NaturalLanguageInterpreter.create(nlu_model, endpoints.nlu)
    _broker = broker.from_endpoint_config(endpoints.event_broker)

    _tracker_store = TrackerStore.find_tracker_store(None,
                                                     endpoints.tracker_store,
                                                     _broker)

    if endpoints and endpoints.model:
        from rasa_core import agent

        app.agent = Agent(interpreter=_interpreter,
                          generator=endpoints.nlg,
                          tracker_store=_tracker_store,
                          action_endpoint=endpoints.action)

        await agent.load_from_server(app.agent, model_server=endpoints.model)
    else:
        app.agent = Agent.load(core_model,
                               interpreter=_interpreter,
                               generator=endpoints.nlg,
                               tracker_store=_tracker_store,
                               action_endpoint=endpoints.action)

    return app.agent
コード例 #2
0
def create_agent(model: Text, endpoints: Text = None) -> 'Agent':
    from rasa_core.interpreter import RasaNLUInterpreter
    from rasa_core.tracker_store import TrackerStore
    from rasa_core import broker
    from rasa_core.utils import AvailableEndpoints

    core_path, nlu_path = get_model_subdirectories(model)
    _endpoints = AvailableEndpoints.read_endpoints(endpoints)

    _interpreter = None
    if os.path.exists(nlu_path):
        _interpreter = RasaNLUInterpreter(model_directory=nlu_path)
    else:
        _interpreter = None
        logging.info("No NLU model found. Running without NLU.")

    _broker = broker.from_endpoint_config(_endpoints.event_broker)

    _tracker_store = TrackerStore.find_tracker_store(None,
                                                     _endpoints.tracker_store,
                                                     _broker)

    return Agent.load(core_path,
                      generator=_endpoints.nlg,
                      tracker_store=_tracker_store,
                      action_endpoint=_endpoints.action)
コード例 #3
0
def test_file_broker_from_config():
    cfg = utils.read_endpoint_config(
        "data/test_endpoints/event_brokers/"
        "file_endpoint.yml", "event_broker")
    actual = broker.from_endpoint_config(cfg)

    assert isinstance(actual, FileProducer)
    assert actual.path == "rasa_event.log"
コード例 #4
0
def test_pika_broker_from_config():
    cfg = utils.read_endpoint_config(
        'data/test_endpoints/event_brokers/'
        'pika_endpoint.yml', "event_broker")
    actual = broker.from_endpoint_config(cfg)

    assert isinstance(actual, PikaProducer)
    assert actual.host == "localhost"
    assert actual.credentials.username == "username"
    assert actual.queue == "queue"
コード例 #5
0
def test_file_broker_logs_to_file(tmpdir):
    fname = tmpdir.join("events.log").strpath

    actual = broker.from_endpoint_config(
        EndpointConfig(**{
            "type": "file",
            "path": fname
        }))

    for e in TEST_EVENTS:
        actual.publish(e.as_dict())

    # reading the events from the file one event per line
    recovered = []
    with open(fname, "r") as f:
        for l in f:
            recovered.append(Event.from_parameters(json.loads(l)))

    assert recovered == TEST_EVENTS
コード例 #6
0
def test_file_broker_properly_logs_newlines(tmpdir):
    fname = tmpdir.join("events.log").strpath

    actual = broker.from_endpoint_config(
        EndpointConfig(**{
            "type": "file",
            "path": fname
        }))

    event_with_newline = UserUttered("hello \n there")

    actual.publish(event_with_newline.as_dict())

    # reading the events from the file one event per line
    recovered = []
    with open(fname, "r") as f:
        for l in f:
            recovered.append(Event.from_parameters(json.loads(l)))

    assert recovered == [event_with_newline]
コード例 #7
0
ファイル: train.py プロジェクト: mbabby/rasa_core
def do_interactive_learning(cmdline_args, stories, additional_arguments=None):
    _endpoints = AvailableEndpoints.read_endpoints(cmdline_args.endpoints)
    _interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu,
                                                     _endpoints.nlu)
    from rasa_core.agent import Agent
    from rasa_core.training import interactive

    if cmdline_args.core:
        if cmdline_args.finetune:
            raise ValueError("--core can only be used without "
                             "--finetune flag.")

        logger.info("Loading a pre-trained model. This means that "
                    "all training-related parameters will be ignored.")

        _broker = broker.from_endpoint_config(_endpoints.event_broker)
        _tracker_store = TrackerStore.find_tracker_store(
            None, _endpoints.tracker_store, _broker)

        _agent = Agent.load(cmdline_args.core,
                            interpreter=_interpreter,
                            generator=_endpoints.nlg,
                            tracker_store=_tracker_store,
                            action_endpoint=_endpoints.action)
    else:
        if cmdline_args.out:
            model_directory = cmdline_args.out
        else:
            model_directory = tempfile.mkdtemp(suffix="_core_model")

        _agent = train(cmdline_args.domain, stories, model_directory,
                       _interpreter, _endpoints, cmdline_args.dump_stories,
                       cmdline_args.config[0], None, additional_arguments)

    interactive.run_interactive_learning(
        _agent,
        stories,
        finetune=cmdline_args.finetune,
        skip_visualization=cmdline_args.skip_visualization)
コード例 #8
0
ファイル: up.py プロジェクト: jayd2446/rasa_stack
def start_core(platform_token):
    from rasa_core.utils import AvailableEndpoints
    _endpoints = AvailableEndpoints(
        # TODO: make endpoints more configurable, esp ports
        model=EndpointConfig(
            "http://localhost:5002"
            "/api/projects/default/models/tags/production",
            token=platform_token,
            wait_time_between_pulls=1),
        event_broker=EndpointConfig(**{"type": "file"}),
        nlg=EndpointConfig("http://localhost:5002"
                           "/api/nlg",
                           token=platform_token))

    from rasa_core import broker
    _broker = broker.from_endpoint_config(_endpoints.event_broker)

    from rasa_core.tracker_store import TrackerStore
    _tracker_store = TrackerStore.find_tracker_store(None,
                                                     _endpoints.tracker_store,
                                                     _broker)

    from rasa_core.run import load_agent
    _agent = load_agent("models",
                        interpreter=None,
                        tracker_store=_tracker_store,
                        endpoints=_endpoints)
    from rasa_core.run import serve_application
    print_success("About to start core")

    serve_application(
        _agent,
        "rasa",
        5005,
        "credentials.yml",
        "*",
        None,  # TODO: configure auth token
        True)
コード例 #9
0
def test_load_non_existent_custom_broker_name():
    config = EndpointConfig(**{"type": "rasa_core.broker.MyProducer"})
    assert broker.from_endpoint_config(config) is None
コード例 #10
0
def test_load_custom_broker_name():
    config = EndpointConfig(**{"type": "rasa_core.broker.FileProducer"})
    assert broker.from_endpoint_config(config)
コード例 #11
0
def test_no_broker_in_config():
    cfg = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "event_broker")

    actual = broker.from_endpoint_config(cfg)

    assert actual is None
コード例 #12
0
    arg_parser = create_argument_parser()
    cmdline_args = arg_parser.parse_args()

    logging.getLogger('werkzeug').setLevel(logging.WARN)
    logging.getLogger('engineio').setLevel(logging.WARN)
    logging.getLogger('matplotlib').setLevel(logging.WARN)
    logging.getLogger('socketio').setLevel(logging.ERROR)
    logging.getLogger('pika').setLevel(logging.ERROR)

    utils.configure_colored_logging(cmdline_args.loglevel)
    utils.configure_file_logging(cmdline_args.loglevel, cmdline_args.log_file)

    logger.info("Rasa process starting")

    _endpoints = AvailableEndpoints.read_endpoints(cmdline_args.endpoints)
    _interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu,
                                                     _endpoints.nlu)
    _broker = broker.from_endpoint_config(_endpoints.event_broker)

    _tracker_store = TrackerStore.find_tracker_store(None,
                                                     _endpoints.tracker_store,
                                                     _broker)
    _agent = load_agent(cmdline_args.core,
                        interpreter=_interpreter,
                        tracker_store=_tracker_store,
                        endpoints=_endpoints)
    serve_application(_agent, cmdline_args.connector, cmdline_args.port,
                      cmdline_args.credentials, cmdline_args.cors,
                      cmdline_args.auth_token, cmdline_args.enable_api,
                      cmdline_args.jwt_secret, cmdline_args.jwt_method)