Ejemplo n.º 1
0
 def test_export_tflite_with_post_processing(self):
     saved_model_path = os.path.join(self.tmp_path, 'saved_model')
     driver = inference.ServingDriver('efficientdet-lite0',
                                      self.tmp_path,
                                      only_network=False)
     driver.export(saved_model_path, tflite='FP32')
     self.assertTrue(
         tf.io.gfile.exists(os.path.join(saved_model_path, 'fp32.tflite')))
     tf.io.gfile.rmtree(saved_model_path)
     tfrecord_path = test_util.make_fake_tfrecord(self.get_temp_dir())
     driver.export(saved_model_path,
                   tflite='INT8',
                   file_pattern=[tfrecord_path],
                   num_calibration_steps=1)
     self.assertTrue(
         tf.io.gfile.exists(os.path.join(saved_model_path, 'int8.tflite')))
Ejemplo n.º 2
0
 def test_export_tflite_only_network(self):
     saved_model_path = os.path.join(self.lite_tmp_path, 'saved_model')
     driver = infer_lib.KerasDriver(self.lite_tmp_path,
                                    False,
                                    'efficientdet-lite0',
                                    only_network=True)
     driver.export(saved_model_path, tflite='FP32')
     self.assertTrue(
         tf.io.gfile.exists(os.path.join(saved_model_path, 'fp32.tflite')))
     tf.io.gfile.rmtree(saved_model_path)
     driver.export(saved_model_path, tflite='FP16')
     self.assertTrue(
         tf.io.gfile.exists(os.path.join(saved_model_path, 'fp16.tflite')))
     tf.io.gfile.rmtree(saved_model_path)
     tfrecord_path = test_util.make_fake_tfrecord(self.get_temp_dir())
     driver.export(saved_model_path,
                   tflite='INT8',
                   file_pattern=[tfrecord_path],
                   num_calibration_steps=1)
     self.assertTrue(
         tf.io.gfile.exists(os.path.join(saved_model_path, 'int8.tflite')))
Ejemplo n.º 3
0
 def test_parser(self):
     tf.random.set_seed(111111)
     params = hparams_config.get_detection_config(
         'efficientdet-d0').as_dict()
     input_anchors = anchors.Anchors(params['min_level'],
                                     params['max_level'],
                                     params['num_scales'],
                                     params['aspect_ratios'],
                                     params['anchor_scale'],
                                     params['image_size'])
     anchor_labeler = anchors.AnchorLabeler(input_anchors,
                                            params['num_classes'])
     example_decoder = tf_example_decoder.TfExampleDecoder(
         regenerate_source_id=params['regenerate_source_id'])
     tfrecord_path = test_util.make_fake_tfrecord(self.get_temp_dir())
     dataset = tf.data.TFRecordDataset([tfrecord_path])
     value = next(iter(dataset))
     reader = dataloader.InputReader(tfrecord_path, True)
     result = reader.dataset_parser(value, example_decoder, anchor_labeler,
                                    params)
     self.assertEqual(len(result), 11)