def testResolveInputsArtifacts(self): pipeline = self.load_pipeline_proto( 'pipeline_for_input_resolver_test.pbtxt') my_example_gen = pipeline.nodes[0].pipeline_node another_example_gen = pipeline.nodes[1].pipeline_node my_transform = pipeline.nodes[2].pipeline_node my_trainer = pipeline.nodes[3].pipeline_node with self.get_metadata() as m: # Publishes first ExampleGen with two output channels. `output_examples` # will be consumed by downstream Transform. output_example = self.make_examples(uri='my_examples_uri') side_examples = self.make_examples(uri='side_examples_uri') output_artifacts = self.fake_execute(m, my_example_gen, input_map=None, output_map={ 'output_examples': [output_example], 'another_examples': [side_examples] }) output_example = output_artifacts['output_examples'][0] side_examples = output_artifacts['another_examples'][0] # Publishes second ExampleGen with one output channel with the same output # key as the first ExampleGen. However this is not consumed by downstream # nodes. another_output_example = self.make_examples( uri='another_examples_uri') self.fake_execute( m, another_example_gen, input_map=None, output_map={'output_examples': [another_output_example]}) # Gets inputs for transform. Should get back what the first ExampleGen # published in the `output_examples` channel. transform_inputs = inputs_utils.resolve_input_artifacts( m, my_transform.inputs) self.assertArtifactMapEqual({'examples': [output_example]}, transform_inputs) # Tries to resolve inputs for trainer. As trainer also requires min_count # for both input channels (from example_gen and from transform) but we did # not publish anything from transform, it should return nothing. self.assertIsNone( inputs_utils.resolve_input_artifacts(m, my_trainer.inputs)) # Tries to resolve inputs for transform after adding a new context query # to the input spec that refers to a non-existent context. Inputs cannot # be resolved in this case. context_query = my_transform.inputs.inputs['examples'].channels[ 0].context_queries.add() context_query.type.name = 'non_existent_context' context_query.name.field_value.string_value = 'non_existent_context' transform_inputs = inputs_utils.resolve_input_artifacts( m, my_transform.inputs) self.assertIsNone(transform_inputs)
def testResolveInputArtifacts_NonDictArg(self): self._setup_pipeline_for_input_resolver_test() self._append_resolver_step(self._my_transform, DuplicateOp) self._append_resolver_step(self._my_transform, IdentityStrategy) with self.assertRaisesRegex(TypeError, 'Invalid argument type'): inputs_utils.resolve_input_artifacts(self._metadata_handler, self._my_transform.inputs)
def testResolveInputArtifacts_NonDictOutput(self): self._setup_pipeline_for_input_resolver_test() self._append_resolver_step(self._my_transform, BadOutputStrategy) with self.assertRaisesRegex(TypeError, 'Invalid input resolution result'): inputs_utils.resolve_input_artifacts(self._metadata_handler, self._my_transform.inputs)
def testResolverWithLatestArtifactStrategy(self): pipeline = self.load_pipeline_proto( 'pipeline_for_input_resolver_test.pbtxt') my_example_gen = pipeline.nodes[0].pipeline_node my_transform = pipeline.nodes[2].pipeline_node with self.get_metadata() as m: # Publishes first ExampleGen with two output channels. `output_examples` # will be consumed by downstream Transform. output_example_1 = self.make_examples(uri='my_examples_uri_1') output_example_2 = self.make_examples(uri='my_examples_uri_2') output_artifacts = self.fake_execute( m, my_example_gen, input_map=None, output_map={ 'output_examples': [output_example_1, output_example_2] }) output_example_1 = output_artifacts['output_examples'][0] output_example_2 = output_artifacts['output_examples'][1] transform_resolver = ( my_transform.inputs.resolver_config.resolver_steps.add()) transform_resolver.class_path = ( 'tfx.dsl.input_resolution.strategies.latest_artifact_strategy' '.LatestArtifactStrategy') transform_resolver.config_json = '{}' # Gets inputs for transform. Should get back what the first ExampleGen # published in the `output_examples` channel. transform_inputs = inputs_utils.resolve_input_artifacts( m, my_transform.inputs) self.assertArtifactMapEqual({'examples': [output_example_2]}, transform_inputs)
def testResolveInputArtifacts_SkippingStrategy(self): self._setup_pipeline_for_input_resolver_test() self._append_resolver_step(self._my_transform, SkippingStrategy) result = inputs_utils.resolve_input_artifacts( self._metadata_handler, self._my_transform.inputs) self.assertIsNone(result)
def run(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> driver_output_pb2.DriverOutput: # Fake a constant span number, which, on prod, is usually calculated based # on date. span = 2 with self._mlmd_connection as m: previous_output = inputs_utils.resolve_input_artifacts( m, self._self_output) # Version should be the max of existing version + 1 if span exists, # otherwise 0. version = 0 if previous_output: version = max([ artifact.get_int_custom_property('version') for artifact in previous_output['examples'] if artifact.get_int_custom_property('span') == span ] or [-1]) + 1 output_example = copy.deepcopy( output_dict['output_examples'][0].mlmd_artifact) output_example.custom_properties['span'].int_value = span output_example.custom_properties['version'].int_value = version result = driver_output_pb2.DriverOutput() result.output_artifacts['output_examples'].artifacts.append( output_example) result.exec_properties['span'].int_value = span result.exec_properties['version'].int_value = version return result
def generate_resolved_info(metadata_handler: metadata.Metadata, node: pipeline_pb2.PipelineNode) -> ResolvedInfo: """Returns a `ResolvedInfo` object for executing the node. Args: metadata_handler: A handler to access MLMD db. node: The pipeline node for which to generate. Returns: A `ResolvedInfo` with input resolutions. """ # Register node contexts. contexts = context_lib.register_contexts_if_not_exists( metadata_handler=metadata_handler, node_contexts=node.contexts) # Resolve execution properties. exec_properties = inputs_utils.resolve_parameters( node_parameters=node.parameters) # Resolve inputs. input_artifacts = inputs_utils.resolve_input_artifacts( metadata_handler=metadata_handler, node_inputs=node.inputs) return ResolvedInfo(contexts=contexts, exec_properties=exec_properties, input_artifacts=input_artifacts)
def testResolveInputArtifacts_MixedStrategyAndOp(self): self._setup_pipeline_for_input_resolver_test() self._append_resolver_step(self._my_transform, IdentityStrategy) self._append_resolver_step(self._my_transform, IdentityOp) result = inputs_utils.resolve_input_artifacts( self._metadata_handler, self._my_transform.inputs) self.assertArtifactMapEqual({'examples': self._examples}, result)
def testLatestArtifacts_withInputKeys(self): pipeline = self.load_pipeline_proto( 'pipeline_for_input_resolver_test.pbtxt') my_example_gen = pipeline.nodes[0].pipeline_node my_transform = pipeline.nodes[2].pipeline_node my_trainer = pipeline.nodes[3].pipeline_node # Use LatestArtifactStrategy for TransformGraph only. resolver = my_trainer.inputs.resolver_config.resolver_steps.add() resolver.class_path = ( 'tfx.dsl.input_resolution.strategies.latest_artifact_strategy' '.LatestArtifactStrategy') resolver.config_json = '{}' resolver.input_keys.append('transform_graph') with self.get_metadata() as m: ex1 = self.make_examples(uri='examples/1') ex2 = self.make_examples(uri='examples/2') tf1 = self.make_transform_graph(uri='transform_graph/1') tf2 = self.make_transform_graph(uri='transform_graph/2') output_artifacts = self.fake_execute( m, my_example_gen, input_map=None, output_map={'output_examples': [ex1]}) ex1 = output_artifacts['output_examples'][0] output_artifacts = self.fake_execute( m, my_example_gen, input_map=None, output_map={'output_examples': [ex2]}) ex2 = output_artifacts['output_examples'][0] output_artifacts = self.fake_execute( m, my_transform, input_map={'examples': [ex1]}, output_map={'transform_graph': [tf1]}) tf1 = output_artifacts['transform_graph'][0] output_artifacts = self.fake_execute( m, my_transform, input_map={'examples': [ex2]}, output_map={'transform_graph': [tf2]}) tf2 = output_artifacts['transform_graph'][0] result = inputs_utils.resolve_input_artifacts( metadata_handler=m, node_inputs=my_trainer.inputs) # "examples" input channel doesn't go through the resolver and its order is # undeterministic. Sort artifacts by ID for testing convenience. result['examples'].sort(key=lambda a: a.id) self.assertArtifactMapEqual( { 'examples': [ex1, ex2], 'transform_graph': [tf2] }, result)
def run( self, mlmd_connection: metadata.Metadata, pipeline_node: pipeline_pb2.PipelineNode, pipeline_info: pipeline_pb2.PipelineInfo, pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec ) -> data_types.ExecutionInfo: """Runs Resolver specific logic. Args: mlmd_connection: ML metadata connection. pipeline_node: The specification of the node that this launcher lauches. pipeline_info: The information of the pipeline that this node runs in. pipeline_runtime_spec: The runtime information of the pipeline that this node runs in. Returns: The execution of the run. """ logging.info('Running as an resolver node.') with mlmd_connection as m: # 1.Prepares all contexts. contexts = context_lib.prepare_contexts( metadata_handler=m, node_contexts=pipeline_node.contexts) # 2. Resolves inputs an execution properties. exec_properties = inputs_utils.resolve_parameters( node_parameters=pipeline_node.parameters) input_artifacts = inputs_utils.resolve_input_artifacts( metadata_handler=m, node_inputs=pipeline_node.inputs) # 3. Registers execution in metadata. execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=pipeline_node.node_info.type, contexts=contexts, exec_properties=exec_properties) # 4. Publish the execution as a cached execution with # resolved input artifact as the output artifacts. execution_publish_utils.publish_internal_execution( metadata_handler=m, contexts=contexts, execution_id=execution.id, output_artifacts=input_artifacts) return data_types.ExecutionInfo(execution_id=execution.id, input_dict=input_artifacts, output_dict=input_artifacts, exec_properties=exec_properties, pipeline_node=pipeline_node, pipeline_info=pipeline_info)
def testResolverWithResolverPolicy(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join(self._testdata_dir, 'pipeline_for_input_resolver_test.pbtxt'), pipeline) my_example_gen = pipeline.nodes[0].pipeline_node my_transform = pipeline.nodes[2].pipeline_node connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() with metadata.Metadata(connection_config=connection_config) as m: # Publishes first ExampleGen with two output channels. `output_examples` # will be consumed by downstream Transform. output_example_1 = types.Artifact( my_example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example_1.uri = 'my_examples_uri_1' output_example_2 = types.Artifact( my_example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example_2.uri = 'my_examples_uri_2' contexts = context_lib.register_contexts_if_not_exists( m, my_example_gen.contexts) execution = execution_publish_utils.register_execution( m, my_example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example_1, output_example_2], }) my_transform.inputs.resolver_config.resolver_policy = ( pipeline_pb2.ResolverConfig.LATEST_ARTIFACT) # Gets inputs for transform. Should get back what the first ExampleGen # published in the `output_examples` channel. transform_inputs = inputs_utils.resolve_input_artifacts( m, my_transform.inputs) self.assertEqual(len(transform_inputs), 1) self.assertEqual(len(transform_inputs['examples']), 1) self.assertProtoPartiallyEquals( transform_inputs['examples'][0].mlmd_artifact, output_example_2.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def testLatestUnprocessedArtifacts_NoneIfEverythingProcessed(self): pipeline = self.load_pipeline_proto( 'pipeline_for_input_resolver_test.pbtxt') my_example_gen = pipeline.nodes[0].pipeline_node my_transform = pipeline.nodes[2].pipeline_node resolver1 = my_transform.inputs.resolver_config.resolver_steps.add() resolver1.class_path = ( 'tfx.dsl.resolvers.unprocessed_artifacts_resolver' '.UnprocessedArtifactsResolver') resolver1.config_json = '{"execution_type_name": "Transform"}' resolver2 = my_transform.inputs.resolver_config.resolver_steps.add() resolver2.class_path = ( 'tfx.dsl.input_resolution.strategies.latest_artifact_strategy' '.LatestArtifactStrategy') resolver2.config_json = '{}' with self.get_metadata() as m: ex1 = self.make_examples(uri='a') ex2 = self.make_examples(uri='b') output_artifacts = self.fake_execute( m, my_example_gen, input_map=None, output_map={'output_examples': [ex1]}) ex1 = output_artifacts['output_examples'][0] output_artifacts = self.fake_execute( m, my_example_gen, input_map=None, output_map={'output_examples': [ex2]}) ex2 = output_artifacts['output_examples'][0] self.fake_execute(m, my_transform, input_map={'examples': [ex1]}, output_map=None) self.fake_execute(m, my_transform, input_map={'examples': [ex2]}, output_map=None) result = inputs_utils.resolve_input_artifacts( metadata_handler=m, node_inputs=my_transform.inputs) self.assertIsNone(result)
def testResolveInputArtifactsOutputKeyUnset(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join( self._testdata_dir, 'pipeline_for_input_resolver_test_output_key_unset.pbtxt'), pipeline) my_trainer = pipeline.nodes[0].pipeline_node my_pusher = pipeline.nodes[1].pipeline_node connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() with metadata.Metadata(connection_config=connection_config) as m: # Publishes Trainer with one output channels. `output_model` # will be consumed by the Pusher in the different run. output_model = types.Artifact( my_trainer.outputs.outputs['model'].artifact_spec.type) output_model.uri = 'my_output_model_uri' contexts = context_lib.register_contexts_if_not_exists( m, my_trainer.contexts) execution = execution_publish_utils.register_execution( m, my_trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_model], }) # Gets inputs for pusher. Should get back what the first Model # published in the `output_model` channel. pusher_inputs = inputs_utils.resolve_input_artifacts( m, my_pusher.inputs) self.assertEqual(len(pusher_inputs), 1) self.assertEqual(len(pusher_inputs['model']), 1) self.assertProtoPartiallyEquals( output_model.mlmd_artifact, pusher_inputs['model'][0].mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def testResolveInputArtifactsOutputKeyUnset(self): pipeline = self.load_pipeline_proto( 'pipeline_for_input_resolver_test_output_key_unset.pbtxt') my_trainer = pipeline.nodes[0].pipeline_node my_pusher = pipeline.nodes[1].pipeline_node with self.get_metadata() as m: # Publishes Trainer with one output channels. `output_model` # will be consumed by the Pusher in the different run. output_model = self.make_model(uri='my_output_model_uri') output_artifacts = self.fake_execute( m, my_trainer, input_map=None, output_map={'model': [output_model]}) output_model = output_artifacts['model'][0] # Gets inputs for pusher. Should get back what the first Model # published in the `output_model` channel. pusher_inputs = inputs_utils.resolve_input_artifacts( m, my_pusher.inputs) self.assertArtifactMapEqual({'model': [output_model]}, pusher_inputs)
def _prepare_execution(self) -> _PrepareExecutionResult: """Prepares inputs, outputs and execution properties for actual execution.""" # TODO(b/150979622): handle the edge case that the component get evicted # between successful pushlish and stateful working dir being clean up. # Otherwise following retries will keep failing because of duplicate # publishes. with self._mlmd_connection as m: # 1.Prepares all contexts. contexts = context_lib.register_contexts_if_not_exists( metadata_handler=m, node_contexts=self._pipeline_node.contexts) # 2. Resolves inputs an execution properties. exec_properties = inputs_utils.resolve_parameters( node_parameters=self._pipeline_node.parameters) input_artifacts = inputs_utils.resolve_input_artifacts( metadata_handler=m, node_inputs=self._pipeline_node.inputs) # 3. If not all required inputs are met. Return ExecutionInfo with # is_execution_needed being false. No publish will happen so down stream # nodes won't be triggered. if input_artifacts is None: return _PrepareExecutionResult( execution_info=data_types.ExecutionInfo(), contexts=contexts, is_execution_needed=False) # 4. Registers execution in metadata. execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=self._pipeline_node.node_info.type, contexts=contexts, input_artifacts=input_artifacts, exec_properties=exec_properties) # 5. Resolve output output_artifacts = self._output_resolver.generate_output_artifacts( execution.id) # If there is a custom driver, runs it. if self._driver_operator: driver_output = self._driver_operator.run_driver( data_types.ExecutionInfo( input_dict=input_artifacts, output_dict=output_artifacts, exec_properties=exec_properties, execution_output_uri=self._output_resolver. get_driver_output_uri())) self._update_with_driver_output(driver_output, exec_properties, output_artifacts) # We reconnect to MLMD here because the custom driver closes MLMD connection # on returning. with self._mlmd_connection as m: # 6. Check cached result cache_context = cache_utils.get_cache_context( metadata_handler=m, pipeline_node=self._pipeline_node, pipeline_info=self._pipeline_info, input_artifacts=input_artifacts, output_artifacts=output_artifacts, parameters=exec_properties) contexts.append(cache_context) cached_outputs = cache_utils.get_cached_outputs( metadata_handler=m, cache_context=cache_context) # 7. Should cache be used? if (self._pipeline_node.execution_options.caching_options. enable_cache and cached_outputs): # Publishes cache result execution_publish_utils.publish_cached_execution( metadata_handler=m, contexts=contexts, execution_id=execution.id, output_artifacts=cached_outputs) return _PrepareExecutionResult( execution_info=data_types.ExecutionInfo( execution_id=execution.id), execution_metadata=execution, contexts=contexts, is_execution_needed=False) pipeline_run_id = (self._pipeline_runtime_spec.pipeline_run_id. field_value.string_value) # 8. Going to trigger executor. return _PrepareExecutionResult( execution_info=data_types.ExecutionInfo( execution_id=execution.id, input_dict=input_artifacts, output_dict=output_artifacts, exec_properties=exec_properties, execution_output_uri=self._output_resolver. get_executor_output_uri(execution.id), stateful_working_dir=(self._output_resolver. get_stateful_working_directory()), tmp_dir=self._output_resolver.make_tmp_dir(execution.id), pipeline_node=self._pipeline_node, pipeline_info=self._pipeline_info, pipeline_run_id=pipeline_run_id), execution_metadata=execution, contexts=contexts, is_execution_needed=True)
def testResolverInputsArtifacts(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join(self._testdata_dir, 'pipeline_for_input_resolver_test.pbtxt'), pipeline) my_example_gen = pipeline.nodes[0].pipeline_node another_example_gen = pipeline.nodes[1].pipeline_node my_transform = pipeline.nodes[2].pipeline_node my_trainer = pipeline.nodes[3].pipeline_node connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() with metadata.Metadata(connection_config=connection_config) as m: # Publishes first ExampleGen with two output channels. `output_examples` # will be consumed by downstream Transform. output_example = types.Artifact( my_example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example.uri = 'my_examples_uri' side_examples = types.Artifact( my_example_gen.outputs.outputs['side_examples'].artifact_spec. type) side_examples.uri = 'side_examples_uri' contexts = context_lib.register_contexts_if_not_exists( m, my_example_gen.contexts) execution = execution_publish_utils.register_execution( m, my_example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example], 'another_examples': [side_examples] }) # Publishes second ExampleGen with one output channel with the same output # key as the first ExampleGen. However this is not consumed by downstream # nodes. another_output_example = types.Artifact( another_example_gen.outputs.outputs['output_examples']. artifact_spec.type) another_output_example.uri = 'another_examples_uri' contexts = context_lib.register_contexts_if_not_exists( m, another_example_gen.contexts) execution = execution_publish_utils.register_execution( m, another_example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [another_output_example], }) # Gets inputs for transform. Should get back what the first ExampleGen # published in the `output_examples` channel. transform_inputs = inputs_utils.resolve_input_artifacts( m, my_transform.inputs) self.assertEqual(len(transform_inputs), 1) self.assertEqual(len(transform_inputs['examples']), 1) self.assertProtoPartiallyEquals( transform_inputs['examples'][0].mlmd_artifact, output_example.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) # Tries to resolve inputs for trainer. As trainer also requires min_count # for both input channels (from example_gen and from transform) but we did # not publish anything from transform, it should return nothing. self.assertIsNone( inputs_utils.resolve_input_artifacts(m, my_trainer.inputs))
def testSuccess(self): with self._mlmd_connection as m: # Publishes two models which will be consumed by downstream resolver. output_model_1 = types.Artifact( self._my_trainer.outputs.outputs['model'].artifact_spec.type) output_model_1.uri = 'my_model_uri_1' output_model_2 = types.Artifact( self._my_trainer.outputs.outputs['model'].artifact_spec.type) output_model_2.uri = 'my_model_uri_2' contexts = context_lib.prepare_contexts(m, self._my_trainer.contexts) execution = execution_publish_utils.register_execution( m, self._my_trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_model_1, output_model_2], }) handler = resolver_node_handler.ResolverNodeHandler() execution_metadata = handler.run( mlmd_connection=self._mlmd_connection, pipeline_node=self._resolver_node, pipeline_info=self._pipeline_info, pipeline_runtime_spec=self._pipeline_runtime_spec) with self._mlmd_connection as m: # There is no way to directly verify the output artifact of the resolver # So here a fake downstream component is created which listens to the # resolver's output and we verify its input. down_stream_node = text_format.Parse( """ inputs { inputs { key: "input_models" value { channels { producer_node_query { id: "my_resolver" } context_queries { type { name: "pipeline" } name { field_value { string_value: "my_pipeline" } } } context_queries { type { name: "component" } name { field_value { string_value: "my_resolver" } } } artifact_query { type { name: "Model" } } output_key: "models" } min_count: 1 } } } upstream_nodes: "my_resolver" """, pipeline_pb2.PipelineNode()) downstream_input_artifacts = inputs_utils.resolve_input_artifacts( metadata_handler=m, node_inputs=down_stream_node.inputs) downstream_input_model = downstream_input_artifacts['input_models'] self.assertLen(downstream_input_model, 1) self.assertProtoPartiallyEquals( """ id: 2 type_id: 5 uri: "my_model_uri_2" state: LIVE""", downstream_input_model[0].mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) [execution] = m.store.get_executions_by_id([execution_metadata.id]) self.assertProtoPartiallyEquals(""" id: 2 type_id: 6 last_known_state: COMPLETE """, execution, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def testLauncher_ReEntry(self): # Some executors or runtime environment may reschedule the launcher job # before the launcher job can publish any results of the execution to MLMD. # The launcher should reuse the previous execution and proceed to a # successful execution. self.reloadPipelineWithNewRunId() LauncherTest.fakeUpstreamOutputs(self._mlmd_connection, self._example_gen, self._transform) def create_test_launcher(executor_operators): return launcher.Launcher( pipeline_node=self._trainer, mlmd_connection=self._mlmd_connection, pipeline_info=self._pipeline_info, pipeline_runtime_spec=self._pipeline_runtime_spec, executor_spec=self._trainer_executor_spec, custom_executor_operators=executor_operators) test_launcher = create_test_launcher( {_PYTHON_CLASS_EXECUTABLE_SPEC: _FakeCrashingExecutorOperator}) # The first launch simulates the launcher being restarted by preventing the # publishing of any results to MLMD. with contextlib.ExitStack() as stack: stack.enter_context( mock.patch.object(test_launcher, '_publish_failed_execution')) stack.enter_context( mock.patch.object(test_launcher, '_clean_up_stateless_execution_info')) stack.enter_context(self.assertRaises(FakeError)) test_launcher.launch() # Retrieve execution from the first launch, which should be in RUNNING # state. with self._mlmd_connection as m: contexts = context_lib.prepare_contexts( metadata_handler=m, node_contexts=test_launcher._pipeline_node.contexts) exec_properties = data_types_utils.build_parsed_value_dict( inputs_utils.resolve_parameters_with_schema( node_parameters=test_launcher._pipeline_node.parameters)) input_artifacts = inputs_utils.resolve_input_artifacts( metadata_handler=m, node_inputs=test_launcher._pipeline_node.inputs) first_execution = test_launcher._register_or_reuse_execution( metadata_handler=m, contexts=contexts, input_artifacts=input_artifacts, exec_properties=exec_properties) self.assertEqual(first_execution.last_known_state, metadata_store_pb2.Execution.RUNNING) # Create a second test launcher. It should reuse the previous execution. second_test_launcher = create_test_launcher( {_PYTHON_CLASS_EXECUTABLE_SPEC: _FakeExecutorOperator}) execution_info = second_test_launcher.launch() with self._mlmd_connection as m: self.assertEqual(first_execution.id, execution_info.execution_id) executions = m.store.get_executions_by_id( [execution_info.execution_id]) self.assertLen(executions, 1) self.assertEqual(executions.pop().last_known_state, metadata_store_pb2.Execution.COMPLETE) # Create a third test launcher. It should not require an execution. third_test_launcher = create_test_launcher( {_PYTHON_CLASS_EXECUTABLE_SPEC: _FakeExecutorOperator}) execution_preparation_result = third_test_launcher._prepare_execution() self.assertFalse(execution_preparation_result.is_execution_needed)