Ejemplo n.º 1
0
 def _exec_node_tasks():
     for pipeline_id in ('pipeline1', 'pipeline2'):
         yield [
             test_utils.create_exec_node_task(
                 node_uid=task_lib.
                 NodeUid(pipeline_uid=task_lib.PipelineUid(
                     pipeline_id=pipeline_id, pipeline_run_id=None),
                         node_id='Transform')),
             test_utils.create_exec_node_task(
                 node_uid=task_lib.
                 NodeUid(pipeline_uid=task_lib.PipelineUid(
                     pipeline_id=pipeline_id, pipeline_run_id=None),
                         node_id='Trainer'))
         ]
Ejemplo n.º 2
0
    def test_stop_pipeline_non_existent(self):
        with self._mlmd_connection as m:
            # Stop pipeline without creating one.
            with self.assertRaises(
                    status_lib.StatusNotOkError) as exception_context:
                pipeline_ops.stop_pipeline(
                    m,
                    task_lib.PipelineUid(pipeline_id='foo',
                                         pipeline_run_id=None))
            self.assertEqual(status_lib.Code.NOT_FOUND,
                             exception_context.exception.code)

            # Initiate pipeline start and mark it completed.
            pipeline1 = _test_pipeline('pipeline1')
            execution = pipeline_ops.initiate_pipeline_start(m, pipeline1)
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline1)
            pipeline_ops._initiate_pipeline_stop(m, pipeline_uid)
            execution.last_known_state = metadata_store_pb2.Execution.COMPLETE
            m.store.put_executions([execution])

            # Try to initiate stop again.
            with self.assertRaises(
                    status_lib.StatusNotOkError) as exception_context:
                pipeline_ops.stop_pipeline(m, pipeline_uid)
            self.assertEqual(status_lib.Code.ALREADY_EXISTS,
                             exception_context.exception.code)
Ejemplo n.º 3
0
 def test_task_ids(self):
   pipeline_uid = task_lib.PipelineUid(pipeline_id='pipeline')
   node_uid = task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Trainer')
   exec_node_task = test_utils.create_exec_node_task(node_uid)
   self.assertEqual(('ExecNodeTask', node_uid), exec_node_task.task_id)
   cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid)
   self.assertEqual(('CancelNodeTask', node_uid), cancel_node_task.task_id)
Ejemplo n.º 4
0
def pipeline_uid_from_orchestrator_context(
        context: metadata_store_pb2.Context) -> task_lib.PipelineUid:
    """Returns pipeline uid from orchestrator reserved context."""
    splits = context.name.split(':')
    pipeline_id = splits[0]
    key = splits[1] if len(splits) > 1 else ''
    return task_lib.PipelineUid(pipeline_id=pipeline_id, key=key)
Ejemplo n.º 5
0
 def test_sync_pipeline_run_id_runtime_parameter(self):
     with self._mlmd_connection as m:
         pipeline = _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC)
         pipeline_state = pipeline_ops.initiate_pipeline_start(m, pipeline)
         self.assertNotEmpty(pipeline_state.pipeline.runtime_spec.
                             pipeline_run_id.field_value.string_value)
         self.assertEqual(task_lib.PipelineUid(pipeline_id='pipeline1'),
                          pipeline_state.pipeline_uid)
Ejemplo n.º 6
0
 def test_scheduler_not_found(self):
     task = test_utils.create_exec_node_task(node_uid=task_lib.NodeUid(
         pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'),
         node_id='Transform'),
                                             pipeline=self._pipeline)
     with self.assertRaisesRegex(ValueError, 'No task scheduler found'):
         ts.TaskSchedulerRegistry.create_task_scheduler(
             mock.Mock(), self._pipeline, task)
Ejemplo n.º 7
0
def _test_exec_node_task(node_id,
                         pipeline_id,
                         pipeline_run_id=None,
                         pipeline=None):
    node_uid = task_lib.NodeUid(pipeline_uid=task_lib.PipelineUid(
        pipeline_id=pipeline_id, pipeline_run_id=pipeline_run_id),
                                node_id=node_id)
    return test_utils.create_exec_node_task(node_uid, pipeline=pipeline)
Ejemplo n.º 8
0
 def test_node_uid_from_pipeline_node(self):
   pipeline = pipeline_pb2.Pipeline()
   pipeline.pipeline_info.id = 'pipeline'
   node = pipeline_pb2.PipelineNode()
   node.node_info.id = 'Trainer'
   self.assertEqual(
       task_lib.NodeUid(
           pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'),
           node_id='Trainer'),
       task_lib.NodeUid.from_pipeline_node(pipeline, node))
Ejemplo n.º 9
0
 def test_node_uid_from_pipeline_node(self):
     pipeline = pipeline_pb2.Pipeline()
     pipeline.pipeline_info.id = 'pipeline'
     pipeline.runtime_spec.pipeline_run_id.field_value.string_value = 'run0'
     node = pipeline_pb2.PipelineNode()
     node.node_info.id = 'Trainer'
     self.assertEqual(
         task_lib.NodeUid(pipeline_uid=task_lib.PipelineUid(
             pipeline_id='pipeline', pipeline_run_id='run0'),
                          node_id='Trainer'),
         task_lib.NodeUid.from_pipeline_node(pipeline, node))
Ejemplo n.º 10
0
  def test_register_using_executor_spec_type_url(self):
    # Register a fake task scheduler.
    ts.TaskSchedulerRegistry.register(self._spec_type_url, _FakeTaskScheduler)

    # Create a task and verify that the correct scheduler is instantiated.
    task = test_utils.create_exec_node_task(
        node_uid=task_lib.NodeUid(
            pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'),
            node_id='Trainer'),
        pipeline=self._pipeline)
    task_scheduler = ts.TaskSchedulerRegistry.create_task_scheduler(
        mock.Mock(), self._pipeline, task)
    self.assertIsInstance(task_scheduler, _FakeTaskScheduler)
Ejemplo n.º 11
0
    def test_register_using_node_type_name(self):
        # Register a fake task scheduler.
        ts.TaskSchedulerRegistry.register(constants.IMPORTER_NODE_TYPE,
                                          _FakeTaskScheduler)

        # Create a task and verify that the correct scheduler is instantiated.
        task = test_utils.create_exec_node_task(node_uid=task_lib.NodeUid(
            pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'),
            node_id='Importer'),
                                                pipeline=self._pipeline)
        task_scheduler = ts.TaskSchedulerRegistry.create_task_scheduler(
            mock.Mock(), self._pipeline, task)
        self.assertIsInstance(task_scheduler, _FakeTaskScheduler)
Ejemplo n.º 12
0
    def test_registration_and_creation(self):
        # Create a pipeline IR containing deployment config for testing.
        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
            class_path='trainer.TrainerExecutor')
        deployment_config.executor_specs['Trainer'].Pack(executor_spec)
        pipeline = pipeline_pb2.Pipeline()
        pipeline.deployment_config.Pack(deployment_config)

        # Register a fake task scheduler.
        spec_type_url = deployment_config.executor_specs['Trainer'].type_url
        ts.TaskSchedulerRegistry.register(spec_type_url, _FakeTaskScheduler)

        # Create a task and verify that the correct scheduler is instantiated.
        task = test_utils.create_exec_node_task(
            node_uid=task_lib.NodeUid(pipeline_uid=task_lib.PipelineUid(
                pipeline_id='pipeline', pipeline_run_id=None),
                                      node_id='Trainer'))
        task_scheduler = ts.TaskSchedulerRegistry.create_task_scheduler(
            mock.Mock(), pipeline, task)
        self.assertIsInstance(task_scheduler, _FakeTaskScheduler)
Ejemplo n.º 13
0
    def test_stop_initiated_async_pipelines(self, mock_gen_task_from_active,
                                            mock_async_task_gen,
                                            mock_sync_task_gen):
        with self._mlmd_connection as m:
            pipeline1 = _test_pipeline('pipeline1')
            pipeline1.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline1.nodes.add().pipeline_node.node_info.id = 'Trainer'
            pipeline1.nodes.add().pipeline_node.node_info.id = 'Evaluator'
            pipeline_ops.initiate_pipeline_start(m, pipeline1)
            pipeline1_execution = pipeline_ops._initiate_pipeline_stop(
                m, task_lib.PipelineUid.from_pipeline(pipeline1))

            task_queue = tq.TaskQueue()

            # For the stop-initiated pipeline, "Transform" execution task is in queue,
            # "Trainer" has an active execution in MLMD but no task in queue,
            # "Evaluator" has no active execution.
            task_queue.enqueue(
                test_utils.create_exec_node_task(node_uid=task_lib.NodeUid(
                    pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1',
                                                      pipeline_run_id=None),
                    node_id='Transform')))
            transform_task = task_queue.dequeue(
            )  # simulates task being processed
            mock_gen_task_from_active.side_effect = [
                test_utils.create_exec_node_task(node_uid=task_lib.NodeUid(
                    pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1',
                                                      pipeline_run_id=None),
                    node_id='Trainer'),
                                                 is_cancelled=True), None,
                None, None, None
            ]

            pipeline_ops.generate_tasks(m, task_queue)

            # There are no active pipelines so these shouldn't be called.
            mock_async_task_gen.assert_not_called()
            mock_sync_task_gen.assert_not_called()

            # Simulate finishing the "Transform" ExecNodeTask.
            task_queue.task_done(transform_task)

            # CancelNodeTask for the "Transform" ExecNodeTask should be next.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_cancel_node_task(task))
            self.assertEqual('Transform', task.node_uid.node_id)

            # ExecNodeTask for "Trainer" is next.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_exec_node_task(task))
            self.assertEqual('Trainer', task.node_uid.node_id)

            self.assertTrue(task_queue.is_empty())

            mock_gen_task_from_active.assert_has_calls([
                mock.call(m,
                          pipeline1,
                          pipeline1.nodes[1].pipeline_node,
                          mock.ANY,
                          is_cancelled=True),
                mock.call(m,
                          pipeline1,
                          pipeline1.nodes[2].pipeline_node,
                          mock.ANY,
                          is_cancelled=True)
            ])
            self.assertEqual(2, mock_gen_task_from_active.call_count)

            # Pipeline execution should continue to be active since active node
            # executions were found in the last call to `generate_tasks`.
            [execution
             ] = m.store.get_executions_by_id([pipeline1_execution.id])
            self.assertTrue(execution_lib.is_execution_active(execution))

            # Call `generate_tasks` again; this time there are no more active node
            # executions so the pipeline should be marked as cancelled.
            pipeline_ops.generate_tasks(m, task_queue)
            self.assertTrue(task_queue.is_empty())
            [execution
             ] = m.store.get_executions_by_id([pipeline1_execution.id])
            self.assertEqual(metadata_store_pb2.Execution.CANCELED,
                             execution.last_known_state)
Ejemplo n.º 14
0
def _pipeline_uid_from_context(
        context: metadata_store_pb2.Context) -> task_lib.PipelineUid:
    """Returns pipeline uid from orchestrator reserved context."""
    pipeline_id = context.name.split(_ORCHESTRATOR_RESERVED_ID + '_')[1]
    return task_lib.PipelineUid(pipeline_id=pipeline_id, pipeline_run_id=None)
Ejemplo n.º 15
0
def pipeline_uid_from_orchestrator_context(
        context: metadata_store_pb2.Context) -> task_lib.PipelineUid:
    """Returns pipeline uid from orchestrator reserved context."""
    return task_lib.PipelineUid(context.name)
Ejemplo n.º 16
0
def create_node_uid(pipeline_id, node_id):
  """Creates node uid."""
  return task_lib.NodeUid(
      pipeline_uid=task_lib.PipelineUid(pipeline_id=pipeline_id),
      node_id=node_id)
Ejemplo n.º 17
0
def _test_cancel_node_task(node_id, pipeline_id):
    node_uid = task_lib.NodeUid(
        pipeline_uid=task_lib.PipelineUid(pipeline_id=pipeline_id),
        node_id=node_id)
    return task_lib.CancelNodeTask(node_uid=node_uid)
Ejemplo n.º 18
0
def _test_task(node_id, pipeline_id, key=''):
  node_uid = task_lib.NodeUid(
      pipeline_uid=task_lib.PipelineUid(pipeline_id=pipeline_id, key=key),
      node_id=node_id)
  return test_utils.create_exec_node_task(node_uid)