def test_result_shape(self, image_height, image_width, num_instances, regenerate_source_id): decoder = tf_example_decoder.TfExampleDecoder( include_mask=True, regenerate_source_id=regenerate_source_id) serialized_example = tfexample_utils.create_detection_test_example( image_height=image_height, image_width=image_width, image_channel=3, num_instances=num_instances).SerializeToString() decoded_tensors = decoder.decode( tf.convert_to_tensor(value=serialized_example)) results = tf.nest.map_structure(lambda x: x.numpy(), decoded_tensors) self.assertAllEqual((image_height, image_width, 3), results['image'].shape) if not regenerate_source_id: self.assertEqual(tfexample_utils.DUMP_SOURCE_ID, results['source_id']) self.assertEqual(image_height, results['height']) self.assertEqual(image_width, results['width']) self.assertAllEqual((num_instances, ), results['groundtruth_classes'].shape) self.assertAllEqual((num_instances, ), results['groundtruth_is_crowd'].shape) self.assertAllEqual((num_instances, ), results['groundtruth_area'].shape) self.assertAllEqual((num_instances, 4), results['groundtruth_boxes'].shape) self.assertAllEqual((num_instances, image_height, image_width), results['groundtruth_instance_masks'].shape) self.assertAllEqual((num_instances, ), results['groundtruth_instance_masks_png'].shape)
def test_export_tflite_detection(self, experiment, quant_type, input_image_size): test_tfrecord_file = os.path.join(self.get_temp_dir(), 'det_test.tfrecord') example = tfexample_utils.create_detection_test_example( image_height=input_image_size[0], image_width=input_image_size[1], image_channel=3, num_instances=10) self._create_test_tfrecord( tfrecord_file=test_tfrecord_file, example=example, num_samples=10) params = exp_factory.get_exp_config(experiment) params.task.validation_data.input_path = test_tfrecord_file params.task.train_data.input_path = test_tfrecord_file temp_dir = self.get_temp_dir() module = detection_serving.DetectionModule( params=params, batch_size=1, input_image_size=input_image_size, input_type='tflite') self._export_from_module( module=module, input_type='tflite', saved_model_dir=os.path.join(temp_dir, 'saved_model')) tflite_model = export_tflite_lib.convert_tflite_model( saved_model_dir=os.path.join(temp_dir, 'saved_model'), quant_type=quant_type, params=params, calibration_steps=5) self.assertIsInstance(tflite_model, bytes)
def test_retinanet_task(self, test_config, is_training): """RetinaNet task test for training and val using toy configs.""" input_image_size = [384, 384] test_tfrecord_file = os.path.join(self.get_temp_dir(), 'det_test.tfrecord') example = tfexample_utils.create_detection_test_example( image_height=input_image_size[0], image_width=input_image_size[1], image_channel=3, num_instances=10) self._create_test_tfrecord(tfrecord_file=test_tfrecord_file, example=example, num_samples=10) config = exp_factory.get_exp_config(test_config) # modify config to suit local testing config.task.model.input_size = [128, 128, 3] config.trainer.steps_per_loop = 1 config.task.train_data.global_batch_size = 1 config.task.validation_data.global_batch_size = 1 config.task.train_data.shuffle_buffer_size = 2 config.task.validation_data.shuffle_buffer_size = 2 config.task.validation_data.input_path = test_tfrecord_file config.task.train_data.input_path = test_tfrecord_file config.train_steps = 1 task = retinanet.RetinaNetTask(config.task) model = task.build_model() metrics = task.build_metrics(training=is_training) strategy = tf.distribute.get_strategy() data_config = config.task.train_data if is_training else config.task.validation_data dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs, data_config) iterator = iter(dataset) opt_factory = optimization.OptimizerFactory( config.trainer.optimizer_config) optimizer = opt_factory.build_optimizer( opt_factory.build_learning_rate()) if is_training: task.train_step(next(iterator), model, optimizer, metrics=metrics) else: task.validation_step(next(iterator), model, metrics=metrics)
def test_result_shape(self, image_height, image_width, num_instances): label_map_dir = self.get_temp_dir() label_map_name = 'label_map.csv' label_map_path = os.path.join(label_map_dir, label_map_name) with open(label_map_path, 'w') as f: f.write(LABEL_MAP_CSV_CONTENT) decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap( label_map_path, include_mask=True) serialized_example = tfexample_utils.create_detection_test_example( image_height=image_height, image_width=image_width, image_channel=3, num_instances=num_instances).SerializeToString() decoded_tensors = decoder.decode( tf.convert_to_tensor(value=serialized_example)) results = tf.nest.map_structure(lambda x: x.numpy(), decoded_tensors) self.assertAllEqual( (image_height, image_width, 3), results['image'].shape) self.assertEqual(tfexample_utils.DUMP_SOURCE_ID, results['source_id']) self.assertEqual(image_height, results['height']) self.assertEqual(image_width, results['width']) self.assertAllEqual( (num_instances,), results['groundtruth_classes'].shape) self.assertAllEqual( (num_instances,), results['groundtruth_is_crowd'].shape) self.assertAllEqual( (num_instances,), results['groundtruth_area'].shape) self.assertAllEqual( (num_instances, 4), results['groundtruth_boxes'].shape) self.assertAllEqual( (num_instances, image_height, image_width), results['groundtruth_instance_masks'].shape) self.assertAllEqual( (num_instances,), results['groundtruth_instance_masks_png'].shape)
def test_scan_and_generator_annotation_file(self): num_samples = 10 example = tfexample_utils.create_detection_test_example( image_height=512, image_width=512, image_channel=3, num_instances=10) tf_examples = [example] * num_samples data_file = os.path.join(self.create_tempdir(), 'test.tfrecord') tfexample_utils.dump_to_tfrecord(record_file=data_file, tf_examples=tf_examples) annotation_file = os.path.join(self.create_tempdir(), 'annotation.json') coco_utils.scan_and_generator_annotation_file( file_pattern=data_file, file_type='tfrecord', num_samples=num_samples, include_mask=True, annotation_file=annotation_file) self.assertTrue( tf.io.gfile.exists(annotation_file), msg='Annotation file {annotation_file} does not exists.')