def setUp(self):
        super().setUp()
        self._test_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        self._pipeline_name = 'imdb_stub_test'
        # This example assumes that the imdb data and imdb utility function are
        # stored in tfx/examples/imdb. Feel free to customize this as needed.
        imdb_root = os.path.dirname(imdb_pipeline_native_keras.__file__)
        self._data_root = os.path.join(imdb_root, 'data')
        self._module_file = os.path.join(imdb_root,
                                         'imdb_utils_native_keras.py')
        self._serving_model_dir = os.path.join(self._test_dir, 'serving_model')
        self._pipeline_root = os.path.join(self._test_dir, 'pipelines',
                                           self._pipeline_name)
        # Metadata path for recording successful pipeline run.
        self._recorded_mlmd_path = os.path.join(self._test_dir, 'record',
                                                'metadata.db')
        # Metadata path for stub pipeline
        self._metadata_path = os.path.join(self._test_dir, 'metadata',
                                           self._pipeline_name, 'metadata.db')
        self._recorded_output_dir = os.path.join(self._test_dir, 'testdata')

        record_imdb_pipeline = imdb_pipeline_native_keras._create_pipeline(  # pylint:disable=protected-access
            pipeline_name=self._pipeline_name,
            data_root=self._data_root,
            module_file=self._module_file,
            serving_model_dir=self._serving_model_dir,
            pipeline_root=self._pipeline_root,
            metadata_path=self._recorded_mlmd_path,
            beam_pipeline_args=[])

        BeamDagRunner().run(record_imdb_pipeline)

        pipeline_recorder_utils.record_pipeline(
            output_dir=self._recorded_output_dir,
            metadata_db_uri=self._recorded_mlmd_path,
            pipeline_name=self._pipeline_name)

        # Run pipeline with stub executors.
        self.imdb_pipeline = imdb_pipeline_native_keras._create_pipeline(  # pylint:disable=protected-access
            pipeline_name=self._pipeline_name,
            data_root=self._data_root,
            module_file=self._module_file,
            serving_model_dir=self._serving_model_dir,
            pipeline_root=self._pipeline_root,
            metadata_path=self._metadata_path,
            beam_pipeline_args=[])
  def testImdbPipelineNativeKeras(self):
    pipeline = imdb_pipeline_native_keras._create_pipeline(
        pipeline_name=self._pipeline_name,
        data_root=self._data_root,
        module_file=self._module_file,
        serving_model_dir=self._serving_model_dir,
        pipeline_root=self._pipeline_root,
        metadata_path=self._metadata_path,
        beam_pipeline_args=[])

    BeamDagRunner().run(pipeline)

    self.assertTrue(fileio.exists(self._serving_model_dir))
    self.assertTrue(fileio.exists(self._metadata_path))
    expected_execution_count = 9  # 8 components + 1 resolver
    metadata_config = metadata.sqlite_metadata_connection_config(
        self._metadata_path)
    with metadata.Metadata(metadata_config) as m:
      artifact_count = len(m.store.get_artifacts())
      execution_count = len(m.store.get_executions())
      self.assertGreaterEqual(artifact_count, execution_count)
      self.assertEqual(expected_execution_count, execution_count)

    self.assertPipelineExecution()

    # Runs pipeline the second time.
    BeamDagRunner().run(pipeline)

    # All executions but Evaluator and Pusher are cached.
    with metadata.Metadata(metadata_config) as m:
      # Artifact count is increased by 3 caused by Evaluator and Pusher.
      self.assertEqual(artifact_count + 3, len(m.store.get_artifacts()))
      artifact_count = len(m.store.get_artifacts())
      self.assertEqual(expected_execution_count * 2,
                       len(m.store.get_executions()))

    # Runs pipeline the third time.
    BeamDagRunner().run(pipeline)

    # Asserts cache execution.
    with metadata.Metadata(metadata_config) as m:
      # Artifact count is unchanged.
      self.assertEqual(artifact_count, len(m.store.get_artifacts()))
      self.assertEqual(expected_execution_count * 3,
                       len(m.store.get_executions()))