def test_cached_execution(self): """Tests that cached execution is used if one is available.""" # Fake ExampleGen run. example_gen_exec = otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) # Invoking generator should produce an ExecNodeTask for StatsGen. [stats_gen_task] = self._generate_and_test( False, num_initial_executions=1, num_tasks_generated=1, num_new_executions=1, num_active_executions=1) self.assertEqual('my_statistics_gen', stats_gen_task.node_uid.node_id) # Finish StatsGen execution. otu.fake_execute_node(self._mlmd_connection, stats_gen_task) # Prepare another pipeline with a new pipeline_run_id. pipeline_run_id = str(uuid.uuid4()) new_pipeline = self._make_pipeline(self._pipeline_root, pipeline_run_id) with self._mlmd_connection as m: contexts = m.store.get_contexts_by_execution(example_gen_exec.id) # We use node context as cache context for ease of testing. cache_context = [c for c in contexts if c.name == 'my_example_gen'][0] # Fake example_gen cached execution. otu.fake_cached_execution(self._mlmd_connection, cache_context, otu.get_node(new_pipeline, 'my_example_gen')) stats_gen = otu.get_node(new_pipeline, 'my_statistics_gen') # Invoking generator for the new pipeline should result in: # 1. StatsGen execution succeeds with state "CACHED" but no ExecNodeTask # generated. # 2. An ExecNodeTask is generated for SchemaGen (component downstream of # StatsGen) with an active execution in MLMD. [schema_gen_task] = self._generate_and_test( False, pipeline=new_pipeline, num_initial_executions=3, num_tasks_generated=1, num_new_executions=2, num_active_executions=1) self.assertEqual('my_schema_gen', schema_gen_task.node_uid.node_id) # Check that StatsGen execution is successful in state "CACHED". with self._mlmd_connection as m: executions = task_gen_utils.get_executions(m, stats_gen) self.assertLen(executions, 1) execution = executions[0] self.assertTrue(execution_lib.is_execution_successful(execution)) self.assertEqual(metadata_store_pb2.Execution.CACHED, execution.last_known_state)
def setUp(self): super().setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) self._pipeline_root = pipeline_root # Makes sure multiple connections within a test always connect to the same # MLMD instance. metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') self._metadata_path = metadata_path connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) # Sets up the pipeline. pipeline = self._make_pipeline(self._pipeline_root, str(uuid.uuid4())) self._pipeline = pipeline # Extracts components. self._example_gen = test_utils.get_node(pipeline, 'my_example_gen') self._stats_gen = test_utils.get_node(pipeline, 'my_statistics_gen') self._schema_gen = test_utils.get_node(pipeline, 'my_schema_gen') self._transform = test_utils.get_node(pipeline, 'my_transform') self._example_validator = test_utils.get_node(pipeline, 'my_example_validator') self._trainer = test_utils.get_node(pipeline, 'my_trainer') self._evaluator = test_utils.get_node(pipeline, 'my_evaluator') self._chore_a = test_utils.get_node(pipeline, 'chore_a') self._chore_b = test_utils.get_node(pipeline, 'chore_b') self._task_queue = tq.TaskQueue() self._mock_service_job_manager = mock.create_autospec( service_jobs.ServiceJobManager, instance=True) self._mock_service_job_manager.is_pure_service_node.side_effect = ( lambda _, node_id: node_id == self._example_gen.node_info.id) self._mock_service_job_manager.is_mixed_service_node.side_effect = ( lambda _, node_id: node_id == self._transform.node_info.id) def _default_ensure_node_services(unused_pipeline_state, node_id): self.assertIn( node_id, (self._example_gen.node_info.id, self._transform.node_info.id)) return service_jobs.ServiceStatus.SUCCESS self._mock_service_job_manager.ensure_node_services.side_effect = ( _default_ensure_node_services)
def _execute_nodes(handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline, version: int): """Creates fake execution of nodes.""" example_gen = test_utils.get_node(pipeline, 'my_example_gen') stats_gen = test_utils.get_node(pipeline, 'my_statistics_gen') schema_gen = test_utils.get_node(pipeline, 'my_schema_gen') transform = test_utils.get_node(pipeline, 'my_transform') example_validator = test_utils.get_node(pipeline, 'my_example_validator') trainer = test_utils.get_node(pipeline, 'my_trainer') test_utils.fake_example_gen_run_with_handle(handle, example_gen, 1, version) test_utils.fake_component_output_with_handle(handle, stats_gen, active=False) test_utils.fake_component_output_with_handle(handle, schema_gen, active=False) test_utils.fake_component_output_with_handle(handle, transform, active=False) test_utils.fake_component_output_with_handle(handle, example_validator, active=False) test_utils.fake_component_output_with_handle(handle, trainer, active=False)