Beispiel #1
0
    def start(self) -> Iterator[None]:
        """
        Context manager for initializing execution.

        Creates a `dask.distributed.Client` and yields it.
        """
        if sys.platform != "win32":
            # Fix for https://github.com/dask/distributed/issues/4168
            import multiprocessing.popen_spawn_posix  # noqa
        from distributed import Client, performance_report

        performance_report_context = (
            performance_report(self.performance_report_path)
            if self.performance_report_path
            else nullcontext()
        )

        try:
            if self.address is not None:
                self.logger.info(
                    "Connecting to an existing Dask cluster at %s", self.address
                )
                with Client(self.address, **self.client_kwargs) as client:

                    with performance_report_context:
                        self.client = client
                        try:
                            self._pre_start_yield()
                            yield
                        finally:
                            self._post_start_yield()
            else:
                assert callable(self.cluster_class)  # mypy
                assert isinstance(self.cluster_kwargs, dict)  # mypy
                self.logger.info(
                    "Creating a new Dask cluster with `%s.%s`...",
                    self.cluster_class.__module__,
                    self.cluster_class.__qualname__,
                )
                with self.cluster_class(**self.cluster_kwargs) as cluster:
                    if getattr(cluster, "dashboard_link", None):
                        self.logger.info(
                            "The Dask dashboard is available at %s",
                            cluster.dashboard_link,
                        )
                    if self.adapt_kwargs:
                        cluster.adapt(**self.adapt_kwargs)
                    with Client(cluster, **self.client_kwargs) as client:
                        with performance_report_context:
                            self.client = client
                            try:
                                self._pre_start_yield()
                                yield
                            finally:
                                self._post_start_yield()
        finally:
            self.client = None
Beispiel #2
0
def test_setup_api_connection_runs_test_query(test_query_succeeds, cloud_api):
    agent = Agent()

    # Ignore registration
    agent._register_agent = MagicMock()

    if test_query_succeeds:
        # Create a successful test query
        agent.client.graphql = MagicMock(return_value="Hello")

    with nullcontext() if test_query_succeeds else pytest.raises(Exception):
        agent._setup_api_connection()
    def get_task_run_state(self, state: State, inputs: Dict[str,
                                                            Result]) -> State:
        """
        Runs the task and traps any signals or errors it raises.
        Also checkpoints the result of a successful task, if `task.checkpoint` is `True`.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        Raises:
            - signals.PAUSE: if the task raises PAUSE
            - ENDRUN: if the task is not ready to run
        """
        task_name = prefect.context.get("task_full_name", self.task.name)

        if not state.is_running():
            self.logger.debug(
                f"Task {task_name!r}: Can't run task because it's not in a Running "
                "state; ending run.")

            raise ENDRUN(state)

        value = None
        raw_inputs = {k: r.value for k, r in inputs.items()}
        new_state = None
        try:
            self.logger.debug(
                f"Task {task_name!r}: Calling task.run() method...")

            # Create a stdout redirect if the task has log_stdout enabled
            log_context = (
                redirect_stdout(
                    prefect.utilities.logging.RedirectToLog(self.logger))
                if getattr(self.task, "log_stdout", False) else nullcontext()
            )  # type: AbstractContextManager

            with log_context:
                value = prefect.utilities.executors.run_task_with_timeout(
                    task=self.task,
                    args=(),
                    kwargs=raw_inputs,
                    logger=self.logger,
                )

        except TaskTimeoutSignal as exc:  # Convert timeouts to a `TimedOut` state
            if prefect.context.get("raise_on_exception"):
                raise exc
            state = TimedOut("Task timed out during execution.", result=exc)
            return state

        except signals.LOOP as exc:  # Convert loop signals to a `Looped` state
            new_state = exc.state
            assert isinstance(new_state, Looped)
            value = new_state.result
            new_state.message = exc.state.message or "Task is looping ({})".format(
                new_state.loop_count)

        except signals.SUCCESS as exc:
            # Success signals can be treated like a normal result
            new_state = exc.state
            assert isinstance(new_state, Success)
            value = new_state.result

        except Exception as exc:  # Handle exceptions in the task
            if prefect.context.get("raise_on_exception"):
                raise
            self.logger.error(
                f"Task {task_name!r}: Exception encountered during task execution!",
                exc_info=True,
            )
            state = Failed(f"Error during execution of task: {exc!r}",
                           result=exc)
            return state

        # checkpoint tasks if a result is present, except for when the user has opted out by
        # disabling checkpointing
        if (prefect.context.get("checkpointing") is True
                and self.task.checkpoint is not False and value is not None):
            try:
                formatting_kwargs = {
                    **prefect.context.get("parameters", {}).copy(),
                    **prefect.context,
                    **raw_inputs,
                }
                result = self.result.write(value, **formatting_kwargs)
            except ResultNotImplementedError:
                result = self.result.from_value(value=value)
        else:
            result = self.result.from_value(value=value)

        if new_state is not None:
            new_state.result = result
            return new_state

        state = Success(result=result, message="Task run succeeded.")
        return state
    def get_task_run_state(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Runs the task and traps any signals or errors it raises.
        Also checkpoints the result of a successful task, if `task.checkpoint` is `True`.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        Raises:
            - signals.PAUSE: if the task raises PAUSE
            - ENDRUN: if the task is not ready to run
        """
        if not state.is_running():
            self.logger.debug(
                "Task '{name}': Can't run task because it's not in a "
                "Running state; ending run.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )

            raise ENDRUN(state)

        value = None
        raw_inputs = {k: r.value for k, r in inputs.items()}
        new_state = None
        try:
            self.logger.debug(
                "Task '{name}': Calling task.run() method...".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )

            # Create a stdout redirect if the task has log_stdout enabled
            log_context = (
                redirect_stdout(prefect.utilities.logging.RedirectToLog(self.logger))
                if getattr(self.task, "log_stdout", False)
                else nullcontext()
            )  # type: AbstractContextManager

            with log_context:
                value = prefect.utilities.executors.run_task_with_timeout(
                    task=self.task,
                    args=(),
                    kwargs=raw_inputs,
                    logger=self.logger,
                )

        # inform user of timeout
        except TimeoutError as exc:
            if prefect.context.get("raise_on_exception"):
                raise exc
            state = TimedOut("Task timed out during execution.", result=exc)
            return state

        except signals.LOOP as exc:
            new_state = exc.state
            assert isinstance(new_state, Looped)
            value = new_state.result
            new_state.message = exc.state.message or "Task is looping ({})".format(
                new_state.loop_count
            )

        # checkpoint tasks if a result is present, except for when the user has opted out by
        # disabling checkpointing
        if (
            prefect.context.get("checkpointing") is True
            and self.task.checkpoint is not False
            and value is not None
        ):
            try:
                formatting_kwargs = {
                    **prefect.context.get("parameters", {}).copy(),
                    **prefect.context,
                    **raw_inputs,
                }
                result = self.result.write(value, **formatting_kwargs)