def test_export_tflite_with_post_processing(self): saved_model_path = os.path.join(self.tmp_path, 'saved_model') driver = inference.ServingDriver('efficientdet-lite0', self.tmp_path, only_network=False) driver.export(saved_model_path, tflite='FP32') self.assertTrue( tf.io.gfile.exists(os.path.join(saved_model_path, 'fp32.tflite'))) tf.io.gfile.rmtree(saved_model_path) tfrecord_path = test_util.make_fake_tfrecord(self.get_temp_dir()) driver.export(saved_model_path, tflite='INT8', file_pattern=[tfrecord_path], num_calibration_steps=1) self.assertTrue( tf.io.gfile.exists(os.path.join(saved_model_path, 'int8.tflite')))
def test_export_tflite_only_network(self): saved_model_path = os.path.join(self.lite_tmp_path, 'saved_model') driver = infer_lib.KerasDriver(self.lite_tmp_path, False, 'efficientdet-lite0', only_network=True) driver.export(saved_model_path, tflite='FP32') self.assertTrue( tf.io.gfile.exists(os.path.join(saved_model_path, 'fp32.tflite'))) tf.io.gfile.rmtree(saved_model_path) driver.export(saved_model_path, tflite='FP16') self.assertTrue( tf.io.gfile.exists(os.path.join(saved_model_path, 'fp16.tflite'))) tf.io.gfile.rmtree(saved_model_path) tfrecord_path = test_util.make_fake_tfrecord(self.get_temp_dir()) driver.export(saved_model_path, tflite='INT8', file_pattern=[tfrecord_path], num_calibration_steps=1) self.assertTrue( tf.io.gfile.exists(os.path.join(saved_model_path, 'int8.tflite')))
def test_parser(self): tf.random.set_seed(111111) params = hparams_config.get_detection_config( 'efficientdet-d0').as_dict() input_anchors = anchors.Anchors(params['min_level'], params['max_level'], params['num_scales'], params['aspect_ratios'], params['anchor_scale'], params['image_size']) anchor_labeler = anchors.AnchorLabeler(input_anchors, params['num_classes']) example_decoder = tf_example_decoder.TfExampleDecoder( regenerate_source_id=params['regenerate_source_id']) tfrecord_path = test_util.make_fake_tfrecord(self.get_temp_dir()) dataset = tf.data.TFRecordDataset([tfrecord_path]) value = next(iter(dataset)) reader = dataloader.InputReader(tfrecord_path, True) result = reader.dataset_parser(value, example_decoder, anchor_labeler, params) self.assertEqual(len(result), 11)