示例#1
0
 def testGetCachedOutputArtifactsForNodesWithNoOuput(self):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         cache_context = context_lib.register_context_if_not_exists(
             m, context_lib.CONTEXT_TYPE_EXECUTION_CACHE, 'cache_key')
         cached_output = cache_utils.get_cached_outputs(m, cache_context)
         # No succeed execution is associate with this context yet, so the cached
         # output is None.
         self.assertIsNone(cached_output)
         execution_one = execution_publish_utils.register_execution(
             m, metadata_store_pb2.ExecutionType(name='my_type'),
             [cache_context])
         execution_publish_utils.publish_succeeded_execution(
             m, execution_one.id, [cache_context])
         cached_output = cache_utils.get_cached_outputs(m, cache_context)
         # A succeed execution is associate with this context, so the cached
         # output is not None but an empty dict.
         self.assertIsNotNone(cached_output)
         self.assertEmpty(cached_output)
示例#2
0
def fake_cached_execution(mlmd_connection, cache_context, component):
    """Writes cached execution; MLMD must have previous execution associated with cache_context."""
    with mlmd_connection as m:
        cached_outputs = cache_utils.get_cached_outputs(
            m, cache_context=cache_context)
        contexts = context_lib.prepare_contexts(m, component.contexts)
        execution = execution_publish_utils.register_execution(
            m, component.node_info.type, contexts)
        execution_publish_utils.publish_cached_execution(
            m,
            contexts=contexts,
            execution_id=execution.id,
            output_artifacts=cached_outputs)
示例#3
0
    def _prepare_execution(self) -> _PrepareExecutionResult:
        """Prepares inputs, outputs and execution properties for actual execution."""
        # TODO(b/150979622): handle the edge case that the component get evicted
        # between successful pushlish and stateful working dir being clean up.
        # Otherwise following retries will keep failing because of duplicate
        # publishes.
        with self._mlmd_connection as m:
            # 1.Prepares all contexts.
            contexts = context_lib.register_contexts_if_not_exists(
                metadata_handler=m, node_contexts=self._pipeline_node.contexts)

            # 2. Resolves inputs an execution properties.
            exec_properties = inputs_utils.resolve_parameters(
                node_parameters=self._pipeline_node.parameters)
            input_artifacts = inputs_utils.resolve_input_artifacts(
                metadata_handler=m, node_inputs=self._pipeline_node.inputs)
            # 3. If not all required inputs are met. Return ExecutionInfo with
            # is_execution_needed being false. No publish will happen so down stream
            # nodes won't be triggered.
            if input_artifacts is None:
                return _PrepareExecutionResult(
                    execution_info=data_types.ExecutionInfo(),
                    contexts=contexts,
                    is_execution_needed=False)

            # 4. Registers execution in metadata.
            execution = execution_publish_utils.register_execution(
                metadata_handler=m,
                execution_type=self._pipeline_node.node_info.type,
                contexts=contexts,
                input_artifacts=input_artifacts,
                exec_properties=exec_properties)

            # 5. Resolve output
            output_artifacts = self._output_resolver.generate_output_artifacts(
                execution.id)

        # If there is a custom driver, runs it.
        if self._driver_operator:
            driver_output = self._driver_operator.run_driver(
                data_types.ExecutionInfo(
                    input_dict=input_artifacts,
                    output_dict=output_artifacts,
                    exec_properties=exec_properties,
                    execution_output_uri=self._output_resolver.
                    get_driver_output_uri()))
            self._update_with_driver_output(driver_output, exec_properties,
                                            output_artifacts)

        # We reconnect to MLMD here because the custom driver closes MLMD connection
        # on returning.
        with self._mlmd_connection as m:
            # 6. Check cached result
            cache_context = cache_utils.get_cache_context(
                metadata_handler=m,
                pipeline_node=self._pipeline_node,
                pipeline_info=self._pipeline_info,
                input_artifacts=input_artifacts,
                output_artifacts=output_artifacts,
                parameters=exec_properties)
            contexts.append(cache_context)
            cached_outputs = cache_utils.get_cached_outputs(
                metadata_handler=m, cache_context=cache_context)

            # 7. Should cache be used?
            if (self._pipeline_node.execution_options.caching_options.
                    enable_cache and cached_outputs):
                # Publishes cache result
                execution_publish_utils.publish_cached_execution(
                    metadata_handler=m,
                    contexts=contexts,
                    execution_id=execution.id,
                    output_artifacts=cached_outputs)
                return _PrepareExecutionResult(
                    execution_info=data_types.ExecutionInfo(
                        execution_id=execution.id),
                    execution_metadata=execution,
                    contexts=contexts,
                    is_execution_needed=False)

            pipeline_run_id = (self._pipeline_runtime_spec.pipeline_run_id.
                               field_value.string_value)

            # 8. Going to trigger executor.
            return _PrepareExecutionResult(
                execution_info=data_types.ExecutionInfo(
                    execution_id=execution.id,
                    input_dict=input_artifacts,
                    output_dict=output_artifacts,
                    exec_properties=exec_properties,
                    execution_output_uri=self._output_resolver.
                    get_executor_output_uri(execution.id),
                    stateful_working_dir=(self._output_resolver.
                                          get_stateful_working_directory()),
                    tmp_dir=self._output_resolver.make_tmp_dir(execution.id),
                    pipeline_node=self._pipeline_node,
                    pipeline_info=self._pipeline_info,
                    pipeline_run_id=pipeline_run_id),
                execution_metadata=execution,
                contexts=contexts,
                is_execution_needed=True)
示例#4
0
 def testGetCachedOutputArtifacts(self):
     # Output artifacts that will be used by the first execution with the same
     # cache key.
     output_model_one = standard_artifacts.Model()
     output_model_one.uri = 'model_one'
     output_model_two = standard_artifacts.Model()
     output_model_two.uri = 'model_two'
     output_example_one = standard_artifacts.Examples()
     output_example_one.uri = 'example_one'
     # Output artifacts that will be used by the second execution with the same
     # cache key.
     output_model_three = standard_artifacts.Model()
     output_model_three.uri = 'model_three'
     output_model_four = standard_artifacts.Model()
     output_model_four.uri = 'model_four'
     output_example_two = standard_artifacts.Examples()
     output_example_two.uri = 'example_two'
     output_models_key = 'output_models'
     output_examples_key = 'output_examples'
     with metadata.Metadata(connection_config=self._connection_config) as m:
         cache_context = context_lib.register_context_if_not_exists(
             m, context_lib.CONTEXT_TYPE_EXECUTION_CACHE, 'cache_key')
         execution_one = execution_publish_utils.register_execution(
             m, metadata_store_pb2.ExecutionType(name='my_type'),
             [cache_context])
         execution_publish_utils.publish_succeeded_execution(
             m,
             execution_one.id, [cache_context],
             output_artifacts={
                 output_models_key: [output_model_one, output_model_two],
                 output_examples_key: [output_example_one]
             })
         execution_two = execution_publish_utils.register_execution(
             m, metadata_store_pb2.ExecutionType(name='my_type'),
             [cache_context])
         execution_publish_utils.publish_succeeded_execution(
             m,
             execution_two.id, [cache_context],
             output_artifacts={
                 output_models_key: [output_model_three, output_model_four],
                 output_examples_key: [output_example_two]
             })
         # The cached output got should be the artifacts produced by the most
         # recent execution under the given cache context.
         cached_output = cache_utils.get_cached_outputs(m, cache_context)
         self.assertLen(cached_output, 2)
         self.assertLen(cached_output[output_models_key], 2)
         self.assertLen(cached_output[output_examples_key], 1)
         self.assertProtoPartiallyEquals(
             cached_output[output_models_key][0].mlmd_artifact,
             output_model_three.mlmd_artifact,
             ignored_fields=[
                 'create_time_since_epoch', 'last_update_time_since_epoch'
             ])
         self.assertProtoPartiallyEquals(
             cached_output[output_models_key][1].mlmd_artifact,
             output_model_four.mlmd_artifact,
             ignored_fields=[
                 'create_time_since_epoch', 'last_update_time_since_epoch'
             ])
         self.assertProtoPartiallyEquals(
             cached_output[output_examples_key][0].mlmd_artifact,
             output_example_two.mlmd_artifact,
             ignored_fields=[
                 'create_time_since_epoch', 'last_update_time_since_epoch'
             ])
示例#5
0
    def _maybe_generate_task(
        self,
        node: pipeline_pb2.PipelineNode,
        node_executions: Sequence[metadata_store_pb2.Execution],
        successful_node_ids: MutableSet[str],
    ) -> Optional[task_lib.Task]:
        """Generates a task to execute or `None` if no action is required.

    If node is executable, `ExecNodeTask` is returned.

    If node execution is infeasible due to unsatisfied preconditions such as
    missing inputs or service job failure, task to abort the pipeline is
    returned.

    If cache is enabled and previously computed outputs are found, those outputs
    are used to finish the execution. Since node execution can be elided, `None`
    is returned after adding the node_id to `successful_node_ids` set.

    Args:
      node: The pipeline node for which to generate a task.
      node_executions: Node executions fetched from MLMD.
      successful_node_ids: Set that tracks successful node ids.

    Returns:
      Returns an `ExecNodeTask` if node can be executed. If an error occurs,
      a `FinalizePipelineTask` is returned to abort the pipeline execution.
    """
        result = task_gen_utils.generate_task_from_active_execution(
            self._mlmd_handle, self._pipeline, node, node_executions)
        if result:
            return result

        node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline, node)
        resolved_info = task_gen_utils.generate_resolved_info(
            self._mlmd_handle, node)
        if resolved_info.input_artifacts is None:
            return self._abort_task(
                f'failure to resolve inputs; node uid {node_uid}')

        execution = execution_publish_utils.register_execution(
            metadata_handler=self._mlmd_handle,
            execution_type=node.node_info.type,
            contexts=resolved_info.contexts,
            input_artifacts=resolved_info.input_artifacts,
            exec_properties=resolved_info.exec_properties)
        outputs_resolver = outputs_utils.OutputsResolver(
            node, self._pipeline.pipeline_info, self._pipeline.runtime_spec,
            self._pipeline.execution_mode)
        output_artifacts = outputs_resolver.generate_output_artifacts(
            execution.id)

        # Check if we can elide node execution by reusing previously computed
        # outputs for the node.
        cache_context = cache_utils.get_cache_context(
            self._mlmd_handle,
            pipeline_node=node,
            pipeline_info=self._pipeline.pipeline_info,
            executor_spec=_get_executor_spec(self._pipeline,
                                             node.node_info.id),
            input_artifacts=resolved_info.input_artifacts,
            output_artifacts=output_artifacts,
            parameters=resolved_info.exec_properties)
        contexts = resolved_info.contexts + [cache_context]
        if node.execution_options.caching_options.enable_cache:
            cached_outputs = cache_utils.get_cached_outputs(
                self._mlmd_handle, cache_context=cache_context)
            if cached_outputs is not None:
                logging.info(
                    'Eliding node execution, using cached outputs; node uid: %s',
                    node_uid)
                execution_publish_utils.publish_cached_execution(
                    self._mlmd_handle,
                    contexts=contexts,
                    execution_id=execution.id,
                    output_artifacts=cached_outputs)
                successful_node_ids.add(node.node_info.id)
                pstate.record_state_change_time()
                return None

        # For mixed service nodes, we ensure node services and check service
        # status; pipeline is aborted if the service jobs have failed.
        service_status = self._ensure_node_services_if_mixed(node.node_info.id)
        if service_status == service_jobs.ServiceStatus.FAILED:
            return self._abort_task(
                f'associated service job failed; node uid: {node_uid}')

        return task_lib.ExecNodeTask(
            node_uid=node_uid,
            execution_id=execution.id,
            contexts=contexts,
            input_artifacts=resolved_info.input_artifacts,
            exec_properties=resolved_info.exec_properties,
            output_artifacts=output_artifacts,
            executor_output_uri=outputs_resolver.get_executor_output_uri(
                execution.id),
            stateful_working_dir=outputs_resolver.
            get_stateful_working_directory(execution.id),
            pipeline=self._pipeline)