def fakeUpstreamOutputs(mlmd_connection: metadata.Metadata, example_gen: pipeline_pb2.PipelineNode, transform: pipeline_pb2.PipelineNode): with mlmd_connection as m: if example_gen: # Publishes ExampleGen output. output_example = types.Artifact( example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example.uri = 'my_examples_uri' contexts = context_lib.register_contexts_if_not_exists( m, example_gen.contexts) execution = execution_publish_utils.register_execution( m, example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example], }) if transform: # Publishes Transform output. output_transform_graph = types.Artifact( transform.outputs.outputs['transform_graph'].artifact_spec. type) output_example.uri = 'my_transform_graph_uri' contexts = context_lib.register_contexts_if_not_exists( m, transform.contexts) execution = execution_publish_utils.register_execution( m, transform.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'transform_graph': [output_transform_graph], })
def testRegisterExecution(self): with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) input_example = standard_artifacts.Examples() execution_publish_utils.register_execution( m, self._execution_type, contexts, input_artifacts={'examples': [input_example]}, exec_properties={ 'p1': 1, }) [execution] = m.store.get_executions() self.assertProtoPartiallyEquals(""" id: 1 type_id: 3 custom_properties { key: 'p1' value {int_value: 1} } last_known_state: RUNNING """, execution, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) [event] = m.store.get_events_by_execution_ids([execution.id]) self.assertProtoPartiallyEquals( """ artifact_id: 1 execution_id: 1 path { steps { key: 'examples' } steps { index: 0 } } type: INPUT """, event, ignored_fields=['milliseconds_since_epoch']) # Verifies the context-execution edges are set up. self.assertCountEqual([c.id for c in contexts], [ c.id for c in m.store.get_contexts_by_execution(execution.id) ]) self.assertCountEqual([c.id for c in contexts], [ c.id for c in m.store.get_contexts_by_artifact(input_example.id) ])
def testPublishSuccessExecutionDropsEmptyResult(self): with metadata.Metadata(connection_config=self._connection_config) as m: executor_output = text_format.Parse( """ execution_result { code: 0 } """, execution_result_pb2.ExecutorOutput()) contexts = self._generate_contexts(m) execution_id = execution_publish_utils.register_execution( m, self._execution_type, contexts).id execution_publish_utils.publish_failed_execution(m, contexts, execution_id, executor_output) [execution] = m.store.get_executions_by_id([execution_id]) self.assertProtoPartiallyEquals( """ id: 1 last_known_state: FAILED """, execution, ignored_fields=[ 'type_id', 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def setUp(self): super().setUp() # Set up MLMD connection. pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) with self._mlmd_connection as m: self._execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=metadata_store_pb2.ExecutionType( name='test_execution_type'), contexts=[], input_artifacts=[]) # Set up gRPC stub. port = portpicker.pick_unused_port() self.sidecar = execution_watcher.ExecutionWatcher( port, mlmd_connection=self._mlmd_connection, execution=self._execution, creds=grpc.local_server_credentials()) self.sidecar.start() self.stub = execution_watcher.generate_service_stub( self.sidecar.address, grpc.local_channel_credentials())
def testPublishFailedExecution(self): with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) execution_id = execution_publish_utils.register_execution( m, self._execution_type, contexts).id execution_publish_utils.publish_failed_execution(m, contexts, execution_id) [execution] = m.store.get_executions_by_id([execution_id]) self.assertProtoPartiallyEquals( """ id: 1 type_id: 3 last_known_state: FAILED """, execution, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) # No events because there is no artifact published. events = m.store.get_events_by_execution_ids([execution.id]) self.assertEmpty(events) # Verifies the context-execution edges are set up. self.assertCountEqual( [c.id for c in contexts], [c.id for c in m.store.get_contexts_by_execution(execution.id)])
def _generate_task( self, metadata_handler: metadata.Metadata, node: pipeline_pb2.PipelineNode) -> Optional[task_pb2.Task]: """Generates a node execution task. If a node execution is not feasible, `None` is returned. Args: metadata_handler: A handler to access MLMD db. node: The pipeline node for which to generate a task. Returns: Returns a `Task` or `None` if task generation is deemed infeasible. """ if not task_gen_utils.is_feasible_node(node): return None executions = task_gen_utils.get_executions(metadata_handler, node) result = task_gen_utils.generate_task_from_active_execution( self._pipeline, node, executions) if result: return result resolved_info = task_gen_utils.generate_resolved_info( metadata_handler, node) if resolved_info.input_artifacts is None: logging.info( 'Task cannot be generated for node %s since no input artifacts ' 'are resolved.', node.node_info.id) return None # If the latest successful execution had the same resolved input artifacts, # the component should not be triggered, so task is not generated. # TODO(b/170231077): This logic should be handled by the resolver when it's # implemented. Also, currently only the artifact ids of previous execution # are checked to decide if a new execution is warranted but it may also be # necessary to factor in the difference of execution properties. latest_exec = task_gen_utils.get_latest_successful_execution( executions) if latest_exec: artifact_ids_by_event_type = ( execution_lib.get_artifact_ids_by_event_type_for_execution_id( metadata_handler, latest_exec.id)) latest_exec_input_artifact_ids = artifact_ids_by_event_type.get( metadata_store_pb2.Event.INPUT, set()) current_exec_input_artifact_ids = set( a.id for a in itertools.chain( *resolved_info.input_artifacts.values())) if latest_exec_input_artifact_ids == current_exec_input_artifact_ids: return None execution = execution_publish_utils.register_execution( metadata_handler=metadata_handler, execution_type=node.node_info.type, contexts=resolved_info.contexts, input_artifacts=resolved_info.input_artifacts, exec_properties=resolved_info.exec_properties) return task_gen_utils.create_task(self._pipeline, node, execution)
def fake_execute(self, metadata_handler, pipeline_node, input_map, output_map): contexts = context_lib.prepare_contexts(metadata_handler, pipeline_node.contexts) execution = execution_publish_utils.register_execution( metadata_handler, pipeline_node.node_info.type, contexts, input_map) return execution_publish_utils.publish_succeeded_execution( metadata_handler, execution.id, contexts, output_map)
def _generate_task( self, node: pipeline_pb2.PipelineNode, node_executions: Sequence[metadata_store_pb2.Execution] ) -> task_lib.Task: """Generates a node execution task. If node execution is not feasible, `None` is returned. Args: node: The pipeline node for which to generate a task. node_executions: Node executions fetched from MLMD. Returns: Returns an `ExecNodeTask` if node can be executed. If an error occurs, a `FinalizePipelineTask` is returned to abort the pipeline execution. """ result = task_gen_utils.generate_task_from_active_execution( self._mlmd_handle, self._pipeline, node, node_executions) if result: return result node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline, node) resolved_info = task_gen_utils.generate_resolved_info( self._mlmd_handle, node) if resolved_info.input_artifacts is None: return task_lib.FinalizePipelineTask( pipeline_uid=self._pipeline_state.pipeline_uid, status=status_lib.Status( code=status_lib.Code.ABORTED, message= (f'Aborting pipeline execution due to failure to resolve ' f'inputs; problematic node uid: {node_uid}'))) execution = execution_publish_utils.register_execution( metadata_handler=self._mlmd_handle, execution_type=node.node_info.type, contexts=resolved_info.contexts, input_artifacts=resolved_info.input_artifacts, exec_properties=resolved_info.exec_properties) outputs_resolver = outputs_utils.OutputsResolver( node, self._pipeline.pipeline_info, self._pipeline.runtime_spec, self._pipeline.execution_mode) return task_lib.ExecNodeTask( node_uid=node_uid, execution=execution, contexts=resolved_info.contexts, input_artifacts=resolved_info.input_artifacts, exec_properties=resolved_info.exec_properties, output_artifacts=outputs_resolver.generate_output_artifacts( execution.id), executor_output_uri=outputs_resolver.get_executor_output_uri( execution.id), stateful_working_dir=outputs_resolver. get_stateful_working_directory(execution.id), pipeline=self._pipeline)
def _generate_task( self, metadata_handler: metadata.Metadata, node: pipeline_pb2.PipelineNode) -> Optional[task_lib.Task]: """Generates a node execution task. If node execution is not feasible, `None` is returned. Args: metadata_handler: A handler to access MLMD db. node: The pipeline node for which to generate a task. Returns: Returns a `Task` or `None` if task generation is deemed infeasible. """ if not task_gen_utils.is_feasible_node(node): return None executions = task_gen_utils.get_executions(metadata_handler, node) result = task_gen_utils.generate_task_from_active_execution( metadata_handler, self._pipeline, node, executions) if result: return result resolved_info = task_gen_utils.generate_resolved_info( metadata_handler, node) if resolved_info.input_artifacts is None: # TODO(goutham): If the pipeline can't make progress, there should be a # standard mechanism to surface it to the user. logging.warning( 'Task cannot be generated for node %s since no input artifacts ' 'are resolved.', node.node_info.id) return None execution = execution_publish_utils.register_execution( metadata_handler=metadata_handler, execution_type=node.node_info.type, contexts=resolved_info.contexts, input_artifacts=resolved_info.input_artifacts, exec_properties=resolved_info.exec_properties) outputs_resolver = outputs_utils.OutputsResolver( node, self._pipeline.pipeline_info, self._pipeline.runtime_spec, self._pipeline.execution_mode) return task_lib.ExecNodeTask( node_uid=task_lib.NodeUid.from_pipeline_node(self._pipeline, node), execution=execution, contexts=resolved_info.contexts, input_artifacts=resolved_info.input_artifacts, exec_properties=resolved_info.exec_properties, output_artifacts=outputs_resolver.generate_output_artifacts( execution.id), executor_output_uri=outputs_resolver.get_executor_output_uri( execution.id), stateful_working_dir=outputs_resolver. get_stateful_working_directory(execution.id))
def fakeUpstreamOutputs(mlmd_connection: metadata.Metadata, example_gen: pipeline_pb2.PipelineNode, transform: pipeline_pb2.PipelineNode, cached: bool = False): publish_execution = ( execution_publish_utils.publish_cached_execution if cached else execution_publish_utils.publish_succeeded_execution) with mlmd_connection as m: if example_gen: # Publishes ExampleGen output. output_example = types.Artifact( example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example.uri = 'my_examples_uri' contexts = context_lib.prepare_contexts( m, example_gen.contexts) execution = execution_publish_utils.register_execution( m, example_gen.node_info.type, contexts) publish_execution(metadata_handler=m, execution_id=execution.id, contexts=contexts, output_artifacts={ 'output_examples': [output_example], }) if transform: # Publishes Transform output. output_transform_graph = types.Artifact( transform.outputs.outputs['transform_graph'].artifact_spec. type) output_transform_graph.uri = 'my_transform_graph_uri' contexts = context_lib.prepare_contexts(m, transform.contexts) execution = execution_publish_utils.register_execution( m, transform.node_info.type, contexts) publish_execution(metadata_handler=m, execution_id=execution.id, contexts=contexts, output_artifacts={ 'transform_graph': [output_transform_graph], })
def testPublishSuccessExecutionFailChangedType(self): with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) execution_id = execution_publish_utils.register_execution( m, self._execution_type, contexts).id executor_output = execution_result_pb2.ExecutorOutput() executor_output.output_artifacts['examples'].artifacts.add().type_id = 10 with self.assertRaisesRegex(RuntimeError, 'change artifact type'): execution_publish_utils.publish_succeeded_execution( m, execution_id, contexts, {'examples': [standard_artifacts.Examples(),]}, executor_output)
def fake_cached_execution(mlmd_connection, cache_context, component): """Writes cached execution; MLMD must have previous execution associated with cache_context.""" with mlmd_connection as m: cached_outputs = cache_utils.get_cached_outputs( m, cache_context=cache_context) contexts = context_lib.prepare_contexts(m, component.contexts) execution = execution_publish_utils.register_execution( m, component.node_info.type, contexts) execution_publish_utils.publish_cached_execution( m, contexts=contexts, execution_id=execution.id, output_artifacts=cached_outputs)
def _register_execution( self, metadata_handler: metadata.Metadata, contexts: List[metadata_store_pb2.Context], input_artifacts: MutableMapping[str, Sequence[types.Artifact]], exec_properties: Mapping[str, types.Property] ) -> metadata_store_pb2.Execution: """Registers an execution in MLMD.""" return execution_publish_utils.register_execution( metadata_handler=metadata_handler, execution_type=self._pipeline_node.node_info.type, contexts=contexts, input_artifacts=input_artifacts, exec_properties=exec_properties)
def run( self, mlmd_connection: metadata.Metadata, pipeline_node: pipeline_pb2.PipelineNode, pipeline_info: pipeline_pb2.PipelineInfo, pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec ) -> data_types.ExecutionInfo: """Runs Resolver specific logic. Args: mlmd_connection: ML metadata connection. pipeline_node: The specification of the node that this launcher lauches. pipeline_info: The information of the pipeline that this node runs in. pipeline_runtime_spec: The runtime information of the pipeline that this node runs in. Returns: The execution of the run. """ logging.info('Running as an resolver node.') with mlmd_connection as m: # 1.Prepares all contexts. contexts = context_lib.prepare_contexts( metadata_handler=m, node_contexts=pipeline_node.contexts) # 2. Resolves inputs an execution properties. exec_properties = inputs_utils.resolve_parameters( node_parameters=pipeline_node.parameters) input_artifacts = inputs_utils.resolve_input_artifacts( metadata_handler=m, node_inputs=pipeline_node.inputs) # 3. Registers execution in metadata. execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=pipeline_node.node_info.type, contexts=contexts, exec_properties=exec_properties) # 4. Publish the execution as a cached execution with # resolved input artifact as the output artifacts. execution_publish_utils.publish_internal_execution( metadata_handler=m, contexts=contexts, execution_id=execution.id, output_artifacts=input_artifacts) return data_types.ExecutionInfo(execution_id=execution.id, input_dict=input_artifacts, output_dict=input_artifacts, exec_properties=exec_properties, pipeline_node=pipeline_node, pipeline_info=pipeline_info)
def fake_transform_output(mlmd_connection, transform, execution=None): """Writes fake transform output and execution to MLMD.""" with mlmd_connection as m: output_transform_graph = types.Artifact( transform.outputs.outputs['transform_graph'].artifact_spec.type) output_transform_graph.uri = 'my_transform_graph_uri' contexts = context_lib.prepare_contexts(m, transform.contexts) if not execution: execution = execution_publish_utils.register_execution( m, transform.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'transform_graph': [output_transform_graph], })
def fake_trainer_output(mlmd_connection, trainer, execution=None): """Writes fake trainer output and execution to MLMD.""" with mlmd_connection as m: output_trainer_model = types.Artifact( trainer.outputs.outputs['model'].artifact_spec.type) output_trainer_model.uri = 'my_trainer_model_uri' contexts = context_lib.prepare_contexts(m, trainer.contexts) if not execution: execution = execution_publish_utils.register_execution( m, trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_trainer_model], })
def fake_example_gen_run_with_handle(mlmd_handle, example_gen, span, version): """Writes fake example_gen output and successful execution to MLMD.""" output_example = types.Artifact( example_gen.outputs.outputs['output_examples'].artifact_spec.type) output_example.set_int_custom_property('span', span) output_example.set_int_custom_property('version', version) output_example.uri = 'my_examples_uri' contexts = context_lib.prepare_contexts(mlmd_handle, example_gen.contexts) execution = execution_publish_utils.register_execution( mlmd_handle, example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( mlmd_handle, execution.id, contexts, { 'output_examples': [output_example], })
def fake_component_output_with_handle(mlmd_handle, component, execution=None, active=False): """Writes fake component output and execution to MLMD.""" output_key, output_value = next(iter(component.outputs.outputs.items())) output = types.Artifact(output_value.artifact_spec.type) output.uri = str(uuid.uuid4()) contexts = context_lib.prepare_contexts(mlmd_handle, component.contexts) if not execution: execution = execution_publish_utils.register_execution( mlmd_handle, component.node_info.type, contexts) if not active: execution_publish_utils.publish_succeeded_execution( mlmd_handle, execution.id, contexts, {output_key: [output]})
def fakeExampleGenOutput(mlmd_connection: metadata.Metadata, example_gen: pipeline_pb2.PipelineNode, span: int, version: int): with mlmd_connection as m: output_example = types.Artifact( example_gen.outputs.outputs['output_examples'].artifact_spec.type) output_example.set_int_custom_property('span', span) output_example.set_int_custom_property('version', version) output_example.uri = 'my_examples_uri' contexts = context_lib.prepare_contexts(m, example_gen.contexts) execution = execution_publish_utils.register_execution( m, example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example], })
def testPublishInternalExecution(self): with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) execution_id = execution_publish_utils.register_execution( m, self._execution_type, contexts).id output_example = standard_artifacts.Examples() execution_publish_utils.publish_internal_execution( m, contexts, execution_id, output_artifacts={'examples': [output_example]}) [execution] = m.store.get_executions() self.assertProtoPartiallyEquals(""" id: 1 type_id: 3 last_known_state: COMPLETE """, execution, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) [event] = m.store.get_events_by_execution_ids([execution.id]) self.assertProtoPartiallyEquals( """ artifact_id: 1 execution_id: 1 path { steps { key: 'examples' } steps { index: 0 } } type: INTERNAL_OUTPUT """, event, ignored_fields=['milliseconds_since_epoch']) # Verifies the context-execution edges are set up. self.assertCountEqual([c.id for c in contexts], [ c.id for c in m.store.get_contexts_by_execution(execution.id) ]) self.assertCountEqual([c.id for c in contexts], [ c.id for c in m.store.get_contexts_by_artifact(output_example.id) ])
def testResolverWithResolverPolicy(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join(self._testdata_dir, 'pipeline_for_input_resolver_test.pbtxt'), pipeline) my_example_gen = pipeline.nodes[0].pipeline_node my_transform = pipeline.nodes[2].pipeline_node connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() with metadata.Metadata(connection_config=connection_config) as m: # Publishes first ExampleGen with two output channels. `output_examples` # will be consumed by downstream Transform. output_example_1 = types.Artifact( my_example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example_1.uri = 'my_examples_uri_1' output_example_2 = types.Artifact( my_example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example_2.uri = 'my_examples_uri_2' contexts = context_lib.register_contexts_if_not_exists( m, my_example_gen.contexts) execution = execution_publish_utils.register_execution( m, my_example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example_1, output_example_2], }) my_transform.inputs.resolver_config.resolver_policy = ( pipeline_pb2.ResolverConfig.LATEST_ARTIFACT) # Gets inputs for transform. Should get back what the first ExampleGen # published in the `output_examples` channel. transform_inputs = inputs_utils.resolve_input_artifacts( m, my_transform.inputs) self.assertEqual(len(transform_inputs), 1) self.assertEqual(len(transform_inputs['examples']), 1) self.assertProtoPartiallyEquals( transform_inputs['examples'][0].mlmd_artifact, output_example_2.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def testPublishSuccessExecutionUpdatesCustomProperties(self): with metadata.Metadata(connection_config=self._connection_config) as m: executor_output = text_format.Parse( """ execution_properties { key: "int" value { int_value: 1 } } execution_properties { key: "string" value { string_value: "string_value" } } """, execution_result_pb2.ExecutorOutput()) contexts = self._generate_contexts(m) execution_id = execution_publish_utils.register_execution( m, self._execution_type, contexts).id execution_publish_utils.publish_succeeded_execution( m, execution_id, contexts, {}, executor_output) [execution] = m.store.get_executions_by_id([execution_id]) self.assertProtoPartiallyEquals(""" id: 1 last_known_state: COMPLETE custom_properties { key: "int" value { int_value: 1 } } custom_properties { key: "string" value { string_value: "string_value" } } """, execution, ignored_fields=[ 'type_id', 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def testGetCachedOutputArtifactsForNodesWithNoOuput(self): with metadata.Metadata(connection_config=self._connection_config) as m: cache_context = context_lib.register_context_if_not_exists( m, context_lib.CONTEXT_TYPE_EXECUTION_CACHE, 'cache_key') cached_output = cache_utils.get_cached_outputs(m, cache_context) # No succeed execution is associate with this context yet, so the cached # output is None. self.assertIsNone(cached_output) execution_one = execution_publish_utils.register_execution( m, metadata_store_pb2.ExecutionType(name='my_type'), [cache_context]) execution_publish_utils.publish_succeeded_execution( m, execution_one.id, [cache_context]) cached_output = cache_utils.get_cached_outputs(m, cache_context) # A succeed execution is associate with this context, so the cached # output is not None but an empty dict. self.assertIsNotNone(cached_output) self.assertEmpty(cached_output)
def testPublishSuccessExecutionFailChangedUriDir(self): output_example = standard_artifacts.Examples() output_example.uri = '/my/original_uri' output_dict = {'examples': [output_example]} with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) execution_id = execution_publish_utils.register_execution( m, self._execution_type, contexts).id executor_output = execution_result_pb2.ExecutorOutput() new_example = executor_output.output_artifacts[ 'examples'].artifacts.add() new_example.uri = '/my/new_uri/1' with self.assertRaisesRegex( RuntimeError, 'When there is one artifact to publish, the URI of it should be ' 'identical to the URI of system generated artifact.'): execution_publish_utils.publish_succeeded_execution( m, execution_id, contexts, output_dict, executor_output)
def testPublishSuccessExecutionFailTooManyLayerOfSubDir(self): output_example = standard_artifacts.Examples() output_example.uri = '/my/original_uri' output_dict = {'examples': [output_example]} with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) execution_id = execution_publish_utils.register_execution( m, self._execution_type, contexts).id executor_output = execution_result_pb2.ExecutorOutput() new_example = executor_output.output_artifacts[ 'examples'].artifacts.add() new_example.uri = '/my/original_uri/1/1' with self.assertRaisesRegex( RuntimeError, 'The URI of executor generated artifacts should either be identical ' 'to the URI of system generated artifact or be a direct sub-dir of ' 'it.'): execution_publish_utils.publish_succeeded_execution( m, execution_id, contexts, output_dict, executor_output)
def _register_execution( metadata_handler: metadata.Metadata, execution_type: metadata_store_pb2.ExecutionType, contexts: List[metadata_store_pb2.Context], input_artifacts: MutableMapping[str, Sequence[types.Artifact]], exec_properties: Mapping[str, types.Property]) -> metadata_store_pb2.Execution: """Registers an execution in MLMD.""" kfp_pod_name = os.environ.get(_KFP_POD_NAME_ENV_KEY) execution_properties_copy = copy.deepcopy(exec_properties) execution_properties_copy = cast(MutableMapping[str, types.Property], execution_properties_copy) if kfp_pod_name: logging.info('Adding KFP pod name %s to execution', kfp_pod_name) execution_properties_copy[_KFP_POD_NAME_PROPERTY_KEY] = kfp_pod_name return execution_publish_utils.register_execution( metadata_handler=metadata_handler, execution_type=execution_type, contexts=contexts, input_artifacts=input_artifacts, exec_properties=execution_properties_copy)
def _generate_task( self, metadata_handler: metadata.Metadata, node: pipeline_pb2.PipelineNode) -> Optional[task_pb2.Task]: """Generates a node execution task. If node execution is not feasible, `None` is returned. Args: metadata_handler: A handler to access MLMD db. node: The pipeline node for which to generate a task. Returns: Returns a `Task` or `None` if task generation is deemed infeasible. """ if not task_gen_utils.is_feasible_node(node): return None executions = task_gen_utils.get_executions(metadata_handler, node) task = task_gen_utils.generate_task_from_active_execution( self._pipeline, node, executions) if task: return task resolved_info = task_gen_utils.generate_resolved_info( metadata_handler, node) if resolved_info.input_artifacts is None: # TODO(goutham): If the pipeline can't make progress, there should be a # standard mechanism to surface it to the user. logging.warning( 'Task cannot be generated for node %s since no input artifacts ' 'are resolved.', node.node_info.id) return None execution = execution_publish_utils.register_execution( metadata_handler=metadata_handler, execution_type=node.node_info.type, contexts=resolved_info.contexts, input_artifacts=resolved_info.input_artifacts, exec_properties=resolved_info.exec_properties) return task_gen_utils.create_task(self._pipeline, node, execution)
def testPublishFailedExecution(self): with metadata.Metadata(connection_config=self._connection_config) as m: executor_output = text_format.Parse( """ execution_result { code: 1 result_message: 'error message.' } """, execution_result_pb2.ExecutorOutput()) contexts = self._generate_contexts(m) execution_id = execution_publish_utils.register_execution( m, self._execution_type, contexts).id execution_publish_utils.publish_failed_execution(m, contexts, execution_id, executor_output) [execution] = m.store.get_executions_by_id([execution_id]) self.assertProtoPartiallyEquals( """ id: 1 last_known_state: FAILED custom_properties { key: '__execution_result__' value { string_value: '{\\n "resultMessage": "error message.",\\n "code": 1\\n}' } } """, execution, ignored_fields=[ 'type_id', 'create_time_since_epoch', 'last_update_time_since_epoch' ]) # No events because there is no artifact published. events = m.store.get_events_by_execution_ids([execution.id]) self.assertEmpty(events) # Verifies the context-execution edges are set up. self.assertCountEqual( [c.id for c in contexts], [c.id for c in m.store.get_contexts_by_execution(execution.id)])
def testResolveInputArtifactsOutputKeyUnset(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join( self._testdata_dir, 'pipeline_for_input_resolver_test_output_key_unset.pbtxt'), pipeline) my_trainer = pipeline.nodes[0].pipeline_node my_pusher = pipeline.nodes[1].pipeline_node connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() with metadata.Metadata(connection_config=connection_config) as m: # Publishes Trainer with one output channels. `output_model` # will be consumed by the Pusher in the different run. output_model = types.Artifact( my_trainer.outputs.outputs['model'].artifact_spec.type) output_model.uri = 'my_output_model_uri' contexts = context_lib.register_contexts_if_not_exists( m, my_trainer.contexts) execution = execution_publish_utils.register_execution( m, my_trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_model], }) # Gets inputs for pusher. Should get back what the first Model # published in the `output_model` channel. pusher_inputs = inputs_utils.resolve_input_artifacts( m, my_pusher.inputs) self.assertEqual(len(pusher_inputs), 1) self.assertEqual(len(pusher_inputs['model']), 1) self.assertProtoPartiallyEquals( output_model.mlmd_artifact, pusher_inputs['model'][0].mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def testPublishSuccessExecutionFailInvalidUri(self, invalid_uri): output_example = standard_artifacts.Examples() output_example.uri = '/my/original_uri' output_dict = {'examples': [output_example]} with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) execution_id = execution_publish_utils.register_execution( m, self._execution_type, contexts).id executor_output = execution_result_pb2.ExecutorOutput() system_generated_artifact = executor_output.output_artifacts[ 'examples'].artifacts.add() system_generated_artifact.uri = '/my/original_uri/0' new_artifact = executor_output.output_artifacts[ 'examples'].artifacts.add() new_artifact.uri = invalid_uri with self.assertRaisesRegex( RuntimeError, 'When there are multiple artifacts to publish, their URIs should be ' 'direct sub-directories of the URI of the system generated artifact.' ): execution_publish_utils.publish_succeeded_execution( m, execution_id, contexts, output_dict, executor_output)