コード例 #1
0
async def get_or_create_mapped_task_run_children(
        flow_run_id: str, task_id: str, max_map_index: int) -> List[str]:
    """
    Creates and/or retrieves mapped child task runs for a given flow run and task.

    Args:
        - flow_run_id (str): the flow run associated with the parent task run
        - task_id (str): the task ID to create and/or retrieve
        - max_map_index (int,): the number of mapped children e.g., a value of 2 yields 3 mapped children
    """
    # grab task info
    task = await models.Task.where(id=task_id
                                   ).first({"cache_key", "tenant_id"})
    # generate task runs to upsert
    task_runs = [
        models.TaskRun(
            tenant_id=task.tenant_id,
            flow_run_id=flow_run_id,
            task_id=task_id,
            map_index=i,
            cache_key=task.cache_key,
        ) for i in range(max_map_index + 1)
    ]
    # upsert the mapped children
    task_runs = (await models.TaskRun().insert_many(
        objects=task_runs,
        on_conflict=dict(
            constraint="task_run_unique_identifier_key",
            update_columns=["cache_key"],
        ),
        selection_set={"returning": {"id", "map_index"}},
    ))["returning"]
    task_runs.sort(key=lambda task_run: task_run.map_index)
    # get task runs without states
    stateless_runs = await models.TaskRun.where({
        "flow_run_id": {
            "_eq": flow_run_id
        },
        "task_id": {
            "_eq": task_id
        },
        # this syntax indicates "where there are no states"
        "_not": {
            "states": {}
        },
    }).get({"id", "map_index", "version"})
    # create and insert states for stateless task runs
    task_run_states = [
        models.TaskRunState(
            tenant_id=task.tenant_id,
            task_run_id=task_run.id,
            **models.TaskRunState.fields_from_state(
                Pending(message="Task run created")),
        ) for task_run in stateless_runs
    ]
    await models.TaskRunState().insert_many(task_run_states)

    # return the task run ids
    return [task_run.id for task_run in task_runs]
コード例 #2
0
    async def test_get_or_create_mapped_children_handles_partial_children(
        self, flow_id, flow_run_id
    ):
        # get a task from the flow
        task = await models.Task.where({"flow_id": {"_eq": flow_id}}).first(
            {"id", "cache_key"}
        )

        # create a few mapped children
        await models.TaskRun(
            flow_run_id=flow_run_id,
            task_id=task.id,
            map_index=3,
            cache_key=task.cache_key,
        ).insert()
        stateful_child = await models.TaskRun(
            flow_run_id=flow_run_id,
            task_id=task.id,
            map_index=6,
            cache_key=task.cache_key,
            states=[
                models.TaskRunState(
                    **models.TaskRunState.fields_from_state(
                        Pending(message="Task run created")
                    ),
                )
            ],
        ).insert()

        # retrieve mapped children
        mapped_children = await api.runs.get_or_create_mapped_task_run_children(
            flow_run_id=flow_run_id, task_id=task.id, max_map_index=10
        )
        assert len(mapped_children) == 11
        map_indices = []
        # confirm each of the mapped children has a state and is ordered properly
        for child in mapped_children:
            task_run = await models.TaskRun.where(id=child).first(
                {
                    "map_index": True,
                    with_args(
                        "states",
                        {"order_by": {"version": EnumValue("desc")}, "limit": 1},
                    ): {"id"},
                }
            )
            map_indices.append(task_run.map_index)
            assert task_run.states[0] is not None
        assert map_indices == sorted(map_indices)

        # confirm the one child created with a state only has the one state
        child_states = await models.TaskRunState.where(
            {"task_run_id": {"_eq": stateful_child}}
        ).get()
        assert len(child_states) == 1
コード例 #3
0
async def get_or_create_task_run(flow_run_id: str,
                                 task_id: str,
                                 map_index: int = None) -> str:
    """
    Since some task runs are created dynamically (when tasks are mapped, for example)
    we don't know if a task run exists the first time we query it. This function will take
    key information about a task run and create it if it doesn't already exist, returning its id.
    """

    if map_index is None:
        map_index = -1

    # try to load an existing task run
    task_run = await models.TaskRun.where({
        "flow_run_id": {
            "_eq": flow_run_id
        },
        "task_id": {
            "_eq": task_id
        },
        "map_index": {
            "_eq": map_index
        },
    }).first({"id"})

    if task_run:
        return task_run.id

    try:
        # load the tenant ID and cache_key
        task = await models.Task.where(id=task_id
                                       ).first({"cache_key", "tenant_id"})
        # create the task run
        return await models.TaskRun(
            tenant_id=task.tenant_id,
            flow_run_id=flow_run_id,
            task_id=task_id,
            map_index=map_index,
            cache_key=task.cache_key,
            states=[
                models.TaskRunState(
                    tenant_id=task.tenant_id,
                    state="Pending",
                    timestamp=pendulum.now(),
                    message="Task run created",
                    serialized_state=Pending(
                        message="Task run created").serialize(),
                )
            ],
        ).insert()

    except Exception:
        raise ValueError("Invalid ID")
コード例 #4
0
ファイル: test_runs.py プロジェクト: vitasiku/server
    async def test_get_or_create_mapped_task_run_children_with_partial_children(
        self, run_query, flow_run_id, flow_id
    ):
        task = await models.Task.where({"flow_id": {"_eq": flow_id}}).first({"id"})
        # create a couple of children
        preexisting_run_1 = await models.TaskRun(
            flow_run_id=flow_run_id,
            task_id=task.id,
            map_index=3,
            cache_key=task.cache_key,
        ).insert()
        preexisting_run_2 = await models.TaskRun(
            flow_run_id=flow_run_id,
            task_id=task.id,
            map_index=6,
            cache_key=task.cache_key,
            states=[
                models.TaskRunState(
                    **models.TaskRunState.fields_from_state(
                        Pending(message="Task run created")
                    ),
                )
            ],
        ).insert()
        # call the route
        result = await run_query(
            query=self.mutation,
            variables=dict(
                input=dict(flow_run_id=flow_run_id, task_id=task.id, max_map_index=10)
            ),
        )
        mapped_children = result.data.get_or_create_mapped_task_run_children.ids
        # should have 11 children, indices 0-10
        assert len(mapped_children) == 11

        # confirm the preexisting task runs were included in the results
        assert preexisting_run_1 in mapped_children
        assert preexisting_run_2 in mapped_children

        # confirm the results are ordered
        map_indices = []
        for child in mapped_children:
            map_indices.append(
                (await models.TaskRun.where(id=child).first({"map_index"})).map_index
            )
        assert map_indices == sorted(map_indices)
コード例 #5
0
ファイル: runs.py プロジェクト: kad-schoom/server
async def get_or_create_task_run_info(flow_run_id: str,
                                      task_id: str,
                                      map_index: int = None) -> dict:
    """
    Given a flow_run_id, task_id, and map_index, return details about the corresponding task run.
    If the task run doesn't exist, it will be created.

    Returns:
        - dict: a dict of details about the task run, including its id, version, and state.
    """

    if map_index is None:
        map_index = -1

    task_run = await models.TaskRun.where({
        "flow_run_id": {
            "_eq": flow_run_id
        },
        "task_id": {
            "_eq": task_id
        },
        "map_index": {
            "_eq": map_index
        },
    }).first({"id", "version", "state", "serialized_state"})

    if task_run:
        return dict(
            id=task_run.id,
            version=task_run.version,
            state=task_run.state,
            serialized_state=task_run.serialized_state,
        )

    # if it isn't found, add it to the DB
    task = await models.Task.where(id=task_id
                                   ).first({"cache_key", "tenant_id"})
    if not task:
        raise ValueError("Invalid task ID")

    db_task_run = models.TaskRun(
        tenant_id=task.tenant_id,
        flow_run_id=flow_run_id,
        task_id=task_id,
        map_index=map_index,
        cache_key=task.cache_key,
        version=0,
    )

    db_task_run_state = models.TaskRunState(
        tenant_id=task.tenant_id,
        state="Pending",
        timestamp=pendulum.now(),
        message="Task run created",
        serialized_state=Pending(message="Task run created").serialize(),
    )

    db_task_run.states = [db_task_run_state]
    run = await db_task_run.insert(
        on_conflict=dict(
            constraint="task_run_unique_identifier_key",
            update_columns=["cache_key"],
        ),
        selection_set={"returning": {"id"}},
    )

    return dict(
        id=run.returning.id,
        version=db_task_run.version,
        state="Pending",
        serialized_state=db_task_run_state.serialized_state,
    )
コード例 #6
0
ファイル: states.py プロジェクト: vitasiku/server
async def set_task_run_state(
        task_run_id: str,
        state: State,
        version: int = None,
        flow_run_version: int = None) -> models.TaskRunState:
    """
    Updates a task run state.

    Args:
        - task_run_id (str): the task run id to update
        - state (State): the new state
        - version (int): a version to enforce version-locking
        - flow_run_version (int): a flow run version to enforce version-lockgin

    Returns:
        - models.TaskRunState
    """

    if task_run_id is None:
        raise ValueError(f"Invalid task run ID.")

    where = {
        "id": {
            "_eq": task_run_id
        },
        "_or": [
            {
                # EITHER version locking is enabled and the versions match
                "version": {
                    "_eq": version
                },
                "flow_run": {
                    "version": {
                        "_eq": flow_run_version
                    },
                    "flow": {
                        "flow_group": {
                            "settings": {
                                "_contains": {
                                    "version_locking_enabled": True
                                }
                            }
                        }
                    },
                },
            },
            # OR version locking is not enabled
            {
                "flow_run": {
                    "flow": {
                        "flow_group": {
                            "_not": {
                                "settings": {
                                    "_contains": {
                                        "version_locking_enabled": True
                                    }
                                }
                            }
                        }
                    }
                }
            },
        ],
    }

    task_run = await models.TaskRun.where(where).first({
        "id": True,
        "tenant_id": True,
        "version": True,
        "state": True,
        "serialized_state": True,
        "flow_run": {
            "id": True,
            "state": True
        },
    })

    if not task_run:
        raise ValueError(f"State update failed for task run ID {task_run_id}")

    # ------------------------------------------------------
    # if the state is running, ensure the flow run is also running
    # ------------------------------------------------------
    if state.is_running() and task_run.flow_run.state != "Running":
        raise ValueError(
            f"State update failed for task run ID {task_run_id}: provided "
            f"a running state but associated flow run {task_run.flow_run.id} is not "
            "in a running state.")

    # ------------------------------------------------------
    # if we have cached inputs on the old state, we need to carry them forward
    # ------------------------------------------------------
    if not state.cached_inputs and task_run.serialized_state.get(
            "cached_inputs", None):
        # load up the old state's cached inputs and apply them to the new state
        serialized_state = state_schema.load(task_run.serialized_state)
        state.cached_inputs = serialized_state.cached_inputs

    # --------------------------------------------------------
    # prepare the new state for the database
    # --------------------------------------------------------

    task_run_state = models.TaskRunState(
        id=str(uuid.uuid4()),
        tenant_id=task_run.tenant_id,
        task_run_id=task_run.id,
        version=(task_run.version or 0) + 1,
        timestamp=pendulum.now("UTC"),
        message=state.message,
        result=state.result,
        start_time=getattr(state, "start_time", None),
        state=type(state).__name__,
        serialized_state=state.serialize(),
    )

    await task_run_state.insert()

    # --------------------------------------------------------
    # apply downstream updates
    # --------------------------------------------------------

    # FOR RUNNING STATES:
    #   - update the task run heartbeat
    if state.is_running():
        await api.runs.update_task_run_heartbeat(task_run_id=task_run_id)

    return task_run_state