Example #1
0
def init_and_tear_down_server(request):
    """
    Once per run of the entire set of tests, we create a new server, and
    clean it up at the end.
    """
    global SERVER_PORT
    SERVER_PORT = _get_safe_port()
    file_store_path = tempfile.mkdtemp("test_rest_tracking_file_store")
    env = {FILE_STORE_ENV_VAR: file_store_path}
    with mock.patch.dict(os.environ, env):
        process = Process(target=lambda: app.run(LOCALHOST, SERVER_PORT))
        process.start()
    _await_server_up_or_die(SERVER_PORT)

    # Yielding here causes pytest to resume execution at the end of all tests.
    yield

    print("Terminating server...")
    process.terminate()
    _await_server_down_or_die(process)
Example #2
0
def init_and_tear_down_server(request):
    """
    Once per run of the entire set of tests, we create a new server, and
    clean it up at the end.
    """
    mlflow.set_tracking_uri(None)
    global SERVER_PORT
    SERVER_PORT = _get_safe_port()
    env = {BACKEND_STORE_URI_ENV_VAR: server_root_dir}
    with mock.patch.dict(os.environ, env):
        process = Process(target=lambda: app.run(LOCALHOST, SERVER_PORT))
        process.start()
    _await_server_up_or_die(SERVER_PORT)

    # Yielding here causes pytest to resume execution at the end of all tests.
    yield

    print("Terminating server...")
    process.terminate()
    _await_server_down_or_die(process)
    shutil.rmtree(server_root_dir)
def _init_server(backend_uri, root_artifact_uri):
    """
    Launch a new REST server using the tracking store specified by backend_uri and root artifact
    directory specified by root_artifact_uri.
    :returns A tuple (url, process) containing the string URL of the server and a handle to the
             server process (a multiprocessing.Process object).
    """
    mlflow.set_tracking_uri(None)
    server_port = _get_safe_port()
    env = {
        BACKEND_STORE_URI_ENV_VAR: backend_uri,
        ARTIFACT_ROOT_ENV_VAR: tempfile.mkdtemp(dir=root_artifact_uri),
    }
    with mock.patch.dict(os.environ, env):
        process = Process(target=lambda: app.run(LOCALHOST, server_port))
        process.start()
    _await_server_up_or_die(server_port)
    url = "http://{hostname}:{port}".format(hostname=LOCALHOST,
                                            port=server_port)
    print("Launching tracking server against backend URI %s. Server URL: %s" %
          (backend_uri, url))
    return url, process
from mlflow.server import app
from auth_middleware import AuthMiddleware

app.wsgi_app = AuthMiddleware(app.wsgi_app)

if __name__ == '__main__':
    app.run()