def testMrpcPipelineNativeKeras(self): pipeline = bert_mrpc_pipeline._create_pipeline( pipeline_name=self._pipeline_name, data_root=self._data_root, module_file=self._module_file, serving_model_dir=self._serving_model_dir, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, beam_pipeline_args=['--direct_num_workers=1']) BeamDagRunner().run(pipeline) self.assertTrue(tf.io.gfile.exists(self._serving_model_dir)) self.assertTrue(tf.io.gfile.exists(self._metadata_path)) expected_execution_count = 9 # 8 components + 1 resolver metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) with metadata.Metadata(metadata_config) as m: artifact_count = len(m.store.get_artifacts()) execution_count = len(m.store.get_executions()) self.assertGreaterEqual(artifact_count, execution_count) self.assertEqual(expected_execution_count, execution_count) self.assertPipelineExecution() # Runs pipeline the second time. BeamDagRunner().run(pipeline) # All executions but Evaluator and Pusher are cached. with metadata.Metadata(metadata_config) as m: # Artifact count is increased by 3 caused by Evaluator and Pusher. self.assertEqual(artifact_count + 3, len(m.store.get_artifacts())) artifact_count = len(m.store.get_artifacts()) self.assertEqual(expected_execution_count * 2, len(m.store.get_executions())) # Runs pipeline the third time. BeamDagRunner().run(pipeline) # Asserts cache execution. with metadata.Metadata(metadata_config) as m: # Artifact count is unchanged. self.assertEqual(artifact_count, len(m.store.get_artifacts())) self.assertEqual(expected_execution_count * 3, len(m.store.get_executions()))
def testMrpcPipelineNativeKeras(self): pipeline = bert_mrpc_pipeline._create_pipeline( pipeline_name=self._pipeline_name, data_root=self._data_root, module_file=self._module_file, serving_model_dir=self._serving_model_dir, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, beam_pipeline_args=[]) LocalDagRunner().run(pipeline) self.assertTrue(fileio.exists(self._serving_model_dir)) self.assertTrue(fileio.exists(self._metadata_path)) expected_execution_count = 9 # 8 components + 1 resolver metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) with metadata.Metadata(metadata_config) as m: artifact_count = len(m.store.get_artifacts()) execution_count = len(m.store.get_executions()) self.assertGreaterEqual(artifact_count, execution_count) self.assertEqual(expected_execution_count, execution_count) self.assertPipelineExecution()