コード例 #1
0
def _process_stop_initiated_pipelines(
    mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
    pipeline_details: Sequence[_PipelineDetail]) -> None:
  """Processes stop initiated pipelines."""
  for detail in pipeline_details:
    pipeline = detail.pipeline_state.pipeline
    execution = detail.pipeline_state.execution
    has_active_executions = False
    for node in _get_all_pipeline_nodes(pipeline):
      # If the node has an ExecNodeTask in the task queue, issue a cancellation.
      # Otherwise, if the node has an active execution in MLMD but no
      # ExecNodeTask enqueued, it may be due to orchestrator restart after
      # pipeline stop was initiated but before the schedulers could finish. So,
      # enqueue an ExecNodeTask with is_cancelled set to give a chance for the
      # scheduler to finish gracefully.
      exec_node_task_id = task_lib.exec_node_task_id_from_pipeline_node(
          pipeline, node)
      if task_queue.contains_task_id(exec_node_task_id):
        task_queue.enqueue(
            task_lib.CancelNodeTask(
                node_uid=task_lib.NodeUid.from_pipeline_node(pipeline, node)))
        has_active_executions = True
      else:
        executions = task_gen_utils.get_executions(mlmd_handle, node)
        exec_node_task = task_gen_utils.generate_task_from_active_execution(
            mlmd_handle, pipeline, node, executions, is_cancelled=True)
        if exec_node_task:
          task_queue.enqueue(exec_node_task)
          has_active_executions = True
    if not has_active_executions:
      updated_execution = copy.deepcopy(execution)
      updated_execution.last_known_state = metadata_store_pb2.Execution.CANCELED
      mlmd_handle.store.put_executions([updated_execution])
コード例 #2
0
ファイル: task_test.py プロジェクト: sycdesign/tfx
 def test_task_ids(self):
   pipeline_uid = task_lib.PipelineUid(pipeline_id='pipeline')
   node_uid = task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Trainer')
   exec_node_task = test_utils.create_exec_node_task(node_uid)
   self.assertEqual(('ExecNodeTask', node_uid), exec_node_task.task_id)
   cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid)
   self.assertEqual(('CancelNodeTask', node_uid), cancel_node_task.task_id)
コード例 #3
0
ファイル: task_test.py プロジェクト: kp425/tfx
 def test_task_ids(self):
     node_uid = task_lib.NodeUid(pipeline_id='pipeline',
                                 pipeline_run_id='run0',
                                 node_id='Trainer')
     exec_node_task = task_lib.ExecNodeTask(node_uid=node_uid,
                                            execution_id=123)
     self.assertEqual(('ExecNodeTask', node_uid), exec_node_task.task_id)
     cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid)
     self.assertEqual(('CancelNodeTask', node_uid),
                      cancel_node_task.task_id)
コード例 #4
0
ファイル: task_test.py プロジェクト: hamzamaiot/tfx
 def test_task_type_ids(self):
     self.assertEqual('ExecNodeTask', task_lib.ExecNodeTask.task_type_id())
     self.assertEqual('CancelNodeTask',
                      task_lib.CancelNodeTask.task_type_id())
     node_uid = task_lib.NodeUid(pipeline_id='pipeline',
                                 pipeline_run_id='run0',
                                 node_id='Trainer')
     exec_node_task = test_utils.create_exec_node_task(node_uid)
     self.assertEqual('ExecNodeTask', exec_node_task.task_type_id())
     cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid)
     self.assertEqual('CancelNodeTask', cancel_node_task.task_type_id())
コード例 #5
0
ファイル: pipeline_ops.py プロジェクト: jay90099/tfx
def _maybe_enqueue_cancellation_task(mlmd_handle: metadata.Metadata,
                                     pipeline: pipeline_pb2.Pipeline,
                                     node: pipeline_pb2.PipelineNode,
                                     task_queue: tq.TaskQueue,
                                     pause: bool = False) -> bool:
    """Enqueues a node cancellation task if not already stopped.

  If the node has an ExecNodeTask in the task queue, issue a cancellation.
  Otherwise, when pause=False, if the node has an active execution in MLMD but
  no ExecNodeTask enqueued, it may be due to orchestrator restart after stopping
  was initiated but before the schedulers could finish. So, enqueue an
  ExecNodeTask with is_cancelled set to give a chance for the scheduler to
  finish gracefully.

  Args:
    mlmd_handle: A handle to the MLMD db.
    pipeline: The pipeline containing the node to cancel.
    node: The node to cancel.
    task_queue: A `TaskQueue` instance into which any cancellation tasks will be
      enqueued.
    pause: Whether the cancellation is to pause the node rather than cancelling
      the execution.

  Returns:
    `True` if a cancellation task was enqueued. `False` if node is already
    stopped or no cancellation was required.
  """
    exec_node_task_id = task_lib.exec_node_task_id_from_pipeline_node(
        pipeline, node)
    if task_queue.contains_task_id(exec_node_task_id):
        task_queue.enqueue(
            task_lib.CancelNodeTask(
                node_uid=task_lib.NodeUid.from_pipeline_node(pipeline, node),
                pause=pause))
        return True
    if not pause:
        executions = task_gen_utils.get_executions(mlmd_handle, node)
        exec_node_task = task_gen_utils.generate_task_from_active_execution(
            mlmd_handle, pipeline, node, executions, is_cancelled=True)
        if exec_node_task:
            task_queue.enqueue(exec_node_task)
            return True
    return False
コード例 #6
0
ファイル: task_manager_test.py プロジェクト: suryaavala/tfx
def _test_cancel_node_task(node_id, pipeline_id):
    node_uid = task_lib.NodeUid(
        pipeline_uid=task_lib.PipelineUid(pipeline_id=pipeline_id),
        node_id=node_id)
    return task_lib.CancelNodeTask(node_uid=node_uid)
コード例 #7
0
def _test_cancel_task(node_id, pipeline_id, pipeline_run_id=None):
    node_uid = task_lib.NodeUid(pipeline_id=pipeline_id,
                                pipeline_run_id=pipeline_run_id,
                                node_id=node_id)
    return task_lib.CancelNodeTask(node_uid=node_uid)