Beispiel #1
0
def _process_stop_initiated_pipelines(
    mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
    pipeline_details: Sequence[_PipelineDetail]) -> None:
  """Processes stop initiated pipelines."""
  for detail in pipeline_details:
    pipeline = detail.pipeline_state.pipeline
    execution = detail.pipeline_state.execution
    has_active_executions = False
    for node in _get_all_pipeline_nodes(pipeline):
      # If the node has an ExecNodeTask in the task queue, issue a cancellation.
      # Otherwise, if the node has an active execution in MLMD but no
      # ExecNodeTask enqueued, it may be due to orchestrator restart after
      # pipeline stop was initiated but before the schedulers could finish. So,
      # enqueue an ExecNodeTask with is_cancelled set to give a chance for the
      # scheduler to finish gracefully.
      exec_node_task_id = task_lib.exec_node_task_id_from_pipeline_node(
          pipeline, node)
      if task_queue.contains_task_id(exec_node_task_id):
        task_queue.enqueue(
            task_lib.CancelNodeTask(
                node_uid=task_lib.NodeUid.from_pipeline_node(pipeline, node)))
        has_active_executions = True
      else:
        executions = task_gen_utils.get_executions(mlmd_handle, node)
        exec_node_task = task_gen_utils.generate_task_from_active_execution(
            mlmd_handle, pipeline, node, executions, is_cancelled=True)
        if exec_node_task:
          task_queue.enqueue(exec_node_task)
          has_active_executions = True
    if not has_active_executions:
      updated_execution = copy.deepcopy(execution)
      updated_execution.last_known_state = metadata_store_pb2.Execution.CANCELED
      mlmd_handle.store.put_executions([updated_execution])
Beispiel #2
0
 def test_task_ids(self):
   pipeline_uid = task_lib.PipelineUid(pipeline_id='pipeline')
   node_uid = task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Trainer')
   exec_node_task = test_utils.create_exec_node_task(node_uid)
   self.assertEqual(('ExecNodeTask', node_uid), exec_node_task.task_id)
   cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid)
   self.assertEqual(('CancelNodeTask', node_uid), cancel_node_task.task_id)
Beispiel #3
0
 def test_task_ids(self):
     node_uid = task_lib.NodeUid(pipeline_id='pipeline',
                                 pipeline_run_id='run0',
                                 node_id='Trainer')
     exec_node_task = task_lib.ExecNodeTask(node_uid=node_uid,
                                            execution_id=123)
     self.assertEqual(('ExecNodeTask', node_uid), exec_node_task.task_id)
     cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid)
     self.assertEqual(('CancelNodeTask', node_uid),
                      cancel_node_task.task_id)
Beispiel #4
0
 def test_task_type_ids(self):
     self.assertEqual('ExecNodeTask', task_lib.ExecNodeTask.task_type_id())
     self.assertEqual('CancelNodeTask',
                      task_lib.CancelNodeTask.task_type_id())
     node_uid = task_lib.NodeUid(pipeline_id='pipeline',
                                 pipeline_run_id='run0',
                                 node_id='Trainer')
     exec_node_task = test_utils.create_exec_node_task(node_uid)
     self.assertEqual('ExecNodeTask', exec_node_task.task_type_id())
     cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid)
     self.assertEqual('CancelNodeTask', cancel_node_task.task_type_id())
Beispiel #5
0
def _maybe_enqueue_cancellation_task(mlmd_handle: metadata.Metadata,
                                     pipeline: pipeline_pb2.Pipeline,
                                     node: pipeline_pb2.PipelineNode,
                                     task_queue: tq.TaskQueue,
                                     pause: bool = False) -> bool:
    """Enqueues a node cancellation task if not already stopped.

  If the node has an ExecNodeTask in the task queue, issue a cancellation.
  Otherwise, when pause=False, if the node has an active execution in MLMD but
  no ExecNodeTask enqueued, it may be due to orchestrator restart after stopping
  was initiated but before the schedulers could finish. So, enqueue an
  ExecNodeTask with is_cancelled set to give a chance for the scheduler to
  finish gracefully.

  Args:
    mlmd_handle: A handle to the MLMD db.
    pipeline: The pipeline containing the node to cancel.
    node: The node to cancel.
    task_queue: A `TaskQueue` instance into which any cancellation tasks will be
      enqueued.
    pause: Whether the cancellation is to pause the node rather than cancelling
      the execution.

  Returns:
    `True` if a cancellation task was enqueued. `False` if node is already
    stopped or no cancellation was required.
  """
    exec_node_task_id = task_lib.exec_node_task_id_from_pipeline_node(
        pipeline, node)
    if task_queue.contains_task_id(exec_node_task_id):
        task_queue.enqueue(
            task_lib.CancelNodeTask(
                node_uid=task_lib.NodeUid.from_pipeline_node(pipeline, node),
                pause=pause))
        return True
    if not pause:
        executions = task_gen_utils.get_executions(mlmd_handle, node)
        exec_node_task = task_gen_utils.generate_task_from_active_execution(
            mlmd_handle, pipeline, node, executions, is_cancelled=True)
        if exec_node_task:
            task_queue.enqueue(exec_node_task)
            return True
    return False
Beispiel #6
0
def _test_cancel_node_task(node_id, pipeline_id):
    node_uid = task_lib.NodeUid(
        pipeline_uid=task_lib.PipelineUid(pipeline_id=pipeline_id),
        node_id=node_id)
    return task_lib.CancelNodeTask(node_uid=node_uid)
Beispiel #7
0
def _test_cancel_task(node_id, pipeline_id, pipeline_run_id=None):
    node_uid = task_lib.NodeUid(pipeline_id=pipeline_id,
                                pipeline_run_id=pipeline_run_id,
                                node_id=node_id)
    return task_lib.CancelNodeTask(node_uid=node_uid)