예제 #1
0
def test_pytask_collect_task_teardown(tmp_path, depends_on, produces, platform,
                                      expectation):
    session = DummyClass()
    session.config = {
        "stata": "stata",
        "stata_source_key": "source",
        "platform": platform,
    }

    task = DummyClass()
    task.depends_on = {
        i: FilePathNode.from_path(tmp_path / n)
        for i, n in enumerate(depends_on)
    }
    task.produces = {
        i: FilePathNode.from_path(tmp_path / n)
        for i, n in enumerate(produces)
    }
    task.function = task_dummy
    task.name = "task_dummy"
    task.path = Path()

    markers = [Mark("stata", (), {})]
    task.markers = markers
    task.function.pytaskmark = markers

    with expectation:
        pytask_collect_task_teardown(session, task)
예제 #2
0
def pytask_execute_task_process_report(session: Session,
                                       report: ExecutionReport) -> bool:
    """Process the execution report of a task.

    If a task failed, skip all subsequent tasks. Else, update the states of related
    nodes in the database.

    """
    task = report.task
    if report.outcome == TaskOutcome.SUCCESS:
        update_states_in_database(session.dag, task.name)
    else:
        for descending_task_name in descending_tasks(task.name, session.dag):
            descending_task = session.dag.nodes[descending_task_name]["task"]
            descending_task.markers.append(
                Mark(
                    "skip_ancestor_failed",
                    (),
                    {"reason": f"Previous task {task.name!r} failed."},
                ))

        session.n_tasks_failed += 1
        if session.n_tasks_failed >= session.config["max_failures"]:
            session.should_stop = True

        if report.exc_info and isinstance(report.exc_info[1], Exit):
            session.should_stop = True

    return True
예제 #3
0
def test_pytask_execute_task_setup_raise_error(stata, platform, expectation):
    """Make sure that the task setup raises errors."""
    # Act like r is installed since we do not test this.
    task = DummyClass()
    task.markers = [Mark("stata", (), {})]

    session = DummyClass()
    session.config = {"stata": stata, "platform": platform}

    with expectation:
        pytask_execute_task_setup(session, task)
예제 #4
0
    def wrapper(func: Callable[..., Any]) -> None:
        unwrapped = inspect.unwrap(func)
        path = Path(inspect.getfile(unwrapped)).absolute().resolve()
        parsed_kwargs = {} if kwargs is None else kwargs
        parsed_name = name if isinstance(name, str) else func.__name__

        if hasattr(unwrapped, "pytask_meta"):
            unwrapped.pytask_meta.name = parsed_name
            unwrapped.pytask_meta.kwargs = parsed_kwargs
            unwrapped.pytask_meta.markers.append(Mark("task", (), {}))
            unwrapped.pytask_meta.id_ = id
        else:
            unwrapped.pytask_meta = CollectionMetadata(
                name=parsed_name,
                kwargs=parsed_kwargs,
                markers=[Mark("task", (), {})],
                id_=id,
            )

        COLLECTED_TASKS[path].append(unwrapped)

        return unwrapped
예제 #5
0
def pytask_resolve_dependencies_select_execution_dag(dag: nx.DiGraph) -> None:
    """Select the tasks which need to be executed."""
    scheduler = TopologicalSorter.from_dag(dag)
    visited_nodes = []

    for task_name in scheduler.static_order():
        if task_name not in visited_nodes:
            have_changed = _have_task_or_neighbors_changed(task_name, dag)
            if have_changed:
                visited_nodes += list(task_and_descending_tasks(
                    task_name, dag))
            else:
                dag.nodes[task_name]["task"].markers.append(
                    Mark("skip_unchanged", (), {}))
예제 #6
0
def pytask_execute_task_process_report(session: Session,
                                       report: ExecutionReport) -> bool | None:
    """Process the execution reports for skipped tasks.

    This functions allows to turn skipped tasks to successful tasks.

    """
    task = report.task

    if report.exc_info:
        if isinstance(report.exc_info[1], SkippedUnchanged):
            report.outcome = TaskOutcome.SKIP_UNCHANGED

        elif isinstance(report.exc_info[1], Skipped):
            report.outcome = TaskOutcome.SKIP

            for descending_task_name in descending_tasks(
                    task.name, session.dag):
                descending_task = session.dag.nodes[descending_task_name][
                    "task"]
                descending_task.markers.append(
                    Mark(
                        "skip",
                        (),
                        {
                            "reason":
                            f"Previous task {task.name!r} was skipped."
                        },
                    ))

        elif isinstance(report.exc_info[1], SkippedAncestorFailed):
            report.outcome = TaskOutcome.SKIP_PREVIOUS_FAILED
            report.exc_info = remove_traceback_from_exc_info(report.exc_info)

    if report.exc_info and isinstance(
            report.exc_info[1],
        (Skipped, SkippedUnchanged, SkippedAncestorFailed)):
        return True
    else:
        return None
예제 #7
0
        (None, []),
        ("some-arg", ["some-arg"]),
        (["arg1", "arg2"], ["arg1", "arg2"]),
    ],
)
def test_stata(stata_args, expected):
    options = stata(stata_args)
    assert options == expected


@pytest.mark.unit
@pytest.mark.parametrize(
    "marks, expected",
    [
        (
            [Mark("stata", ("a", ), {}),
             Mark("stata", ("b", ), {})],
            Mark("stata", ("a", "b"), {}),
        ),
        (
            [Mark("stata", ("a", ), {}),
             Mark("stata", (), {"stata": "b"})],
            Mark("stata", ("a", ), {"stata": "b"}),
        ),
    ],
)
def test_merge_all_markers(marks, expected):
    task = DummyClass()
    task.markers = marks
    out = _merge_all_markers(task)
    assert out == expected