Exemple #1
0
def test_reload_on_request_context():
    with instance_for_test() as instance:
        with define_out_of_process_workspace(__file__,
                                             "get_repo") as workspace:
            # Create a process context
            process_context = WorkspaceProcessContext(workspace=workspace,
                                                      instance=instance)
            assert len(process_context.repository_locations) == 1

            # Save the repository name
            repository_location = process_context.repository_locations[0]
            repo = list(repository_location.get_repositories().values())[0]
            repo_name = repo.name

            # Create a request context from the process context
            request_context = process_context.create_request_context()

            # Reload the location and save the new repository name
            process_context.reload_repository_location(
                repository_location.name)
            repository_location = process_context.repository_locations[0]
            repo = list(repository_location.get_repositories().values())[0]
            new_repo_name = repo.name

            # Save the repository name from the request context
            repository_location = request_context.repository_locations[0]
            repo = list(repository_location.get_repositories().values())[0]
            request_context_repo_name = repo.name

            # Check that the repository has changed
            assert repo_name != new_repo_name

            # Check that the repository name is still the same on the request context,
            # confirming that the old repository location is still running
            assert repo_name == request_context_repo_name
Exemple #2
0
def test_can_reload_on_external_repository_error():
    with instance_for_test() as instance:
        with ExitStack() as exit_stack:
            with mock.patch(
                    # note it where the function is *used* that needs to mocked, not
                    # where it is defined.
                    # see https://docs.python.org/3/library/unittest.mock.html#where-to-patch
                    "dagster.core.host_representation.handle.sync_get_streaming_external_repositories_grpc"
            ) as external_repository_mock:
                external_repository_mock.side_effect = Exception(
                    "get_external_repo_failure")

                with pytest.warns(
                        UserWarning,
                        match=re.escape("get_external_repo_failure")):
                    workspace = exit_stack.enter_context(
                        define_out_of_process_workspace(__file__, "get_repo"))

                assert not workspace.has_repository_location_handle(
                    main_repo_location_name())
                assert workspace.has_repository_location_error(
                    main_repo_location_name())
                process_context = WorkspaceProcessContext(workspace=workspace,
                                                          instance=instance)
                assert len(process_context.repository_locations) == 0

            workspace.reload_repository_location(main_repo_location_name())
            assert workspace.has_repository_location_handle(
                main_repo_location_name())
            process_context = WorkspaceProcessContext(workspace=workspace,
                                                      instance=instance)
            assert len(process_context.repository_locations) == 1
Exemple #3
0
def test_handle_cleaup_by_gc_with_dangling_request_reference():
    called = {"yup": False}

    def call_me():
        called["yup"] = True

    with instance_for_test() as instance:
        with define_out_of_process_workspace(__file__,
                                             "get_repo") as workspace:
            # Create a process context
            process_context = WorkspaceProcessContext(workspace=workspace,
                                                      instance=instance)
            process_context.repository_locations[  # pylint: disable=protected-access
                0]._handle.cleanup = call_me

            assert len(process_context.repository_locations) == 1

            assert not called["yup"]

            # The request context maintains a reference to the location handle through the
            # repository location
            request_context = (  # pylint: disable=unused-variable
                process_context.create_request_context())

            # Even though we reload, verify the handle isn't cleaned up
            process_context.reload_repository_location("test_location")
            gc.collect()
            assert not called["yup"]

            # Free reference, make sure handle is cleaned up
            request_context = None
            gc.collect()
            assert called["yup"]
Exemple #4
0
def create_app_from_workspace(workspace: Workspace,
                              instance: DagsterInstance,
                              path_prefix: str = ""):
    check.inst_param(workspace, "workspace", Workspace)
    check.inst_param(instance, "instance", DagsterInstance)
    check.str_param(path_prefix, "path_prefix")

    if path_prefix:
        if not path_prefix.startswith("/"):
            raise Exception(
                f'The path prefix should begin with a leading "/": got {path_prefix}'
            )
        if path_prefix.endswith("/"):
            raise Exception(
                f'The path prefix should not include a trailing "/": got {path_prefix}'
            )

    warn_if_compute_logs_disabled()

    print("Loading repository...")  # pylint: disable=print-call

    context = WorkspaceProcessContext(instance=instance,
                                      workspace=workspace,
                                      version=__version__)

    log_workspace_stats(instance, context)

    schema = create_schema()

    return instantiate_app_with_views(context, schema, path_prefix)
Exemple #5
0
def test_reload_on_request_context_2():
    # This is the similar to the test `test_reload_on_request_context`,
    # but calls reload from the request_context instead of on the process_context

    with instance_for_test() as instance:
        with define_out_of_process_workspace(__file__,
                                             "get_repo") as workspace:
            # Create a process context
            process_context = WorkspaceProcessContext(workspace=workspace,
                                                      instance=instance)
            assert len(process_context.repository_locations) == 1

            # Save the repository name
            repository_location = process_context.repository_locations[0]
            repo = list(repository_location.get_repositories().values())[0]
            repo_name = repo.name

            # Create a request context from the process context
            request_context = process_context.create_request_context()

            # Reload the location from the request context
            new_request_context = request_context.reload_repository_location(
                repository_location.name)

            # Save the repository name from the:
            #   - Old request context
            #   - New request context
            #   - Process context
            repository_location = process_context.repository_locations[0]
            repo = list(repository_location.get_repositories().values())[0]
            new_repo_name_process_context = repo.name

            repository_location = new_request_context.repository_locations[0]
            repo = list(repository_location.get_repositories().values())[0]
            new_request_context_repo_name = repo.name

            repository_location = request_context.repository_locations[0]
            repo = list(repository_location.get_repositories().values())[0]
            request_context_repo_name = repo.name

            assert repo_name == request_context_repo_name
            assert request_context_repo_name != new_request_context_repo_name
            assert new_repo_name_process_context == new_request_context_repo_name
Exemple #6
0
def test_reload_on_process_context():
    with instance_for_test() as instance:
        with define_out_of_process_workspace(__file__, "get_repo") as workspace:
            # Create a process context
            process_context = WorkspaceProcessContext(workspace=workspace, instance=instance)
            assert len(process_context.repository_locations) == 1

            # Save the repository name
            repository_location = process_context.repository_locations[0]
            repo = list(repository_location.get_repositories().values())[0]
            repo_name = repo.name

            # Reload the location and save the new repository name
            process_context.reload_repository_location(repository_location.name)
            repository_location = process_context.repository_locations[0]
            repo = list(repository_location.get_repositories().values())[0]
            new_repo_name = repo.name

            # Check that the repository has changed
            assert repo_name != new_repo_name
Exemple #7
0
def test_handle_cleaup_by_gc_without_request_context():

    called = {"yup": False}

    def call_me():
        called["yup"] = True

    with instance_for_test() as instance:
        with define_out_of_process_workspace(__file__, "get_repo") as workspace:
            # Create a process context
            process_context = WorkspaceProcessContext(workspace=workspace, instance=instance)
            assert len(process_context.repository_locations) == 1
            process_context.repository_locations[  # pylint: disable=protected-access
                0
            ]._handle.cleanup = call_me

            # Reload the location from the request context
            assert not called["yup"]
            process_context.reload_repository_location("test_location")

            # There are no more references to the location, so it should be GC'd
            gc.collect()
            assert called["yup"]
Exemple #8
0
def test_log_workspace_stats(caplog):
    with instance_for_test() as instance:
        with load_workspace_from_yaml_paths([
                file_relative_path(__file__,
                                   "./multi_env_telemetry_workspace.yaml")
        ]) as workspace:
            context = WorkspaceProcessContext(instance=instance,
                                              workspace=workspace)
            log_workspace_stats(instance, context)

            for record in caplog.records:
                message = json.loads(record.getMessage())
                assert message.get("action") == UPDATE_REPO_STATS
                assert set(message.keys()) == EXPECTED_KEYS

            assert len(caplog.records) == 2
Exemple #9
0
def execute_query(workspace,
                  query,
                  variables=None,
                  use_sync_executor=False,
                  instance=None):
    check.inst_param(workspace, "workspace", Workspace)
    check.str_param(query, "query")
    check.opt_dict_param(variables, "variables")
    instance = (check.inst_param(instance, "instance", DagsterInstance)
                if instance else DagsterInstance.get())
    check.bool_param(use_sync_executor, "use_sync_executor")

    query = query.strip("'\" \n\t")

    context = WorkspaceProcessContext(
        workspace=workspace, instance=instance,
        version=__version__).create_request_context()

    executor = SyncExecutor() if use_sync_executor else GeventExecutor()

    result = graphql(
        request_string=query,
        schema=create_schema(),
        context_value=context,
        variable_values=variables,
        executor=executor,
    )

    result_dict = result.to_dict()

    # Here we detect if this is in fact an error response
    # If so, we iterate over the result_dict and the original result
    # which contains a GraphQLError. If that GraphQL error contains
    # an original_error property (which is the exception the resolver
    # has thrown, typically) we serialize the stack trace of that exception
    # in the 'stack_trace' property of each error to ease debugging

    if "errors" in result_dict:
        check.invariant(len(result_dict["errors"]) == len(result.errors))
        for python_error, error_dict in zip(result.errors,
                                            result_dict["errors"]):
            if hasattr(python_error,
                       "original_error") and python_error.original_error:
                error_dict["stack_trace"] = get_stack_trace_array(
                    python_error.original_error)

    return result_dict