def _create_test_tfrecord(self, num_samples): tfexample_utils.dump_to_tfrecord(self._test_tfrecord_file, [ tf.train.Example.FromString( tfexample_utils.create_classification_example(image_height=256, image_width=256)) for _ in range(num_samples) ])
def test_export_tflite_image_classification(self, experiment, quant_type, input_image_size): test_tfrecord_file = os.path.join(self.get_temp_dir(), 'cls_test.tfrecord') example = tf.train.Example.FromString( tfexample_utils.create_classification_example( image_height=input_image_size[0], image_width=input_image_size[1])) self._create_test_tfrecord( tfrecord_file=test_tfrecord_file, example=example, num_samples=10) params = exp_factory.get_exp_config(experiment) params.task.validation_data.input_path = test_tfrecord_file params.task.train_data.input_path = test_tfrecord_file temp_dir = self.get_temp_dir() module = image_classification_serving.ClassificationModule( params=params, batch_size=1, input_image_size=input_image_size, input_type='tflite') self._export_from_module( module=module, input_type='tflite', saved_model_dir=os.path.join(temp_dir, 'saved_model')) tflite_model = export_tflite_lib.convert_tflite_model( saved_model_dir=os.path.join(temp_dir, 'saved_model'), quant_type=quant_type, params=params, calibration_steps=5) self.assertIsInstance(tflite_model, bytes)
def test_decoder(self, image_height, image_width, num_instances): decoder = classification_input.Decoder(image_field_key=IMAGE_FIELD_KEY, label_field_key=LABEL_FIELD_KEY) serialized_example = tfexample_utils.create_classification_example( image_height, image_width) decoded_tensors = decoder.decode( tf.convert_to_tensor(serialized_example)) results = tf.nest.map_structure(lambda x: x.numpy(), decoded_tensors) self.assertCountEqual([IMAGE_FIELD_KEY, LABEL_FIELD_KEY], results.keys()) self.assertEqual(0, results[LABEL_FIELD_KEY])
def test_parser(self, output_size, dtype, is_training, aug_name, is_multilabel, decode_jpeg_only, image_format): serialized_example = tfexample_utils.create_classification_example( output_size[0], output_size[1], image_format, is_multilabel) if aug_name == 'randaug': aug_type = common.Augmentation( type=aug_name, randaug=common.RandAugment(magnitude=10)) elif aug_name == 'autoaug': aug_type = common.Augmentation( type=aug_name, autoaug=common.AutoAugment(augmentation_name='test')) else: aug_type = None decoder = classification_input.Decoder(image_field_key=IMAGE_FIELD_KEY, label_field_key=LABEL_FIELD_KEY, is_multilabel=is_multilabel) parser = classification_input.Parser(output_size=output_size[:2], num_classes=10, image_field_key=IMAGE_FIELD_KEY, label_field_key=LABEL_FIELD_KEY, is_multilabel=is_multilabel, decode_jpeg_only=decode_jpeg_only, aug_rand_hflip=False, aug_type=aug_type, dtype=dtype) decoded_tensors = decoder.decode(serialized_example) image, label = parser.parse_fn(is_training)(decoded_tensors) self.assertAllEqual(image.numpy().shape, output_size) if not is_multilabel: self.assertAllEqual(label, 0) else: self.assertAllEqual(label.numpy().shape, [10]) if dtype == 'float32': self.assertAllEqual(image.dtype, tf.float32) elif dtype == 'float16': self.assertAllEqual(image.dtype, tf.float16) elif dtype == 'bfloat16': self.assertAllEqual(image.dtype, tf.bfloat16)