async def get_or_create_mapped_task_run_children( flow_run_id: str, task_id: str, max_map_index: int) -> List[str]: """ Creates and/or retrieves mapped child task runs for a given flow run and task. Args: - flow_run_id (str): the flow run associated with the parent task run - task_id (str): the task ID to create and/or retrieve - max_map_index (int,): the number of mapped children e.g., a value of 2 yields 3 mapped children """ # grab task info task = await models.Task.where(id=task_id ).first({"cache_key", "tenant_id"}) # generate task runs to upsert task_runs = [ models.TaskRun( tenant_id=task.tenant_id, flow_run_id=flow_run_id, task_id=task_id, map_index=i, cache_key=task.cache_key, ) for i in range(max_map_index + 1) ] # upsert the mapped children task_runs = (await models.TaskRun().insert_many( objects=task_runs, on_conflict=dict( constraint="task_run_unique_identifier_key", update_columns=["cache_key"], ), selection_set={"returning": {"id", "map_index"}}, ))["returning"] task_runs.sort(key=lambda task_run: task_run.map_index) # get task runs without states stateless_runs = await models.TaskRun.where({ "flow_run_id": { "_eq": flow_run_id }, "task_id": { "_eq": task_id }, # this syntax indicates "where there are no states" "_not": { "states": {} }, }).get({"id", "map_index", "version"}) # create and insert states for stateless task runs task_run_states = [ models.TaskRunState( tenant_id=task.tenant_id, task_run_id=task_run.id, **models.TaskRunState.fields_from_state( Pending(message="Task run created")), ) for task_run in stateless_runs ] await models.TaskRunState().insert_many(task_run_states) # return the task run ids return [task_run.id for task_run in task_runs]
async def test_get_or_create_mapped_children_handles_partial_children( self, flow_id, flow_run_id ): # get a task from the flow task = await models.Task.where({"flow_id": {"_eq": flow_id}}).first( {"id", "cache_key"} ) # create a few mapped children await models.TaskRun( flow_run_id=flow_run_id, task_id=task.id, map_index=3, cache_key=task.cache_key, ).insert() stateful_child = await models.TaskRun( flow_run_id=flow_run_id, task_id=task.id, map_index=6, cache_key=task.cache_key, states=[ models.TaskRunState( **models.TaskRunState.fields_from_state( Pending(message="Task run created") ), ) ], ).insert() # retrieve mapped children mapped_children = await api.runs.get_or_create_mapped_task_run_children( flow_run_id=flow_run_id, task_id=task.id, max_map_index=10 ) assert len(mapped_children) == 11 map_indices = [] # confirm each of the mapped children has a state and is ordered properly for child in mapped_children: task_run = await models.TaskRun.where(id=child).first( { "map_index": True, with_args( "states", {"order_by": {"version": EnumValue("desc")}, "limit": 1}, ): {"id"}, } ) map_indices.append(task_run.map_index) assert task_run.states[0] is not None assert map_indices == sorted(map_indices) # confirm the one child created with a state only has the one state child_states = await models.TaskRunState.where( {"task_run_id": {"_eq": stateful_child}} ).get() assert len(child_states) == 1
async def get_or_create_task_run(flow_run_id: str, task_id: str, map_index: int = None) -> str: """ Since some task runs are created dynamically (when tasks are mapped, for example) we don't know if a task run exists the first time we query it. This function will take key information about a task run and create it if it doesn't already exist, returning its id. """ if map_index is None: map_index = -1 # try to load an existing task run task_run = await models.TaskRun.where({ "flow_run_id": { "_eq": flow_run_id }, "task_id": { "_eq": task_id }, "map_index": { "_eq": map_index }, }).first({"id"}) if task_run: return task_run.id try: # load the tenant ID and cache_key task = await models.Task.where(id=task_id ).first({"cache_key", "tenant_id"}) # create the task run return await models.TaskRun( tenant_id=task.tenant_id, flow_run_id=flow_run_id, task_id=task_id, map_index=map_index, cache_key=task.cache_key, states=[ models.TaskRunState( tenant_id=task.tenant_id, state="Pending", timestamp=pendulum.now(), message="Task run created", serialized_state=Pending( message="Task run created").serialize(), ) ], ).insert() except Exception: raise ValueError("Invalid ID")
async def test_get_or_create_mapped_task_run_children_with_partial_children( self, run_query, flow_run_id, flow_id ): task = await models.Task.where({"flow_id": {"_eq": flow_id}}).first({"id"}) # create a couple of children preexisting_run_1 = await models.TaskRun( flow_run_id=flow_run_id, task_id=task.id, map_index=3, cache_key=task.cache_key, ).insert() preexisting_run_2 = await models.TaskRun( flow_run_id=flow_run_id, task_id=task.id, map_index=6, cache_key=task.cache_key, states=[ models.TaskRunState( **models.TaskRunState.fields_from_state( Pending(message="Task run created") ), ) ], ).insert() # call the route result = await run_query( query=self.mutation, variables=dict( input=dict(flow_run_id=flow_run_id, task_id=task.id, max_map_index=10) ), ) mapped_children = result.data.get_or_create_mapped_task_run_children.ids # should have 11 children, indices 0-10 assert len(mapped_children) == 11 # confirm the preexisting task runs were included in the results assert preexisting_run_1 in mapped_children assert preexisting_run_2 in mapped_children # confirm the results are ordered map_indices = [] for child in mapped_children: map_indices.append( (await models.TaskRun.where(id=child).first({"map_index"})).map_index ) assert map_indices == sorted(map_indices)
async def get_or_create_task_run_info(flow_run_id: str, task_id: str, map_index: int = None) -> dict: """ Given a flow_run_id, task_id, and map_index, return details about the corresponding task run. If the task run doesn't exist, it will be created. Returns: - dict: a dict of details about the task run, including its id, version, and state. """ if map_index is None: map_index = -1 task_run = await models.TaskRun.where({ "flow_run_id": { "_eq": flow_run_id }, "task_id": { "_eq": task_id }, "map_index": { "_eq": map_index }, }).first({"id", "version", "state", "serialized_state"}) if task_run: return dict( id=task_run.id, version=task_run.version, state=task_run.state, serialized_state=task_run.serialized_state, ) # if it isn't found, add it to the DB task = await models.Task.where(id=task_id ).first({"cache_key", "tenant_id"}) if not task: raise ValueError("Invalid task ID") db_task_run = models.TaskRun( tenant_id=task.tenant_id, flow_run_id=flow_run_id, task_id=task_id, map_index=map_index, cache_key=task.cache_key, version=0, ) db_task_run_state = models.TaskRunState( tenant_id=task.tenant_id, state="Pending", timestamp=pendulum.now(), message="Task run created", serialized_state=Pending(message="Task run created").serialize(), ) db_task_run.states = [db_task_run_state] run = await db_task_run.insert( on_conflict=dict( constraint="task_run_unique_identifier_key", update_columns=["cache_key"], ), selection_set={"returning": {"id"}}, ) return dict( id=run.returning.id, version=db_task_run.version, state="Pending", serialized_state=db_task_run_state.serialized_state, )
async def set_task_run_state( task_run_id: str, state: State, version: int = None, flow_run_version: int = None) -> models.TaskRunState: """ Updates a task run state. Args: - task_run_id (str): the task run id to update - state (State): the new state - version (int): a version to enforce version-locking - flow_run_version (int): a flow run version to enforce version-lockgin Returns: - models.TaskRunState """ if task_run_id is None: raise ValueError(f"Invalid task run ID.") where = { "id": { "_eq": task_run_id }, "_or": [ { # EITHER version locking is enabled and the versions match "version": { "_eq": version }, "flow_run": { "version": { "_eq": flow_run_version }, "flow": { "flow_group": { "settings": { "_contains": { "version_locking_enabled": True } } } }, }, }, # OR version locking is not enabled { "flow_run": { "flow": { "flow_group": { "_not": { "settings": { "_contains": { "version_locking_enabled": True } } } } } } }, ], } task_run = await models.TaskRun.where(where).first({ "id": True, "tenant_id": True, "version": True, "state": True, "serialized_state": True, "flow_run": { "id": True, "state": True }, }) if not task_run: raise ValueError(f"State update failed for task run ID {task_run_id}") # ------------------------------------------------------ # if the state is running, ensure the flow run is also running # ------------------------------------------------------ if state.is_running() and task_run.flow_run.state != "Running": raise ValueError( f"State update failed for task run ID {task_run_id}: provided " f"a running state but associated flow run {task_run.flow_run.id} is not " "in a running state.") # ------------------------------------------------------ # if we have cached inputs on the old state, we need to carry them forward # ------------------------------------------------------ if not state.cached_inputs and task_run.serialized_state.get( "cached_inputs", None): # load up the old state's cached inputs and apply them to the new state serialized_state = state_schema.load(task_run.serialized_state) state.cached_inputs = serialized_state.cached_inputs # -------------------------------------------------------- # prepare the new state for the database # -------------------------------------------------------- task_run_state = models.TaskRunState( id=str(uuid.uuid4()), tenant_id=task_run.tenant_id, task_run_id=task_run.id, version=(task_run.version or 0) + 1, timestamp=pendulum.now("UTC"), message=state.message, result=state.result, start_time=getattr(state, "start_time", None), state=type(state).__name__, serialized_state=state.serialize(), ) await task_run_state.insert() # -------------------------------------------------------- # apply downstream updates # -------------------------------------------------------- # FOR RUNNING STATES: # - update the task run heartbeat if state.is_running(): await api.runs.update_task_run_heartbeat(task_run_id=task_run_id) return task_run_state