def test_load_dataset(self):
        output_size = 1280
        max_num_boxes = 100
        batch_size = 2
        data_config = coco.COCODataConfig(
            tfds_name='coco/2017',
            tfds_split='validation',
            is_training=False,
            global_batch_size=batch_size,
            output_size=(output_size, output_size),
            max_num_boxes=max_num_boxes,
        )

        num_examples = 10

        def as_dataset(self, *args, **kwargs):
            del args
            del kwargs
            return tf.data.Dataset.from_generator(
                lambda: (_gen_fn() for i in range(num_examples)),
                output_types=self.info.features.dtype,
                output_shapes=self.info.features.shape,
            )

        with tfds.testing.mock_data(num_examples=num_examples,
                                    as_dataset_fn=as_dataset):
            dataset = coco.COCODataLoader(data_config).load()
            dataset_iter = iter(dataset)
            images, labels = next(dataset_iter)
            self.assertEqual(images.shape,
                             (batch_size, output_size, output_size, 3))
            self.assertEqual(labels['classes'].shape,
                             (batch_size, max_num_boxes))
            self.assertEqual(labels['boxes'].shape,
                             (batch_size, max_num_boxes, 4))
            self.assertEqual(labels['id'].shape, (batch_size, ))
            self.assertEqual(labels['image_info'].shape, (batch_size, 4, 2))
            self.assertEqual(labels['is_crowd'].shape,
                             (batch_size, max_num_boxes))
    def test_preprocess(self, is_training):
        output_size = 1280
        max_num_boxes = 100
        batch_size = 2
        data_config = coco.COCODataConfig(
            tfds_name='coco/2017',
            tfds_split='validation',
            is_training=is_training,
            global_batch_size=batch_size,
            output_size=(output_size, output_size),
            max_num_boxes=max_num_boxes,
        )

        dl = coco.COCODataLoader(data_config)
        inputs = _gen_fn()
        image, label = dl.preprocess(inputs)
        self.assertEqual(image.shape, (output_size, output_size, 3))
        self.assertEqual(label['classes'].shape, (max_num_boxes))
        self.assertEqual(label['boxes'].shape, (max_num_boxes, 4))
        if not is_training:
            self.assertDTypeEqual(label['id'], int)
            self.assertEqual(label['image_info'].shape, (4, 2))
            self.assertEqual(label['is_crowd'].shape, (max_num_boxes))
Beispiel #3
0
    def build_inputs(
            self,
            params,
            input_context: Optional[tf.distribute.InputContext] = None):
        """Build input dataset."""
        if isinstance(params, coco.COCODataConfig):
            dataset = coco.COCODataLoader(params).load(input_context)
        else:
            if params.tfds_name:
                decoder = tfds_factory.get_detection_decoder(params.tfds_name)
            else:
                decoder_cfg = params.decoder.get()
                if params.decoder.type == 'simple_decoder':
                    decoder = tf_example_decoder.TfExampleDecoder(
                        regenerate_source_id=decoder_cfg.regenerate_source_id)
                elif params.decoder.type == 'label_map_decoder':
                    decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
                        label_map=decoder_cfg.label_map,
                        regenerate_source_id=decoder_cfg.regenerate_source_id)
                else:
                    raise ValueError('Unknown decoder type: {}!'.format(
                        params.decoder.type))

            parser = detr_input.Parser(
                class_offset=self._task_config.losses.class_offset,
                output_size=self._task_config.model.input_size[:2],
            )

            reader = input_reader_factory.input_reader_generator(
                params,
                dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
                decoder_fn=decoder.decode,
                parser_fn=parser.parse_fn(params.is_training))
            dataset = reader.read(input_context=input_context)

        return dataset
 def build_inputs(self, params, input_context=None):
     """Build input dataset."""
     return coco.COCODataLoader(params).load(input_context)