async def get_or_create_mapped_task_run_children( flow_run_id: str, task_id: str, max_map_index: int) -> List[str]: """ Creates and/or retrieves mapped child task runs for a given flow run and task. Args: - flow_run_id (str): the flow run associated with the parent task run - task_id (str): the task ID to create and/or retrieve - max_map_index (int,): the number of mapped children e.g., a value of 2 yields 3 mapped children """ # grab task info task = await models.Task.where(id=task_id ).first({"cache_key", "tenant_id"}) # generate task runs to upsert task_runs = [ models.TaskRun( tenant_id=task.tenant_id, flow_run_id=flow_run_id, task_id=task_id, map_index=i, cache_key=task.cache_key, ) for i in range(max_map_index + 1) ] # upsert the mapped children task_runs = (await models.TaskRun().insert_many( objects=task_runs, on_conflict=dict( constraint="task_run_unique_identifier_key", update_columns=["cache_key"], ), selection_set={"returning": {"id", "map_index"}}, ))["returning"] task_runs.sort(key=lambda task_run: task_run.map_index) # get task runs without states stateless_runs = await models.TaskRun.where({ "flow_run_id": { "_eq": flow_run_id }, "task_id": { "_eq": task_id }, # this syntax indicates "where there are no states" "_not": { "states": {} }, }).get({"id", "map_index", "version"}) # create and insert states for stateless task runs task_run_states = [ models.TaskRunState( tenant_id=task.tenant_id, task_run_id=task_run.id, **models.TaskRunState.fields_from_state( Pending(message="Task run created")), ) for task_run in stateless_runs ] await models.TaskRunState().insert_many(task_run_states) # return the task run ids return [task_run.id for task_run in task_runs]
async def test_get_or_create_task_run_info_hits_db(self, tenant_id, flow_run_id, task_id): task_run = models.TaskRun( id=str(uuid.uuid4()), tenant_id=tenant_id, flow_run_id=flow_run_id, task_id=task_id, map_index=12, version=17, state="Success", serialized_state=dict(message="hi"), ) await task_run.insert() task_run_info = await api.runs.get_or_create_task_run_info( flow_run_id=flow_run_id, task_id=task_id, map_index=task_run.map_index) assert task_run_info["id"] == task_run.id assert task_run_info["version"] == task_run.version assert task_run_info["state"] == task_run.state assert task_run_info["serialized_state"] == task_run.serialized_state
async def get_or_create_task_run_info(flow_run_id: str, task_id: str, map_index: int = None) -> dict: """ Given a flow_run_id, task_id, and map_index, return details about the corresponding task run. If the task run doesn't exist, it will be created. Returns: - dict: a dict of details about the task run, including its id, version, and state. """ if map_index is None: map_index = -1 task_run = await models.TaskRun.where({ "flow_run_id": { "_eq": flow_run_id }, "task_id": { "_eq": task_id }, "map_index": { "_eq": map_index }, }).first({"id", "version", "state", "serialized_state"}) if task_run: return dict( id=task_run.id, version=task_run.version, state=task_run.state, serialized_state=task_run.serialized_state, ) # if it isn't found, add it to the DB task = await models.Task.where(id=task_id ).first({"cache_key", "tenant_id"}) if not task: raise ValueError("Invalid task ID") db_task_run = models.TaskRun( tenant_id=task.tenant_id, flow_run_id=flow_run_id, task_id=task_id, map_index=map_index, cache_key=task.cache_key, version=0, ) db_task_run_state = models.TaskRunState( tenant_id=task.tenant_id, state="Pending", timestamp=pendulum.now(), message="Task run created", serialized_state=Pending(message="Task run created").serialize(), ) db_task_run.states = [db_task_run_state] run = await db_task_run.insert( on_conflict=dict( constraint="task_run_unique_identifier_key", update_columns=["cache_key"], ), selection_set={"returning": {"id"}}, ) return dict( id=run.returning.id, version=db_task_run.version, state="Pending", serialized_state=db_task_run_state.serialized_state, )