Ejemplo n.º 1
0
  def setUp(self):
    super().setUp()
    # Create test data for image classification.
    self.test_tfrecord_file_cls = os.path.join(self.get_temp_dir(),
                                               'cls_test.tfrecord')
    example = tf.train.Example.FromString(
        tfexample_utils.create_classification_example(
            image_height=224, image_width=224))
    self._create_test_tfrecord(
        tfrecord_file=self.test_tfrecord_file_cls,
        example=example,
        num_samples=10)

    # Create test data for object detection.
    self.test_tfrecord_file_det = os.path.join(self.get_temp_dir(),
                                               'det_test.tfrecord')
    example = tfexample_utils.create_detection_test_example(
        image_height=128, image_width=128, image_channel=3, num_instances=10)
    self._create_test_tfrecord(
        tfrecord_file=self.test_tfrecord_file_det,
        example=example,
        num_samples=10)

    # Create test data for semantic segmentation.
    self.test_tfrecord_file_seg = os.path.join(self.get_temp_dir(),
                                               'seg_test.tfrecord')
    example = tfexample_utils.create_segmentation_test_example(
        image_height=512, image_width=512, image_channel=3)
    self._create_test_tfrecord(
        tfrecord_file=self.test_tfrecord_file_seg,
        example=example,
        num_samples=10)
    def test_export_tflite_image_classification(self, experiment, quant_type,
                                                input_image_size):
        test_tfrecord_file = os.path.join(self.get_temp_dir(),
                                          'cls_test.tfrecord')
        example = tf.train.Example.FromString(
            tfexample_utils.create_classification_example(
                image_height=input_image_size[0],
                image_width=input_image_size[1]))
        self._create_test_tfrecord(tfrecord_file=test_tfrecord_file,
                                   example=example,
                                   num_samples=10)
        params = exp_factory.get_exp_config(experiment)
        params.task.validation_data.input_path = test_tfrecord_file
        params.task.train_data.input_path = test_tfrecord_file
        temp_dir = self.get_temp_dir()
        module = image_classification_serving.ClassificationModule(
            params=params,
            batch_size=1,
            input_image_size=input_image_size,
            input_type='tflite')
        self._export_from_module(module=module,
                                 input_type='tflite',
                                 saved_model_dir=os.path.join(
                                     temp_dir, 'saved_model'))

        tflite_model = export_tflite_lib.convert_tflite_model(
            saved_model_dir=os.path.join(temp_dir, 'saved_model'),
            quant_type=quant_type,
            params=params,
            calibration_steps=5)

        self.assertIsInstance(tflite_model, bytes)
 def _create_test_tfrecord(self, test_tfrecord_file, num_samples,
                           input_image_size):
     example = tf.train.Example.FromString(
         tfexample_utils.create_classification_example(
             image_height=input_image_size[0],
             image_width=input_image_size[1]))
     examples = [example] * num_samples
     tfexample_utils.dump_to_tfrecord(record_file=test_tfrecord_file,
                                      tf_examples=examples)
    def test_decoder(self, image_height, image_width, num_instances):
        decoder = classification_input.Decoder(image_field_key=IMAGE_FIELD_KEY,
                                               label_field_key=LABEL_FIELD_KEY)

        serialized_example = tfexample_utils.create_classification_example(
            image_height, image_width)
        decoded_tensors = decoder.decode(
            tf.convert_to_tensor(serialized_example))

        results = tf.nest.map_structure(lambda x: x.numpy(), decoded_tensors)
        self.assertCountEqual([IMAGE_FIELD_KEY, LABEL_FIELD_KEY],
                              results.keys())
        self.assertEqual(0, results[LABEL_FIELD_KEY])
Ejemplo n.º 5
0
    def test_task(self, config_name):
        input_image_size = [224, 224]
        test_tfrecord_file = os.path.join(self.get_temp_dir(),
                                          'cls_test.tfrecord')
        example = tf.train.Example.FromString(
            tfexample_utils.create_classification_example(
                image_height=input_image_size[0],
                image_width=input_image_size[1]))
        self._create_test_tfrecord(tfrecord_file=test_tfrecord_file,
                                   example=example,
                                   num_samples=10)

        config = exp_factory.get_exp_config(config_name)
        config.task.train_data.global_batch_size = 2
        config.task.validation_data.input_path = test_tfrecord_file
        config.task.train_data.input_path = test_tfrecord_file
        task = img_cls_task.ImageClassificationTask(config.task)
        model = task.build_model()
        metrics = task.build_metrics()
        strategy = tf.distribute.get_strategy()

        dataset = orbit.utils.make_distributed_dataset(strategy,
                                                       task.build_inputs,
                                                       config.task.train_data)

        iterator = iter(dataset)
        opt_factory = optimization.OptimizerFactory(
            config.trainer.optimizer_config)
        optimizer = opt_factory.build_optimizer(
            opt_factory.build_learning_rate())
        logs = task.train_step(next(iterator),
                               model,
                               optimizer,
                               metrics=metrics)
        for metric in metrics:
            logs[metric.name] = metric.result()
        self.assertIn('loss', logs)
        self.assertIn('accuracy', logs)
        self.assertIn('top_5_accuracy', logs)
        logs = task.validation_step(next(iterator), model, metrics=metrics)
        for metric in metrics:
            logs[metric.name] = metric.result()
        self.assertIn('loss', logs)
        self.assertIn('accuracy', logs)
        self.assertIn('top_5_accuracy', logs)
    def test_parser(self, output_size, dtype, is_training, aug_name,
                    is_multilabel, decode_jpeg_only, image_format):

        serialized_example = tfexample_utils.create_classification_example(
            output_size[0], output_size[1], image_format, is_multilabel)

        if aug_name == 'randaug':
            aug_type = common.Augmentation(
                type=aug_name, randaug=common.RandAugment(magnitude=10))
        elif aug_name == 'autoaug':
            aug_type = common.Augmentation(
                type=aug_name,
                autoaug=common.AutoAugment(augmentation_name='test'))
        else:
            aug_type = None

        decoder = classification_input.Decoder(image_field_key=IMAGE_FIELD_KEY,
                                               label_field_key=LABEL_FIELD_KEY,
                                               is_multilabel=is_multilabel)
        parser = classification_input.Parser(output_size=output_size[:2],
                                             num_classes=10,
                                             image_field_key=IMAGE_FIELD_KEY,
                                             label_field_key=LABEL_FIELD_KEY,
                                             is_multilabel=is_multilabel,
                                             decode_jpeg_only=decode_jpeg_only,
                                             aug_rand_hflip=False,
                                             aug_type=aug_type,
                                             dtype=dtype)

        decoded_tensors = decoder.decode(serialized_example)
        image, label = parser.parse_fn(is_training)(decoded_tensors)

        self.assertAllEqual(image.numpy().shape, output_size)

        if not is_multilabel:
            self.assertAllEqual(label, 0)
        else:
            self.assertAllEqual(label.numpy().shape, [10])

        if dtype == 'float32':
            self.assertAllEqual(image.dtype, tf.float32)
        elif dtype == 'float16':
            self.assertAllEqual(image.dtype, tf.float16)
        elif dtype == 'bfloat16':
            self.assertAllEqual(image.dtype, tf.bfloat16)