Пример #1
0
    def new(cls, mlmd_handle: metadata.Metadata,
            pipeline: pipeline_pb2.Pipeline) -> 'PipelineState':
        """Creates a `PipelineState` object for a new pipeline.

    No active pipeline with the same pipeline uid should exist for the call to
    be successful.

    Args:
      mlmd_handle: A handle to the MLMD db.
      pipeline: IR of the pipeline.

    Returns:
      A `PipelineState` object.

    Raises:
      status_lib.StatusNotOkError: If a pipeline with same UID already exists.
    """
        pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
        context = context_lib.register_context_if_not_exists(
            mlmd_handle,
            context_type_name=_ORCHESTRATOR_RESERVED_ID,
            context_name=orchestrator_context_name(pipeline_uid))

        executions = mlmd_handle.store.get_executions_by_context(context.id)
        if any(e for e in executions if execution_lib.is_execution_active(e)):
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.ALREADY_EXISTS,
                message=f'Pipeline with uid {pipeline_uid} already active.')

        execution = execution_lib.prepare_execution(
            mlmd_handle,
            _ORCHESTRATOR_EXECUTION_TYPE,
            metadata_store_pb2.Execution.NEW,
            exec_properties={
                _PIPELINE_IR:
                base64.b64encode(pipeline.SerializeToString()).decode('utf-8')
            },
        )
        if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
            data_types_utils.set_metadata_value(
                execution.custom_properties[_PIPELINE_RUN_ID],
                pipeline.runtime_spec.pipeline_run_id.field_value.string_value)

        execution = execution_lib.put_execution(mlmd_handle, execution,
                                                [context])
        record_state_change_time()

        return cls(mlmd_handle=mlmd_handle,
                   pipeline=pipeline,
                   execution_id=execution.id)
Пример #2
0
def initiate_pipeline_start(
        mlmd_handle: metadata.Metadata,
        pipeline: pipeline_pb2.Pipeline) -> metadata_store_pb2.Execution:
    """Initiates a pipeline start operation.

  Upon success, MLMD is updated to signal that the given pipeline must be
  started.

  Args:
    mlmd_handle: A handle to the MLMD db.
    pipeline: IR of the pipeline to start.

  Returns:
    The pipeline-level MLMD execution proto upon success.

  Raises:
    status_lib.StatusNotOkError: Failure to initiate pipeline start or if
      execution is not inactive after waiting `timeout_secs`.
  """
    pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
    context = context_lib.register_context_if_not_exists(
        mlmd_handle,
        context_type_name=_ORCHESTRATOR_RESERVED_ID,
        context_name=_orchestrator_context_name(pipeline_uid))

    executions = mlmd_handle.store.get_executions_by_context(context.id)
    if any(e for e in executions if execution_lib.is_execution_active(e)):
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.ALREADY_EXISTS,
            message=f'Pipeline with uid {pipeline_uid} already started.')

    execution = execution_lib.prepare_execution(
        mlmd_handle,
        _ORCHESTRATOR_EXECUTION_TYPE,
        metadata_store_pb2.Execution.NEW,
        exec_properties={
            _PIPELINE_IR:
            base64.b64encode(pipeline.SerializeToString()).decode('utf-8')
        })
    execution = execution_lib.put_execution(mlmd_handle, execution, [context])
    logging.info('Registered execution (id: %s) for the pipeline with uid: %s',
                 execution.id, pipeline_uid)
    return execution
Пример #3
0
def _fix_deployment_config(
    input_pipeline: p_pb2.Pipeline,
    node_ids_to_keep: Collection[str]) -> Union[any_pb2.Any, None]:
  """Filter per-node deployment configs.

  Cast deployment configs from Any proto to IntermediateDeploymentConfig.
  Take all three per-node fields and filter out the nodes using
  node_ids_to_keep. This works because those fields don't contain references to
  other nodes.

  Args:
    input_pipeline: The input Pipeline IR proto.
    node_ids_to_keep: Set of node_ids to keep.

  Returns:
    If the deployment_config field is set in the input_pipeline, this would
    output the deployment config with filtered per-node configs, then cast into
    an Any proto. If the deployment_config field is unset in the input_pipeline,
    then this function would return None.
  """
  if not input_pipeline.HasField('deployment_config'):
    return None

  deployment_config = p_pb2.IntermediateDeploymentConfig()
  input_pipeline.deployment_config.Unpack(deployment_config)

  def _fix_per_node_config(config_map: MutableMapping[str, Any]):
    # We have to make two passes because we cannot modify the dictionary while
    # iterating over it.
    node_ids_to_delete = [
        node_id for node_id in config_map if node_id not in node_ids_to_keep
    ]
    for node_id_to_delete in node_ids_to_delete:
      del config_map[node_id_to_delete]

  _fix_per_node_config(deployment_config.executor_specs)
  _fix_per_node_config(deployment_config.custom_driver_specs)
  _fix_per_node_config(deployment_config.node_level_platform_configs)

  result = any_pb2.Any()
  result.Pack(deployment_config)
  return result