def setUp(self):
    super().setUp()
    pipeline_root = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self.id())

    metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
    self._metadata_path = metadata_path
    connection_config = metadata.sqlite_metadata_connection_config(
        metadata_path)
    connection_config.sqlite.SetInParent()
    self._mlmd_connection = metadata.Metadata(
        connection_config=connection_config)

    pipeline = self._make_pipeline(pipeline_root, str(uuid.uuid4()))
    self._pipeline = pipeline
    self._importer_node = self._pipeline.nodes[0].pipeline_node

    self._task_queue = tq.TaskQueue()
    [importer_task] = test_utils.run_generator_and_test(
        test_case=self,
        mlmd_connection=self._mlmd_connection,
        generator_class=sptg.SyncPipelineTaskGenerator,
        pipeline=self._pipeline,
        task_queue=self._task_queue,
        use_task_queue=True,
        service_job_manager=None,
        num_initial_executions=0,
        num_tasks_generated=1,
        num_new_executions=1,
        num_active_executions=1,
        expected_exec_nodes=[self._importer_node])
    self._importer_task = importer_task
Beispiel #2
0
    def setUp(self):
        super(AsyncPipelineTaskGeneratorTest, self).setUp()
        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())
        self._pipeline_root = pipeline_root

        # Makes sure multiple connections within a test always connect to the same
        # MLMD instance.
        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        self._metadata_path = metadata_path
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)

        # Sets up the pipeline.
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'async_pipeline.pbtxt'), pipeline)
        self._pipeline = pipeline
        self._pipeline_info = pipeline.pipeline_info
        self._pipeline_runtime_spec = pipeline.runtime_spec
        self._pipeline_runtime_spec.pipeline_root.field_value.string_value = (
            pipeline_root)

        # Extracts components.
        self._example_gen = pipeline.nodes[0].pipeline_node
        self._transform = pipeline.nodes[1].pipeline_node
        self._trainer = pipeline.nodes[2].pipeline_node

        self._task_queue = tq.TaskQueue()
        self._ignore_node_ids = set([self._example_gen.node_info.id])
Beispiel #3
0
    def test_task_queue_operations(self):
        t1 = _test_task(node_id='trainer', pipeline_id='my_pipeline')
        t2 = _test_task(node_id='transform',
                        pipeline_id='my_pipeline',
                        pipeline_run_id='run_0')
        tq = task_queue.TaskQueue()

        # Enqueueing new tasks is successful.
        self.assertTrue(tq.enqueue(t1))
        self.assertTrue(tq.enqueue(t2))

        # Re-enqueueing the same tasks fails.
        self.assertFalse(tq.enqueue(t1))
        self.assertFalse(tq.enqueue(t2))

        # Dequeue succeeds and returns `None` when queue is empty.
        self.assertEqual(t1, tq.dequeue())
        self.assertEqual(t2, tq.dequeue())
        self.assertIsNone(tq.dequeue())
        self.assertIsNone(tq.dequeue(0.1))

        # Re-enqueueing the same tasks fails as `task_done` has not been called.
        self.assertFalse(tq.enqueue(t1))
        self.assertFalse(tq.enqueue(t2))

        tq.task_done(t1)
        tq.task_done(t2)

        # Re-enqueueing is allowed after `task_done` has been called.
        self.assertTrue(tq.enqueue(t1))
        self.assertTrue(tq.enqueue(t2))
  def setUp(self):
    super(AsyncPipelineTaskGeneratorTest, self).setUp()
    pipeline_root = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self.id())
    self._pipeline_root = pipeline_root

    # Makes sure multiple connections within a test always connect to the same
    # MLMD instance.
    metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
    self._metadata_path = metadata_path
    connection_config = metadata.sqlite_metadata_connection_config(
        metadata_path)
    connection_config.sqlite.SetInParent()
    self._mlmd_connection = metadata.Metadata(
        connection_config=connection_config)

    # Sets up the pipeline.
    pipeline = pipeline_pb2.Pipeline()
    self.load_proto_from_text(
        os.path.join(
            os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'),
        pipeline)
    self._pipeline = pipeline
    self._pipeline_info = pipeline.pipeline_info
    self._pipeline_runtime_spec = pipeline.runtime_spec
    self._pipeline_runtime_spec.pipeline_root.field_value.string_value = (
        pipeline_root)

    # Extracts components.
    self._example_gen = pipeline.nodes[0].pipeline_node
    self._transform = pipeline.nodes[1].pipeline_node
    self._trainer = pipeline.nodes[2].pipeline_node

    self._task_queue = tq.TaskQueue()

    self._mock_service_job_manager = mock.create_autospec(
        service_jobs.ServiceJobManager, instance=True)

    def _is_pure_service_node(unused_pipeline_state, node_id):
      return node_id == self._example_gen.node_info.id

    def _is_mixed_service_node(unused_pipeline_state, node_id):
      return node_id == self._transform.node_info.id

    self._mock_service_job_manager.is_pure_service_node.side_effect = (
        _is_pure_service_node)
    self._mock_service_job_manager.is_mixed_service_node.side_effect = (
        _is_mixed_service_node)

    def _default_ensure_node_services(unused_pipeline_state, node_id):
      self.assertIn(
          node_id,
          (self._example_gen.node_info.id, self._transform.node_info.id))
      return service_jobs.ServiceStatus.RUNNING

    self._mock_service_job_manager.ensure_node_services.side_effect = (
        _default_ensure_node_services)
Beispiel #5
0
    def test_queue_multiplexing(self, mock_publish):
        # Create a pipeline IR containing deployment config for testing.
        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
            class_path='trainer.TrainerExecutor')
        deployment_config.executor_specs['Trainer'].Pack(executor_spec)
        deployment_config.executor_specs['Transform'].Pack(executor_spec)
        deployment_config.executor_specs['Evaluator'].Pack(executor_spec)
        pipeline = pipeline_pb2.Pipeline()
        pipeline.deployment_config.Pack(deployment_config)

        collector = _Collector()

        # Register a bunch of fake task schedulers.
        # Register fake task scheduler.
        ts.TaskSchedulerRegistry.register(
            deployment_config.executor_specs['Trainer'].type_url,
            functools.partial(_FakeTaskScheduler,
                              block_nodes={'Trainer', 'Transform'},
                              collector=collector))

        task_queue = tq.TaskQueue()

        # Enqueue some tasks.
        trainer_exec_task = _test_exec_task('Trainer', 'test-pipeline')
        task_queue.enqueue(trainer_exec_task)
        task_queue.enqueue(_test_cancel_task('Trainer', 'test-pipeline'))

        with tm.TaskManager(mock.Mock(),
                            pipeline,
                            task_queue,
                            max_active_task_schedulers=1000,
                            max_dequeue_wait_secs=0.1,
                            process_all_queued_tasks_before_exit=True):
            # Enqueue more tasks after task manager starts.
            transform_exec_task = _test_exec_task('Transform', 'test-pipeline')
            task_queue.enqueue(transform_exec_task)
            evaluator_exec_task = _test_exec_task('Evaluator', 'test-pipeline')
            task_queue.enqueue(evaluator_exec_task)
            task_queue.enqueue(_test_cancel_task('Transform', 'test-pipeline'))

        # Ensure that all exec and cancellation tasks were processed correctly.
        self.assertCountEqual(
            [trainer_exec_task, transform_exec_task, evaluator_exec_task],
            collector.scheduled_tasks)
        self.assertCountEqual([trainer_exec_task, transform_exec_task],
                              collector.cancelled_tasks)
        mock_publish.assert_has_calls([
            mock.call(mock.ANY, pipeline, trainer_exec_task, mock.ANY),
            mock.call(mock.ANY, pipeline, transform_exec_task, mock.ANY),
            mock.call(mock.ANY, pipeline, evaluator_exec_task, mock.ANY)
        ],
                                      any_order=True)
Beispiel #6
0
  def test_task_handling(self, mock_publish):
    collector = _Collector()

    # Register a fake task scheduler.
    ts.TaskSchedulerRegistry.register(
        self._type_url,
        functools.partial(
            _FakeTaskScheduler,
            block_nodes={'Trainer', 'Transform'},
            collector=collector))

    task_queue = tq.TaskQueue()

    # Enqueue some tasks.
    trainer_exec_task = _test_exec_node_task('Trainer', 'test-pipeline',
                                             pipeline=self._pipeline)
    task_queue.enqueue(trainer_exec_task)
    task_queue.enqueue(_test_cancel_node_task('Trainer', 'test-pipeline'))

    with self._task_manager(task_queue) as task_manager:
      # Enqueue more tasks after task manager starts.
      transform_exec_task = _test_exec_node_task('Transform', 'test-pipeline',
                                                 pipeline=self._pipeline)
      task_queue.enqueue(transform_exec_task)
      evaluator_exec_task = _test_exec_node_task('Evaluator', 'test-pipeline',
                                                 pipeline=self._pipeline)
      task_queue.enqueue(evaluator_exec_task)
      task_queue.enqueue(_test_cancel_node_task('Transform', 'test-pipeline'))

    self.assertTrue(task_manager.done())
    self.assertIsNone(task_manager.exception())

    # Ensure that all exec and cancellation tasks were processed correctly.
    self.assertCountEqual(
        [trainer_exec_task, transform_exec_task, evaluator_exec_task],
        collector.scheduled_tasks)
    self.assertCountEqual([trainer_exec_task, transform_exec_task],
                          collector.cancelled_tasks)
    mock_publish.assert_has_calls([
        mock.call(
            mlmd_handle=mock.ANY,
            task=trainer_exec_task,
            result=mock.ANY),
        mock.call(
            mlmd_handle=mock.ANY,
            task=transform_exec_task,
            result=mock.ANY),
        mock.call(
            mlmd_handle=mock.ANY,
            task=evaluator_exec_task,
            result=mock.ANY),
    ],
                                  any_order=True)
Beispiel #7
0
    def setUp(self):
        super().setUp()
        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())
        self._pipeline_root = pipeline_root

        # Makes sure multiple connections within a test always connect to the same
        # MLMD instance.
        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        self._metadata_path = metadata_path
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)

        # Sets up the pipeline.
        pipeline = self._make_pipeline(self._pipeline_root, str(uuid.uuid4()))
        self._pipeline = pipeline

        # Extracts components.
        self._example_gen = test_utils.get_node(pipeline, 'my_example_gen')
        self._stats_gen = test_utils.get_node(pipeline, 'my_statistics_gen')
        self._schema_gen = test_utils.get_node(pipeline, 'my_schema_gen')
        self._transform = test_utils.get_node(pipeline, 'my_transform')
        self._example_validator = test_utils.get_node(pipeline,
                                                      'my_example_validator')
        self._trainer = test_utils.get_node(pipeline, 'my_trainer')
        self._evaluator = test_utils.get_node(pipeline, 'my_evaluator')
        self._chore_a = test_utils.get_node(pipeline, 'chore_a')
        self._chore_b = test_utils.get_node(pipeline, 'chore_b')

        self._task_queue = tq.TaskQueue()

        self._mock_service_job_manager = mock.create_autospec(
            service_jobs.ServiceJobManager, instance=True)

        self._mock_service_job_manager.is_pure_service_node.side_effect = (
            lambda _, node_id: node_id == self._example_gen.node_info.id)
        self._mock_service_job_manager.is_mixed_service_node.side_effect = (
            lambda _, node_id: node_id == self._transform.node_info.id)

        def _default_ensure_node_services(unused_pipeline_state, node_id):
            self.assertIn(
                node_id,
                (self._example_gen.node_info.id, self._transform.node_info.id))
            return service_jobs.ServiceStatus.SUCCESS

        self._mock_service_job_manager.ensure_node_services.side_effect = (
            _default_ensure_node_services)
Beispiel #8
0
  def test_exceptions_are_surfaced(self, mock_publish):

    def _publish(**kwargs):
      task = kwargs['task']
      assert task_lib.is_exec_node_task(task)
      if task.node_uid.node_id == 'Transform':
        raise ValueError('test error')
      return mock.DEFAULT

    mock_publish.side_effect = _publish

    collector = _Collector()

    # Register a fake task scheduler.
    ts.TaskSchedulerRegistry.register(
        self._type_url,
        functools.partial(
            _FakeTaskScheduler, block_nodes={}, collector=collector))

    task_queue = tq.TaskQueue()

    with self._task_manager(task_queue) as task_manager:
      transform_task = _test_exec_node_task('Transform', 'test-pipeline',
                                            pipeline=self._pipeline)
      trainer_task = _test_exec_node_task('Trainer', 'test-pipeline',
                                          pipeline=self._pipeline)
      task_queue.enqueue(transform_task)
      task_queue.enqueue(trainer_task)

    self.assertTrue(task_manager.done())
    exception = task_manager.exception()
    self.assertIsNotNone(exception)
    self.assertIsInstance(exception, tm.TasksProcessingError)
    self.assertLen(exception.errors, 1)
    self.assertEqual('test error', str(exception.errors[0]))

    self.assertCountEqual([transform_task, trainer_task],
                          collector.scheduled_tasks)
    mock_publish.assert_has_calls([
        mock.call(
            mlmd_handle=mock.ANY,
            task=transform_task,
            result=mock.ANY),
        mock.call(
            mlmd_handle=mock.ANY,
            task=trainer_task,
            result=mock.ANY),
    ],
                                  any_order=True)
Beispiel #9
0
  def setUp(self):
    super(SyncPipelineTaskGeneratorTest, self).setUp()
    pipeline_root = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self.id())
    self._pipeline_root = pipeline_root

    # Makes sure multiple connections within a test always connect to the same
    # MLMD instance.
    metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
    self._metadata_path = metadata_path
    connection_config = metadata.sqlite_metadata_connection_config(
        metadata_path)
    connection_config.sqlite.SetInParent()
    self._mlmd_connection = metadata.Metadata(
        connection_config=connection_config)

    # Sets up the pipeline.
    pipeline = pipeline_pb2.Pipeline()
    self.load_proto_from_text(
        os.path.join(
            os.path.dirname(__file__), 'testdata', 'sync_pipeline.pbtxt'),
        pipeline)
    self._pipeline_run_id = str(uuid.uuid4())
    runtime_parameter_utils.substitute_runtime_parameter(
        pipeline, {
            'pipeline_root': pipeline_root,
            'pipeline_run_id': self._pipeline_run_id
        })
    self._pipeline = pipeline

    # Extracts components.
    self._example_gen = _get_node(pipeline, 'my_example_gen')
    self._stats_gen = _get_node(pipeline, 'my_statistics_gen')
    self._schema_gen = _get_node(pipeline, 'my_schema_gen')
    self._transform = _get_node(pipeline, 'my_transform')
    self._example_validator = _get_node(pipeline, 'my_example_validator')
    self._trainer = _get_node(pipeline, 'my_trainer')

    self._task_queue = tq.TaskQueue()

    self._mock_service_job_manager = mock.create_autospec(
        service_jobs.ServiceJobManager, instance=True)

    self._mock_service_job_manager.is_pure_service_node.side_effect = (
        lambda _, node_id: node_id == self._example_gen.node_info.id)
    self._mock_service_job_manager.is_mixed_service_node.side_effect = (
        lambda _, node_id: node_id == self._transform.node_info.id)
Beispiel #10
0
    def test_invalid_task_done_raises_errors(self):
        t1 = _test_task(node_id='trainer', pipeline_id='my_pipeline')
        t2 = _test_task(node_id='transform', pipeline_id='my_pipeline')
        tq = task_queue.TaskQueue()

        # Enqueue t1, but calling `task_done` raises error since t1 is not dequeued.
        self.assertTrue(tq.enqueue(t1))
        with self.assertRaisesRegex(RuntimeError, 'Must call `dequeue`'):
            tq.task_done(t1)

        # `task_done` succeeds after dequeueing.
        self.assertEqual(t1, tq.dequeue())
        tq.task_done(t1)

        # Error since t2 is not in the queue.
        with self.assertRaisesRegex(RuntimeError, 'Task not present'):
            tq.task_done(t2)
Beispiel #11
0
    def setUp(self):
        super().setUp()

        # Set a constant version for artifact version tag.
        patcher = mock.patch('tfx.version.__version__')
        patcher.start()
        tfx_version.__version__ = '0.123.4.dev'
        self.addCleanup(patcher.stop)

        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())

        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)

        pipeline = self._make_pipeline(pipeline_root, str(uuid.uuid4()))
        self._pipeline = pipeline
        self._importer_node = self._pipeline.nodes[0].pipeline_node

        self._task_queue = tq.TaskQueue()
        [importer_task] = test_utils.run_generator_and_test(
            test_case=self,
            mlmd_connection=self._mlmd_connection,
            generator_class=sptg.SyncPipelineTaskGenerator,
            pipeline=self._pipeline,
            task_queue=self._task_queue,
            use_task_queue=True,
            service_job_manager=None,
            num_initial_executions=0,
            num_tasks_generated=1,
            num_new_executions=1,
            num_active_executions=1,
            expected_exec_nodes=[self._importer_node],
            ignore_update_node_state_tasks=True)
        self._importer_task = importer_task
Beispiel #12
0
    def test_handling_finalize_node_task(self, task_gen):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1')
            pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
            pipeline_ops.initiate_pipeline_start(m, pipeline)
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
            finalize_reason = status_lib.Status(code=status_lib.Code.ABORTED,
                                                message='foo bar')
            task_gen.return_value.generate.side_effect = [
                [
                    test_utils.create_exec_node_task(
                        task_lib.NodeUid(pipeline_uid=pipeline_uid,
                                         node_id='Transform')),
                    task_lib.FinalizeNodeTask(node_uid=task_lib.NodeUid(
                        pipeline_uid=pipeline_uid, node_id='Trainer'),
                                              status=finalize_reason)
                ],
            ]

            task_queue = tq.TaskQueue()
            pipeline_ops.orchestrate(m, task_queue,
                                     service_jobs.DummyServiceJobManager())
            task_gen.return_value.generate.assert_called_once()
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_exec_node_task(task))
            self.assertEqual(
                test_utils.create_node_uid('pipeline1', 'Transform'),
                task.node_uid)

            # Load pipeline state and verify node stop initiation.
            with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
                self.assertEqual(
                    finalize_reason,
                    pipeline_state.node_stop_initiated_reason(
                        task_lib.NodeUid(pipeline_uid=pipeline_uid,
                                         node_id='Trainer')))
Beispiel #13
0
  def test_handling_finalize_pipeline_task(self, task_gen):
    with self._mlmd_connection as m:
      pipeline = _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC)
      pipeline_ops.initiate_pipeline_start(m, pipeline)
      pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
      finalize_reason = status_lib.Status(
          code=status_lib.Code.ABORTED, message='foo bar')
      task_gen.return_value.generate.side_effect = [
          [
              task_lib.FinalizePipelineTask(
                  pipeline_uid=pipeline_uid, status=finalize_reason)
          ],
      ]

      task_queue = tq.TaskQueue()
      pipeline_ops.orchestrate(m, task_queue,
                               service_jobs.DummyServiceJobManager())
      task_gen.return_value.generate.assert_called_once()
      self.assertTrue(task_queue.is_empty())

      # Load pipeline state and verify stop initiation.
      with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
        self.assertEqual(finalize_reason,
                         pipeline_state.stop_initiated_reason())
Beispiel #14
0
  def test_stop_initiated_pipelines(self, pipeline, mock_gen_task_from_active,
                                    mock_async_task_gen, mock_sync_task_gen):
    with self._mlmd_connection as m:
      pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'

      mock_service_job_manager = mock.create_autospec(
          service_jobs.ServiceJobManager, instance=True)
      mock_service_job_manager.is_pure_service_node.side_effect = (
          lambda _, node_id: node_id == 'ExampleGen')
      mock_service_job_manager.is_mixed_service_node.side_effect = (
          lambda _, node_id: node_id == 'Transform')

      pipeline_ops.initiate_pipeline_start(m, pipeline)
      with pstate.PipelineState.load(
          m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state:
        pipeline_state.initiate_stop(
            status_lib.Status(code=status_lib.Code.CANCELLED))
      pipeline_execution = pipeline_state.execution

      task_queue = tq.TaskQueue()

      # For the stop-initiated pipeline, "Transform" execution task is in queue,
      # "Trainer" has an active execution in MLMD but no task in queue,
      # "Evaluator" has no active execution.
      task_queue.enqueue(
          test_utils.create_exec_node_task(
              task_lib.NodeUid(
                  pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline),
                  node_id='Transform')))
      transform_task = task_queue.dequeue()  # simulates task being processed
      mock_gen_task_from_active.side_effect = [
          test_utils.create_exec_node_task(
              node_uid=task_lib.NodeUid(
                  pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline),
                  node_id='Trainer'),
              is_cancelled=True), None, None, None, None
      ]

      pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)

      # There are no active pipelines so these shouldn't be called.
      mock_async_task_gen.assert_not_called()
      mock_sync_task_gen.assert_not_called()

      # stop_node_services should be called for ExampleGen which is a pure
      # service node.
      mock_service_job_manager.stop_node_services.assert_called_once_with(
          mock.ANY, 'ExampleGen')
      mock_service_job_manager.reset_mock()

      task_queue.task_done(transform_task)  # Pop out transform task.

      # CancelNodeTask for the "Transform" ExecNodeTask should be next.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_cancel_node_task(task))
      self.assertEqual('Transform', task.node_uid.node_id)

      # ExecNodeTask (with is_cancelled=True) for "Trainer" is next.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual('Trainer', task.node_uid.node_id)
      self.assertTrue(task.is_cancelled)

      self.assertTrue(task_queue.is_empty())

      mock_gen_task_from_active.assert_has_calls([
          mock.call(
              m,
              pipeline_state.pipeline,
              pipeline.nodes[2].pipeline_node,
              mock.ANY,
              is_cancelled=True),
          mock.call(
              m,
              pipeline_state.pipeline,
              pipeline.nodes[3].pipeline_node,
              mock.ANY,
              is_cancelled=True)
      ])
      self.assertEqual(2, mock_gen_task_from_active.call_count)

      # Pipeline execution should continue to be active since active node
      # executions were found in the last call to `orchestrate`.
      [execution] = m.store.get_executions_by_id([pipeline_execution.id])
      self.assertTrue(execution_lib.is_execution_active(execution))

      # Call `orchestrate` again; this time there are no more active node
      # executions so the pipeline should be marked as cancelled.
      pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)
      self.assertTrue(task_queue.is_empty())
      [execution] = m.store.get_executions_by_id([pipeline_execution.id])
      self.assertEqual(metadata_store_pb2.Execution.CANCELED,
                       execution.last_known_state)

      # stop_node_services should be called on both ExampleGen and Transform
      # which are service nodes.
      mock_service_job_manager.stop_node_services.assert_has_calls(
          [mock.call(mock.ANY, 'ExampleGen'),
           mock.call(mock.ANY, 'Transform')],
          any_order=True)
Beispiel #15
0
  def test_orchestrate_active_pipelines(self, mock_async_task_gen,
                                        mock_sync_task_gen):
    with self._mlmd_connection as m:
      # Sync and async active pipelines.
      async_pipelines = [
          _test_pipeline('pipeline1'),
          _test_pipeline('pipeline2'),
      ]
      sync_pipelines = [
          _test_pipeline('pipeline3', pipeline_pb2.Pipeline.SYNC),
          _test_pipeline('pipeline4', pipeline_pb2.Pipeline.SYNC),
      ]

      for pipeline in async_pipelines + sync_pipelines:
        pipeline_ops.initiate_pipeline_start(m, pipeline)

      # Active executions for active async pipelines.
      mock_async_task_gen.return_value.generate.side_effect = [
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          async_pipelines[0]),
                      node_id='Transform'))
          ],
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          async_pipelines[1]),
                      node_id='Trainer'))
          ],
      ]

      # Active executions for active sync pipelines.
      mock_sync_task_gen.return_value.generate.side_effect = [
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          sync_pipelines[0]),
                      node_id='Trainer'))
          ],
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          sync_pipelines[1]),
                      node_id='Validator'))
          ],
      ]

      task_queue = tq.TaskQueue()
      pipeline_ops.orchestrate(m, task_queue,
                               service_jobs.DummyServiceJobManager())

      self.assertEqual(2, mock_async_task_gen.return_value.generate.call_count)
      self.assertEqual(2, mock_sync_task_gen.return_value.generate.call_count)

      # Verify that tasks are enqueued in the expected order.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid)
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline2', 'Trainer'), task.node_uid)
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline3', 'Trainer'), task.node_uid)
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline4', 'Validator'), task.node_uid)
      self.assertTrue(task_queue.is_empty())
Beispiel #16
0
    def setUp(self):
        super(TaskManagerE2ETest, self).setUp()
        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())

        # Makes sure multiple connections within a test always connect to the same
        # MLMD instance.
        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        self._metadata_path = metadata_path
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)

        # Sets up the pipeline.
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'async_pipeline.pbtxt'), pipeline)

        # Extracts components.
        self._example_gen = pipeline.nodes[0].pipeline_node
        self._transform = pipeline.nodes[1].pipeline_node
        self._trainer = pipeline.nodes[2].pipeline_node

        # Pack deployment config for testing.
        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
            class_path='fake.ClassPath')
        deployment_config.executor_specs[self._trainer.node_info.id].Pack(
            executor_spec)
        deployment_config.executor_specs[self._transform.node_info.id].Pack(
            executor_spec)
        self._type_url = deployment_config.executor_specs[
            self._trainer.node_info.id].type_url
        pipeline.deployment_config.Pack(deployment_config)
        self._pipeline = pipeline
        self._pipeline_info = pipeline.pipeline_info
        self._pipeline_runtime_spec = pipeline.runtime_spec
        self._pipeline_runtime_spec.pipeline_root.field_value.string_value = (
            pipeline_root)

        ts.TaskSchedulerRegistry.clear()
        self._task_queue = tq.TaskQueue()

        # Run fake example-gen to prepare downstreams component triggers.
        test_utils.fake_example_gen_run(self._mlmd_connection,
                                        self._example_gen, 1, 1)

        # Task generator should produce a task to run transform.
        with self._mlmd_connection as m:
            pipeline_state = pstate.PipelineState(m, self._pipeline, 0)
            tasks = asptg.AsyncPipelineTaskGenerator(
                m, pipeline_state, self._task_queue.contains_task_id,
                service_jobs.DummyServiceJobManager()).generate()
        self.assertLen(tasks, 1)
        task = tasks[0]
        self.assertEqual('my_transform', task.node_uid.node_id)

        # Task generator should produce a task to run transform.
        with self._mlmd_connection as m:
            pipeline_state = pstate.PipelineState(m, self._pipeline, 0)
            tasks = asptg.AsyncPipelineTaskGenerator(
                m, pipeline_state, self._task_queue.contains_task_id,
                service_jobs.DummyServiceJobManager()).generate()
        self.assertLen(tasks, 1)
        self._task = tasks[0]
        self.assertEqual('my_transform', self._task.node_uid.node_id)
        self._task_queue.enqueue(self._task)

        # There should be 1 active execution in MLMD.
        with self._mlmd_connection as m:
            executions = m.store.get_executions()
        active_executions = [
            e for e in executions
            if e.last_known_state == metadata_store_pb2.Execution.RUNNING
        ]
        self.assertLen(active_executions, 1)

        # Active execution id.
        self._execution_id = active_executions[0].id
Beispiel #17
0
    def test_stop_initiated_async_pipelines(self, mock_gen_task_from_active,
                                            mock_async_task_gen,
                                            mock_sync_task_gen):
        with self._mlmd_connection as m:
            pipeline1 = _test_pipeline('pipeline1')
            pipeline1.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline1.nodes.add().pipeline_node.node_info.id = 'Trainer'
            pipeline1.nodes.add().pipeline_node.node_info.id = 'Evaluator'
            pipeline_ops.initiate_pipeline_start(m, pipeline1)
            pipeline1_execution = pipeline_ops._initiate_pipeline_stop(
                m, task_lib.PipelineUid.from_pipeline(pipeline1))

            task_queue = tq.TaskQueue()

            # For the stop-initiated pipeline, "Transform" execution task is in queue,
            # "Trainer" has an active execution in MLMD but no task in queue,
            # "Evaluator" has no active execution.
            task_queue.enqueue(
                test_utils.create_exec_node_task(node_uid=task_lib.NodeUid(
                    pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1',
                                                      pipeline_run_id=None),
                    node_id='Transform')))
            transform_task = task_queue.dequeue(
            )  # simulates task being processed
            mock_gen_task_from_active.side_effect = [
                test_utils.create_exec_node_task(node_uid=task_lib.NodeUid(
                    pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1',
                                                      pipeline_run_id=None),
                    node_id='Trainer'),
                                                 is_cancelled=True), None,
                None, None, None
            ]

            pipeline_ops.generate_tasks(m, task_queue)

            # There are no active pipelines so these shouldn't be called.
            mock_async_task_gen.assert_not_called()
            mock_sync_task_gen.assert_not_called()

            # Simulate finishing the "Transform" ExecNodeTask.
            task_queue.task_done(transform_task)

            # CancelNodeTask for the "Transform" ExecNodeTask should be next.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_cancel_node_task(task))
            self.assertEqual('Transform', task.node_uid.node_id)

            # ExecNodeTask for "Trainer" is next.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_exec_node_task(task))
            self.assertEqual('Trainer', task.node_uid.node_id)

            self.assertTrue(task_queue.is_empty())

            mock_gen_task_from_active.assert_has_calls([
                mock.call(m,
                          pipeline1,
                          pipeline1.nodes[1].pipeline_node,
                          mock.ANY,
                          is_cancelled=True),
                mock.call(m,
                          pipeline1,
                          pipeline1.nodes[2].pipeline_node,
                          mock.ANY,
                          is_cancelled=True)
            ])
            self.assertEqual(2, mock_gen_task_from_active.call_count)

            # Pipeline execution should continue to be active since active node
            # executions were found in the last call to `generate_tasks`.
            [execution
             ] = m.store.get_executions_by_id([pipeline1_execution.id])
            self.assertTrue(execution_lib.is_execution_active(execution))

            # Call `generate_tasks` again; this time there are no more active node
            # executions so the pipeline should be marked as cancelled.
            pipeline_ops.generate_tasks(m, task_queue)
            self.assertTrue(task_queue.is_empty())
            [execution
             ] = m.store.get_executions_by_id([pipeline1_execution.id])
            self.assertEqual(metadata_store_pb2.Execution.CANCELED,
                             execution.last_known_state)
    def test_resolver_task_scheduler(self):
        with self._mlmd_connection as m:
            # Publishes two models which will be consumed by downstream resolver.
            output_model_1 = types.Artifact(
                self._trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_1.uri = 'my_model_uri_1'

            output_model_2 = types.Artifact(
                self._trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_2.uri = 'my_model_uri_2'

            contexts = context_lib.prepare_contexts(m, self._trainer.contexts)
            execution = execution_publish_utils.register_execution(
                m, self._trainer.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'model': [output_model_1, output_model_2],
                })

        task_queue = tq.TaskQueue()

        # Verify that resolver task is generated.
        [resolver_task] = test_utils.run_generator_and_test(
            test_case=self,
            mlmd_connection=self._mlmd_connection,
            generator_class=sptg.SyncPipelineTaskGenerator,
            pipeline=self._pipeline,
            task_queue=task_queue,
            use_task_queue=False,
            service_job_manager=None,
            num_initial_executions=1,
            num_tasks_generated=1,
            num_new_executions=1,
            num_active_executions=1,
            expected_exec_nodes=[self._resolver_node],
            ignore_update_node_state_tasks=True)

        with self._mlmd_connection as m:
            # Run resolver task scheduler and publish results.
            ts_result = resolver_task_scheduler.ResolverTaskScheduler(
                mlmd_handle=m, pipeline=self._pipeline,
                task=resolver_task).schedule()
            self.assertEqual(status_lib.Code.OK, ts_result.status.code)
            self.assertIsInstance(ts_result.output,
                                  task_scheduler.ResolverNodeOutput)
            self.assertCountEqual(
                ['resolved_model'],
                ts_result.output.resolved_input_artifacts.keys())
            models = ts_result.output.resolved_input_artifacts[
                'resolved_model']
            self.assertLen(models, 1)
            self.assertEqual('my_model_uri_2', models[0].mlmd_artifact.uri)
            tm._publish_execution_results(m, resolver_task, ts_result)

        # Verify resolver node output is input to the downstream consumer node.
        [consumer_task] = test_utils.run_generator_and_test(
            test_case=self,
            mlmd_connection=self._mlmd_connection,
            generator_class=sptg.SyncPipelineTaskGenerator,
            pipeline=self._pipeline,
            task_queue=task_queue,
            use_task_queue=False,
            service_job_manager=None,
            num_initial_executions=2,
            num_tasks_generated=1,
            num_new_executions=1,
            num_active_executions=1,
            expected_exec_nodes=[self._consumer_node],
            ignore_update_node_state_tasks=True)
        self.assertCountEqual(['resolved_model'],
                              consumer_task.input_artifacts.keys())
        input_models = consumer_task.input_artifacts['resolved_model']
        self.assertLen(input_models, 1)
        self.assertEqual('my_model_uri_2', input_models[0].mlmd_artifact.uri)
Beispiel #19
0
    def test_task_handling(self, mock_publish):
        collector = _Collector()

        # Register a fake task scheduler.
        ts.TaskSchedulerRegistry.register(
            self._type_url,
            functools.partial(_FakeTaskScheduler,
                              block_nodes={'Trainer', 'Transform', 'Pusher'},
                              collector=collector))

        task_queue = tq.TaskQueue()

        # Enqueue some tasks.
        trainer_exec_task = _test_exec_node_task('Trainer',
                                                 'test-pipeline',
                                                 pipeline=self._pipeline)
        task_queue.enqueue(trainer_exec_task)
        task_queue.enqueue(_test_cancel_node_task('Trainer', 'test-pipeline'))

        with self._task_manager(task_queue) as task_manager:
            # Enqueue more tasks after task manager starts.
            transform_exec_task = _test_exec_node_task('Transform',
                                                       'test-pipeline',
                                                       pipeline=self._pipeline)
            task_queue.enqueue(transform_exec_task)
            evaluator_exec_task = _test_exec_node_task('Evaluator',
                                                       'test-pipeline',
                                                       pipeline=self._pipeline)
            task_queue.enqueue(evaluator_exec_task)
            task_queue.enqueue(
                _test_cancel_node_task('Transform', 'test-pipeline'))
            pusher_exec_task = _test_exec_node_task('Pusher',
                                                    'test-pipeline',
                                                    pipeline=self._pipeline)
            task_queue.enqueue(pusher_exec_task)
            task_queue.enqueue(
                _test_cancel_node_task('Pusher', 'test-pipeline', pause=True))

        self.assertTrue(task_manager.done())
        self.assertIsNone(task_manager.exception())

        # Ensure that all exec and cancellation tasks were processed correctly.
        self.assertCountEqual([
            trainer_exec_task,
            transform_exec_task,
            evaluator_exec_task,
            pusher_exec_task,
        ], collector.scheduled_tasks)
        self.assertCountEqual([
            trainer_exec_task,
            transform_exec_task,
            pusher_exec_task,
        ], collector.cancelled_tasks)

        result_ok = ts.TaskSchedulerResult(status=status_lib.Status(
            code=status_lib.Code.OK, message='_FakeTaskScheduler result'))
        result_cancelled = ts.TaskSchedulerResult(
            status=status_lib.Status(code=status_lib.Code.CANCELLED,
                                     message='_FakeTaskScheduler result'))
        mock_publish.assert_has_calls([
            mock.call(mlmd_handle=mock.ANY,
                      task=trainer_exec_task,
                      result=result_cancelled),
            mock.call(mlmd_handle=mock.ANY,
                      task=transform_exec_task,
                      result=result_cancelled),
            mock.call(mlmd_handle=mock.ANY,
                      task=evaluator_exec_task,
                      result=result_ok),
        ],
                                      any_order=True)
        # It is expected that publish is not called for Pusher because it was
        # cancelled with pause=True so there must be only 3 calls.
        self.assertLen(mock_publish.mock_calls, 3)
Beispiel #20
0
    def test_active_pipelines_with_stop_initiated_nodes(
            self, mock_gen_task_from_active, mock_async_task_gen):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline')
            pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'

            transform_node_uid = task_lib.NodeUid.from_pipeline_node(
                pipeline, pipeline.nodes[0].pipeline_node)
            transform_task = test_utils.create_exec_node_task(
                node_uid=transform_node_uid)

            trainer_node_uid = task_lib.NodeUid.from_pipeline_node(
                pipeline, pipeline.nodes[1].pipeline_node)
            trainer_task = test_utils.create_exec_node_task(
                node_uid=trainer_node_uid)

            evaluator_node_uid = task_lib.NodeUid.from_pipeline_node(
                pipeline, pipeline.nodes[2].pipeline_node)
            evaluator_task = test_utils.create_exec_node_task(
                node_uid=evaluator_node_uid)
            cancelled_evaluator_task = test_utils.create_exec_node_task(
                node_uid=evaluator_node_uid, is_cancelled=True)

            pipeline_ops.initiate_pipeline_start(m, pipeline)
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                # Stop trainer and evaluator.
                pipeline_state.initiate_node_stop(trainer_node_uid)
                pipeline_state.initiate_node_stop(evaluator_node_uid)

            task_queue = tq.TaskQueue()

            # Simulate a new transform execution being triggered.
            mock_async_task_gen.return_value.generate.return_value = [
                transform_task
            ]
            # Simulate ExecNodeTask for trainer already present in the task queue.
            task_queue.enqueue(trainer_task)
            # Simulate Evaluator having an active execution in MLMD.
            mock_gen_task_from_active.side_effect = [evaluator_task]

            pipeline_ops.orchestrate(m, task_queue)
            self.assertEqual(
                1, mock_async_task_gen.return_value.generate.call_count)

            # Verify that tasks are enqueued in the expected order:

            # Pre-existing trainer task.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertEqual(trainer_task, task)

            # CancelNodeTask for trainer.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_cancel_node_task(task))
            self.assertEqual(trainer_node_uid, task.node_uid)

            # ExecNodeTask with is_cancelled=True for evaluator.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(cancelled_evaluator_task, task)

            # ExecNodeTask for newly triggered transform node.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertEqual(transform_task, task)

            # No more tasks.
            self.assertTrue(task_queue.is_empty())
Beispiel #21
0
  def test_active_pipelines_with_stop_initiated_nodes(self,
                                                      mock_gen_task_from_active,
                                                      mock_async_task_gen):
    with self._mlmd_connection as m:
      pipeline = _test_pipeline('pipeline')
      pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'

      mock_service_job_manager = mock.create_autospec(
          service_jobs.ServiceJobManager, instance=True)
      mock_service_job_manager.is_pure_service_node.side_effect = (
          lambda _, node_id: node_id == 'ExampleGen')
      example_gen_node_uid = task_lib.NodeUid.from_pipeline_node(
          pipeline, pipeline.nodes[0].pipeline_node)

      transform_node_uid = task_lib.NodeUid.from_pipeline_node(
          pipeline, pipeline.nodes[1].pipeline_node)
      transform_task = test_utils.create_exec_node_task(
          node_uid=transform_node_uid)

      trainer_node_uid = task_lib.NodeUid.from_pipeline_node(
          pipeline, pipeline.nodes[2].pipeline_node)
      trainer_task = test_utils.create_exec_node_task(node_uid=trainer_node_uid)

      evaluator_node_uid = task_lib.NodeUid.from_pipeline_node(
          pipeline, pipeline.nodes[3].pipeline_node)
      evaluator_task = test_utils.create_exec_node_task(
          node_uid=evaluator_node_uid)
      cancelled_evaluator_task = test_utils.create_exec_node_task(
          node_uid=evaluator_node_uid, is_cancelled=True)

      pipeline_ops.initiate_pipeline_start(m, pipeline)
      with pstate.PipelineState.load(
          m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state:
        # Stop example-gen, trainer and evaluator.
        pipeline_state.initiate_node_stop(
            example_gen_node_uid,
            status_lib.Status(code=status_lib.Code.CANCELLED))
        pipeline_state.initiate_node_stop(
            trainer_node_uid, status_lib.Status(code=status_lib.Code.CANCELLED))
        pipeline_state.initiate_node_stop(
            evaluator_node_uid, status_lib.Status(code=status_lib.Code.ABORTED))

      task_queue = tq.TaskQueue()

      # Simulate a new transform execution being triggered.
      mock_async_task_gen.return_value.generate.return_value = [transform_task]
      # Simulate ExecNodeTask for trainer already present in the task queue.
      task_queue.enqueue(trainer_task)
      # Simulate Evaluator having an active execution in MLMD.
      mock_gen_task_from_active.side_effect = [evaluator_task]

      pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)
      self.assertEqual(1, mock_async_task_gen.return_value.generate.call_count)

      # stop_node_services should be called on example-gen which is a pure
      # service node.
      mock_service_job_manager.stop_node_services.assert_called_once_with(
          mock.ANY, 'ExampleGen')

      # Verify that tasks are enqueued in the expected order:

      # Pre-existing trainer task.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertEqual(trainer_task, task)

      # CancelNodeTask for trainer.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_cancel_node_task(task))
      self.assertEqual(trainer_node_uid, task.node_uid)

      # ExecNodeTask with is_cancelled=True for evaluator.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(cancelled_evaluator_task, task)

      # ExecNodeTask for newly triggered transform node.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertEqual(transform_task, task)

      # No more tasks.
      self.assertTrue(task_queue.is_empty())
Beispiel #22
0
    def test_generate_tasks_async_active_pipelines(self, mock_async_task_gen,
                                                   mock_sync_task_gen):
        with self._mlmd_connection as m:
            # One active pipeline.
            pipeline1 = _test_pipeline('pipeline1')
            pipeline_ops.initiate_pipeline_start(m, pipeline1)

            # Another active pipeline (with previously completed execution).
            pipeline2 = _test_pipeline('pipeline2')
            execution2 = pipeline_ops.initiate_pipeline_start(m, pipeline2)
            execution2.last_known_state = metadata_store_pb2.Execution.COMPLETE
            m.store.put_executions([execution2])
            execution2 = pipeline_ops.initiate_pipeline_start(m, pipeline2)

            # Inactive pipelines should be ignored.
            pipeline3 = _test_pipeline('pipeline3')
            execution3 = pipeline_ops.initiate_pipeline_start(m, pipeline3)
            execution3.last_known_state = metadata_store_pb2.Execution.COMPLETE
            m.store.put_executions([execution3])

            # For active pipelines pipeline1 and pipeline2, there are a couple of
            # active executions.
            def _exec_node_tasks():
                for pipeline_id in ('pipeline1', 'pipeline2'):
                    yield [
                        test_utils.create_exec_node_task(
                            node_uid=task_lib.
                            NodeUid(pipeline_uid=task_lib.PipelineUid(
                                pipeline_id=pipeline_id, pipeline_run_id=None),
                                    node_id='Transform')),
                        test_utils.create_exec_node_task(
                            node_uid=task_lib.
                            NodeUid(pipeline_uid=task_lib.PipelineUid(
                                pipeline_id=pipeline_id, pipeline_run_id=None),
                                    node_id='Trainer'))
                    ]

            mock_async_task_gen.return_value.generate.side_effect = _exec_node_tasks(
            )

            task_queue = tq.TaskQueue()
            pipeline_ops.generate_tasks(m, task_queue)

            self.assertEqual(
                2, mock_async_task_gen.return_value.generate.call_count)
            mock_sync_task_gen.assert_not_called()

            # Verify that tasks are enqueued in the expected order.
            for node_id in ('Transform', 'Trainer'):
                task = task_queue.dequeue()
                task_queue.task_done(task)
                self.assertTrue(task_lib.is_exec_node_task(task))
                self.assertEqual(node_id, task.node_uid.node_id)
                self.assertEqual('pipeline1',
                                 task.node_uid.pipeline_uid.pipeline_id)
            for node_id in ('Transform', 'Trainer'):
                task = task_queue.dequeue()
                task_queue.task_done(task)
                self.assertTrue(task_lib.is_exec_node_task(task))
                self.assertEqual(node_id, task.node_uid.node_id)
                self.assertEqual('pipeline2',
                                 task.node_uid.pipeline_uid.pipeline_id)
            self.assertTrue(task_queue.is_empty())
Beispiel #23
0
    def setUp(self):
        super().setUp()
        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())

        # Makes sure multiple connections within a test always connect to the same
        # MLMD instance.
        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        self._metadata_path = metadata_path
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)

        # Sets up the pipeline.
        pipeline = test_async_pipeline.create_pipeline()

        # Extracts components.
        self._example_gen = pipeline.nodes[0].pipeline_node
        self._transform = pipeline.nodes[1].pipeline_node
        self._trainer = pipeline.nodes[2].pipeline_node

        # Pack deployment config for testing.
        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
            class_path='fake.ClassPath')
        deployment_config.executor_specs[self._trainer.node_info.id].Pack(
            executor_spec)
        deployment_config.executor_specs[self._transform.node_info.id].Pack(
            executor_spec)
        self._type_url = deployment_config.executor_specs[
            self._trainer.node_info.id].type_url
        pipeline.deployment_config.Pack(deployment_config)
        self._pipeline = pipeline
        self._pipeline_info = pipeline.pipeline_info
        self._pipeline_runtime_spec = pipeline.runtime_spec
        self._pipeline_runtime_spec.pipeline_root.field_value.string_value = (
            pipeline_root)

        ts.TaskSchedulerRegistry.clear()
        self._task_queue = tq.TaskQueue()

        # Run fake example-gen to prepare downstreams component triggers.
        test_utils.fake_example_gen_run(self._mlmd_connection,
                                        self._example_gen, 1, 1)

        # Task generator should produce two tasks for transform. The first one is
        # UpdateNodeStateTask and the second one is ExecNodeTask.
        with self._mlmd_connection as m:
            pipeline_state = pstate.PipelineState.new(m, self._pipeline)
            tasks = asptg.AsyncPipelineTaskGenerator(
                m, self._task_queue.contains_task_id,
                service_jobs.DummyServiceJobManager()).generate(pipeline_state)
        self.assertLen(tasks, 2)
        self.assertTrue(task_lib.is_update_node_state_task(tasks[0]))
        self.assertEqual(pstate.NodeState.RUNNING, tasks[0].state)
        self.assertEqual('my_transform', tasks[0].node_uid.node_id)
        self.assertTrue(task_lib.is_exec_node_task(tasks[1]))
        self.assertEqual('my_transform', tasks[1].node_uid.node_id)
        self.assertTrue(os.path.exists(tasks[1].stateful_working_dir))
        self.assertTrue(os.path.exists(tasks[1].tmp_dir))
        self._task = tasks[1]
        self._output_artifact_uri = self._task.output_artifacts[
            'transform_graph'][0].uri
        self.assertTrue(os.path.exists(self._output_artifact_uri))
        self._task_queue.enqueue(self._task)

        # There should be 1 active execution in MLMD.
        with self._mlmd_connection as m:
            executions = m.store.get_executions()
        active_executions = [
            e for e in executions
            if e.last_known_state == metadata_store_pb2.Execution.RUNNING
        ]
        self.assertLen(active_executions, 1)

        # Active execution id.
        self._execution_id = active_executions[0].id