def test_image_classification_demo(self):
        with patch_data_loader():
            with tempfile.TemporaryDirectory() as temp_dir:
                # Use cached training data if exists.
                data_dir = image_classification_demo.download_demo_data(
                    cache_dir=get_cache_dir(),
                    file_hash='6f87fb78e9cc9ab41eff2015b380011d')

                tflite_filename = os.path.join(temp_dir, 'model.tflite')
                label_filename = os.path.join(temp_dir, 'label.txt')
                image_classification_demo.run(data_dir,
                                              tflite_filename,
                                              label_filename,
                                              spec='efficientnet_b0',
                                              epochs=1)
                self.assertTrue(tf.io.gfile.exists(tflite_filename))
                self.assertTrue(tf.io.gfile.exists(label_filename))
示例#2
0
文件: cli.py 项目: bqi1/PoseNet
  def image_classification(self,
                           data_dir,
                           export_dir,
                           spec='efficientnet_lite0',
                           **kwargs):
    """Run Image classification.

    Args:
      data_dir: str, input directory of training data. (required)
      export_dir: str, output directory to export files. (required)
      spec: str, model_name. Valid: {MODELS}, default: efficientnet_lite0.
      **kwargs: --epochs: int, epoch num to run. More: see `create` function.
    """
    # Convert types
    data_dir = str(data_dir)
    export_dir = str(export_dir)

    image_classification_demo.run(data_dir, export_dir, spec, **kwargs)
示例#3
0
    def test_image_classification_demo(self):
        with patch_data_loader():
            with tempfile.TemporaryDirectory() as temp_dir:
                # Use cached training data if exists.
                data_dir = image_classification_demo.download_demo_data(
                    cache_dir=test_util.get_cache_dir(temp_dir,
                                                      'flower_photos.tgz'),
                    file_hash='6f87fb78e9cc9ab41eff2015b380011d')

                tflite_filename = os.path.join(temp_dir, 'model.tflite')
                label_filename = os.path.join(temp_dir, 'labels.txt')
                image_classification_demo.run(data_dir,
                                              temp_dir,
                                              spec='efficientnet_lite0',
                                              epochs=1,
                                              batch_size=1)

                self.assertTrue(tf.io.gfile.exists(tflite_filename))
                self.assertGreater(os.path.getsize(tflite_filename), 0)

                self.assertFalse(tf.io.gfile.exists(label_filename))
示例#4
0
文件: cli.py 项目: lixiaoyue/examples
  def image_classification(self,
                           data_dir,
                           tflite_filename,
                           label_filename,
                           spec='efficientnet_b0',
                           **kwargs):
    """Run Image classification.

    Args:
      data_dir: str, input directory of training data. (required)
      tflite_filename: str, output path to export tflite file. (required)
      label_filename: str, output path to export label file. (required)
      spec: str, model_name. Valid: {MODELS}, default: efficientnet_b0.
      **kwargs: --epochs: int, epoch num to run. More: see `create` function.
    """
    # Convert types
    data_dir = str(data_dir)
    tflite_filename = str(tflite_filename)
    label_filename = str(label_filename)

    image_classification_demo.run(data_dir, tflite_filename, label_filename,
                                  spec, **kwargs)