Ejemplo n.º 1
0
def _get_executor_spec(pipeline: pipeline_pb2.Pipeline,
                       node_id: str) -> Optional[any_pb2.Any]:
    """Returns executor spec for given node_id if it exists in pipeline IR, None otherwise."""
    if not pipeline.deployment_config.Is(
            pipeline_pb2.IntermediateDeploymentConfig.DESCRIPTOR):
        return None
    depl_config = pipeline_pb2.IntermediateDeploymentConfig()
    pipeline.deployment_config.Unpack(depl_config)
    return depl_config.executor_specs.get(node_id)
Ejemplo n.º 2
0
 def main(argv):
   del argv
   with mlmd_connection_func(FLAGS.path) as m:
     depl_config = pipeline_pb2.IntermediateDeploymentConfig()
     executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
         class_path='fake.ClassPath')
     depl_config.executor_specs['arg1'].Pack(executor_spec)
     depl_config.executor_specs['arg2'].Pack(executor_spec)
     create_sample_pipeline(m, FLAGS.pipeline_id, FLAGS.pipeline_run_num,
                            FLAGS.export_ir_dir, FLAGS.ir_file, depl_config,
                            execute_nodes_func)
Ejemplo n.º 3
0
    def test_queue_multiplexing(self, mock_publish):
        # Create a pipeline IR containing deployment config for testing.
        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
            class_path='trainer.TrainerExecutor')
        deployment_config.executor_specs['Trainer'].Pack(executor_spec)
        deployment_config.executor_specs['Transform'].Pack(executor_spec)
        deployment_config.executor_specs['Evaluator'].Pack(executor_spec)
        pipeline = pipeline_pb2.Pipeline()
        pipeline.deployment_config.Pack(deployment_config)

        collector = _Collector()

        # Register a bunch of fake task schedulers.
        # Register fake task scheduler.
        ts.TaskSchedulerRegistry.register(
            deployment_config.executor_specs['Trainer'].type_url,
            functools.partial(_FakeTaskScheduler,
                              block_nodes={'Trainer', 'Transform'},
                              collector=collector))

        task_queue = tq.TaskQueue()

        # Enqueue some tasks.
        trainer_exec_task = _test_exec_task('Trainer', 'test-pipeline')
        task_queue.enqueue(trainer_exec_task)
        task_queue.enqueue(_test_cancel_task('Trainer', 'test-pipeline'))

        with tm.TaskManager(mock.Mock(),
                            pipeline,
                            task_queue,
                            max_active_task_schedulers=1000,
                            max_dequeue_wait_secs=0.1,
                            process_all_queued_tasks_before_exit=True):
            # Enqueue more tasks after task manager starts.
            transform_exec_task = _test_exec_task('Transform', 'test-pipeline')
            task_queue.enqueue(transform_exec_task)
            evaluator_exec_task = _test_exec_task('Evaluator', 'test-pipeline')
            task_queue.enqueue(evaluator_exec_task)
            task_queue.enqueue(_test_cancel_task('Transform', 'test-pipeline'))

        # Ensure that all exec and cancellation tasks were processed correctly.
        self.assertCountEqual(
            [trainer_exec_task, transform_exec_task, evaluator_exec_task],
            collector.scheduled_tasks)
        self.assertCountEqual([trainer_exec_task, transform_exec_task],
                              collector.cancelled_tasks)
        mock_publish.assert_has_calls([
            mock.call(mock.ANY, pipeline, trainer_exec_task, mock.ANY),
            mock.call(mock.ANY, pipeline, transform_exec_task, mock.ANY),
            mock.call(mock.ANY, pipeline, evaluator_exec_task, mock.ANY)
        ],
                                      any_order=True)
Ejemplo n.º 4
0
def replace_executor_with_stub(pipeline: pipeline_pb2.Pipeline,
                               test_data_dir: str,
                               test_component_ids: List[str]):
    """Replace executors in pipeline IR with the stub executor.

  This funciton will replace the IR inplace.
  For example,

  pipeline_mock.replace_executor_with_stub(
      pipeline_ir,
      test_data_dir,
      test_component_ids = ['Trainer', 'Transform'])

  Then you can pass the modified `pipeline_ir` into a dag runner to execute
  the stubbed pipeline.

  Args:
    pipeline: The pipeline to alter.
    test_data_dir: The directory where pipeline outputs are recorded
      (pipeline_recorder.py).
    test_component_ids: List of ids of components that are to be tested. In
      other words, executors of components other than those specified by this
      list will be replaced with a BaseStubExecutor.

  Returns:
    None
  """
    deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
    if not pipeline.deployment_config.Unpack(deployment_config):
        raise NotImplementedError(
            'Unexpected pipeline.deployment_config type "{}". Currently only '
            'IntermediateDeploymentConfig is supported.'.format(
                pipeline.deployment_config.type_url))

    for component_id in deployment_config.executor_specs:
        if component_id not in test_component_ids:
            executable_spec = deployment_config.executor_specs[component_id]
            if not executable_spec.Is(
                    executable_spec_pb2.PythonClassExecutableSpec.DESCRIPTOR):
                raise NotImplementedError(
                    'Unexpected executable_spec type "{}". Currently only '
                    'PythonClassExecutableSpec is supported.'.format(
                        executable_spec.type_url))
            stub_executor_class_spec = executor_spec.ExecutorClassSpec(
                base_stub_executor.BaseStubExecutor)
            stub_executor_class_spec.add_extra_flags(
                (base_stub_executor.TEST_DATA_DIR_FLAG + '=' + test_data_dir,
                 base_stub_executor.COMPONENT_ID_FLAG + '=' + component_id))
            stub_executor_spec = stub_executor_class_spec.encode()
            executable_spec.Pack(stub_executor_spec)
    pipeline.deployment_config.Pack(deployment_config)
Ejemplo n.º 5
0
  def compile(self, tfx_pipeline: pipeline.Pipeline) -> pipeline_pb2.Pipeline:
    """Compiles a tfx pipeline into uDSL proto.

    Args:
      tfx_pipeline: A TFX pipeline.

    Returns:
      A Pipeline proto that encodes all necessary information of the pipeline.
    """
    _validate_pipeline(tfx_pipeline)
    context = _CompilerContext(tfx_pipeline)
    pipeline_pb = pipeline_pb2.Pipeline()
    pipeline_pb.pipeline_info.id = context.pipeline_info.pipeline_name
    pipeline_pb.execution_mode = context.execution_mode
    if isinstance(context.pipeline_info.pipeline_root, placeholder.Placeholder):
      pipeline_pb.runtime_spec.pipeline_root.placeholder.CopyFrom(
          context.pipeline_info.pipeline_root.encode())
    else:
      compiler_utils.set_runtime_parameter_pb(
          pipeline_pb.runtime_spec.pipeline_root.runtime_parameter,
          constants.PIPELINE_ROOT_PARAMETER_NAME, str,
          context.pipeline_info.pipeline_root)
    if pipeline_pb.execution_mode == pipeline_pb2.Pipeline.ExecutionMode.SYNC:
      compiler_utils.set_runtime_parameter_pb(
          pipeline_pb.runtime_spec.pipeline_run_id.runtime_parameter,
          constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)

    deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
    if tfx_pipeline.metadata_connection_config:
      deployment_config.metadata_connection_config.Pack(
          tfx_pipeline.metadata_connection_config)
    for node in tfx_pipeline.components:
      # In ASYNC mode Resolver nodes are merged into the downstream node as a
      # ResolverConfig
      if compiler_utils.is_resolver(node) and context.is_async_mode:
        continue
      node_pb = self._compile_node(node, context, deployment_config,
                                   tfx_pipeline.enable_cache)
      pipeline_or_node = pipeline_pb.PipelineOrNode()
      pipeline_or_node.pipeline_node.CopyFrom(node_pb)
      # TODO(b/158713812): Support sub-pipeline.
      pipeline_pb.nodes.append(pipeline_or_node)
      context.node_pbs[node.id] = node_pb

    if tfx_pipeline.platform_config:
      deployment_config.pipeline_level_platform_config.Pack(
          tfx_pipeline.platform_config)
    pipeline_pb.deployment_config.Pack(deployment_config)
    return pipeline_pb
Ejemplo n.º 6
0
 def setUp(self):
   super().setUp()
   pipeline = pipeline_pb2.Pipeline()
   pipeline.pipeline_info.id = 'pipeline'
   pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
   pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
   importer_node = pipeline.nodes.add().pipeline_node
   importer_node.node_info.id = 'Importer'
   importer_node.node_info.type.name = constants.IMPORTER_NODE_TYPE
   deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
   executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
       class_path='trainer.TrainerExecutor')
   deployment_config.executor_specs['Trainer'].Pack(executor_spec)
   pipeline.deployment_config.Pack(deployment_config)
   self._spec_type_url = deployment_config.executor_specs['Trainer'].type_url
   self._pipeline = pipeline
   ts.TaskSchedulerRegistry.clear()
Ejemplo n.º 7
0
    def compile(self,
                tfx_pipeline: pipeline.Pipeline) -> pipeline_pb2.Pipeline:
        """Compiles a tfx pipeline into uDSL proto.

    Args:
      tfx_pipeline: A TFX pipeline.

    Returns:
      A Pipeline proto that encodes all necessary information of the pipeline.
    """
        context = _CompilerContext(
            tfx_pipeline.pipeline_info,
            compiler_utils.resolve_execution_mode(tfx_pipeline))
        pipeline_pb = pipeline_pb2.Pipeline()
        pipeline_pb.pipeline_info.id = context.pipeline_info.pipeline_name
        pipeline_pb.execution_mode = context.execution_mode
        compiler_utils.set_runtime_parameter_pb(
            pipeline_pb.runtime_spec.pipeline_root.runtime_parameter,
            constants.PIPELINE_ROOT_PARAMETER_NAME, str,
            context.pipeline_info.pipeline_root)
        if pipeline_pb.execution_mode == pipeline_pb2.Pipeline.ExecutionMode.SYNC:
            compiler_utils.set_runtime_parameter_pb(
                pipeline_pb.runtime_spec.pipeline_run_id.runtime_parameter,
                constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)

        assert compiler_utils.ensure_topological_order(
            tfx_pipeline.components), (
                "Pipeline components are not topologically sorted.")
        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        if tfx_pipeline.metadata_connection_config:
            deployment_config.metadata_connection_config.Pack(
                tfx_pipeline.metadata_connection_config)
        for node in tfx_pipeline.components:
            node_pb = self._compile_node(node, context, deployment_config,
                                         tfx_pipeline.enable_cache)
            pipeline_or_node = pipeline_pb.PipelineOrNode()
            pipeline_or_node.pipeline_node.CopyFrom(node_pb)
            # TODO(b/158713812): Support sub-pipeline.
            pipeline_pb.nodes.append(pipeline_or_node)
            context.node_pbs[node.id] = node_pb

        if tfx_pipeline.platform_config:
            deployment_config.pipeline_level_platform_config.Pack(
                tfx_pipeline.platform_config)
        pipeline_pb.deployment_config.Pack(deployment_config)
        return pipeline_pb
Ejemplo n.º 8
0
    def setUp(self):
        super(TaskManagerTest, self).setUp()

        # Create a pipeline IR containing deployment config for testing.
        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
            class_path='trainer.TrainerExecutor')
        deployment_config.executor_specs['Trainer'].Pack(executor_spec)
        deployment_config.executor_specs['Transform'].Pack(executor_spec)
        deployment_config.executor_specs['Evaluator'].Pack(executor_spec)
        pipeline = pipeline_pb2.Pipeline()
        pipeline.deployment_config.Pack(deployment_config)

        ts.TaskSchedulerRegistry.clear()

        self._deployment_config = deployment_config
        self._pipeline = pipeline
        self._type_url = deployment_config.executor_specs['Trainer'].type_url
Ejemplo n.º 9
0
    def _dehydrate_tfx_ir(self, original_pipeline: pipeline_pb2.Pipeline,
                          node_id: str) -> pipeline_pb2.Pipeline:
        pipeline = copy.deepcopy(original_pipeline)
        for node in pipeline.nodes:
            if (node.WhichOneof('node') == 'pipeline_node'
                    and node.pipeline_node.node_info.id == node_id):
                del pipeline.nodes[:]
                pipeline.nodes.extend([node])
                break

        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        pipeline.deployment_config.Unpack(deployment_config)
        self._del_unused_field(node_id, deployment_config.executor_specs)
        self._del_unused_field(node_id, deployment_config.custom_driver_specs)
        self._del_unused_field(node_id,
                               deployment_config.node_level_platform_configs)
        pipeline.deployment_config.Pack(deployment_config)
        return pipeline
Ejemplo n.º 10
0
def extract_local_deployment_config(
    pipeline: pipeline_pb2.Pipeline
) -> local_deployment_config_pb2.LocalDeploymentConfig:
    """Extracts the proto.Any pipeline.deployment_config to LocalDeploymentConfig."""

    if not pipeline.deployment_config:
        raise ValueError('deployment_config is not available in the pipeline.')

    result = local_deployment_config_pb2.LocalDeploymentConfig()
    if pipeline.deployment_config.Unpack(result):
        return result

    result = pipeline_pb2.IntermediateDeploymentConfig()
    if pipeline.deployment_config.Unpack(result):
        return _to_local_deployment(result)

    raise ValueError('deployment_config {} of type {} is not supported'.format(
        pipeline.deployment_config, type(pipeline.deployment_config)))
Ejemplo n.º 11
0
    def test_registration_and_creation(self):
        # Create a pipeline IR containing deployment config for testing.
        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
            class_path='trainer.TrainerExecutor')
        deployment_config.executor_specs['Trainer'].Pack(executor_spec)
        pipeline = pipeline_pb2.Pipeline()
        pipeline.deployment_config.Pack(deployment_config)

        # Register a fake task scheduler.
        spec_type_url = deployment_config.executor_specs['Trainer'].type_url
        ts.TaskSchedulerRegistry.register(spec_type_url, _FakeTaskScheduler)

        # Create a task and verify that the correct scheduler is instantiated.
        task = test_utils.create_exec_node_task(node_uid=task_lib.NodeUid(
            pipeline_id='pipeline', pipeline_run_id=None, node_id='Trainer'))
        task_scheduler = ts.TaskSchedulerRegistry.create_task_scheduler(
            mock.Mock(), pipeline, task)
        self.assertIsInstance(task_scheduler, _FakeTaskScheduler)
Ejemplo n.º 12
0
def _fix_deployment_config(
    input_pipeline: p_pb2.Pipeline,
    node_ids_to_keep: Collection[str]) -> Union[any_pb2.Any, None]:
  """Filter per-node deployment configs.

  Cast deployment configs from Any proto to IntermediateDeploymentConfig.
  Take all three per-node fields and filter out the nodes using
  node_ids_to_keep. This works because those fields don't contain references to
  other nodes.

  Args:
    input_pipeline: The input Pipeline IR proto.
    node_ids_to_keep: Set of node_ids to keep.

  Returns:
    If the deployment_config field is set in the input_pipeline, this would
    output the deployment config with filtered per-node configs, then cast into
    an Any proto. If the deployment_config field is unset in the input_pipeline,
    then this function would return None.
  """
  if not input_pipeline.HasField('deployment_config'):
    return None

  deployment_config = p_pb2.IntermediateDeploymentConfig()
  input_pipeline.deployment_config.Unpack(deployment_config)

  def _fix_per_node_config(config_map: MutableMapping[str, Any]):
    # We have to make two passes because we cannot modify the dictionary while
    # iterating over it.
    node_ids_to_delete = [
        node_id for node_id in config_map if node_id not in node_ids_to_keep
    ]
    for node_id_to_delete in node_ids_to_delete:
      del config_map[node_id_to_delete]

  _fix_per_node_config(deployment_config.executor_specs)
  _fix_per_node_config(deployment_config.custom_driver_specs)
  _fix_per_node_config(deployment_config.node_level_platform_configs)

  result = any_pb2.Any()
  result.Pack(deployment_config)
  return result
Ejemplo n.º 13
0
 def _scheduler_class_for_executor_spec(
         cls: Type[T], pipeline: pipeline_pb2.Pipeline,
         task: task_lib.ExecNodeTask) -> Type[TaskScheduler]:
     """Returns scheduler class for executor spec url if feasible, raises error otherwise."""
     if not pipeline.deployment_config.Is(
             pipeline_pb2.IntermediateDeploymentConfig.DESCRIPTOR):
         raise ValueError('No deployment config found in pipeline IR')
     depl_config = pipeline_pb2.IntermediateDeploymentConfig()
     pipeline.deployment_config.Unpack(depl_config)
     node_id = task.node_uid.node_id
     if node_id not in depl_config.executor_specs:
         raise ValueError(f'Executor spec not found for node id: {node_id}')
     executor_spec_type_url = depl_config.executor_specs[node_id].type_url
     scheduler_class = cls._task_scheduler_registry.get(
         executor_spec_type_url)
     if scheduler_class is None:
         raise ValueError(
             f'No task scheduler registered for executor spec type url: '
             f'{executor_spec_type_url}')
     return scheduler_class
Ejemplo n.º 14
0
    def create_task_scheduler(cls: Type[T], mlmd_handle: metadata.Metadata,
                              pipeline: pipeline_pb2.Pipeline,
                              task: task_lib.Task) -> TaskScheduler:
        """Creates a task scheduler for the given task.

    Note that this assumes deployment_config packed in the pipeline IR is of
    type `IntermediateDeploymentConfig`. This detail may change in the future.

    Args:
      mlmd_handle: A handle to the MLMD db.
      pipeline: The pipeline IR.
      task: The task that needs to be scheduled.

    Returns:
      An instance of `TaskScheduler` for the given task.

    Raises:
      NotImplementedError: Raised if not an `ExecNodeTask`.
      ValueError: Deployment config not present in the IR proto or if executor
        spec for the node corresponding to `task` not configured in the IR.
    """
        if not task_lib.is_exec_node_task(task):
            raise NotImplementedError(
                'Can create a task scheduler only for an `ExecNodeTask`.')
        task = typing.cast(task_lib.ExecNodeTask, task)
        # TODO(b/170383494): Decide which DeploymentConfig to use.
        if not pipeline.deployment_config.Is(
                pipeline_pb2.IntermediateDeploymentConfig.DESCRIPTOR):
            raise ValueError('No deployment config found in pipeline IR.')
        depl_config = pipeline_pb2.IntermediateDeploymentConfig()
        pipeline.deployment_config.Unpack(depl_config)
        node_id = task.node_uid.node_id
        if node_id not in depl_config.executor_specs:
            raise ValueError(
                'Executor spec for node id `{}` not found in pipeline IR.'.
                format(node_id))
        executor_spec_type_url = depl_config.executor_specs[node_id].type_url
        return cls._task_scheduler_registry[executor_spec_type_url](
            mlmd_handle=mlmd_handle, pipeline=pipeline, task=task)
Ejemplo n.º 15
0
      }
    }
    custom_driver_specs {
      key: "my_example_gen"
      value {
        [type.googleapis.com/tfx.orchestration.executable_spec.PythonClassExecutableSpec] {
          class_path: "tfx.components.example_gen_driver"
        }
      }
    }
    metadata_connection_config {
      [type.googleapis.com/ml_metadata.ConnectionConfig] {
        fake_database {}
      }
    }
""", pipeline_pb2.IntermediateDeploymentConfig())

_executed_components = []
_component_executors = {}
_component_drivers = {}
_conponent_to_pipeline_run = {}


# TODO(b/162980675): When PythonExecutorOperator is implemented. We don't
# Need to Fake the whole FakeComponentAsDoFn. Instead, just fake or mock
# executors.
class _FakeComponentAsDoFn(beam_dag_runner._PipelineNodeAsDoFn):
    def __init__(self, pipeline_node: pipeline_pb2.PipelineNode,
                 mlmd_connection: metadata.Metadata,
                 pipeline_info: pipeline_pb2.PipelineInfo,
                 pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec,
Ejemplo n.º 16
0
    def setUp(self):
        super(TaskManagerE2ETest, self).setUp()
        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())

        # Makes sure multiple connections within a test always connect to the same
        # MLMD instance.
        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        self._metadata_path = metadata_path
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)

        # Sets up the pipeline.
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'async_pipeline.pbtxt'), pipeline)

        # Extracts components.
        self._example_gen = pipeline.nodes[0].pipeline_node
        self._transform = pipeline.nodes[1].pipeline_node
        self._trainer = pipeline.nodes[2].pipeline_node

        # Pack deployment config for testing.
        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
            class_path='fake.ClassPath')
        deployment_config.executor_specs[self._trainer.node_info.id].Pack(
            executor_spec)
        deployment_config.executor_specs[self._transform.node_info.id].Pack(
            executor_spec)
        self._type_url = deployment_config.executor_specs[
            self._trainer.node_info.id].type_url
        pipeline.deployment_config.Pack(deployment_config)
        self._pipeline = pipeline
        self._pipeline_info = pipeline.pipeline_info
        self._pipeline_runtime_spec = pipeline.runtime_spec
        self._pipeline_runtime_spec.pipeline_root.field_value.string_value = (
            pipeline_root)

        ts.TaskSchedulerRegistry.clear()
        self._task_queue = tq.TaskQueue()

        # Run fake example-gen to prepare downstreams component triggers.
        test_utils.fake_example_gen_run(self._mlmd_connection,
                                        self._example_gen, 1, 1)

        # Task generator should produce a task to run transform.
        with self._mlmd_connection as m:
            pipeline_state = pstate.PipelineState(m, self._pipeline, 0)
            tasks = asptg.AsyncPipelineTaskGenerator(
                m, pipeline_state, self._task_queue.contains_task_id,
                service_jobs.DummyServiceJobManager()).generate()
        self.assertLen(tasks, 1)
        task = tasks[0]
        self.assertEqual('my_transform', task.node_uid.node_id)

        # Task generator should produce a task to run transform.
        with self._mlmd_connection as m:
            pipeline_state = pstate.PipelineState(m, self._pipeline, 0)
            tasks = asptg.AsyncPipelineTaskGenerator(
                m, pipeline_state, self._task_queue.contains_task_id,
                service_jobs.DummyServiceJobManager()).generate()
        self.assertLen(tasks, 1)
        self._task = tasks[0]
        self.assertEqual('my_transform', self._task.node_uid.node_id)
        self._task_queue.enqueue(self._task)

        # There should be 1 active execution in MLMD.
        with self._mlmd_connection as m:
            executions = m.store.get_executions()
        active_executions = [
            e for e in executions
            if e.last_known_state == metadata_store_pb2.Execution.RUNNING
        ]
        self.assertLen(active_executions, 1)

        # Active execution id.
        self._execution_id = active_executions[0].id
Ejemplo n.º 17
0
    def setUp(self):
        super().setUp()
        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())

        # Makes sure multiple connections within a test always connect to the same
        # MLMD instance.
        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        self._metadata_path = metadata_path
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)

        # Sets up the pipeline.
        pipeline = test_async_pipeline.create_pipeline()

        # Extracts components.
        self._example_gen = pipeline.nodes[0].pipeline_node
        self._transform = pipeline.nodes[1].pipeline_node
        self._trainer = pipeline.nodes[2].pipeline_node

        # Pack deployment config for testing.
        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
            class_path='fake.ClassPath')
        deployment_config.executor_specs[self._trainer.node_info.id].Pack(
            executor_spec)
        deployment_config.executor_specs[self._transform.node_info.id].Pack(
            executor_spec)
        self._type_url = deployment_config.executor_specs[
            self._trainer.node_info.id].type_url
        pipeline.deployment_config.Pack(deployment_config)
        self._pipeline = pipeline
        self._pipeline_info = pipeline.pipeline_info
        self._pipeline_runtime_spec = pipeline.runtime_spec
        self._pipeline_runtime_spec.pipeline_root.field_value.string_value = (
            pipeline_root)

        ts.TaskSchedulerRegistry.clear()
        self._task_queue = tq.TaskQueue()

        # Run fake example-gen to prepare downstreams component triggers.
        test_utils.fake_example_gen_run(self._mlmd_connection,
                                        self._example_gen, 1, 1)

        # Task generator should produce two tasks for transform. The first one is
        # UpdateNodeStateTask and the second one is ExecNodeTask.
        with self._mlmd_connection as m:
            pipeline_state = pstate.PipelineState.new(m, self._pipeline)
            tasks = asptg.AsyncPipelineTaskGenerator(
                m, self._task_queue.contains_task_id,
                service_jobs.DummyServiceJobManager()).generate(pipeline_state)
        self.assertLen(tasks, 2)
        self.assertTrue(task_lib.is_update_node_state_task(tasks[0]))
        self.assertEqual(pstate.NodeState.RUNNING, tasks[0].state)
        self.assertEqual('my_transform', tasks[0].node_uid.node_id)
        self.assertTrue(task_lib.is_exec_node_task(tasks[1]))
        self.assertEqual('my_transform', tasks[1].node_uid.node_id)
        self.assertTrue(os.path.exists(tasks[1].stateful_working_dir))
        self.assertTrue(os.path.exists(tasks[1].tmp_dir))
        self._task = tasks[1]
        self._output_artifact_uri = self._task.output_artifacts[
            'transform_graph'][0].uri
        self.assertTrue(os.path.exists(self._output_artifact_uri))
        self._task_queue.enqueue(self._task)

        # There should be 1 active execution in MLMD.
        with self._mlmd_connection as m:
            executions = m.store.get_executions()
        active_executions = [
            e for e in executions
            if e.last_known_state == metadata_store_pb2.Execution.RUNNING
        ]
        self.assertLen(active_executions, 1)

        # Active execution id.
        self._execution_id = active_executions[0].id