示例#1
0
async def test_events_schema(
    monkeypatch: MonkeyPatch, default_agent: Agent, config_path: Text
):
    # this allows us to patch the printing part used in debug mode to collect the
    # reported events
    monkeypatch.setenv("RASA_TELEMETRY_DEBUG", "true")
    monkeypatch.setenv("RASA_TELEMETRY_ENABLED", "true")

    mock = Mock()
    monkeypatch.setattr(telemetry, "print_telemetry_event", mock)

    with open(TELEMETRY_EVENTS_JSON) as f:
        schemas = json.load(f)["events"]
    initial = asyncio.all_tasks()
    # Generate all known backend telemetry events, and then use events.json to
    # validate their schema.
    training_data = TrainingDataImporter.load_from_config(config_path)

    with telemetry.track_model_training(training_data, "rasa"):
        await asyncio.sleep(1)

    telemetry.track_telemetry_disabled()

    telemetry.track_data_split(0.5, "nlu")

    telemetry.track_validate_files(True)

    telemetry.track_data_convert("yaml", "nlu")

    telemetry.track_tracker_export(5, TrackerStore(domain=None), EventBroker())

    telemetry.track_interactive_learning_start(True, False)

    telemetry.track_server_start([CmdlineInput()], None, None, 42, True)

    telemetry.track_project_init("tests/")

    telemetry.track_shell_started("nlu")

    telemetry.track_rasa_x_local()

    telemetry.track_visualization()

    telemetry.track_core_model_test(5, True, default_agent)

    telemetry.track_nlu_model_test(TrainingData())

    pending = asyncio.all_tasks() - initial
    await asyncio.gather(*pending)

    assert mock.call_count == 15

    for args, _ in mock.call_args_list:
        event = args[0]
        # `metrics_id` automatically gets added to all event but is
        # not part of the schema so we need to remove it before validation
        del event["properties"]["metrics_id"]
        jsonschema.validate(
            instance=event["properties"], schema=schemas[event["event"]]
        )
示例#2
0
文件: run.py 项目: hitlmt/rasa
def serve_application(
    model_path: Optional[Text] = None,
    channel: Optional[Text] = None,
    port: int = constants.DEFAULT_SERVER_PORT,
    credentials: Optional[Text] = None,
    cors: Optional[Union[Text, List[Text]]] = None,
    auth_token: Optional[Text] = None,
    enable_api: bool = True,
    response_timeout: int = constants.DEFAULT_RESPONSE_TIMEOUT,
    jwt_secret: Optional[Text] = None,
    jwt_method: Optional[Text] = None,
    endpoints: Optional[AvailableEndpoints] = None,
    remote_storage: Optional[Text] = None,
    log_file: Optional[Text] = None,
    ssl_certificate: Optional[Text] = None,
    ssl_keyfile: Optional[Text] = None,
    ssl_ca_file: Optional[Text] = None,
    ssl_password: Optional[Text] = None,
    conversation_id: Optional[Text] = uuid.uuid4().hex,
) -> None:
    """Run the API entrypoint."""

    if not channel and not credentials:
        channel = "cmdline"

    input_channels = create_http_input_channels(channel, credentials)

    app = configure_app(
        input_channels,
        cors,
        auth_token,
        enable_api,
        response_timeout,
        jwt_secret,
        jwt_method,
        port=port,
        endpoints=endpoints,
        log_file=log_file,
        conversation_id=conversation_id,
    )

    ssl_context = server.create_ssl_context(ssl_certificate, ssl_keyfile,
                                            ssl_ca_file, ssl_password)
    protocol = "https" if ssl_context else "http"

    logger.info(
        f"Starting Rasa server on {constants.DEFAULT_SERVER_FORMAT.format(protocol, port)}"
    )

    app.register_listener(
        partial(load_agent_on_start, model_path, endpoints, remote_storage),
        "before_server_start",
    )
    app.register_listener(close_resources, "after_server_stop")

    # noinspection PyUnresolvedReferences
    async def clear_model_files(_app: Sanic, _loop: Text) -> None:
        if app.agent.model_directory:
            shutil.rmtree(_app.agent.model_directory)

    number_of_workers = rasa.core.utils.number_of_sanic_workers(
        endpoints.lock_store if endpoints else None)

    telemetry.track_server_start(input_channels, endpoints, model_path,
                                 number_of_workers, enable_api)

    app.register_listener(clear_model_files, "after_server_stop")

    rasa.utils.common.update_sanic_log_level(log_file)
    app.run(
        host="0.0.0.0",
        port=port,
        ssl=ssl_context,
        backlog=int(os.environ.get(ENV_SANIC_BACKLOG, "100")),
        workers=number_of_workers,
    )
示例#3
0
文件: run.py 项目: zoovu/rasa
def serve_application(
    model_path: Optional[Text] = None,
    channel: Optional[Text] = None,
    interface: Optional[Text] = constants.DEFAULT_SERVER_INTERFACE,
    port: int = constants.DEFAULT_SERVER_PORT,
    credentials: Optional[Text] = None,
    cors: Optional[Union[Text, List[Text]]] = None,
    auth_token: Optional[Text] = None,
    enable_api: bool = True,
    response_timeout: int = constants.DEFAULT_RESPONSE_TIMEOUT,
    jwt_secret: Optional[Text] = None,
    jwt_method: Optional[Text] = None,
    endpoints: Optional[AvailableEndpoints] = None,
    remote_storage: Optional[Text] = None,
    log_file: Optional[Text] = None,
    ssl_certificate: Optional[Text] = None,
    ssl_keyfile: Optional[Text] = None,
    ssl_ca_file: Optional[Text] = None,
    ssl_password: Optional[Text] = None,
    conversation_id: Optional[Text] = uuid.uuid4().hex,
    use_syslog: Optional[bool] = False,
    syslog_address: Optional[Text] = None,
    syslog_port: Optional[int] = None,
    syslog_protocol: Optional[Text] = None,
) -> None:
    """Run the API entrypoint."""
    if not channel and not credentials:
        channel = "cmdline"

    input_channels = create_http_input_channels(channel, credentials)

    app = configure_app(
        input_channels,
        cors,
        auth_token,
        enable_api,
        response_timeout,
        jwt_secret,
        jwt_method,
        port=port,
        endpoints=endpoints,
        log_file=log_file,
        conversation_id=conversation_id,
        use_syslog=use_syslog,
        syslog_address=syslog_address,
        syslog_port=syslog_port,
        syslog_protocol=syslog_protocol,
    )

    ssl_context = server.create_ssl_context(ssl_certificate, ssl_keyfile,
                                            ssl_ca_file, ssl_password)
    protocol = "https" if ssl_context else "http"

    logger.info(f"Starting Rasa server on {protocol}://{interface}:{port}")

    app.register_listener(
        partial(load_agent_on_start, model_path, endpoints, remote_storage),
        "before_server_start",
    )
    app.register_listener(close_resources, "after_server_stop")

    number_of_workers = rasa.core.utils.number_of_sanic_workers(
        endpoints.lock_store if endpoints else None)

    telemetry.track_server_start(input_channels, endpoints, model_path,
                                 number_of_workers, enable_api)

    rasa.utils.common.update_sanic_log_level(
        log_file,
        use_syslog,
        syslog_address,
        syslog_port,
        syslog_protocol,
    )

    app.run(
        host=interface,
        port=port,
        ssl=ssl_context,
        backlog=int(os.environ.get(ENV_SANIC_BACKLOG, "100")),
        workers=number_of_workers,
    )