예제 #1
0
    def testIrisPipelineNativeKeras(self):
        BeamDagRunner().run(
            iris_pipeline_native_keras._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                direct_num_workers=1))

        self.assertTrue(tf.io.gfile.exists(self._serving_model_dir))
        self.assertTrue(tf.io.gfile.exists(self._metadata_path))
        expected_execution_count = 9  # 8 components + 1 resolver
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(expected_execution_count, execution_count)

        self.assertPipelineExecution()

        # Runs pipeline the second time.
        BeamDagRunner().run(
            iris_pipeline_native_keras._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                direct_num_workers=1))

        # All executions but Evaluator and Pusher are cached.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is increased by 3 caused by Evaluator and Pusher.
            self.assertEqual(artifact_count + 3, len(m.store.get_artifacts()))
            artifact_count = len(m.store.get_artifacts())
            self.assertEqual(expected_execution_count * 2,
                             len(m.store.get_executions()))

        # Runs pipeline the third time.
        BeamDagRunner().run(
            iris_pipeline_native_keras._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                direct_num_workers=1))

        # Asserts cache execution.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is unchanged.
            self.assertEqual(artifact_count, len(m.store.get_artifacts()))
            self.assertEqual(expected_execution_count * 3,
                             len(m.store.get_executions()))
    def testIrisPipelineNativeKeras(self):
        BeamDagRunner().run(
            iris_pipeline_native_keras._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path))

        self.assertTrue(tf.io.gfile.exists(self._metadata_path))
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(7, execution_count)

        self.assertPipelineExecution()
예제 #3
0
    def testIrisPipelineNativeKerasWithTuner(self):
        BeamDagRunner().run(
            iris_pipeline_native_keras._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                enable_tuning=True,
                direct_num_workers=1))

        self.assertTrue(tf.io.gfile.exists(self._serving_model_dir))
        self.assertTrue(tf.io.gfile.exists(self._metadata_path))
        expected_execution_count = 10  # 9 components + 1 resolver
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(expected_execution_count, execution_count)

        self.assertPipelineExecution(True)