async def test_version_order_determines_timestamp(self, task_run_id): details = dict(task_run_id=task_run_id, serialized_state={}) st = pendulum.now("UTC") et = pendulum.now("UTC").add(days=1) not_et = pendulum.now("UTC").add(hours=1) await models.TaskRunState.insert_many([ models.TaskRunState(**details, version=0, state="Pending"), models.TaskRunState(**details, version=1, state="Running", timestamp=st), models.TaskRunState(**details, version=2, state="Failed"), models.TaskRunState(**details, version=3, state="Retrying"), models.TaskRunState(**details, version=4, state="Running"), models.TaskRunState(**details, version=6, state="Failed"), models.TaskRunState(**details, version=7, state="Retrying"), models.TaskRunState(**details, version=8, state="Running"), models.TaskRunState(**details, version=5, state="Success", timestamp=not_et), models.TaskRunState(**details, version=9, state="Success", timestamp=et), ]) run = await models.TaskRun.where(id=task_run_id).first( {"start_time", "end_time", "duration", "run_count"}) assert run.start_time == st assert run.end_time == et assert run.duration == (et - st).as_timedelta() assert run.run_count == 3
async def test_no_end_if_running_is_last_state(self, task_run_id): details = dict( task_run_id=task_run_id, serialized_state={}, ) st = pendulum.now("UTC") await models.TaskRunState.insert_many([ models.TaskRunState( **details, version=0, state="Pending", ), models.TaskRunState(**details, version=1, state="Running", timestamp=st), ]) run = await models.TaskRun.where(id=task_run_id).first( {"start_time", "end_time", "duration"}) assert run.start_time == st assert run.end_time is None assert run.duration is None
async def test_start_and_end_from_running_state(self, task_run_id): details = dict(task_run_id=task_run_id, serialized_state={}) st = pendulum.now("UTC") et = pendulum.now("UTC").add(days=1) await models.TaskRunState.insert_many([ models.TaskRunState(**details, version=0, state="Pending"), models.TaskRunState(**details, version=1, state="Running", timestamp=st), models.TaskRunState(**details, version=2, state="Failed", timestamp=et), models.TaskRunState(**details, version=3, state="Retrying"), models.TaskRunState(**details, version=4, state="Retrying"), ]) run = await models.TaskRun.where(id=task_run_id).first( {"start_time", "end_time", "duration", "run_count"}) assert run.start_time == st assert run.end_time == et assert run.duration == (et - st).as_timedelta() assert run.run_count == 1
async def test_inserting_running_state_has_effect(self, task_run_id): details = dict(task_run_id=task_run_id, serialized_state={}) await models.TaskRunState.insert_many([ models.TaskRunState(**details, version=0, state="Pending"), models.TaskRunState(**details, version=1, state="Running"), models.TaskRunState(**details, version=2, state="Failed"), models.TaskRunState(**details, version=3, state="Retrying"), models.TaskRunState(**details, version=4, state="Retrying"), ]) run = await models.TaskRun.where(id=task_run_id).first( {"start_time", "end_time", "duration", "run_count"}) assert run.start_time is not None assert run.end_time is not None assert run.duration is not None assert run.run_count == 1
async def set_task_run_state(task_run_id: str, state: State, force=False) -> None: """ Updates a task run state. Args: - task_run_id (str): the task run id to update - state (State): the new state - false (bool): if True, avoids pipeline checks """ if task_run_id is None: raise ValueError(f"Invalid task run ID.") task_run = await models.TaskRun.where({"id": {"_eq": task_run_id},}).first( { "id": True, "version": True, "state": True, "serialized_state": True, "flow_run": {"id": True, "state": True}, } ) if not task_run: raise ValueError(f"Invalid task run ID: {task_run_id}.") # ------------------------------------------------------ # if the state is running, ensure the flow run is also running # ------------------------------------------------------ if not force and state.is_running() and task_run.flow_run.state != "Running": raise ValueError( f"State update failed for task run ID {task_run_id}: provided " f"a running state but associated flow run {task_run.flow_run.id} is not " "in a running state." ) # ------------------------------------------------------ # if we have cached inputs on the old state, we need to carry them forward # ------------------------------------------------------ if not state.cached_inputs and task_run.serialized_state.get("cached_inputs", None): # load up the old state's cached inputs and apply them to the new state serialized_state = state_schema.load(task_run.serialized_state) state.cached_inputs = serialized_state.cached_inputs # -------------------------------------------------------- # prepare the new state for the database # -------------------------------------------------------- task_run_state = models.TaskRunState( task_run_id=task_run.id, version=(task_run.version or 0) + 1, timestamp=pendulum.now("UTC"), message=state.message, result=state.result, start_time=getattr(state, "start_time", None), state=type(state).__name__, serialized_state=state.serialize(), ) await task_run_state.insert()
async def get_or_create_mapped_task_run_children( flow_run_id: str, task_id: str, max_map_index: int) -> List[str]: """ Creates and/or retrieves mapped child task runs for a given flow run and task. Args: - flow_run_id (str): the flow run associated with the parent task run - task_id (str): the task ID to create and/or retrieve - max_map_index (int,): the number of mapped children e.g., a value of 2 yields 3 mapped children """ # grab task info task = await models.Task.where(id=task_id ).first({"cache_key", "tenant_id"}) # generate task runs to upsert task_runs = [ models.TaskRun( tenant_id=task.tenant_id, flow_run_id=flow_run_id, task_id=task_id, map_index=i, cache_key=task.cache_key, ) for i in range(max_map_index + 1) ] # upsert the mapped children task_runs = (await models.TaskRun().insert_many( objects=task_runs, on_conflict=dict( constraint="task_run_unique_identifier_key", update_columns=["cache_key"], ), selection_set={"returning": {"id", "map_index"}}, ))["returning"] task_runs.sort(key=lambda task_run: task_run.map_index) # get task runs without states stateless_runs = await models.TaskRun.where({ "flow_run_id": { "_eq": flow_run_id }, "task_id": { "_eq": task_id }, # this syntax indicates "where there are no states" "_not": { "states": {} }, }).get({"id", "map_index", "version"}) # create and insert states for stateless task runs task_run_states = [ models.TaskRunState( tenant_id=task.tenant_id, task_run_id=task_run.id, **models.TaskRunState.fields_from_state( Pending(message="Task run created")), ) for task_run in stateless_runs ] await models.TaskRunState().insert_many(task_run_states) # return the task run ids return [task_run.id for task_run in task_runs]
async def test_get_or_create_mapped_children_handles_partial_children( self, flow_id, flow_run_id): # get a task from the flow task = await models.Task.where({ "flow_id": { "_eq": flow_id } }).first({"id", "cache_key"}) # create a few mapped children await models.TaskRun( flow_run_id=flow_run_id, task_id=task.id, map_index=3, cache_key=task.cache_key, ).insert() stateful_child = await models.TaskRun( flow_run_id=flow_run_id, task_id=task.id, map_index=6, cache_key=task.cache_key, states=[ models.TaskRunState( **models.TaskRunState.fields_from_state( Pending(message="Task run created")), ) ], ).insert() # retrieve mapped children mapped_children = await api.runs.get_or_create_mapped_task_run_children( flow_run_id=flow_run_id, task_id=task.id, max_map_index=10) assert len(mapped_children) == 11 map_indices = [] # confirm each of the mapped children has a state and is ordered properly for child in mapped_children: task_run = await models.TaskRun.where(id=child).first({ "map_index": True, with_args( "states", { "order_by": { "version": EnumValue("desc") }, "limit": 1 }, ): {"id"}, }) map_indices.append(task_run.map_index) assert task_run.states[0] is not None assert map_indices == sorted(map_indices) # confirm the one child created with a state only has the one state child_states = await models.TaskRunState.where({ "task_run_id": { "_eq": stateful_child } }).get() assert len(child_states) == 1
async def get_or_create_task_run(flow_run_id: str, task_id: str, map_index: int = None) -> str: """ Since some task runs are created dynamically (when tasks are mapped, for example) we don't know if a task run exists the first time we query it. This function will take key information about a task run and create it if it doesn't already exist, returning its id. """ if map_index is None: map_index = -1 # try to load an existing task run task_run = await models.TaskRun.where({ "flow_run_id": { "_eq": flow_run_id }, "task_id": { "_eq": task_id }, "map_index": { "_eq": map_index }, }).first({"id"}) if task_run: return task_run.id try: # load the tenant ID and cache_key task = await models.Task.where(id=task_id ).first({"cache_key", "tenant_id"}) # create the task run return await models.TaskRun( tenant_id=task.tenant_id, flow_run_id=flow_run_id, task_id=task_id, map_index=map_index, cache_key=task.cache_key, states=[ models.TaskRunState( tenant_id=task.tenant_id, **models.TaskRunState.fields_from_state( Pending(message="Task run created")), ) ], ).insert() except Exception: raise ValueError("Invalid ID")
async def test_get_or_create_mapped_task_run_children_with_partial_children( self, run_query, flow_run_id, flow_id): task = await models.Task.where({ "flow_id": { "_eq": flow_id } }).first({"id"}) # create a couple of children preexisting_run_1 = await models.TaskRun( flow_run_id=flow_run_id, task_id=task.id, map_index=3, cache_key=task.cache_key, ).insert() preexisting_run_2 = await models.TaskRun( flow_run_id=flow_run_id, task_id=task.id, map_index=6, cache_key=task.cache_key, states=[ models.TaskRunState( **models.TaskRunState.fields_from_state( Pending(message="Task run created")), ) ], ).insert() # call the route result = await run_query( query=self.mutation, variables=dict(input=dict( flow_run_id=flow_run_id, task_id=task.id, max_map_index=10)), ) mapped_children = result.data.get_or_create_mapped_task_run_children.ids # should have 11 children, indices 0-10 assert len(mapped_children) == 11 # confirm the preexisting task runs were included in the results assert preexisting_run_1 in mapped_children assert preexisting_run_2 in mapped_children # confirm the results are ordered map_indices = [] for child in mapped_children: map_indices.append( (await models.TaskRun.where(id=child).first({"map_index"} )).map_index) assert map_indices == sorted(map_indices)
async def set_task_run_state(task_run_id: str, state: State,) -> None: """ Updates a task run state. Args: - task_run_id (str): the task run id to update - state (State): the new state """ if task_run_id is None: raise ValueError(f"Invalid task run ID.") task_run = await models.TaskRun.where({"id": {"_eq": task_run_id},}).first( {"id": True, "version": True, "state": True, "serialized_state": True,} ) if not task_run: raise ValueError(f"Invalid task run ID: {task_run_id}.") # ------------------------------------------------------ # if we have cached inputs on the old state, we need to carry them forward # ------------------------------------------------------ if not state.cached_inputs and task_run.serialized_state.get("cached_inputs", None): # load up the old state's cached inputs and apply them to the new state serialized_state = state_schema.load(task_run.serialized_state) state.cached_inputs = serialized_state.cached_inputs # -------------------------------------------------------- # prepare the new state for the database # -------------------------------------------------------- task_run_state = models.TaskRunState( task_run_id=task_run.id, version=(task_run.version or 0) + 1, timestamp=pendulum.now("UTC"), message=state.message, result=state.result, start_time=getattr(state, "start_time", None), state=type(state).__name__, serialized_state=state.serialize(), ) await task_run_state.insert()
async def set_task_run_state( task_run_id: str, state: State, version: int = None, flow_run_version: int = None) -> models.TaskRunState: """ Updates a task run state. Args: - task_run_id (str): the task run id to update - state (State): the new state - version (int): a version to enforce version-locking - flow_run_version (int): a flow run version to enforce version-lockgin Returns: - models.TaskRunState """ if task_run_id is None: raise ValueError(f"Invalid task run ID.") where = { "id": { "_eq": task_run_id }, "_or": [ { # EITHER version locking is enabled and the versions match "version": { "_eq": version }, "flow_run": { "version": { "_eq": flow_run_version }, "flow": { "flow_group": { "settings": { "_contains": { "version_locking_enabled": True } } } }, }, }, # OR version locking is not enabled { "flow_run": { "flow": { "flow_group": { "_not": { "settings": { "_contains": { "version_locking_enabled": True } } } } } } }, ], } task_run = await models.TaskRun.where(where).first({ "id": True, "tenant_id": True, "version": True, "state": True, "serialized_state": True, "flow_run": { "id": True, "state": True }, }) if not task_run: raise ValueError(f"State update failed for task run ID {task_run_id}") # ------------------------------------------------------ # if the state is running, ensure the flow run is also running # ------------------------------------------------------ if state.is_running() and task_run.flow_run.state != "Running": raise ValueError( f"State update failed for task run ID {task_run_id}: provided " f"a running state but associated flow run {task_run.flow_run.id} is not " "in a running state.") # ------------------------------------------------------ # if we have cached inputs on the old state, we need to carry them forward # ------------------------------------------------------ if not state.cached_inputs and task_run.serialized_state.get( "cached_inputs", None): # load up the old state's cached inputs and apply them to the new state serialized_state = state_schema.load(task_run.serialized_state) state.cached_inputs = serialized_state.cached_inputs # -------------------------------------------------------- # prepare the new state for the database # -------------------------------------------------------- task_run_state = models.TaskRunState( id=str(uuid.uuid4()), tenant_id=task_run.tenant_id, task_run_id=task_run.id, version=(task_run.version or 0) + 1, timestamp=pendulum.now("UTC"), message=state.message, result=state.result, start_time=getattr(state, "start_time", None), state=type(state).__name__, serialized_state=state.serialize(), ) await task_run_state.insert() # -------------------------------------------------------- # apply downstream updates # -------------------------------------------------------- # FOR RUNNING STATES: # - update the task run heartbeat if state.is_running(): await api.runs.update_task_run_heartbeat(task_run_id=task_run_id) return task_run_state