async def test_version_order_determines_timestamp(self, task_run_id):
        details = dict(task_run_id=task_run_id, serialized_state={})

        st = pendulum.now("UTC")
        et = pendulum.now("UTC").add(days=1)
        not_et = pendulum.now("UTC").add(hours=1)

        await models.TaskRunState.insert_many([
            models.TaskRunState(**details, version=0, state="Pending"),
            models.TaskRunState(**details,
                                version=1,
                                state="Running",
                                timestamp=st),
            models.TaskRunState(**details, version=2, state="Failed"),
            models.TaskRunState(**details, version=3, state="Retrying"),
            models.TaskRunState(**details, version=4, state="Running"),
            models.TaskRunState(**details, version=6, state="Failed"),
            models.TaskRunState(**details, version=7, state="Retrying"),
            models.TaskRunState(**details, version=8, state="Running"),
            models.TaskRunState(**details,
                                version=5,
                                state="Success",
                                timestamp=not_et),
            models.TaskRunState(**details,
                                version=9,
                                state="Success",
                                timestamp=et),
        ])

        run = await models.TaskRun.where(id=task_run_id).first(
            {"start_time", "end_time", "duration", "run_count"})
        assert run.start_time == st
        assert run.end_time == et
        assert run.duration == (et - st).as_timedelta()
        assert run.run_count == 3
Beispiel #2
0
    async def test_no_end_if_running_is_last_state(self, task_run_id):
        details = dict(
            task_run_id=task_run_id,
            serialized_state={},
        )

        st = pendulum.now("UTC")

        await models.TaskRunState.insert_many([
            models.TaskRunState(
                **details,
                version=0,
                state="Pending",
            ),
            models.TaskRunState(**details,
                                version=1,
                                state="Running",
                                timestamp=st),
        ])

        run = await models.TaskRun.where(id=task_run_id).first(
            {"start_time", "end_time", "duration"})
        assert run.start_time == st
        assert run.end_time is None
        assert run.duration is None
    async def test_start_and_end_from_running_state(self, task_run_id):
        details = dict(task_run_id=task_run_id, serialized_state={})

        st = pendulum.now("UTC")
        et = pendulum.now("UTC").add(days=1)

        await models.TaskRunState.insert_many([
            models.TaskRunState(**details, version=0, state="Pending"),
            models.TaskRunState(**details,
                                version=1,
                                state="Running",
                                timestamp=st),
            models.TaskRunState(**details,
                                version=2,
                                state="Failed",
                                timestamp=et),
            models.TaskRunState(**details, version=3, state="Retrying"),
            models.TaskRunState(**details, version=4, state="Retrying"),
        ])

        run = await models.TaskRun.where(id=task_run_id).first(
            {"start_time", "end_time", "duration", "run_count"})
        assert run.start_time == st
        assert run.end_time == et
        assert run.duration == (et - st).as_timedelta()
        assert run.run_count == 1
    async def test_inserting_running_state_has_effect(self, task_run_id):
        details = dict(task_run_id=task_run_id, serialized_state={})
        await models.TaskRunState.insert_many([
            models.TaskRunState(**details, version=0, state="Pending"),
            models.TaskRunState(**details, version=1, state="Running"),
            models.TaskRunState(**details, version=2, state="Failed"),
            models.TaskRunState(**details, version=3, state="Retrying"),
            models.TaskRunState(**details, version=4, state="Retrying"),
        ])

        run = await models.TaskRun.where(id=task_run_id).first(
            {"start_time", "end_time", "duration", "run_count"})
        assert run.start_time is not None
        assert run.end_time is not None
        assert run.duration is not None
        assert run.run_count == 1
Beispiel #5
0
async def set_task_run_state(task_run_id: str, state: State, force=False) -> None:
    """
    Updates a task run state.

    Args:
        - task_run_id (str): the task run id to update
        - state (State): the new state
        - false (bool): if True, avoids pipeline checks
    """

    if task_run_id is None:
        raise ValueError(f"Invalid task run ID.")

    task_run = await models.TaskRun.where({"id": {"_eq": task_run_id},}).first(
        {
            "id": True,
            "version": True,
            "state": True,
            "serialized_state": True,
            "flow_run": {"id": True, "state": True},
        }
    )

    if not task_run:
        raise ValueError(f"Invalid task run ID: {task_run_id}.")

    # ------------------------------------------------------
    # if the state is running, ensure the flow run is also running
    # ------------------------------------------------------
    if not force and state.is_running() and task_run.flow_run.state != "Running":
        raise ValueError(
            f"State update failed for task run ID {task_run_id}: provided "
            f"a running state but associated flow run {task_run.flow_run.id} is not "
            "in a running state."
        )

    # ------------------------------------------------------
    # if we have cached inputs on the old state, we need to carry them forward
    # ------------------------------------------------------
    if not state.cached_inputs and task_run.serialized_state.get("cached_inputs", None):
        # load up the old state's cached inputs and apply them to the new state
        serialized_state = state_schema.load(task_run.serialized_state)
        state.cached_inputs = serialized_state.cached_inputs

    # --------------------------------------------------------
    # prepare the new state for the database
    # --------------------------------------------------------

    task_run_state = models.TaskRunState(
        task_run_id=task_run.id,
        version=(task_run.version or 0) + 1,
        timestamp=pendulum.now("UTC"),
        message=state.message,
        result=state.result,
        start_time=getattr(state, "start_time", None),
        state=type(state).__name__,
        serialized_state=state.serialize(),
    )

    await task_run_state.insert()
Beispiel #6
0
async def get_or_create_mapped_task_run_children(
        flow_run_id: str, task_id: str, max_map_index: int) -> List[str]:
    """
    Creates and/or retrieves mapped child task runs for a given flow run and task.

    Args:
        - flow_run_id (str): the flow run associated with the parent task run
        - task_id (str): the task ID to create and/or retrieve
        - max_map_index (int,): the number of mapped children e.g., a value of 2 yields 3 mapped children
    """
    # grab task info
    task = await models.Task.where(id=task_id
                                   ).first({"cache_key", "tenant_id"})
    # generate task runs to upsert
    task_runs = [
        models.TaskRun(
            tenant_id=task.tenant_id,
            flow_run_id=flow_run_id,
            task_id=task_id,
            map_index=i,
            cache_key=task.cache_key,
        ) for i in range(max_map_index + 1)
    ]
    # upsert the mapped children
    task_runs = (await models.TaskRun().insert_many(
        objects=task_runs,
        on_conflict=dict(
            constraint="task_run_unique_identifier_key",
            update_columns=["cache_key"],
        ),
        selection_set={"returning": {"id", "map_index"}},
    ))["returning"]
    task_runs.sort(key=lambda task_run: task_run.map_index)
    # get task runs without states
    stateless_runs = await models.TaskRun.where({
        "flow_run_id": {
            "_eq": flow_run_id
        },
        "task_id": {
            "_eq": task_id
        },
        # this syntax indicates "where there are no states"
        "_not": {
            "states": {}
        },
    }).get({"id", "map_index", "version"})
    # create and insert states for stateless task runs
    task_run_states = [
        models.TaskRunState(
            tenant_id=task.tenant_id,
            task_run_id=task_run.id,
            **models.TaskRunState.fields_from_state(
                Pending(message="Task run created")),
        ) for task_run in stateless_runs
    ]
    await models.TaskRunState().insert_many(task_run_states)

    # return the task run ids
    return [task_run.id for task_run in task_runs]
Beispiel #7
0
    async def test_get_or_create_mapped_children_handles_partial_children(
            self, flow_id, flow_run_id):
        # get a task from the flow
        task = await models.Task.where({
            "flow_id": {
                "_eq": flow_id
            }
        }).first({"id", "cache_key"})

        # create a few mapped children
        await models.TaskRun(
            flow_run_id=flow_run_id,
            task_id=task.id,
            map_index=3,
            cache_key=task.cache_key,
        ).insert()
        stateful_child = await models.TaskRun(
            flow_run_id=flow_run_id,
            task_id=task.id,
            map_index=6,
            cache_key=task.cache_key,
            states=[
                models.TaskRunState(
                    **models.TaskRunState.fields_from_state(
                        Pending(message="Task run created")), )
            ],
        ).insert()

        # retrieve mapped children
        mapped_children = await api.runs.get_or_create_mapped_task_run_children(
            flow_run_id=flow_run_id, task_id=task.id, max_map_index=10)
        assert len(mapped_children) == 11
        map_indices = []
        # confirm each of the mapped children has a state and is ordered properly
        for child in mapped_children:
            task_run = await models.TaskRun.where(id=child).first({
                "map_index": True,
                with_args(
                    "states",
                    {
                        "order_by": {
                            "version": EnumValue("desc")
                        },
                        "limit": 1
                    },
                ): {"id"},
            })
            map_indices.append(task_run.map_index)
            assert task_run.states[0] is not None
        assert map_indices == sorted(map_indices)

        # confirm the one child created with a state only has the one state
        child_states = await models.TaskRunState.where({
            "task_run_id": {
                "_eq": stateful_child
            }
        }).get()
        assert len(child_states) == 1
Beispiel #8
0
async def get_or_create_task_run(flow_run_id: str,
                                 task_id: str,
                                 map_index: int = None) -> str:
    """
    Since some task runs are created dynamically (when tasks are mapped, for example)
    we don't know if a task run exists the first time we query it. This function will take
    key information about a task run and create it if it doesn't already exist, returning its id.
    """

    if map_index is None:
        map_index = -1

    # try to load an existing task run
    task_run = await models.TaskRun.where({
        "flow_run_id": {
            "_eq": flow_run_id
        },
        "task_id": {
            "_eq": task_id
        },
        "map_index": {
            "_eq": map_index
        },
    }).first({"id"})

    if task_run:
        return task_run.id

    try:
        # load the tenant ID and cache_key
        task = await models.Task.where(id=task_id
                                       ).first({"cache_key", "tenant_id"})
        # create the task run
        return await models.TaskRun(
            tenant_id=task.tenant_id,
            flow_run_id=flow_run_id,
            task_id=task_id,
            map_index=map_index,
            cache_key=task.cache_key,
            states=[
                models.TaskRunState(
                    tenant_id=task.tenant_id,
                    **models.TaskRunState.fields_from_state(
                        Pending(message="Task run created")),
                )
            ],
        ).insert()

    except Exception:
        raise ValueError("Invalid ID")
Beispiel #9
0
    async def test_get_or_create_mapped_task_run_children_with_partial_children(
            self, run_query, flow_run_id, flow_id):
        task = await models.Task.where({
            "flow_id": {
                "_eq": flow_id
            }
        }).first({"id"})
        # create a couple of children
        preexisting_run_1 = await models.TaskRun(
            flow_run_id=flow_run_id,
            task_id=task.id,
            map_index=3,
            cache_key=task.cache_key,
        ).insert()
        preexisting_run_2 = await models.TaskRun(
            flow_run_id=flow_run_id,
            task_id=task.id,
            map_index=6,
            cache_key=task.cache_key,
            states=[
                models.TaskRunState(
                    **models.TaskRunState.fields_from_state(
                        Pending(message="Task run created")), )
            ],
        ).insert()
        # call the route
        result = await run_query(
            query=self.mutation,
            variables=dict(input=dict(
                flow_run_id=flow_run_id, task_id=task.id, max_map_index=10)),
        )
        mapped_children = result.data.get_or_create_mapped_task_run_children.ids
        # should have 11 children, indices 0-10
        assert len(mapped_children) == 11

        # confirm the preexisting task runs were included in the results
        assert preexisting_run_1 in mapped_children
        assert preexisting_run_2 in mapped_children

        # confirm the results are ordered
        map_indices = []
        for child in mapped_children:
            map_indices.append(
                (await models.TaskRun.where(id=child).first({"map_index"}
                                                            )).map_index)
        assert map_indices == sorted(map_indices)
Beispiel #10
0
async def set_task_run_state(task_run_id: str, state: State,) -> None:
    """
    Updates a task run state.

    Args:
        - task_run_id (str): the task run id to update
        - state (State): the new state
    """

    if task_run_id is None:
        raise ValueError(f"Invalid task run ID.")

    task_run = await models.TaskRun.where({"id": {"_eq": task_run_id},}).first(
        {"id": True, "version": True, "state": True, "serialized_state": True,}
    )

    if not task_run:
        raise ValueError(f"Invalid task run ID: {task_run_id}.")

    # ------------------------------------------------------
    # if we have cached inputs on the old state, we need to carry them forward
    # ------------------------------------------------------

    if not state.cached_inputs and task_run.serialized_state.get("cached_inputs", None):
        # load up the old state's cached inputs and apply them to the new state
        serialized_state = state_schema.load(task_run.serialized_state)
        state.cached_inputs = serialized_state.cached_inputs

    # --------------------------------------------------------
    # prepare the new state for the database
    # --------------------------------------------------------

    task_run_state = models.TaskRunState(
        task_run_id=task_run.id,
        version=(task_run.version or 0) + 1,
        timestamp=pendulum.now("UTC"),
        message=state.message,
        result=state.result,
        start_time=getattr(state, "start_time", None),
        state=type(state).__name__,
        serialized_state=state.serialize(),
    )

    await task_run_state.insert()
Beispiel #11
0
async def set_task_run_state(
        task_run_id: str,
        state: State,
        version: int = None,
        flow_run_version: int = None) -> models.TaskRunState:
    """
    Updates a task run state.

    Args:
        - task_run_id (str): the task run id to update
        - state (State): the new state
        - version (int): a version to enforce version-locking
        - flow_run_version (int): a flow run version to enforce version-lockgin

    Returns:
        - models.TaskRunState
    """

    if task_run_id is None:
        raise ValueError(f"Invalid task run ID.")

    where = {
        "id": {
            "_eq": task_run_id
        },
        "_or": [
            {
                # EITHER version locking is enabled and the versions match
                "version": {
                    "_eq": version
                },
                "flow_run": {
                    "version": {
                        "_eq": flow_run_version
                    },
                    "flow": {
                        "flow_group": {
                            "settings": {
                                "_contains": {
                                    "version_locking_enabled": True
                                }
                            }
                        }
                    },
                },
            },
            # OR version locking is not enabled
            {
                "flow_run": {
                    "flow": {
                        "flow_group": {
                            "_not": {
                                "settings": {
                                    "_contains": {
                                        "version_locking_enabled": True
                                    }
                                }
                            }
                        }
                    }
                }
            },
        ],
    }

    task_run = await models.TaskRun.where(where).first({
        "id": True,
        "tenant_id": True,
        "version": True,
        "state": True,
        "serialized_state": True,
        "flow_run": {
            "id": True,
            "state": True
        },
    })

    if not task_run:
        raise ValueError(f"State update failed for task run ID {task_run_id}")

    # ------------------------------------------------------
    # if the state is running, ensure the flow run is also running
    # ------------------------------------------------------
    if state.is_running() and task_run.flow_run.state != "Running":
        raise ValueError(
            f"State update failed for task run ID {task_run_id}: provided "
            f"a running state but associated flow run {task_run.flow_run.id} is not "
            "in a running state.")

    # ------------------------------------------------------
    # if we have cached inputs on the old state, we need to carry them forward
    # ------------------------------------------------------
    if not state.cached_inputs and task_run.serialized_state.get(
            "cached_inputs", None):
        # load up the old state's cached inputs and apply them to the new state
        serialized_state = state_schema.load(task_run.serialized_state)
        state.cached_inputs = serialized_state.cached_inputs

    # --------------------------------------------------------
    # prepare the new state for the database
    # --------------------------------------------------------

    task_run_state = models.TaskRunState(
        id=str(uuid.uuid4()),
        tenant_id=task_run.tenant_id,
        task_run_id=task_run.id,
        version=(task_run.version or 0) + 1,
        timestamp=pendulum.now("UTC"),
        message=state.message,
        result=state.result,
        start_time=getattr(state, "start_time", None),
        state=type(state).__name__,
        serialized_state=state.serialize(),
    )

    await task_run_state.insert()

    # --------------------------------------------------------
    # apply downstream updates
    # --------------------------------------------------------

    # FOR RUNNING STATES:
    #   - update the task run heartbeat
    if state.is_running():
        await api.runs.update_task_run_heartbeat(task_run_id=task_run_id)

    return task_run_state