def setUp(self): super().setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') self._metadata_path = metadata_path connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) pipeline = self._make_pipeline(pipeline_root, str(uuid.uuid4())) self._pipeline = pipeline self._importer_node = self._pipeline.nodes[0].pipeline_node self._task_queue = tq.TaskQueue() [importer_task] = test_utils.run_generator_and_test( test_case=self, mlmd_connection=self._mlmd_connection, generator_class=sptg.SyncPipelineTaskGenerator, pipeline=self._pipeline, task_queue=self._task_queue, use_task_queue=True, service_job_manager=None, num_initial_executions=0, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._importer_node]) self._importer_task = importer_task
def _generate_and_test(self, use_task_queue, num_initial_executions, num_tasks_generated, num_new_executions, num_active_executions, pipeline=None, expected_exec_nodes=None, ignore_update_node_state_tasks=False, fail_fast=False): """Generates tasks and tests the effects.""" return test_utils.run_generator_and_test( self, self._mlmd_connection, sptg.SyncPipelineTaskGenerator, pipeline or self._pipeline, self._task_queue, use_task_queue, self._mock_service_job_manager, num_initial_executions=num_initial_executions, num_tasks_generated=num_tasks_generated, num_new_executions=num_new_executions, num_active_executions=num_active_executions, expected_exec_nodes=expected_exec_nodes, ignore_update_node_state_tasks=ignore_update_node_state_tasks, fail_fast=fail_fast)
def _generate_and_test(self, use_task_queue, num_initial_executions, num_tasks_generated, num_new_executions, num_active_executions): """Generates tasks and tests the effects.""" return otu.run_generator_and_test( self, self._mlmd_connection, sptg.SyncPipelineTaskGenerator, self._pipeline, self._task_queue, use_task_queue, self._mock_service_job_manager, num_initial_executions=num_initial_executions, num_tasks_generated=num_tasks_generated, num_new_executions=num_new_executions, num_active_executions=num_active_executions)
def setUp(self): super().setUp() # Set a constant version for artifact version tag. patcher = mock.patch('tfx.version.__version__') patcher.start() tfx_version.__version__ = '0.123.4.dev' self.addCleanup(patcher.stop) pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) pipeline = self._make_pipeline(pipeline_root, str(uuid.uuid4())) self._pipeline = pipeline self._importer_node = self._pipeline.nodes[0].pipeline_node self._task_queue = tq.TaskQueue() [importer_task] = test_utils.run_generator_and_test( test_case=self, mlmd_connection=self._mlmd_connection, generator_class=sptg.SyncPipelineTaskGenerator, pipeline=self._pipeline, task_queue=self._task_queue, use_task_queue=True, service_job_manager=None, num_initial_executions=0, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._importer_node], ignore_update_node_state_tasks=True) self._importer_task = importer_task
def _generate_and_test(self, use_task_queue, num_initial_executions, num_tasks_generated, num_new_executions, num_active_executions, expected_exec_nodes=None, ignore_node_ids=None): """Generates tasks and tests the effects.""" return otu.run_generator_and_test( self, self._mlmd_connection, asptg.AsyncPipelineTaskGenerator, self._pipeline, self._task_queue, use_task_queue, self._mock_service_job_manager, num_initial_executions=num_initial_executions, num_tasks_generated=num_tasks_generated, num_new_executions=num_new_executions, num_active_executions=num_active_executions, expected_exec_nodes=expected_exec_nodes, ignore_node_ids=ignore_node_ids)
def test_resolver_task_scheduler(self): with self._mlmd_connection as m: # Publishes two models which will be consumed by downstream resolver. output_model_1 = types.Artifact( self._trainer.outputs.outputs['model'].artifact_spec.type) output_model_1.uri = 'my_model_uri_1' output_model_2 = types.Artifact( self._trainer.outputs.outputs['model'].artifact_spec.type) output_model_2.uri = 'my_model_uri_2' contexts = context_lib.prepare_contexts(m, self._trainer.contexts) execution = execution_publish_utils.register_execution( m, self._trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_model_1, output_model_2], }) task_queue = tq.TaskQueue() # Verify that resolver task is generated. [resolver_task] = test_utils.run_generator_and_test( test_case=self, mlmd_connection=self._mlmd_connection, generator_class=sptg.SyncPipelineTaskGenerator, pipeline=self._pipeline, task_queue=task_queue, use_task_queue=False, service_job_manager=None, num_initial_executions=1, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._resolver_node], ignore_update_node_state_tasks=True) with self._mlmd_connection as m: # Run resolver task scheduler and publish results. ts_result = resolver_task_scheduler.ResolverTaskScheduler( mlmd_handle=m, pipeline=self._pipeline, task=resolver_task).schedule() self.assertEqual(status_lib.Code.OK, ts_result.status.code) self.assertIsInstance(ts_result.output, task_scheduler.ResolverNodeOutput) self.assertCountEqual( ['resolved_model'], ts_result.output.resolved_input_artifacts.keys()) models = ts_result.output.resolved_input_artifacts[ 'resolved_model'] self.assertLen(models, 1) self.assertEqual('my_model_uri_2', models[0].mlmd_artifact.uri) tm._publish_execution_results(m, resolver_task, ts_result) # Verify resolver node output is input to the downstream consumer node. [consumer_task] = test_utils.run_generator_and_test( test_case=self, mlmd_connection=self._mlmd_connection, generator_class=sptg.SyncPipelineTaskGenerator, pipeline=self._pipeline, task_queue=task_queue, use_task_queue=False, service_job_manager=None, num_initial_executions=2, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._consumer_node], ignore_update_node_state_tasks=True) self.assertCountEqual(['resolved_model'], consumer_task.input_artifacts.keys()) input_models = consumer_task.input_artifacts['resolved_model'] self.assertLen(input_models, 1) self.assertEqual('my_model_uri_2', input_models[0].mlmd_artifact.uri)