def test_audio_classification_demo(self, spec, dataset):
        with patch_data_loader():
            with tempfile.TemporaryDirectory() as temp_dir:
                # Use cached training data if exists.
                data_dir = audio_classification_demo.download_speech_commands_dataset(
                    cache_dir=test_util.get_cache_dir(
                        temp_dir, 'mini_speech_commands.zip'),
                    file_hash='4b8a67bae2973844e84fa7ac988d1a44')

                audio_classification_demo.run(spec,
                                              data_dir,
                                              dataset,
                                              temp_dir,
                                              epochs=1,
                                              batch_size=1)

                tflite_filename = os.path.join(temp_dir, 'model.tflite')
                self.assertTrue(tf.io.gfile.exists(tflite_filename))
                self.assertGreater(os.path.getsize(tflite_filename), 0)
Ejemplo n.º 2
0
  def test_text_classification_demo(self):
    with patch_data_loader():
      with tempfile.TemporaryDirectory() as temp_dir:
        # Use cached training data if exists.
        data_dir = text_classification_demo.download_demo_data(
            cache_dir=test_util.get_cache_dir(temp_dir, 'SST-2.zip'),
            file_hash='9f81648d4199384278b86e315dac217c')

        tflite_filename = os.path.join(temp_dir, 'model.tflite')
        label_filename = os.path.join(temp_dir, 'labels.txt')
        vocab_filename = os.path.join(temp_dir, 'vocab')
        # TODO(b/150597348): Bert model is out of memory when export to tflite.
        # Changed to a smaller bert models like mobilebert later for unittest.
        text_classification_demo.run(
            data_dir, temp_dir, spec='average_word_vec', epochs=1, batch_size=1)

        self.assertTrue(tf.io.gfile.exists(tflite_filename))
        self.assertGreater(os.path.getsize(tflite_filename), 0)

        self.assertFalse(tf.io.gfile.exists(label_filename))
        self.assertFalse(tf.io.gfile.exists(vocab_filename))
Ejemplo n.º 3
0
    def test_image_classification_demo(self):
        with patch_data_loader():
            with tempfile.TemporaryDirectory() as temp_dir:
                # Use cached training data if exists.
                data_dir = image_classification_demo.download_demo_data(
                    cache_dir=test_util.get_cache_dir(temp_dir,
                                                      'flower_photos.tgz'),
                    file_hash='6f87fb78e9cc9ab41eff2015b380011d')

                tflite_filename = os.path.join(temp_dir, 'model.tflite')
                label_filename = os.path.join(temp_dir, 'labels.txt')
                image_classification_demo.run(data_dir,
                                              temp_dir,
                                              spec='efficientnet_lite0',
                                              epochs=1,
                                              batch_size=1)

                self.assertTrue(tf.io.gfile.exists(tflite_filename))
                self.assertGreater(os.path.getsize(tflite_filename), 0)

                self.assertFalse(tf.io.gfile.exists(label_filename))