def test_load_dataset(self): output_size = 1280 max_num_boxes = 100 batch_size = 2 data_config = coco.COCODataConfig( tfds_name='coco/2017', tfds_split='validation', is_training=False, global_batch_size=batch_size, output_size=(output_size, output_size), max_num_boxes=max_num_boxes, ) num_examples = 10 def as_dataset(self, *args, **kwargs): del args del kwargs return tf.data.Dataset.from_generator( lambda: (_gen_fn() for i in range(num_examples)), output_types=self.info.features.dtype, output_shapes=self.info.features.shape, ) with tfds.testing.mock_data(num_examples=num_examples, as_dataset_fn=as_dataset): dataset = coco.COCODataLoader(data_config).load() dataset_iter = iter(dataset) images, labels = next(dataset_iter) self.assertEqual(images.shape, (batch_size, output_size, output_size, 3)) self.assertEqual(labels['classes'].shape, (batch_size, max_num_boxes)) self.assertEqual(labels['boxes'].shape, (batch_size, max_num_boxes, 4)) self.assertEqual(labels['id'].shape, (batch_size, )) self.assertEqual(labels['image_info'].shape, (batch_size, 4, 2)) self.assertEqual(labels['is_crowd'].shape, (batch_size, max_num_boxes))
def test_preprocess(self, is_training): output_size = 1280 max_num_boxes = 100 batch_size = 2 data_config = coco.COCODataConfig( tfds_name='coco/2017', tfds_split='validation', is_training=is_training, global_batch_size=batch_size, output_size=(output_size, output_size), max_num_boxes=max_num_boxes, ) dl = coco.COCODataLoader(data_config) inputs = _gen_fn() image, label = dl.preprocess(inputs) self.assertEqual(image.shape, (output_size, output_size, 3)) self.assertEqual(label['classes'].shape, (max_num_boxes)) self.assertEqual(label['boxes'].shape, (max_num_boxes, 4)) if not is_training: self.assertDTypeEqual(label['id'], int) self.assertEqual(label['image_info'].shape, (4, 2)) self.assertEqual(label['is_crowd'].shape, (max_num_boxes))
def build_inputs( self, params, input_context: Optional[tf.distribute.InputContext] = None): """Build input dataset.""" if isinstance(params, coco.COCODataConfig): dataset = coco.COCODataLoader(params).load(input_context) else: if params.tfds_name: decoder = tfds_factory.get_detection_decoder(params.tfds_name) else: decoder_cfg = params.decoder.get() if params.decoder.type == 'simple_decoder': decoder = tf_example_decoder.TfExampleDecoder( regenerate_source_id=decoder_cfg.regenerate_source_id) elif params.decoder.type == 'label_map_decoder': decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap( label_map=decoder_cfg.label_map, regenerate_source_id=decoder_cfg.regenerate_source_id) else: raise ValueError('Unknown decoder type: {}!'.format( params.decoder.type)) parser = detr_input.Parser( class_offset=self._task_config.losses.class_offset, output_size=self._task_config.model.input_size[:2], ) reader = input_reader_factory.input_reader_generator( params, dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), decoder_fn=decoder.decode, parser_fn=parser.parse_fn(params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def build_inputs(self, params, input_context=None): """Build input dataset.""" return coco.COCODataLoader(params).load(input_context)