Ejemplo n.º 1
0
def in_process_test_workspace(instance, recon_repo):
    with WorkspaceProcessContext(
            instance,
            TestInProcessWorkspaceLoadTarget(
                InProcessRepositoryLocationOrigin(
                    recon_repo))) as workspace_process_context:
        yield workspace_process_context.create_request_context()
def test_terminate_after_shutdown():
    with instance_for_test() as instance:
        with WorkspaceProcessContext(
            instance,
            PythonFileTarget(
                python_file=file_relative_path(__file__, "test_default_run_launcher.py"),
                attribute="nope",
                working_directory=None,
                location_name="test",
            ),
        ) as workspace_process_context:
            workspace = workspace_process_context.create_request_context()

            external_pipeline = (
                workspace.get_repository_location("test")
                .get_repository("nope")
                .get_full_external_pipeline("sleepy_pipeline")
            )

            pipeline_run = instance.create_run_for_pipeline(
                pipeline_def=sleepy_pipeline,
                run_config=None,
                external_pipeline_origin=external_pipeline.get_external_origin(),
                pipeline_code_origin=external_pipeline.get_python_origin(),
            )

            instance.launch_run(pipeline_run.run_id, workspace)

            poll_for_step_start(instance, pipeline_run.run_id)

            repository_location = workspace.get_repository_location("test")
            # Tell the server to shut down once executions finish
            repository_location.grpc_server_registry.get_grpc_endpoint(
                repository_location.origin
            ).create_client().shutdown_server()

            external_pipeline = (
                workspace.get_repository_location("test")
                .get_repository("nope")
                .get_full_external_pipeline("math_diamond")
            )

            doomed_to_fail_pipeline_run = instance.create_run_for_pipeline(
                pipeline_def=math_diamond,
                run_config=None,
                external_pipeline_origin=external_pipeline.get_external_origin(),
                pipeline_code_origin=external_pipeline.get_python_origin(),
            )

            with pytest.raises(DagsterLaunchFailedError):
                instance.launch_run(doomed_to_fail_pipeline_run.run_id, workspace)

            launcher = instance.run_launcher

            # Can terminate the run even after the shutdown event has been received
            assert launcher.can_terminate(pipeline_run.run_id)
            assert launcher.terminate(pipeline_run.run_id)
Ejemplo n.º 3
0
def get_workspace_process_context_from_kwargs(
        instance: DagsterInstance, version: str, read_only: bool,
        kwargs: Dict[str, str]) -> "WorkspaceProcessContext":
    from dagster.core.workspace import WorkspaceProcessContext

    return WorkspaceProcessContext(instance,
                                   get_workspace_load_target(kwargs),
                                   version=version,
                                   read_only=read_only)
def test_run_always_finishes():  # pylint: disable=redefined-outer-name
    with instance_for_test() as instance:
        loadable_target_origin = LoadableTargetOrigin(
            executable_path=sys.executable,
            attribute="nope",
            python_file=file_relative_path(__file__, "test_default_run_launcher.py"),
        )
        server_process = GrpcServerProcess(
            loadable_target_origin=loadable_target_origin, max_workers=4
        )
        with server_process.create_ephemeral_client():  # Shuts down when leaves context
            with WorkspaceProcessContext(
                instance,
                GrpcServerTarget(
                    host="localhost",
                    socket=server_process.socket,
                    port=server_process.port,
                    location_name="test",
                ),
            ) as workspace_process_context:
                workspace = workspace_process_context.create_request_context()

                external_pipeline = (
                    workspace.get_repository_location("test")
                    .get_repository("nope")
                    .get_full_external_pipeline("slow_pipeline")
                )

                pipeline_run = instance.create_run_for_pipeline(
                    pipeline_def=slow_pipeline,
                    run_config=None,
                    external_pipeline_origin=external_pipeline.get_external_origin(),
                    pipeline_code_origin=external_pipeline.get_python_origin(),
                )
                run_id = pipeline_run.run_id

                assert instance.get_run_by_id(run_id).status == PipelineRunStatus.NOT_STARTED

                instance.launch_run(run_id=run_id, workspace=workspace)

        # Server process now receives shutdown event, run has not finished yet
        pipeline_run = instance.get_run_by_id(run_id)
        assert not pipeline_run.is_finished
        assert server_process.server_process.poll() is None

        # Server should wait until run finishes, then shutdown
        pipeline_run = poll_for_finished_run(instance, run_id)
        assert pipeline_run.status == PipelineRunStatus.SUCCESS

        start_time = time.time()
        while server_process.server_process.poll() is None:
            time.sleep(0.05)
            # Verify server process cleans up eventually
            assert time.time() - start_time < 5

        server_process.wait()
Ejemplo n.º 5
0
def create_asgi_client(instance):
    yaml_paths = [file_relative_path(__file__, "./workspace.yaml")]

    with WorkspaceProcessContext(
        instance=instance,
        workspace_load_target=WorkspaceFileTarget(paths=yaml_paths),
        version="",
        read_only=True,
    ) as process_context:
        yield TestClient(DagitWebserver(process_context).create_asgi_app())
Ejemplo n.º 6
0
def get_main_workspace(instance):
    with WorkspaceProcessContext(
            instance,
            PythonFileTarget(
                python_file=file_relative_path(__file__, "setup.py"),
                attribute=main_repo_name(),
                working_directory=None,
                location_name=main_repo_location_name(),
            ),
    ) as workspace_process_context:
        yield workspace_process_context.create_request_context()
Ejemplo n.º 7
0
def create_subscription_context(instance):
    ws = mock.Mock()
    yaml_paths = [file_relative_path(__file__, "./workspace.yaml")]

    with WorkspaceProcessContext(
            instance=instance,
            workspace_load_target=WorkspaceFileTarget(paths=yaml_paths),
            version="",
            read_only=True,
    ) as process_context:
        yield GeventConnectionContext(ws, process_context)
Ejemplo n.º 8
0
def get_bar_workspace(instance):
    with WorkspaceProcessContext(
            instance,
            PythonFileTarget(
                python_file=file_relative_path(__file__, "api_tests_repo.py"),
                attribute="bar_repo",
                working_directory=None,
                location_name="bar_repo_location",
            ),
    ) as workspace_process_context:
        yield workspace_process_context.create_request_context()
Ejemplo n.º 9
0
def in_process_test_workspace(instance,
                              loadable_target_origin,
                              container_image=None):
    with WorkspaceProcessContext(
            instance,
            InProcessTestWorkspaceLoadTarget(
                InProcessRepositoryLocationOrigin(
                    loadable_target_origin,
                    container_image=container_image,
                ), ),
    ) as workspace_process_context:
        yield workspace_process_context.create_request_context()
Ejemplo n.º 10
0
async def graphql_http_endpoint(
    schema: Schema,
    process_context: WorkspaceProcessContext,
    app_path_prefix: str,
    request: Request,
):
    """
    fork of starlette GraphQLApp to allow for
        * our context type (crucial)
        * our GraphiQL playground (could change)
    """

    if request.method == "GET":
        # render graphiql
        if "text/html" in request.headers.get("Accept", ""):
            text = TEMPLATE.replace("{{ app_path_prefix }}", app_path_prefix)
            return HTMLResponse(text)

        data: Union[Dict[str, str], QueryParams] = request.query_params

    elif request.method == "POST":
        content_type = request.headers.get("Content-Type", "")

        if "application/json" in content_type:
            data = await request.json()
        elif "application/graphql" in content_type:
            body = await request.body()
            text = body.decode()
            data = {"query": text}
        elif "query" in request.query_params:
            data = request.query_params
        else:
            return PlainTextResponse(
                "Unsupported Media Type",
                status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
            )

    else:
        return PlainTextResponse(
            "Method Not Allowed",
            status_code=status.HTTP_405_METHOD_NOT_ALLOWED)

    if "query" not in data:
        return PlainTextResponse(
            "No GraphQL query found in the request",
            status_code=status.HTTP_400_BAD_REQUEST,
        )

    query = data["query"]
    variables = data.get("variables")
    operation_name = data.get("operationName")

    # context manager? scoping?
    context = process_context.create_request_context()

    result = await run_in_threadpool(  # threadpool = aio event loop
        schema.execute,
        query,
        variables=variables,
        operation_name=operation_name,
        context=context,
    )

    error_data = [format_graphql_error(err)
                  for err in result.errors] if result.errors else None
    response_data = {"data": result.data}
    if error_data:
        response_data["errors"] = error_data
    status_code = status.HTTP_400_BAD_REQUEST if result.errors else status.HTTP_200_OK

    return JSONResponse(response_data, status_code=status_code)
Ejemplo n.º 11
0
async def graphql_ws_endpoint(
    schema: Schema,
    process_context: WorkspaceProcessContext,
    scope: Scope,
    receive: Receive,
    send: Send,
):
    """
    Implementation of websocket ASGI endpoint for GraphQL.
    Once we are free of conflicting deps, we should be able to use an impl from
    strawberry-graphql or the like.
    """

    websocket = WebSocket(scope=scope, receive=receive, send=send)

    observables = {}
    tasks = {}

    await websocket.accept(subprotocol=GraphQLWS.PROTOCOL)

    try:
        while (websocket.client_state != WebSocketState.DISCONNECTED
               and websocket.application_state != WebSocketState.DISCONNECTED):
            message = await websocket.receive_json()
            operation_id = message.get("id")
            message_type = message.get("type")

            if message_type == GraphQLWS.CONNECTION_INIT:
                await websocket.send_json({"type": GraphQLWS.CONNECTION_ACK})

            elif message_type == GraphQLWS.CONNECTION_TERMINATE:
                await websocket.close()
            elif message_type == GraphQLWS.START:
                try:
                    data = message["payload"]
                    query = data["query"]
                    variables = data.get("variables")
                    operation_name = data.get("operation_name")

                    # correct scoping?
                    request_context = process_context.create_request_context()
                    async_result = schema.execute(
                        query,
                        variables=variables,
                        operation_name=operation_name,
                        context=request_context,
                        allow_subscriptions=True,
                    )
                except GraphQLError as error:
                    payload = format_graphql_error(error)
                    await _send_message(websocket, GraphQLWS.ERROR, payload,
                                        operation_id)
                    continue

                if isinstance(async_result, ExecutionResult):
                    if not async_result.errors:
                        check.failed(
                            f"Only expect non-async result on error, got {async_result}"
                        )
                    payload = format_graphql_error(
                        async_result.errors[0])  # type: ignore
                    await _send_message(websocket, GraphQLWS.ERROR, payload,
                                        operation_id)
                    continue

                # in the future we should get back async gen directly, back compat for now
                disposable, async_gen = _disposable_and_async_gen_from_obs(
                    async_result)

                observables[operation_id] = disposable
                tasks[operation_id] = get_event_loop().create_task(
                    handle_async_results(async_gen, operation_id, websocket))
            elif message_type == GraphQLWS.STOP:
                if operation_id not in observables:
                    return

                observables[operation_id].dispose()
                del observables[operation_id]

                tasks[operation_id].cancel()
                del tasks[operation_id]

    except WebSocketDisconnect:
        pass
    finally:
        for operation_id in observables:
            observables[operation_id].dispose()
            tasks[operation_id].cancel()
def test_server_down():
    with instance_for_test() as instance:
        loadable_target_origin = LoadableTargetOrigin(
            executable_path=sys.executable,
            attribute="nope",
            python_file=file_relative_path(__file__, "test_default_run_launcher.py"),
        )

        server_process = GrpcServerProcess(
            loadable_target_origin=loadable_target_origin, max_workers=4, force_port=True
        )

        with server_process.create_ephemeral_client() as api_client:
            with WorkspaceProcessContext(
                instance,
                GrpcServerTarget(
                    location_name="test",
                    port=api_client.port,
                    socket=api_client.socket,
                    host=api_client.host,
                ),
            ) as workspace_process_context:
                workspace = workspace_process_context.create_request_context()

                external_pipeline = (
                    workspace.get_repository_location("test")
                    .get_repository("nope")
                    .get_full_external_pipeline("sleepy_pipeline")
                )

                pipeline_run = instance.create_run_for_pipeline(
                    pipeline_def=sleepy_pipeline,
                    run_config=None,
                    external_pipeline_origin=external_pipeline.get_external_origin(),
                    pipeline_code_origin=external_pipeline.get_python_origin(),
                )

                instance.launch_run(pipeline_run.run_id, workspace)

                poll_for_step_start(instance, pipeline_run.run_id)

                launcher = instance.run_launcher
                assert launcher.can_terminate(pipeline_run.run_id)

                original_run_tags = instance.get_run_by_id(pipeline_run.run_id).tags[GRPC_INFO_TAG]

                # Replace run tags with an invalid port
                instance.add_run_tags(
                    pipeline_run.run_id,
                    {
                        GRPC_INFO_TAG: seven.json.dumps(
                            merge_dicts({"host": "localhost"}, {"port": find_free_port()})
                        )
                    },
                )

                assert not launcher.can_terminate(pipeline_run.run_id)

                instance.add_run_tags(
                    pipeline_run.run_id,
                    {
                        GRPC_INFO_TAG: original_run_tags,
                    },
                )

                assert launcher.terminate(pipeline_run.run_id)

        server_process.wait()