예제 #1
0
    def test_result_shape(self, image_height, image_width, num_instances,
                          regenerate_source_id):
        decoder = tf_example_decoder.TfExampleDecoder(
            include_mask=True, regenerate_source_id=regenerate_source_id)

        serialized_example = tfexample_utils.create_detection_test_example(
            image_height=image_height,
            image_width=image_width,
            image_channel=3,
            num_instances=num_instances).SerializeToString()
        decoded_tensors = decoder.decode(
            tf.convert_to_tensor(value=serialized_example))

        results = tf.nest.map_structure(lambda x: x.numpy(), decoded_tensors)

        self.assertAllEqual((image_height, image_width, 3),
                            results['image'].shape)
        if not regenerate_source_id:
            self.assertEqual(tfexample_utils.DUMP_SOURCE_ID,
                             results['source_id'])
        self.assertEqual(image_height, results['height'])
        self.assertEqual(image_width, results['width'])
        self.assertAllEqual((num_instances, ),
                            results['groundtruth_classes'].shape)
        self.assertAllEqual((num_instances, ),
                            results['groundtruth_is_crowd'].shape)
        self.assertAllEqual((num_instances, ),
                            results['groundtruth_area'].shape)
        self.assertAllEqual((num_instances, 4),
                            results['groundtruth_boxes'].shape)
        self.assertAllEqual((num_instances, image_height, image_width),
                            results['groundtruth_instance_masks'].shape)
        self.assertAllEqual((num_instances, ),
                            results['groundtruth_instance_masks_png'].shape)
예제 #2
0
  def test_export_tflite_detection(self, experiment, quant_type,
                                   input_image_size):
    test_tfrecord_file = os.path.join(self.get_temp_dir(), 'det_test.tfrecord')
    example = tfexample_utils.create_detection_test_example(
        image_height=input_image_size[0],
        image_width=input_image_size[1],
        image_channel=3,
        num_instances=10)
    self._create_test_tfrecord(
        tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
    params = exp_factory.get_exp_config(experiment)
    params.task.validation_data.input_path = test_tfrecord_file
    params.task.train_data.input_path = test_tfrecord_file
    temp_dir = self.get_temp_dir()
    module = detection_serving.DetectionModule(
        params=params,
        batch_size=1,
        input_image_size=input_image_size,
        input_type='tflite')
    self._export_from_module(
        module=module,
        input_type='tflite',
        saved_model_dir=os.path.join(temp_dir, 'saved_model'))

    tflite_model = export_tflite_lib.convert_tflite_model(
        saved_model_dir=os.path.join(temp_dir, 'saved_model'),
        quant_type=quant_type,
        params=params,
        calibration_steps=5)

    self.assertIsInstance(tflite_model, bytes)
예제 #3
0
    def test_retinanet_task(self, test_config, is_training):
        """RetinaNet task test for training and val using toy configs."""
        input_image_size = [384, 384]
        test_tfrecord_file = os.path.join(self.get_temp_dir(),
                                          'det_test.tfrecord')
        example = tfexample_utils.create_detection_test_example(
            image_height=input_image_size[0],
            image_width=input_image_size[1],
            image_channel=3,
            num_instances=10)
        self._create_test_tfrecord(tfrecord_file=test_tfrecord_file,
                                   example=example,
                                   num_samples=10)
        config = exp_factory.get_exp_config(test_config)
        # modify config to suit local testing
        config.task.model.input_size = [128, 128, 3]
        config.trainer.steps_per_loop = 1
        config.task.train_data.global_batch_size = 1
        config.task.validation_data.global_batch_size = 1
        config.task.train_data.shuffle_buffer_size = 2
        config.task.validation_data.shuffle_buffer_size = 2
        config.task.validation_data.input_path = test_tfrecord_file
        config.task.train_data.input_path = test_tfrecord_file
        config.train_steps = 1

        task = retinanet.RetinaNetTask(config.task)
        model = task.build_model()
        metrics = task.build_metrics(training=is_training)

        strategy = tf.distribute.get_strategy()

        data_config = config.task.train_data if is_training else config.task.validation_data
        dataset = orbit.utils.make_distributed_dataset(strategy,
                                                       task.build_inputs,
                                                       data_config)
        iterator = iter(dataset)
        opt_factory = optimization.OptimizerFactory(
            config.trainer.optimizer_config)
        optimizer = opt_factory.build_optimizer(
            opt_factory.build_learning_rate())

        if is_training:
            task.train_step(next(iterator), model, optimizer, metrics=metrics)
        else:
            task.validation_step(next(iterator), model, metrics=metrics)
  def test_result_shape(self, image_height, image_width, num_instances):
    label_map_dir = self.get_temp_dir()
    label_map_name = 'label_map.csv'
    label_map_path = os.path.join(label_map_dir, label_map_name)
    with open(label_map_path, 'w') as f:
      f.write(LABEL_MAP_CSV_CONTENT)

    decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
        label_map_path, include_mask=True)

    serialized_example = tfexample_utils.create_detection_test_example(
        image_height=image_height,
        image_width=image_width,
        image_channel=3,
        num_instances=num_instances).SerializeToString()
    decoded_tensors = decoder.decode(
        tf.convert_to_tensor(value=serialized_example))

    results = tf.nest.map_structure(lambda x: x.numpy(), decoded_tensors)

    self.assertAllEqual(
        (image_height, image_width, 3), results['image'].shape)
    self.assertEqual(tfexample_utils.DUMP_SOURCE_ID, results['source_id'])
    self.assertEqual(image_height, results['height'])
    self.assertEqual(image_width, results['width'])
    self.assertAllEqual(
        (num_instances,), results['groundtruth_classes'].shape)
    self.assertAllEqual(
        (num_instances,), results['groundtruth_is_crowd'].shape)
    self.assertAllEqual(
        (num_instances,), results['groundtruth_area'].shape)
    self.assertAllEqual(
        (num_instances, 4), results['groundtruth_boxes'].shape)
    self.assertAllEqual(
        (num_instances, image_height, image_width),
        results['groundtruth_instance_masks'].shape)
    self.assertAllEqual(
        (num_instances,), results['groundtruth_instance_masks_png'].shape)
예제 #5
0
    def test_scan_and_generator_annotation_file(self):
        num_samples = 10
        example = tfexample_utils.create_detection_test_example(
            image_height=512,
            image_width=512,
            image_channel=3,
            num_instances=10)
        tf_examples = [example] * num_samples
        data_file = os.path.join(self.create_tempdir(), 'test.tfrecord')
        tfexample_utils.dump_to_tfrecord(record_file=data_file,
                                         tf_examples=tf_examples)
        annotation_file = os.path.join(self.create_tempdir(),
                                       'annotation.json')

        coco_utils.scan_and_generator_annotation_file(
            file_pattern=data_file,
            file_type='tfrecord',
            num_samples=num_samples,
            include_mask=True,
            annotation_file=annotation_file)
        self.assertTrue(
            tf.io.gfile.exists(annotation_file),
            msg='Annotation file {annotation_file} does not exists.')