예제 #1
0
 def _generate_and_test(self, use_task_queue, num_initial_executions,
                        num_tasks_generated, num_new_executions,
                        num_active_executions):
     """Generates tasks and tests the effects."""
     with self._mlmd_connection as m:
         executions = m.store.get_executions()
         self.assertLen(
             executions, num_initial_executions,
             'Expected {} execution(s) in MLMD.'.format(
                 num_initial_executions))
         task_gen = sptg.SyncPipelineTaskGenerator(
             m, self._pipeline, self._task_queue.contains_task_id)
         tasks = task_gen.generate()
         self.assertLen(
             tasks, num_tasks_generated,
             'Expected {} task(s) to be generated.'.format(
                 num_tasks_generated))
         executions = m.store.get_executions()
         num_total_executions = num_initial_executions + num_new_executions
         self.assertLen(
             executions, num_total_executions,
             'Expected {} execution(s) in MLMD.'.format(
                 num_total_executions))
         active_executions = [
             e for e in executions
             if e.last_known_state == metadata_store_pb2.Execution.RUNNING
         ]
         self.assertLen(
             active_executions, num_active_executions,
             'Expected {} active execution(s) in MLMD.'.format(
                 num_active_executions))
         if use_task_queue:
             for task in tasks:
                 self._task_queue.enqueue(task)
         return tasks, active_executions
예제 #2
0
def _get_pipeline_details(mlmd_handle: metadata.Metadata,
                          task_queue: tq.TaskQueue) -> List[_PipelineDetail]:
    """Scans MLMD and returns pipeline details."""
    result = []

    contexts = mlmd_handle.store.get_contexts_by_type(
        _ORCHESTRATOR_RESERVED_ID)

    for context in contexts:
        active_executions = [
            e for e in mlmd_handle.store.get_executions_by_context(context.id)
            if execution_lib.is_execution_active(e)
        ]
        if len(active_executions) > 1:
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.INTERNAL,
                message=(
                    f'Expected 1 but found {len(active_executions)} active '
                    f'executions for context named: {context.name}'))
        if not active_executions:
            continue
        execution = active_executions[0]

        # TODO(goutham): Instead of parsing the pipeline IR each time, we could
        # cache the parsed pipeline IR in `initiate_pipeline_start` and reuse it.
        pipeline_ir_b64 = common_utils.get_metadata_value(
            execution.properties[_PIPELINE_IR])
        pipeline = pipeline_pb2.Pipeline()
        pipeline.ParseFromString(base64.b64decode(pipeline_ir_b64))

        stop_initiated = _is_stop_initiated(execution)

        if stop_initiated:
            generator = None
        else:
            if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
                generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
                    mlmd_handle, pipeline, task_queue.contains_task_id)
            elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
                generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
                    mlmd_handle, pipeline, task_queue.contains_task_id)
            else:
                raise status_lib.StatusNotOkError(
                    code=status_lib.Code.FAILED_PRECONDITION,
                    message=
                    (f'Only SYNC and ASYNC pipeline execution modes supported; '
                     f'found pipeline with execution mode: {pipeline.execution_mode}'
                     ))

        result.append(
            _PipelineDetail(
                context=context,
                execution=execution,
                pipeline=pipeline,
                pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline),
                stop_initiated=stop_initiated,
                generator=generator))

    return result
예제 #3
0
def _process_active_pipelines(
        mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
        service_job_manager: Optional[service_jobs.ServiceJobManager],
        pipeline_states: Sequence[pstate.PipelineState]) -> None:
    """Processes active pipelines."""
    for pipeline_state in pipeline_states:
        pipeline = pipeline_state.pipeline
        execution = pipeline_state.execution
        assert execution.last_known_state in (
            metadata_store_pb2.Execution.NEW,
            metadata_store_pb2.Execution.RUNNING)
        if execution.last_known_state != metadata_store_pb2.Execution.RUNNING:
            updated_execution = copy.deepcopy(execution)
            updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING
            mlmd_handle.store.put_executions([updated_execution])

        if service_job_manager is not None:
            # Ensure all the required services are running.
            _ensure_services(service_job_manager, pipeline_state)
            pure_service_node_ids = _get_pure_service_node_ids(
                service_job_manager, pipeline_state)
        else:
            pure_service_node_ids = set()

        # Create cancellation tasks for stop-initiated nodes if necessary.
        stop_initiated_nodes = _get_stop_initiated_nodes(pipeline_state)
        for node in stop_initiated_nodes:
            if node.node_info.id not in pure_service_node_ids:
                _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node,
                                                 task_queue)

        ignore_node_ids = set(
            n.node_info.id
            for n in stop_initiated_nodes) | pure_service_node_ids

        # Initialize task generator for the pipeline.
        if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
            generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
                mlmd_handle, pipeline, task_queue.contains_task_id,
                ignore_node_ids)
        elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
            generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
                mlmd_handle, pipeline, task_queue.contains_task_id,
                ignore_node_ids)
        else:
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.FAILED_PRECONDITION,
                message=
                (f'Only SYNC and ASYNC pipeline execution modes supported; '
                 f'found pipeline with execution mode: {pipeline.execution_mode}'
                 ))

        # TODO(goutham): Consider concurrent task generation.
        tasks = generator.generate()
        for task in tasks:
            task_queue.enqueue(task)
예제 #4
0
 def _test_no_tasks_generated_when_new(self):
     with self._mlmd_connection as m:
         task_gen = sptg.SyncPipelineTaskGenerator(m, self._pipeline,
                                                   lambda _: False)
         tasks = task_gen.generate()
         self.assertEmpty(
             tasks,
             'Expected no task generation since ExampleGen is ignored for task '
             'generation and dependent downstream nodes are ready.')
         self.assertEmpty(
             m.store.get_executions(),
             'There must not be any registered executions since no tasks were '
             'generated.')
예제 #5
0
def _get_pipeline_details(mlmd_handle: metadata.Metadata,
                          task_queue: tq.TaskQueue) -> List[_PipelineDetail]:
  """Scans MLMD and returns pipeline details."""
  result = []

  contexts = pstate.get_orchestrator_contexts(mlmd_handle)

  for context in contexts:
    try:
      pipeline_state = pstate.PipelineState.load_from_orchestrator_context(
          mlmd_handle, context)
    except status_lib.StatusNotOkError as e:
      if e.code == status_lib.Code.NOT_FOUND:
        continue

    if pipeline_state.is_stop_initiated():
      generator = None
    else:
      pipeline = pipeline_state.pipeline
      if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
        generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
            mlmd_handle, pipeline, task_queue.contains_task_id)
      elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
        generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
            mlmd_handle, pipeline, task_queue.contains_task_id)
      else:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=(
                f'Only SYNC and ASYNC pipeline execution modes supported; '
                f'found pipeline with execution mode: {pipeline.execution_mode}'
            ))

    result.append(
        _PipelineDetail(pipeline_state=pipeline_state, generator=generator))

  return result
예제 #6
0
def _orchestrate_active_pipeline(
        mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
        service_job_manager: service_jobs.ServiceJobManager,
        pipeline_state: pstate.PipelineState) -> None:
    """Orchestrates active pipeline."""
    pipeline = pipeline_state.pipeline
    execution = pipeline_state.execution
    assert execution.last_known_state in (metadata_store_pb2.Execution.NEW,
                                          metadata_store_pb2.Execution.RUNNING)
    if execution.last_known_state != metadata_store_pb2.Execution.RUNNING:
        updated_execution = copy.deepcopy(execution)
        updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING
        mlmd_handle.store.put_executions([updated_execution])

    # Initialize task generator for the pipeline.
    if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
        generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
            mlmd_handle, pipeline_state, task_queue.contains_task_id,
            service_job_manager)
    elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
        # Create cancellation tasks for stop-initiated nodes if necessary.
        stop_initiated_nodes = _get_stop_initiated_nodes(pipeline_state)
        for node in stop_initiated_nodes:
            if service_job_manager.is_pure_service_node(
                    pipeline_state, node.node_info.id):
                service_job_manager.stop_node_services(pipeline_state,
                                                       node.node_info.id)
            elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node,
                                                  task_queue):
                pass
            elif service_job_manager.is_mixed_service_node(
                    pipeline_state, node.node_info.id):
                service_job_manager.stop_node_services(pipeline_state,
                                                       node.node_info.id)
        generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
            mlmd_handle, pipeline_state, task_queue.contains_task_id,
            service_job_manager,
            set(n.node_info.id for n in stop_initiated_nodes))
    else:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=(
                f'Only SYNC and ASYNC pipeline execution modes supported; '
                f'found pipeline with execution mode: {pipeline.execution_mode}'
            ))

    tasks = generator.generate()

    with pipeline_state:
        for task in tasks:
            if task_lib.is_exec_node_task(task):
                task = typing.cast(task_lib.ExecNodeTask, task)
                task_queue.enqueue(task)
            elif task_lib.is_finalize_node_task(task):
                assert pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC
                task = typing.cast(task_lib.FinalizeNodeTask, task)
                pipeline_state.initiate_node_stop(task.node_uid, task.status)
            else:
                assert task_lib.is_finalize_pipeline_task(task)
                assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                assert len(tasks) == 1
                task = typing.cast(task_lib.FinalizePipelineTask, task)
                if task.status.code == status_lib.Code.OK:
                    logging.info('Pipeline run successful; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                else:
                    logging.info('Pipeline run failed; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                pipeline_state.initiate_stop(task.status)
예제 #7
0
def _orchestrate_active_pipeline(
        mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
        service_job_manager: service_jobs.ServiceJobManager,
        pipeline_state: pstate.PipelineState) -> None:
    """Orchestrates active pipeline."""
    pipeline = pipeline_state.pipeline
    with pipeline_state:
        assert pipeline_state.is_active()
        if pipeline_state.get_pipeline_execution_state() != (
                metadata_store_pb2.Execution.RUNNING):
            pipeline_state.set_pipeline_execution_state(
                metadata_store_pb2.Execution.RUNNING)
        orchestration_options = pipeline_state.get_orchestration_options()
        logging.info('Orchestration options: %s', orchestration_options)
        deadline_secs = orchestration_options.deadline_secs
        if (pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                and deadline_secs > 0 and time.time() -
                pipeline_state.pipeline_creation_time_secs_since_epoch() >
                deadline_secs):
            logging.error(
                'Aborting pipeline due to exceeding deadline (%s secs); '
                'pipeline uid: %s', deadline_secs, pipeline_state.pipeline_uid)
            pipeline_state.initiate_stop(
                status_lib.Status(
                    code=status_lib.Code.DEADLINE_EXCEEDED,
                    message=('Pipeline aborted due to exceeding deadline '
                             f'({deadline_secs} secs)')))
            return

    def _filter_by_state(node_infos: List[_NodeInfo],
                         state_str: str) -> List[_NodeInfo]:
        return [n for n in node_infos if n.state.state == state_str]

    node_infos = _get_node_infos(pipeline_state)
    stopping_node_infos = _filter_by_state(node_infos,
                                           pstate.NodeState.STOPPING)

    # Tracks nodes stopped in the current iteration.
    stopped_node_infos: List[_NodeInfo] = []

    # Create cancellation tasks for nodes in state STOPPING.
    for node_info in stopping_node_infos:
        if service_job_manager.is_pure_service_node(
                pipeline_state, node_info.node.node_info.id):
            if service_job_manager.stop_node_services(
                    pipeline_state, node_info.node.node_info.id):
                stopped_node_infos.append(node_info)
        elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline,
                                              node_info.node, task_queue):
            pass
        elif service_job_manager.is_mixed_service_node(
                pipeline_state, node_info.node.node_info.id):
            if service_job_manager.stop_node_services(
                    pipeline_state, node_info.node.node_info.id):
                stopped_node_infos.append(node_info)
        else:
            stopped_node_infos.append(node_info)

    # Change the state of stopped nodes from STOPPING to STOPPED.
    if stopped_node_infos:
        with pipeline_state:
            for node_info in stopped_node_infos:
                node_uid = task_lib.NodeUid.from_pipeline_node(
                    pipeline, node_info.node)
                with pipeline_state.node_state_update_context(
                        node_uid) as node_state:
                    node_state.update(pstate.NodeState.STOPPED,
                                      node_state.status)

    # Initialize task generator for the pipeline.
    if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
        generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
            mlmd_handle,
            task_queue.contains_task_id,
            service_job_manager,
            fail_fast=orchestration_options.fail_fast)
    elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
        generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
            mlmd_handle, task_queue.contains_task_id, service_job_manager)
    else:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=(
                f'Only SYNC and ASYNC pipeline execution modes supported; '
                f'found pipeline with execution mode: {pipeline.execution_mode}'
            ))

    tasks = generator.generate(pipeline_state)

    with pipeline_state:
        # Handle all the UpdateNodeStateTasks by updating node states.
        for task in tasks:
            if task_lib.is_update_node_state_task(task):
                task = typing.cast(task_lib.UpdateNodeStateTask, task)
                with pipeline_state.node_state_update_context(
                        task.node_uid) as node_state:
                    node_state.update(task.state, task.status)

        tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)]

        # If there are still nodes in state STARTING, change them to STARTED.
        for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline):
            node_uid = task_lib.NodeUid.from_pipeline_node(
                pipeline_state.pipeline, node)
            with pipeline_state.node_state_update_context(
                    node_uid) as node_state:
                if node_state.state == pstate.NodeState.STARTING:
                    node_state.update(pstate.NodeState.STARTED)

        for task in tasks:
            if task_lib.is_exec_node_task(task):
                task = typing.cast(task_lib.ExecNodeTask, task)
                task_queue.enqueue(task)
            else:
                assert task_lib.is_finalize_pipeline_task(task)
                assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                assert len(tasks) == 1
                task = typing.cast(task_lib.FinalizePipelineTask, task)
                if task.status.code == status_lib.Code.OK:
                    logging.info('Pipeline run successful; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                else:
                    logging.info('Pipeline run failed; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                pipeline_state.initiate_stop(task.status)