Esempio n. 1
0
def run_generator_and_test(test_case,
                           mlmd_connection,
                           generator_class,
                           pipeline,
                           task_queue,
                           use_task_queue,
                           service_job_manager,
                           num_initial_executions,
                           num_tasks_generated,
                           num_new_executions,
                           num_active_executions,
                           expected_exec_nodes=None,
                           ignore_node_ids=None):
    """Runs generator.generate() and tests the effects."""
    if service_job_manager is None:
        service_job_manager = service_jobs.DummyServiceJobManager()
    with mlmd_connection as m:
        executions = m.store.get_executions()
        test_case.assertLen(
            executions, num_initial_executions,
            f'Expected {num_initial_executions} execution(s) in MLMD.')
        pipeline_state = pstate.PipelineState.new(m, pipeline)
        generator_params = dict(
            mlmd_handle=m,
            pipeline_state=pipeline_state,
            is_task_id_tracked_fn=task_queue.contains_task_id,
            service_job_manager=service_job_manager)
        if generator_class == asptg.AsyncPipelineTaskGenerator:
            generator_params['ignore_node_ids'] = ignore_node_ids
        task_gen = generator_class(**generator_params)
        tasks = task_gen.generate()
        test_case.assertLen(
            tasks, num_tasks_generated,
            f'Expected {num_tasks_generated} task(s) to be generated.')
        executions = m.store.get_executions()
        num_total_executions = num_initial_executions + num_new_executions
        test_case.assertLen(
            executions, num_total_executions,
            f'Expected {num_total_executions} execution(s) in MLMD.')
        active_executions = [
            e for e in executions if execution_lib.is_execution_active(e)
        ]
        test_case.assertLen(
            active_executions, num_active_executions,
            f'Expected {num_active_executions} active execution(s) in MLMD.')
        if expected_exec_nodes:
            for i, task in enumerate(tasks):
                _verify_exec_node_task(test_case, pipeline,
                                       expected_exec_nodes[i],
                                       active_executions[i].id, task)
        if use_task_queue:
            for task in tasks:
                if task_lib.is_exec_node_task(task):
                    task_queue.enqueue(task)
        return tasks
Esempio n. 2
0
def run_generator_and_test(test_case,
                           mlmd_connection,
                           generator_class,
                           pipeline,
                           task_queue,
                           use_task_queue,
                           service_job_manager,
                           num_initial_executions,
                           num_tasks_generated,
                           num_new_executions,
                           num_active_executions,
                           expected_exec_nodes=None,
                           ignore_update_node_state_tasks=False,
                           fail_fast=None):
    """Runs generator.generate() and tests the effects."""
    if service_job_manager is None:
        service_job_manager = service_jobs.DummyServiceJobManager()
    with mlmd_connection as m:
        executions = get_non_orchestrator_executions(m)
        test_case.assertLen(
            executions, num_initial_executions,
            f'Expected {num_initial_executions} execution(s) in MLMD.')
    tasks = run_generator(
        mlmd_connection,
        generator_class,
        pipeline,
        task_queue,
        use_task_queue,
        service_job_manager,
        ignore_update_node_state_tasks=ignore_update_node_state_tasks,
        fail_fast=fail_fast)
    with mlmd_connection as m:
        test_case.assertLen(
            tasks, num_tasks_generated,
            f'Expected {num_tasks_generated} task(s) to be generated.')
        executions = get_non_orchestrator_executions(m)
        num_total_executions = num_initial_executions + num_new_executions
        test_case.assertLen(
            executions, num_total_executions,
            f'Expected {num_total_executions} execution(s) in MLMD.')
        active_executions = [
            e for e in executions if execution_lib.is_execution_active(e)
        ]
        test_case.assertLen(
            active_executions, num_active_executions,
            f'Expected {num_active_executions} active execution(s) in MLMD.')
        if expected_exec_nodes:
            for i, task in enumerate(t for t in tasks
                                     if task_lib.is_exec_node_task(t)):
                _verify_exec_node_task(test_case, pipeline,
                                       expected_exec_nodes[i],
                                       active_executions[i].id, task)
        return tasks
Esempio n. 3
0
  def test_no_tasks_generated_when_no_inputs(self, min_count):
    """Tests no tasks are generated when there are no inputs, regardless of min_count."""
    for node in self._pipeline.nodes:
      for v in node.pipeline_node.inputs.inputs.values():
        v.min_count = min_count

    with self._mlmd_connection as m:
      pipeline_state = test_utils.get_or_create_pipeline_state(
          m, self._pipeline)
      task_gen = asptg.AsyncPipelineTaskGenerator(
          m, lambda _: False, service_jobs.DummyServiceJobManager())
      tasks = task_gen.generate(pipeline_state)
      self.assertEmpty(tasks, 'Expected no task generation when no inputs.')
      self.assertEmpty(
          test_utils.get_non_orchestrator_executions(m),
          'There must not be any registered executions since no tasks were '
          'generated.')
  def test_no_tasks_generated_when_no_inputs(self, min_count):
    """Tests no tasks are generated when there are no inputs, regardless of min_count."""
    for node in self._pipeline.nodes:
      for v in node.pipeline_node.inputs.inputs.values():
        v.min_count = min_count

    with self._mlmd_connection as m:
      pipeline_state = pstate.PipelineState(m, self._pipeline, 0)
      task_gen = asptg.AsyncPipelineTaskGenerator(
          m,
          pipeline_state,
          lambda _: False,
          service_jobs.DummyServiceJobManager(),
          ignore_node_ids=set([self._example_gen.node_info.id]))
      tasks = task_gen.generate()
      self.assertEmpty(tasks, 'Expected no task generation when no inputs.')
      self.assertEmpty(
          m.store.get_executions(),
          'There must not be any registered executions since no tasks were '
          'generated.')
Esempio n. 5
0
    def test_handling_finalize_node_task(self, task_gen):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1')
            pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
            pipeline_ops.initiate_pipeline_start(m, pipeline)
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
            finalize_reason = status_lib.Status(code=status_lib.Code.ABORTED,
                                                message='foo bar')
            task_gen.return_value.generate.side_effect = [
                [
                    test_utils.create_exec_node_task(
                        task_lib.NodeUid(pipeline_uid=pipeline_uid,
                                         node_id='Transform')),
                    task_lib.FinalizeNodeTask(node_uid=task_lib.NodeUid(
                        pipeline_uid=pipeline_uid, node_id='Trainer'),
                                              status=finalize_reason)
                ],
            ]

            task_queue = tq.TaskQueue()
            pipeline_ops.orchestrate(m, task_queue,
                                     service_jobs.DummyServiceJobManager())
            task_gen.return_value.generate.assert_called_once()
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_exec_node_task(task))
            self.assertEqual(
                test_utils.create_node_uid('pipeline1', 'Transform'),
                task.node_uid)

            # Load pipeline state and verify node stop initiation.
            with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
                self.assertEqual(
                    finalize_reason,
                    pipeline_state.node_stop_initiated_reason(
                        task_lib.NodeUid(pipeline_uid=pipeline_uid,
                                         node_id='Trainer')))
Esempio n. 6
0
  def test_handling_finalize_pipeline_task(self, task_gen):
    with self._mlmd_connection as m:
      pipeline = _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC)
      pipeline_ops.initiate_pipeline_start(m, pipeline)
      pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
      finalize_reason = status_lib.Status(
          code=status_lib.Code.ABORTED, message='foo bar')
      task_gen.return_value.generate.side_effect = [
          [
              task_lib.FinalizePipelineTask(
                  pipeline_uid=pipeline_uid, status=finalize_reason)
          ],
      ]

      task_queue = tq.TaskQueue()
      pipeline_ops.orchestrate(m, task_queue,
                               service_jobs.DummyServiceJobManager())
      task_gen.return_value.generate.assert_called_once()
      self.assertTrue(task_queue.is_empty())

      # Load pipeline state and verify stop initiation.
      with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
        self.assertEqual(finalize_reason,
                         pipeline_state.stop_initiated_reason())
Esempio n. 7
0
  def test_orchestrate_active_pipelines(self, mock_async_task_gen,
                                        mock_sync_task_gen):
    with self._mlmd_connection as m:
      # Sync and async active pipelines.
      async_pipelines = [
          _test_pipeline('pipeline1'),
          _test_pipeline('pipeline2'),
      ]
      sync_pipelines = [
          _test_pipeline('pipeline3', pipeline_pb2.Pipeline.SYNC),
          _test_pipeline('pipeline4', pipeline_pb2.Pipeline.SYNC),
      ]

      for pipeline in async_pipelines + sync_pipelines:
        pipeline_ops.initiate_pipeline_start(m, pipeline)

      # Active executions for active async pipelines.
      mock_async_task_gen.return_value.generate.side_effect = [
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          async_pipelines[0]),
                      node_id='Transform'))
          ],
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          async_pipelines[1]),
                      node_id='Trainer'))
          ],
      ]

      # Active executions for active sync pipelines.
      mock_sync_task_gen.return_value.generate.side_effect = [
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          sync_pipelines[0]),
                      node_id='Trainer'))
          ],
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          sync_pipelines[1]),
                      node_id='Validator'))
          ],
      ]

      task_queue = tq.TaskQueue()
      pipeline_ops.orchestrate(m, task_queue,
                               service_jobs.DummyServiceJobManager())

      self.assertEqual(2, mock_async_task_gen.return_value.generate.call_count)
      self.assertEqual(2, mock_sync_task_gen.return_value.generate.call_count)

      # Verify that tasks are enqueued in the expected order.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid)
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline2', 'Trainer'), task.node_uid)
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline3', 'Trainer'), task.node_uid)
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline4', 'Validator'), task.node_uid)
      self.assertTrue(task_queue.is_empty())
Esempio n. 8
0
    def setUp(self):
        super(TaskManagerE2ETest, self).setUp()
        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())

        # Makes sure multiple connections within a test always connect to the same
        # MLMD instance.
        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        self._metadata_path = metadata_path
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)

        # Sets up the pipeline.
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'async_pipeline.pbtxt'), pipeline)

        # Extracts components.
        self._example_gen = pipeline.nodes[0].pipeline_node
        self._transform = pipeline.nodes[1].pipeline_node
        self._trainer = pipeline.nodes[2].pipeline_node

        # Pack deployment config for testing.
        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
            class_path='fake.ClassPath')
        deployment_config.executor_specs[self._trainer.node_info.id].Pack(
            executor_spec)
        deployment_config.executor_specs[self._transform.node_info.id].Pack(
            executor_spec)
        self._type_url = deployment_config.executor_specs[
            self._trainer.node_info.id].type_url
        pipeline.deployment_config.Pack(deployment_config)
        self._pipeline = pipeline
        self._pipeline_info = pipeline.pipeline_info
        self._pipeline_runtime_spec = pipeline.runtime_spec
        self._pipeline_runtime_spec.pipeline_root.field_value.string_value = (
            pipeline_root)

        ts.TaskSchedulerRegistry.clear()
        self._task_queue = tq.TaskQueue()

        # Run fake example-gen to prepare downstreams component triggers.
        test_utils.fake_example_gen_run(self._mlmd_connection,
                                        self._example_gen, 1, 1)

        # Task generator should produce a task to run transform.
        with self._mlmd_connection as m:
            pipeline_state = pstate.PipelineState(m, self._pipeline, 0)
            tasks = asptg.AsyncPipelineTaskGenerator(
                m, pipeline_state, self._task_queue.contains_task_id,
                service_jobs.DummyServiceJobManager()).generate()
        self.assertLen(tasks, 1)
        task = tasks[0]
        self.assertEqual('my_transform', task.node_uid.node_id)

        # Task generator should produce a task to run transform.
        with self._mlmd_connection as m:
            pipeline_state = pstate.PipelineState(m, self._pipeline, 0)
            tasks = asptg.AsyncPipelineTaskGenerator(
                m, pipeline_state, self._task_queue.contains_task_id,
                service_jobs.DummyServiceJobManager()).generate()
        self.assertLen(tasks, 1)
        self._task = tasks[0]
        self.assertEqual('my_transform', self._task.node_uid.node_id)
        self._task_queue.enqueue(self._task)

        # There should be 1 active execution in MLMD.
        with self._mlmd_connection as m:
            executions = m.store.get_executions()
        active_executions = [
            e for e in executions
            if e.last_known_state == metadata_store_pb2.Execution.RUNNING
        ]
        self.assertLen(active_executions, 1)

        # Active execution id.
        self._execution_id = active_executions[0].id
Esempio n. 9
0
    def setUp(self):
        super().setUp()
        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())

        # Makes sure multiple connections within a test always connect to the same
        # MLMD instance.
        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        self._metadata_path = metadata_path
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)

        # Sets up the pipeline.
        pipeline = test_async_pipeline.create_pipeline()

        # Extracts components.
        self._example_gen = pipeline.nodes[0].pipeline_node
        self._transform = pipeline.nodes[1].pipeline_node
        self._trainer = pipeline.nodes[2].pipeline_node

        # Pack deployment config for testing.
        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
            class_path='fake.ClassPath')
        deployment_config.executor_specs[self._trainer.node_info.id].Pack(
            executor_spec)
        deployment_config.executor_specs[self._transform.node_info.id].Pack(
            executor_spec)
        self._type_url = deployment_config.executor_specs[
            self._trainer.node_info.id].type_url
        pipeline.deployment_config.Pack(deployment_config)
        self._pipeline = pipeline
        self._pipeline_info = pipeline.pipeline_info
        self._pipeline_runtime_spec = pipeline.runtime_spec
        self._pipeline_runtime_spec.pipeline_root.field_value.string_value = (
            pipeline_root)

        ts.TaskSchedulerRegistry.clear()
        self._task_queue = tq.TaskQueue()

        # Run fake example-gen to prepare downstreams component triggers.
        test_utils.fake_example_gen_run(self._mlmd_connection,
                                        self._example_gen, 1, 1)

        # Task generator should produce two tasks for transform. The first one is
        # UpdateNodeStateTask and the second one is ExecNodeTask.
        with self._mlmd_connection as m:
            pipeline_state = pstate.PipelineState.new(m, self._pipeline)
            tasks = asptg.AsyncPipelineTaskGenerator(
                m, self._task_queue.contains_task_id,
                service_jobs.DummyServiceJobManager()).generate(pipeline_state)
        self.assertLen(tasks, 2)
        self.assertTrue(task_lib.is_update_node_state_task(tasks[0]))
        self.assertEqual(pstate.NodeState.RUNNING, tasks[0].state)
        self.assertEqual('my_transform', tasks[0].node_uid.node_id)
        self.assertTrue(task_lib.is_exec_node_task(tasks[1]))
        self.assertEqual('my_transform', tasks[1].node_uid.node_id)
        self.assertTrue(os.path.exists(tasks[1].stateful_working_dir))
        self.assertTrue(os.path.exists(tasks[1].tmp_dir))
        self._task = tasks[1]
        self._output_artifact_uri = self._task.output_artifacts[
            'transform_graph'][0].uri
        self.assertTrue(os.path.exists(self._output_artifact_uri))
        self._task_queue.enqueue(self._task)

        # There should be 1 active execution in MLMD.
        with self._mlmd_connection as m:
            executions = m.store.get_executions()
        active_executions = [
            e for e in executions
            if e.last_known_state == metadata_store_pb2.Execution.RUNNING
        ]
        self.assertLen(active_executions, 1)

        # Active execution id.
        self._execution_id = active_executions[0].id