示例#1
0
def fake_cached_execution(mlmd_connection, cache_context, component):
    """Writes cached execution; MLMD must have previous execution associated with cache_context."""
    with mlmd_connection as m:
        cached_outputs = cache_utils.get_cached_outputs(
            m, cache_context=cache_context)
        contexts = context_lib.prepare_contexts(m, component.contexts)
        execution = execution_publish_utils.register_execution(
            m, component.node_info.type, contexts)
        execution_publish_utils.publish_cached_execution(
            m,
            contexts=contexts,
            execution_id=execution.id,
            output_artifacts=cached_outputs)
示例#2
0
 def testPublishCachedExecution(self):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         contexts = self._generate_contexts(m)
         execution_id = execution_publish_utils.register_execution(
             m, self._execution_type, contexts).id
         output_example = standard_artifacts.Examples()
         execution_publish_utils.publish_cached_execution(
             m,
             contexts,
             execution_id,
             output_artifacts={'examples': [output_example]})
         [execution] = m.store.get_executions()
         self.assertProtoPartiallyEquals("""
       id: 1
       type_id: 3
       last_known_state: CACHED
       """,
                                         execution,
                                         ignored_fields=[
                                             'create_time_since_epoch',
                                             'last_update_time_since_epoch'
                                         ])
         [event] = m.store.get_events_by_execution_ids([execution.id])
         self.assertProtoPartiallyEquals(
             """
       artifact_id: 1
       execution_id: 1
       path {
         steps {
           key: 'examples'
         }
         steps {
           index: 0
         }
       }
       type: OUTPUT
       """,
             event,
             ignored_fields=['milliseconds_since_epoch'])
         # Verifies the context-execution edges are set up.
         self.assertCountEqual([c.id for c in contexts], [
             c.id for c in m.store.get_contexts_by_execution(execution.id)
         ])
         self.assertCountEqual([c.id for c in contexts], [
             c.id
             for c in m.store.get_contexts_by_artifact(output_example.id)
         ])
示例#3
0
    def _cache_and_publish(self,
                           existing_execution: metadata_store_pb2.Execution):
        """Updates MLMD."""
        cached_execution_contexts = self._get_cached_execution_contexts(
            existing_execution)
        # Check if there are any previous attempts to cache and publish.
        prev_cache_executions = (
            execution_lib.get_executions_associated_with_all_contexts(
                self._mlmd, contexts=cached_execution_contexts))
        if not prev_cache_executions:
            new_execution = execution_publish_utils.register_execution(
                self._mlmd,
                execution_type=metadata_store_pb2.ExecutionType(
                    id=existing_execution.type_id),
                contexts=cached_execution_contexts)
        else:
            if len(prev_cache_executions) > 1:
                logging.warning(
                    'More than one previous cache executions seen when attempting '
                    'reuse_node_outputs: %s', prev_cache_executions)

            if (prev_cache_executions[-1].last_known_state ==
                    metadata_store_pb2.Execution.CACHED):
                return
            else:
                new_execution = prev_cache_executions[-1]

        output_artifacts = execution_lib.get_artifacts_dict(
            self._mlmd,
            existing_execution.id,
            event_types=list(event_lib.VALID_OUTPUT_EVENT_TYPES))

        execution_publish_utils.publish_cached_execution(
            self._mlmd,
            contexts=cached_execution_contexts,
            execution_id=new_execution.id,
            output_artifacts=output_artifacts)
示例#4
0
    def _prepare_execution(self) -> _PrepareExecutionResult:
        """Prepares inputs, outputs and execution properties for actual execution."""
        # TODO(b/150979622): handle the edge case that the component get evicted
        # between successful pushlish and stateful working dir being clean up.
        # Otherwise following retries will keep failing because of duplicate
        # publishes.
        with self._mlmd_connection as m:
            # 1.Prepares all contexts.
            contexts = context_lib.register_contexts_if_not_exists(
                metadata_handler=m, node_contexts=self._pipeline_node.contexts)

            # 2. Resolves inputs an execution properties.
            exec_properties = inputs_utils.resolve_parameters(
                node_parameters=self._pipeline_node.parameters)
            input_artifacts = inputs_utils.resolve_input_artifacts(
                metadata_handler=m, node_inputs=self._pipeline_node.inputs)
            # 3. If not all required inputs are met. Return ExecutionInfo with
            # is_execution_needed being false. No publish will happen so down stream
            # nodes won't be triggered.
            if input_artifacts is None:
                return _PrepareExecutionResult(
                    execution_info=data_types.ExecutionInfo(),
                    contexts=contexts,
                    is_execution_needed=False)

            # 4. Registers execution in metadata.
            execution = execution_publish_utils.register_execution(
                metadata_handler=m,
                execution_type=self._pipeline_node.node_info.type,
                contexts=contexts,
                input_artifacts=input_artifacts,
                exec_properties=exec_properties)

            # 5. Resolve output
            output_artifacts = self._output_resolver.generate_output_artifacts(
                execution.id)

        # If there is a custom driver, runs it.
        if self._driver_operator:
            driver_output = self._driver_operator.run_driver(
                data_types.ExecutionInfo(
                    input_dict=input_artifacts,
                    output_dict=output_artifacts,
                    exec_properties=exec_properties,
                    execution_output_uri=self._output_resolver.
                    get_driver_output_uri()))
            self._update_with_driver_output(driver_output, exec_properties,
                                            output_artifacts)

        # We reconnect to MLMD here because the custom driver closes MLMD connection
        # on returning.
        with self._mlmd_connection as m:
            # 6. Check cached result
            cache_context = cache_utils.get_cache_context(
                metadata_handler=m,
                pipeline_node=self._pipeline_node,
                pipeline_info=self._pipeline_info,
                input_artifacts=input_artifacts,
                output_artifacts=output_artifacts,
                parameters=exec_properties)
            contexts.append(cache_context)
            cached_outputs = cache_utils.get_cached_outputs(
                metadata_handler=m, cache_context=cache_context)

            # 7. Should cache be used?
            if (self._pipeline_node.execution_options.caching_options.
                    enable_cache and cached_outputs):
                # Publishes cache result
                execution_publish_utils.publish_cached_execution(
                    metadata_handler=m,
                    contexts=contexts,
                    execution_id=execution.id,
                    output_artifacts=cached_outputs)
                return _PrepareExecutionResult(
                    execution_info=data_types.ExecutionInfo(
                        execution_id=execution.id),
                    execution_metadata=execution,
                    contexts=contexts,
                    is_execution_needed=False)

            pipeline_run_id = (self._pipeline_runtime_spec.pipeline_run_id.
                               field_value.string_value)

            # 8. Going to trigger executor.
            return _PrepareExecutionResult(
                execution_info=data_types.ExecutionInfo(
                    execution_id=execution.id,
                    input_dict=input_artifacts,
                    output_dict=output_artifacts,
                    exec_properties=exec_properties,
                    execution_output_uri=self._output_resolver.
                    get_executor_output_uri(execution.id),
                    stateful_working_dir=(self._output_resolver.
                                          get_stateful_working_directory()),
                    tmp_dir=self._output_resolver.make_tmp_dir(execution.id),
                    pipeline_node=self._pipeline_node,
                    pipeline_info=self._pipeline_info,
                    pipeline_run_id=pipeline_run_id),
                execution_metadata=execution,
                contexts=contexts,
                is_execution_needed=True)
示例#5
0
    def _maybe_generate_task(
        self,
        node: pipeline_pb2.PipelineNode,
        node_executions: Sequence[metadata_store_pb2.Execution],
        successful_node_ids: MutableSet[str],
    ) -> Optional[task_lib.Task]:
        """Generates a task to execute or `None` if no action is required.

    If node is executable, `ExecNodeTask` is returned.

    If node execution is infeasible due to unsatisfied preconditions such as
    missing inputs or service job failure, task to abort the pipeline is
    returned.

    If cache is enabled and previously computed outputs are found, those outputs
    are used to finish the execution. Since node execution can be elided, `None`
    is returned after adding the node_id to `successful_node_ids` set.

    Args:
      node: The pipeline node for which to generate a task.
      node_executions: Node executions fetched from MLMD.
      successful_node_ids: Set that tracks successful node ids.

    Returns:
      Returns an `ExecNodeTask` if node can be executed. If an error occurs,
      a `FinalizePipelineTask` is returned to abort the pipeline execution.
    """
        result = task_gen_utils.generate_task_from_active_execution(
            self._mlmd_handle, self._pipeline, node, node_executions)
        if result:
            return result

        node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline, node)
        resolved_info = task_gen_utils.generate_resolved_info(
            self._mlmd_handle, node)
        if resolved_info.input_artifacts is None:
            return self._abort_task(
                f'failure to resolve inputs; node uid {node_uid}')

        execution = execution_publish_utils.register_execution(
            metadata_handler=self._mlmd_handle,
            execution_type=node.node_info.type,
            contexts=resolved_info.contexts,
            input_artifacts=resolved_info.input_artifacts,
            exec_properties=resolved_info.exec_properties)
        outputs_resolver = outputs_utils.OutputsResolver(
            node, self._pipeline.pipeline_info, self._pipeline.runtime_spec,
            self._pipeline.execution_mode)
        output_artifacts = outputs_resolver.generate_output_artifacts(
            execution.id)

        # Check if we can elide node execution by reusing previously computed
        # outputs for the node.
        cache_context = cache_utils.get_cache_context(
            self._mlmd_handle,
            pipeline_node=node,
            pipeline_info=self._pipeline.pipeline_info,
            executor_spec=_get_executor_spec(self._pipeline,
                                             node.node_info.id),
            input_artifacts=resolved_info.input_artifacts,
            output_artifacts=output_artifacts,
            parameters=resolved_info.exec_properties)
        contexts = resolved_info.contexts + [cache_context]
        if node.execution_options.caching_options.enable_cache:
            cached_outputs = cache_utils.get_cached_outputs(
                self._mlmd_handle, cache_context=cache_context)
            if cached_outputs is not None:
                logging.info(
                    'Eliding node execution, using cached outputs; node uid: %s',
                    node_uid)
                execution_publish_utils.publish_cached_execution(
                    self._mlmd_handle,
                    contexts=contexts,
                    execution_id=execution.id,
                    output_artifacts=cached_outputs)
                successful_node_ids.add(node.node_info.id)
                pstate.record_state_change_time()
                return None

        # For mixed service nodes, we ensure node services and check service
        # status; pipeline is aborted if the service jobs have failed.
        service_status = self._ensure_node_services_if_mixed(node.node_info.id)
        if service_status == service_jobs.ServiceStatus.FAILED:
            return self._abort_task(
                f'associated service job failed; node uid: {node_uid}')

        return task_lib.ExecNodeTask(
            node_uid=node_uid,
            execution_id=execution.id,
            contexts=contexts,
            input_artifacts=resolved_info.input_artifacts,
            exec_properties=resolved_info.exec_properties,
            output_artifacts=output_artifacts,
            executor_output_uri=outputs_resolver.get_executor_output_uri(
                execution.id),
            stateful_working_dir=outputs_resolver.
            get_stateful_working_directory(execution.id),
            pipeline=self._pipeline)
示例#6
0
    def run(
        self, mlmd_connection: metadata.Metadata,
        pipeline_node: pipeline_pb2.PipelineNode,
        pipeline_info: pipeline_pb2.PipelineInfo,
        pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec
    ) -> data_types.ExecutionInfo:
        """Runs Importer specific logic.

    Args:
      mlmd_connection: ML metadata connection.
      pipeline_node: The specification of the node that this launcher lauches.
      pipeline_info: The information of the pipeline that this node runs in.
      pipeline_runtime_spec: The runtime information of the pipeline that this
        node runs in.

    Returns:
      The execution of the run.
    """
        logging.info('Running as an importer node.')
        with mlmd_connection as m:
            # 1.Prepares all contexts.
            contexts = context_lib.prepare_contexts(
                metadata_handler=m, node_contexts=pipeline_node.contexts)

            # 2. Resolves execution properties, please note that importers has no
            # input.
            exec_properties = data_types_utils.build_parsed_value_dict(
                inputs_utils.resolve_parameters_with_schema(
                    node_parameters=pipeline_node.parameters))

            # 3. Registers execution in metadata.
            execution = execution_publish_utils.register_execution(
                metadata_handler=m,
                execution_type=pipeline_node.node_info.type,
                contexts=contexts,
                exec_properties=exec_properties)

            # 4. Generate output artifacts to represent the imported artifacts.
            output_spec = pipeline_node.outputs.outputs[
                importer.IMPORT_RESULT_KEY]
            properties = self._extract_proto_map(
                output_spec.artifact_spec.additional_properties)
            custom_properties = self._extract_proto_map(
                output_spec.artifact_spec.additional_custom_properties)
            output_artifact_class = types.Artifact(
                output_spec.artifact_spec.type).type
            output_artifacts = importer.generate_output_dict(
                metadata_handler=m,
                uri=str(exec_properties[importer.SOURCE_URI_KEY]),
                properties=properties,
                custom_properties=custom_properties,
                reimport=bool(exec_properties[importer.REIMPORT_OPTION_KEY]),
                output_artifact_class=output_artifact_class,
                mlmd_artifact_type=output_spec.artifact_spec.type)

            result = data_types.ExecutionInfo(execution_id=execution.id,
                                              input_dict={},
                                              output_dict=output_artifacts,
                                              exec_properties=exec_properties,
                                              pipeline_node=pipeline_node,
                                              pipeline_info=pipeline_info)

            # TODO(b/182316162): consider let the launcher level do the publish
            # for system nodes. So that the version taging logic doesn't need to be
            # handled per system node.
            outputs_utils.tag_output_artifacts_with_version(result.output_dict)

            # 5. Publish the output artifacts. If artifacts are reimported, the
            # execution is published as CACHED. Otherwise it is published as COMPLETE.
            if _is_artifact_reimported(output_artifacts):
                execution_publish_utils.publish_cached_execution(
                    metadata_handler=m,
                    contexts=contexts,
                    execution_id=execution.id,
                    output_artifacts=output_artifacts)

            else:
                execution_publish_utils.publish_succeeded_execution(
                    metadata_handler=m,
                    execution_id=execution.id,
                    contexts=contexts,
                    output_artifacts=output_artifacts)

            return result