示例#1
0
 def testPutParentContextIfNotExists(self):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         parent_context = context_lib.register_context_if_not_exists(
             metadata_handler=m,
             context_type_name='my_context_type',
             context_name='parent_context_name')
         child_context = context_lib.register_context_if_not_exists(
             metadata_handler=m,
             context_type_name='my_context_type',
             context_name='child_context_name')
         context_lib.put_parent_context_if_not_exists(
             m, parent_id=parent_context.id, child_id=child_context.id)
         # Duplicated call should succeed.
         context_lib.put_parent_context_if_not_exists(
             m, parent_id=parent_context.id, child_id=child_context.id)
示例#2
0
  def testRegisterContextByTypeAndName(self):
    with metadata.Metadata(connection_config=self._connection_config) as m:
      context_lib.register_context_if_not_exists(
          metadata_handler=m,
          context_type_name='my_context_type',
          context_name='my_context')
      # Duplicated call should succeed.
      context = context_lib.register_context_if_not_exists(
          metadata_handler=m,
          context_type_name='my_context_type',
          context_name='my_context')

      self.assertProtoEquals(
          """
          id: 1
          name: 'my_context_type'
          """, m.store.get_context_type('my_context_type'))
      self.assertEqual(
          context,
          m.store.get_context_by_type_and_name('my_context_type', 'my_context'))
示例#3
0
    def new(cls, mlmd_handle: metadata.Metadata,
            pipeline: pipeline_pb2.Pipeline) -> 'PipelineState':
        """Creates a `PipelineState` object for a new pipeline.

    No active pipeline with the same pipeline uid should exist for the call to
    be successful.

    Args:
      mlmd_handle: A handle to the MLMD db.
      pipeline: IR of the pipeline.

    Returns:
      A `PipelineState` object.

    Raises:
      status_lib.StatusNotOkError: If a pipeline with same UID already exists.
    """
        pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
        context = context_lib.register_context_if_not_exists(
            mlmd_handle,
            context_type_name=_ORCHESTRATOR_RESERVED_ID,
            context_name=orchestrator_context_name(pipeline_uid))

        executions = mlmd_handle.store.get_executions_by_context(context.id)
        if any(e for e in executions if execution_lib.is_execution_active(e)):
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.ALREADY_EXISTS,
                message=f'Pipeline with uid {pipeline_uid} already active.')

        execution = execution_lib.prepare_execution(
            mlmd_handle,
            _ORCHESTRATOR_EXECUTION_TYPE,
            metadata_store_pb2.Execution.NEW,
            exec_properties={
                _PIPELINE_IR:
                base64.b64encode(pipeline.SerializeToString()).decode('utf-8')
            },
        )
        if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
            data_types_utils.set_metadata_value(
                execution.custom_properties[_PIPELINE_RUN_ID],
                pipeline.runtime_spec.pipeline_run_id.field_value.string_value)

        execution = execution_lib.put_execution(mlmd_handle, execution,
                                                [context])
        record_state_change_time()

        return cls(mlmd_handle=mlmd_handle,
                   pipeline=pipeline,
                   execution_id=execution.id)
示例#4
0
 def testGetCachedOutputArtifactsForNodesWithNoOuput(self):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         cache_context = context_lib.register_context_if_not_exists(
             m, context_lib.CONTEXT_TYPE_EXECUTION_CACHE, 'cache_key')
         cached_output = cache_utils.get_cached_outputs(m, cache_context)
         # No succeed execution is associate with this context yet, so the cached
         # output is None.
         self.assertIsNone(cached_output)
         execution_one = execution_publish_utils.register_execution(
             m, metadata_store_pb2.ExecutionType(name='my_type'),
             [cache_context])
         execution_publish_utils.publish_succeeded_execution(
             m, execution_one.id, [cache_context])
         cached_output = cache_utils.get_cached_outputs(m, cache_context)
         # A succeed execution is associate with this context, so the cached
         # output is not None but an empty dict.
         self.assertIsNotNone(cached_output)
         self.assertEmpty(cached_output)
示例#5
0
def initiate_pipeline_start(
        mlmd_handle: metadata.Metadata,
        pipeline: pipeline_pb2.Pipeline) -> metadata_store_pb2.Execution:
    """Initiates a pipeline start operation.

  Upon success, MLMD is updated to signal that the given pipeline must be
  started.

  Args:
    mlmd_handle: A handle to the MLMD db.
    pipeline: IR of the pipeline to start.

  Returns:
    The pipeline-level MLMD execution proto upon success.

  Raises:
    status_lib.StatusNotOkError: Failure to initiate pipeline start or if
      execution is not inactive after waiting `timeout_secs`.
  """
    pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
    context = context_lib.register_context_if_not_exists(
        mlmd_handle,
        context_type_name=_ORCHESTRATOR_RESERVED_ID,
        context_name=_orchestrator_context_name(pipeline_uid))

    executions = mlmd_handle.store.get_executions_by_context(context.id)
    if any(e for e in executions if execution_lib.is_execution_active(e)):
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.ALREADY_EXISTS,
            message=f'Pipeline with uid {pipeline_uid} already started.')

    execution = execution_lib.prepare_execution(
        mlmd_handle,
        _ORCHESTRATOR_EXECUTION_TYPE,
        metadata_store_pb2.Execution.NEW,
        exec_properties={
            _PIPELINE_IR:
            base64.b64encode(pipeline.SerializeToString()).decode('utf-8')
        })
    execution = execution_lib.put_execution(mlmd_handle, execution, [context])
    logging.info('Registered execution (id: %s) for the pipeline with uid: %s',
                 execution.id, pipeline_uid)
    return execution
示例#6
0
    def _get_pipeline_run_context(
            self,
            run_id: str,
            register_if_not_found: bool = False) -> metadata_store_pb2.Context:
        """Gets the pipeline_run_context for a given pipeline run id.

    When called, it will first attempt to get the pipeline run context from the
    in-memory cache. If not found there, it will raise LookupError unless
    `register_if_not_found` is set to True. If `register_if_not_found` is set to
    True, this method will register the pipeline_run_context in MLMD, add it to
    the in-memory cache, and return the pipeline_run_context.

    Args:
      run_id: The pipeline_run_id whose Context to query.
      register_if_not_found: If set to True, it will register the
        pipeline_run_id in MLMD if the pipeline_run_id cannot be found in MLMD.
        If set to False, it will raise LookupError.  Defaults to False.

    Returns:
      The requested pipeline run Context.

    Raises:
      LookupError: If register_if_not_found is not set to True, and the
        pipeline_run_id cannot be found in MLMD.
    """
        if run_id not in self._pipeline_run_contexts:
            if register_if_not_found:
                pipeline_run_context = context_lib.register_context_if_not_exists(
                    self._mlmd,
                    context_type_name=constants.PIPELINE_RUN_CONTEXT_TYPE_NAME,
                    context_name=run_id)
                self._pipeline_run_contexts[run_id] = pipeline_run_context
            else:
                raise LookupError(
                    f'pipeline_run_id {run_id} not found in MLMD.')
        return self._pipeline_run_contexts[run_id]
示例#7
0
 def testGetCachedOutputArtifacts(self):
     # Output artifacts that will be used by the first execution with the same
     # cache key.
     output_model_one = standard_artifacts.Model()
     output_model_one.uri = 'model_one'
     output_model_two = standard_artifacts.Model()
     output_model_two.uri = 'model_two'
     output_example_one = standard_artifacts.Examples()
     output_example_one.uri = 'example_one'
     # Output artifacts that will be used by the second execution with the same
     # cache key.
     output_model_three = standard_artifacts.Model()
     output_model_three.uri = 'model_three'
     output_model_four = standard_artifacts.Model()
     output_model_four.uri = 'model_four'
     output_example_two = standard_artifacts.Examples()
     output_example_two.uri = 'example_two'
     output_models_key = 'output_models'
     output_examples_key = 'output_examples'
     with metadata.Metadata(connection_config=self._connection_config) as m:
         cache_context = context_lib.register_context_if_not_exists(
             m, context_lib.CONTEXT_TYPE_EXECUTION_CACHE, 'cache_key')
         execution_one = execution_publish_utils.register_execution(
             m, metadata_store_pb2.ExecutionType(name='my_type'),
             [cache_context])
         execution_publish_utils.publish_succeeded_execution(
             m,
             execution_one.id, [cache_context],
             output_artifacts={
                 output_models_key: [output_model_one, output_model_two],
                 output_examples_key: [output_example_one]
             })
         execution_two = execution_publish_utils.register_execution(
             m, metadata_store_pb2.ExecutionType(name='my_type'),
             [cache_context])
         execution_publish_utils.publish_succeeded_execution(
             m,
             execution_two.id, [cache_context],
             output_artifacts={
                 output_models_key: [output_model_three, output_model_four],
                 output_examples_key: [output_example_two]
             })
         # The cached output got should be the artifacts produced by the most
         # recent execution under the given cache context.
         cached_output = cache_utils.get_cached_outputs(m, cache_context)
         self.assertLen(cached_output, 2)
         self.assertLen(cached_output[output_models_key], 2)
         self.assertLen(cached_output[output_examples_key], 1)
         self.assertProtoPartiallyEquals(
             cached_output[output_models_key][0].mlmd_artifact,
             output_model_three.mlmd_artifact,
             ignored_fields=[
                 'create_time_since_epoch', 'last_update_time_since_epoch'
             ])
         self.assertProtoPartiallyEquals(
             cached_output[output_models_key][1].mlmd_artifact,
             output_model_four.mlmd_artifact,
             ignored_fields=[
                 'create_time_since_epoch', 'last_update_time_since_epoch'
             ])
         self.assertProtoPartiallyEquals(
             cached_output[output_examples_key][0].mlmd_artifact,
             output_example_two.mlmd_artifact,
             ignored_fields=[
                 'create_time_since_epoch', 'last_update_time_since_epoch'
             ])
示例#8
0
def get_cache_context(
    metadata_handler: metadata.Metadata,
    pipeline_node: pipeline_pb2.PipelineNode,
    pipeline_info: pipeline_pb2.PipelineInfo,
    executor_spec: Optional[message.Message] = None,
    input_artifacts: Optional[Mapping[Text, Sequence[types.Artifact]]] = None,
    output_artifacts: Optional[Mapping[Text, Sequence[types.Artifact]]] = None,
    parameters: Optional[Mapping[Text,
                                 Any]] = None) -> metadata_store_pb2.Context:
    """Gets cache context for a potential node execution.

  The cache key is generated by applying SHA-256 hashing function on:
  - Serialized pipeline info.
  - Serialized node_info of the PipelineNode.
  - Serialized executor spec
  - Serialized input artifacts if any.
  - Serialized output artifacts if any. The uri was removed during the process.
  - Serialized parameters if any.
  - Serialized module file content if module file is present in parameters.

  Args:
    metadata_handler: A handler to access MLMD store.
    pipeline_node: A pipeline_pb2.PipelineNode instance to represent the node.
    pipeline_info: Information of the pipeline.
    executor_spec: A proto message representing the executor specification.
    input_artifacts: Input artifacts of the potential execution. The order of
      the artifacts under a key matters when calculating the cache key.
    output_artifacts: Output artifacts skeleton of the potential execution. The
      order of the artifadts under a key matters when calculating the cache key.
    parameters: Parameters of the potential execution.

  Returns:
    A metadata_store_pb2.Context for the cache key.
  """
    h = hashlib.sha256()
    h.update(pipeline_info.SerializeToString(deterministic=True))
    h.update(pipeline_node.node_info.SerializeToString(deterministic=True))
    if executor_spec:
        h.update(executor_spec.SerializeToString(deterministic=True))
    for key in sorted(input_artifacts or {}):
        h.update(key.encode())
        for artifact in input_artifacts[key]:
            h.update(
                artifact.mlmd_artifact.SerializeToString(deterministic=True))
    for key in sorted(output_artifacts or {}):
        h.update(key.encode())
        for artifact in output_artifacts[key]:
            stateless_artifact = copy.deepcopy(artifact)
            # Output uri and name should not be taken into consideration as cache key.
            stateless_artifact.uri = ''
            stateless_artifact.name = ''
            h.update(
                stateless_artifact.mlmd_artifact.SerializeToString(
                    deterministic=True))
    parameters = parameters or {}
    for key, value in sorted(parameters.items()):
        h.update(key.encode())
        h.update(str(value).encode())

    # Special treatment for module files as they will be used as part of the logic
    # for processing. Currently this pattern is employeed by Trainer and
    # Transform.
    if ('module_file' in parameters and parameters['module_file']
            and fileio.exists(parameters['module_file'])):
        with fileio.open(parameters['module_file'], 'r') as f:
            h.update(f.read().encode())

    return context_lib.register_context_if_not_exists(
        metadata_handler=metadata_handler,
        context_type_name=context_lib.CONTEXT_TYPE_EXECUTION_CACHE,
        context_name=h.hexdigest())
示例#9
0
    def new(
        cls,
        mlmd_handle: metadata.Metadata,
        pipeline: pipeline_pb2.Pipeline,
        pipeline_run_metadata: Optional[Mapping[str, types.Property]] = None,
    ) -> 'PipelineState':
        """Creates a `PipelineState` object for a new pipeline.

    No active pipeline with the same pipeline uid should exist for the call to
    be successful.

    Args:
      mlmd_handle: A handle to the MLMD db.
      pipeline: IR of the pipeline.
      pipeline_run_metadata: Pipeline run metadata.

    Returns:
      A `PipelineState` object.

    Raises:
      status_lib.StatusNotOkError: If a pipeline with same UID already exists.
    """
        pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
        context = context_lib.register_context_if_not_exists(
            mlmd_handle,
            context_type_name=_ORCHESTRATOR_RESERVED_ID,
            context_name=orchestrator_context_name(pipeline_uid))

        executions = mlmd_handle.store.get_executions_by_context(context.id)
        if any(e for e in executions if execution_lib.is_execution_active(e)):
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.ALREADY_EXISTS,
                message=f'Pipeline with uid {pipeline_uid} already active.')

        exec_properties = {_PIPELINE_IR: _base64_encode(pipeline)}
        if pipeline_run_metadata:
            exec_properties[_PIPELINE_RUN_METADATA] = json_utils.dumps(
                pipeline_run_metadata)

        execution = execution_lib.prepare_execution(
            mlmd_handle,
            _ORCHESTRATOR_EXECUTION_TYPE,
            metadata_store_pb2.Execution.NEW,
            exec_properties=exec_properties)
        if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
            data_types_utils.set_metadata_value(
                execution.custom_properties[_PIPELINE_RUN_ID],
                pipeline.runtime_spec.pipeline_run_id.field_value.string_value)
            # Set the node state to COMPLETE for any nodes that are marked to be
            # skipped in a partial pipeline run.
            node_states_dict = {}
            for node in get_all_pipeline_nodes(pipeline):
                if node.execution_options.HasField('skip'):
                    logging.info('Node %s is skipped in this partial run.',
                                 node.node_info.id)
                    node_states_dict[node.node_info.id] = NodeState(
                        state=NodeState.COMPLETE)
            if node_states_dict:
                _save_node_states_dict(execution, node_states_dict)

        execution = execution_lib.put_execution(mlmd_handle, execution,
                                                [context])
        record_state_change_time()

        return cls(mlmd_handle=mlmd_handle,
                   pipeline=pipeline,
                   execution_id=execution.id)