def test_audio_classification_demo(self, spec, dataset): with patch_data_loader(): with tempfile.TemporaryDirectory() as temp_dir: # Use cached training data if exists. data_dir = audio_classification_demo.download_speech_commands_dataset( cache_dir=test_util.get_cache_dir( temp_dir, 'mini_speech_commands.zip'), file_hash='4b8a67bae2973844e84fa7ac988d1a44') audio_classification_demo.run(spec, data_dir, dataset, temp_dir, epochs=1, batch_size=1) tflite_filename = os.path.join(temp_dir, 'model.tflite') self.assertTrue(tf.io.gfile.exists(tflite_filename)) self.assertGreater(os.path.getsize(tflite_filename), 0)
def test_text_classification_demo(self): with patch_data_loader(): with tempfile.TemporaryDirectory() as temp_dir: # Use cached training data if exists. data_dir = text_classification_demo.download_demo_data( cache_dir=test_util.get_cache_dir(temp_dir, 'SST-2.zip'), file_hash='9f81648d4199384278b86e315dac217c') tflite_filename = os.path.join(temp_dir, 'model.tflite') label_filename = os.path.join(temp_dir, 'labels.txt') vocab_filename = os.path.join(temp_dir, 'vocab') # TODO(b/150597348): Bert model is out of memory when export to tflite. # Changed to a smaller bert models like mobilebert later for unittest. text_classification_demo.run( data_dir, temp_dir, spec='average_word_vec', epochs=1, batch_size=1) self.assertTrue(tf.io.gfile.exists(tflite_filename)) self.assertGreater(os.path.getsize(tflite_filename), 0) self.assertFalse(tf.io.gfile.exists(label_filename)) self.assertFalse(tf.io.gfile.exists(vocab_filename))
def test_image_classification_demo(self): with patch_data_loader(): with tempfile.TemporaryDirectory() as temp_dir: # Use cached training data if exists. data_dir = image_classification_demo.download_demo_data( cache_dir=test_util.get_cache_dir(temp_dir, 'flower_photos.tgz'), file_hash='6f87fb78e9cc9ab41eff2015b380011d') tflite_filename = os.path.join(temp_dir, 'model.tflite') label_filename = os.path.join(temp_dir, 'labels.txt') image_classification_demo.run(data_dir, temp_dir, spec='efficientnet_lite0', epochs=1, batch_size=1) self.assertTrue(tf.io.gfile.exists(tflite_filename)) self.assertGreater(os.path.getsize(tflite_filename), 0) self.assertFalse(tf.io.gfile.exists(label_filename))