def setUp(self):
    super().setUp()
    pipeline_root = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self.id())

    metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
    self._metadata_path = metadata_path
    connection_config = metadata.sqlite_metadata_connection_config(
        metadata_path)
    connection_config.sqlite.SetInParent()
    self._mlmd_connection = metadata.Metadata(
        connection_config=connection_config)

    pipeline = self._make_pipeline(pipeline_root, str(uuid.uuid4()))
    self._pipeline = pipeline
    self._importer_node = self._pipeline.nodes[0].pipeline_node

    self._task_queue = tq.TaskQueue()
    [importer_task] = test_utils.run_generator_and_test(
        test_case=self,
        mlmd_connection=self._mlmd_connection,
        generator_class=sptg.SyncPipelineTaskGenerator,
        pipeline=self._pipeline,
        task_queue=self._task_queue,
        use_task_queue=True,
        service_job_manager=None,
        num_initial_executions=0,
        num_tasks_generated=1,
        num_new_executions=1,
        num_active_executions=1,
        expected_exec_nodes=[self._importer_node])
    self._importer_task = importer_task
Exemplo n.º 2
0
 def _generate_and_test(self,
                        use_task_queue,
                        num_initial_executions,
                        num_tasks_generated,
                        num_new_executions,
                        num_active_executions,
                        pipeline=None,
                        expected_exec_nodes=None,
                        ignore_update_node_state_tasks=False,
                        fail_fast=False):
     """Generates tasks and tests the effects."""
     return test_utils.run_generator_and_test(
         self,
         self._mlmd_connection,
         sptg.SyncPipelineTaskGenerator,
         pipeline or self._pipeline,
         self._task_queue,
         use_task_queue,
         self._mock_service_job_manager,
         num_initial_executions=num_initial_executions,
         num_tasks_generated=num_tasks_generated,
         num_new_executions=num_new_executions,
         num_active_executions=num_active_executions,
         expected_exec_nodes=expected_exec_nodes,
         ignore_update_node_state_tasks=ignore_update_node_state_tasks,
         fail_fast=fail_fast)
Exemplo n.º 3
0
 def _generate_and_test(self, use_task_queue, num_initial_executions,
                        num_tasks_generated, num_new_executions,
                        num_active_executions):
   """Generates tasks and tests the effects."""
   return otu.run_generator_and_test(
       self,
       self._mlmd_connection,
       sptg.SyncPipelineTaskGenerator,
       self._pipeline,
       self._task_queue,
       use_task_queue,
       self._mock_service_job_manager,
       num_initial_executions=num_initial_executions,
       num_tasks_generated=num_tasks_generated,
       num_new_executions=num_new_executions,
       num_active_executions=num_active_executions)
Exemplo n.º 4
0
    def setUp(self):
        super().setUp()

        # Set a constant version for artifact version tag.
        patcher = mock.patch('tfx.version.__version__')
        patcher.start()
        tfx_version.__version__ = '0.123.4.dev'
        self.addCleanup(patcher.stop)

        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())

        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)

        pipeline = self._make_pipeline(pipeline_root, str(uuid.uuid4()))
        self._pipeline = pipeline
        self._importer_node = self._pipeline.nodes[0].pipeline_node

        self._task_queue = tq.TaskQueue()
        [importer_task] = test_utils.run_generator_and_test(
            test_case=self,
            mlmd_connection=self._mlmd_connection,
            generator_class=sptg.SyncPipelineTaskGenerator,
            pipeline=self._pipeline,
            task_queue=self._task_queue,
            use_task_queue=True,
            service_job_manager=None,
            num_initial_executions=0,
            num_tasks_generated=1,
            num_new_executions=1,
            num_active_executions=1,
            expected_exec_nodes=[self._importer_node],
            ignore_update_node_state_tasks=True)
        self._importer_task = importer_task
 def _generate_and_test(self,
                        use_task_queue,
                        num_initial_executions,
                        num_tasks_generated,
                        num_new_executions,
                        num_active_executions,
                        expected_exec_nodes=None,
                        ignore_node_ids=None):
     """Generates tasks and tests the effects."""
     return otu.run_generator_and_test(
         self,
         self._mlmd_connection,
         asptg.AsyncPipelineTaskGenerator,
         self._pipeline,
         self._task_queue,
         use_task_queue,
         self._mock_service_job_manager,
         num_initial_executions=num_initial_executions,
         num_tasks_generated=num_tasks_generated,
         num_new_executions=num_new_executions,
         num_active_executions=num_active_executions,
         expected_exec_nodes=expected_exec_nodes,
         ignore_node_ids=ignore_node_ids)
Exemplo n.º 6
0
    def test_resolver_task_scheduler(self):
        with self._mlmd_connection as m:
            # Publishes two models which will be consumed by downstream resolver.
            output_model_1 = types.Artifact(
                self._trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_1.uri = 'my_model_uri_1'

            output_model_2 = types.Artifact(
                self._trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_2.uri = 'my_model_uri_2'

            contexts = context_lib.prepare_contexts(m, self._trainer.contexts)
            execution = execution_publish_utils.register_execution(
                m, self._trainer.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'model': [output_model_1, output_model_2],
                })

        task_queue = tq.TaskQueue()

        # Verify that resolver task is generated.
        [resolver_task] = test_utils.run_generator_and_test(
            test_case=self,
            mlmd_connection=self._mlmd_connection,
            generator_class=sptg.SyncPipelineTaskGenerator,
            pipeline=self._pipeline,
            task_queue=task_queue,
            use_task_queue=False,
            service_job_manager=None,
            num_initial_executions=1,
            num_tasks_generated=1,
            num_new_executions=1,
            num_active_executions=1,
            expected_exec_nodes=[self._resolver_node],
            ignore_update_node_state_tasks=True)

        with self._mlmd_connection as m:
            # Run resolver task scheduler and publish results.
            ts_result = resolver_task_scheduler.ResolverTaskScheduler(
                mlmd_handle=m, pipeline=self._pipeline,
                task=resolver_task).schedule()
            self.assertEqual(status_lib.Code.OK, ts_result.status.code)
            self.assertIsInstance(ts_result.output,
                                  task_scheduler.ResolverNodeOutput)
            self.assertCountEqual(
                ['resolved_model'],
                ts_result.output.resolved_input_artifacts.keys())
            models = ts_result.output.resolved_input_artifacts[
                'resolved_model']
            self.assertLen(models, 1)
            self.assertEqual('my_model_uri_2', models[0].mlmd_artifact.uri)
            tm._publish_execution_results(m, resolver_task, ts_result)

        # Verify resolver node output is input to the downstream consumer node.
        [consumer_task] = test_utils.run_generator_and_test(
            test_case=self,
            mlmd_connection=self._mlmd_connection,
            generator_class=sptg.SyncPipelineTaskGenerator,
            pipeline=self._pipeline,
            task_queue=task_queue,
            use_task_queue=False,
            service_job_manager=None,
            num_initial_executions=2,
            num_tasks_generated=1,
            num_new_executions=1,
            num_active_executions=1,
            expected_exec_nodes=[self._consumer_node],
            ignore_update_node_state_tasks=True)
        self.assertCountEqual(['resolved_model'],
                              consumer_task.input_artifacts.keys())
        input_models = consumer_task.input_artifacts['resolved_model']
        self.assertLen(input_models, 1)
        self.assertEqual('my_model_uri_2', input_models[0].mlmd_artifact.uri)