def _create_test_tfrecord(self, num_samples):
     tfexample_utils.dump_to_tfrecord(self._test_tfrecord_file, [
         tf.train.Example.FromString(
             tfexample_utils.create_classification_example(image_height=256,
                                                           image_width=256))
         for _ in range(num_samples)
     ])
Exemple #2
0
  def test_export_tflite_image_classification(self, experiment, quant_type,
                                              input_image_size):
    test_tfrecord_file = os.path.join(self.get_temp_dir(), 'cls_test.tfrecord')
    example = tf.train.Example.FromString(
        tfexample_utils.create_classification_example(
            image_height=input_image_size[0], image_width=input_image_size[1]))
    self._create_test_tfrecord(
        tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
    params = exp_factory.get_exp_config(experiment)
    params.task.validation_data.input_path = test_tfrecord_file
    params.task.train_data.input_path = test_tfrecord_file
    temp_dir = self.get_temp_dir()
    module = image_classification_serving.ClassificationModule(
        params=params,
        batch_size=1,
        input_image_size=input_image_size,
        input_type='tflite')
    self._export_from_module(
        module=module,
        input_type='tflite',
        saved_model_dir=os.path.join(temp_dir, 'saved_model'))

    tflite_model = export_tflite_lib.convert_tflite_model(
        saved_model_dir=os.path.join(temp_dir, 'saved_model'),
        quant_type=quant_type,
        params=params,
        calibration_steps=5)

    self.assertIsInstance(tflite_model, bytes)
    def test_decoder(self, image_height, image_width, num_instances):
        decoder = classification_input.Decoder(image_field_key=IMAGE_FIELD_KEY,
                                               label_field_key=LABEL_FIELD_KEY)

        serialized_example = tfexample_utils.create_classification_example(
            image_height, image_width)
        decoded_tensors = decoder.decode(
            tf.convert_to_tensor(serialized_example))

        results = tf.nest.map_structure(lambda x: x.numpy(), decoded_tensors)
        self.assertCountEqual([IMAGE_FIELD_KEY, LABEL_FIELD_KEY],
                              results.keys())
        self.assertEqual(0, results[LABEL_FIELD_KEY])
    def test_parser(self, output_size, dtype, is_training, aug_name,
                    is_multilabel, decode_jpeg_only, image_format):

        serialized_example = tfexample_utils.create_classification_example(
            output_size[0], output_size[1], image_format, is_multilabel)

        if aug_name == 'randaug':
            aug_type = common.Augmentation(
                type=aug_name, randaug=common.RandAugment(magnitude=10))
        elif aug_name == 'autoaug':
            aug_type = common.Augmentation(
                type=aug_name,
                autoaug=common.AutoAugment(augmentation_name='test'))
        else:
            aug_type = None

        decoder = classification_input.Decoder(image_field_key=IMAGE_FIELD_KEY,
                                               label_field_key=LABEL_FIELD_KEY,
                                               is_multilabel=is_multilabel)
        parser = classification_input.Parser(output_size=output_size[:2],
                                             num_classes=10,
                                             image_field_key=IMAGE_FIELD_KEY,
                                             label_field_key=LABEL_FIELD_KEY,
                                             is_multilabel=is_multilabel,
                                             decode_jpeg_only=decode_jpeg_only,
                                             aug_rand_hflip=False,
                                             aug_type=aug_type,
                                             dtype=dtype)

        decoded_tensors = decoder.decode(serialized_example)
        image, label = parser.parse_fn(is_training)(decoded_tensors)

        self.assertAllEqual(image.numpy().shape, output_size)

        if not is_multilabel:
            self.assertAllEqual(label, 0)
        else:
            self.assertAllEqual(label.numpy().shape, [10])

        if dtype == 'float32':
            self.assertAllEqual(image.dtype, tf.float32)
        elif dtype == 'float16':
            self.assertAllEqual(image.dtype, tf.float16)
        elif dtype == 'bfloat16':
            self.assertAllEqual(image.dtype, tf.bfloat16)