예제 #1
0
    def build_inputs(self, params, input_context=None):
        """Build input dataset."""
        decoder_cfg = params.decoder.get()
        if params.decoder.type == 'simple_decoder':
            decoder = tf_example_decoder.TfExampleDecoder(
                regenerate_source_id=decoder_cfg.regenerate_source_id)
        elif params.decoder.type == 'label_map_decoder':
            decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
                label_map=decoder_cfg.label_map,
                regenerate_source_id=decoder_cfg.regenerate_source_id)
        else:
            raise ValueError('Unknown decoder type: {}!'.format(
                params.decoder.type))
        decoder_cfg = params.decoder.get()
        if params.decoder.type == 'simple_decoder':
            decoder = tf_example_decoder.TfExampleDecoder(
                regenerate_source_id=decoder_cfg.regenerate_source_id)
        elif params.decoder.type == 'label_map_decoder':
            decoder = tf_example_decoder.TfExampleDecoderLabelMap(
                label_map=decoder_cfg.label_map,
                regenerate_source_id=decoder_cfg.regenerate_source_id)
        else:
            raise ValueError('Unknown decoder type: {}!'.format(
                params.decoder.type))
        parser = retinanet_input.Parser(
            output_size=self.task_config.model.input_size[:2],
            min_level=self.task_config.model.min_level,
            max_level=self.task_config.model.max_level,
            num_scales=self.task_config.model.anchor.num_scales,
            aspect_ratios=self.task_config.model.anchor.aspect_ratios,
            anchor_size=self.task_config.model.anchor.anchor_size,
            dtype=params.dtype,
            match_threshold=params.parser.match_threshold,
            unmatched_threshold=params.parser.unmatched_threshold,
            aug_rand_hflip=params.parser.aug_rand_hflip,
            aug_scale_min=params.parser.aug_scale_min,
            aug_scale_max=params.parser.aug_scale_max,
            skip_crowd_during_training=params.parser.
            skip_crowd_during_training,
            max_num_instances=params.parser.max_num_instances)

        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder.decode,
                                          parser_fn=parser.parse_fn(
                                              params.is_training))
        dataset = reader.read(input_context=input_context)

        return dataset
예제 #2
0
    def test_result_shape(self, image_height, image_width, num_instances,
                          regenerate_source_id):
        decoder = tf_example_decoder.TfExampleDecoder(
            include_mask=True, regenerate_source_id=regenerate_source_id)

        serialized_example = tfexample_utils.create_detection_test_example(
            image_height=image_height,
            image_width=image_width,
            image_channel=3,
            num_instances=num_instances).SerializeToString()
        decoded_tensors = decoder.decode(
            tf.convert_to_tensor(value=serialized_example))

        results = tf.nest.map_structure(lambda x: x.numpy(), decoded_tensors)

        self.assertAllEqual((image_height, image_width, 3),
                            results['image'].shape)
        if not regenerate_source_id:
            self.assertEqual(tfexample_utils.DUMP_SOURCE_ID,
                             results['source_id'])
        self.assertEqual(image_height, results['height'])
        self.assertEqual(image_width, results['width'])
        self.assertAllEqual((num_instances, ),
                            results['groundtruth_classes'].shape)
        self.assertAllEqual((num_instances, ),
                            results['groundtruth_is_crowd'].shape)
        self.assertAllEqual((num_instances, ),
                            results['groundtruth_area'].shape)
        self.assertAllEqual((num_instances, 4),
                            results['groundtruth_boxes'].shape)
        self.assertAllEqual((num_instances, image_height, image_width),
                            results['groundtruth_instance_masks'].shape)
        self.assertAllEqual((num_instances, ),
                            results['groundtruth_instance_masks_png'].shape)
예제 #3
0
    def _parse_single_example(self, example):
        """Parses a single serialized tf.Example proto.

    Args:
      example: a serialized tf.Example proto string.

    Returns:
      A dictionary of groundtruth with the following fields:
        source_id: a scalar tensor of int64 representing the image source_id.
        height: a scalar tensor of int64 representing the image height.
        width: a scalar tensor of int64 representing the image width.
        boxes: a float tensor of shape [K, 4], representing the groundtruth
          boxes in absolute coordinates with respect to the original image size.
        classes: a int64 tensor of shape [K], representing the class labels of
          each instances.
        is_crowds: a bool tensor of shape [K], indicating whether the instance
          is crowd.
        areas: a float tensor of shape [K], indicating the area of each
          instance.
        masks: a string tensor of shape [K], containing the bytes of the png
          mask of each instance.
    """
        decoder = tf_example_decoder.TfExampleDecoder(
            include_mask=self._include_mask)
        decoded_tensors = decoder.decode(example)

        image = decoded_tensors['image']
        image_size = tf.shape(image)[0:2]
        boxes = box_ops.denormalize_boxes(decoded_tensors['groundtruth_boxes'],
                                          image_size)
        groundtruths = {
            'source_id':
            tf.string_to_number(decoded_tensors['source_id'],
                                out_type=tf.int64),
            'height':
            decoded_tensors['height'],
            'width':
            decoded_tensors['width'],
            'num_detections':
            tf.shape(decoded_tensors['groundtruth_classes'])[0],
            'boxes':
            boxes,
            'classes':
            decoded_tensors['groundtruth_classes'],
            'is_crowds':
            decoded_tensors['groundtruth_is_crowd'],
            'areas':
            decoded_tensors['groundtruth_area'],
        }
        if self._include_mask:
            groundtruths.update({
                'masks':
                decoded_tensors['groundtruth_instance_masks_png'],
            })
        return groundtruths
예제 #4
0
    def build_inputs(
            self,
            params: exp_cfg.DataConfig,
            input_context: Optional[tf.distribute.InputContext] = None):
        """Build input dataset."""

        if params.tfds_name:
            if params.tfds_name in tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP:
                decoder = tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP[
                    params.tfds_name]()
            else:
                raise ValueError('TFDS {} is not supported'.format(
                    params.tfds_name))
        else:
            decoder_cfg = params.decoder.get()
            if params.decoder.type == 'simple_decoder':
                decoder = tf_example_decoder.TfExampleDecoder(
                    regenerate_source_id=decoder_cfg.regenerate_source_id)
            elif params.decoder.type == 'label_map_decoder':
                decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
                    label_map=decoder_cfg.label_map,
                    regenerate_source_id=decoder_cfg.regenerate_source_id)
            else:
                raise ValueError('Unknown decoder type: {}!'.format(
                    params.decoder.type))

        parser = retinanet_input.Parser(
            output_size=self.task_config.model.input_size[:2],
            min_level=self.task_config.model.min_level,
            max_level=self.task_config.model.max_level,
            num_scales=self.task_config.model.anchor.num_scales,
            aspect_ratios=self.task_config.model.anchor.aspect_ratios,
            anchor_size=self.task_config.model.anchor.anchor_size,
            dtype=params.dtype,
            match_threshold=params.parser.match_threshold,
            unmatched_threshold=params.parser.unmatched_threshold,
            aug_rand_hflip=params.parser.aug_rand_hflip,
            aug_scale_min=params.parser.aug_scale_min,
            aug_scale_max=params.parser.aug_scale_max,
            skip_crowd_during_training=params.parser.
            skip_crowd_during_training,
            max_num_instances=params.parser.max_num_instances)

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))
        dataset = reader.read(input_context=input_context)

        return dataset
예제 #5
0
    def build_inputs(
            self,
            params: exp_cfg.DataConfig,
            input_context: Optional[tf.distribute.InputContext] = None):
        """Build input dataset."""
        decoder_cfg = params.decoder.get()
        if params.decoder.type == 'simple_decoder':
            decoder = tf_example_decoder.TfExampleDecoder(
                include_mask=self._task_config.model.include_mask,
                regenerate_source_id=decoder_cfg.regenerate_source_id,
                mask_binarize_threshold=decoder_cfg.mask_binarize_threshold)
        elif params.decoder.type == 'label_map_decoder':
            decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
                label_map=decoder_cfg.label_map,
                include_mask=self._task_config.model.include_mask,
                regenerate_source_id=decoder_cfg.regenerate_source_id,
                mask_binarize_threshold=decoder_cfg.mask_binarize_threshold)
        else:
            raise ValueError('Unknown decoder type: {}!'.format(
                params.decoder.type))

        parser = maskrcnn_input.Parser(
            output_size=self.task_config.model.input_size[:2],
            min_level=self.task_config.model.min_level,
            max_level=self.task_config.model.max_level,
            num_scales=self.task_config.model.anchor.num_scales,
            aspect_ratios=self.task_config.model.anchor.aspect_ratios,
            anchor_size=self.task_config.model.anchor.anchor_size,
            dtype=params.dtype,
            rpn_match_threshold=params.parser.rpn_match_threshold,
            rpn_unmatched_threshold=params.parser.rpn_unmatched_threshold,
            rpn_batch_size_per_im=params.parser.rpn_batch_size_per_im,
            rpn_fg_fraction=params.parser.rpn_fg_fraction,
            aug_rand_hflip=params.parser.aug_rand_hflip,
            aug_scale_min=params.parser.aug_scale_min,
            aug_scale_max=params.parser.aug_scale_max,
            skip_crowd_during_training=params.parser.
            skip_crowd_during_training,
            max_num_instances=params.parser.max_num_instances,
            include_mask=self._task_config.model.include_mask,
            mask_crop_size=params.parser.mask_crop_size)

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))
        dataset = reader.read(input_context=input_context)

        return dataset
예제 #6
0
    def build_inputs(
            self,
            params: exp_cfg.DataConfig,
            input_context: Optional[tf.distribute.InputContext] = None):
        """Build input dataset."""
        if params.tfds_name:
            decoder = tfds_factory.get_detection_decoder(params.tfds_name)
        else:
            decoder_cfg = params.decoder.get()
            if params.decoder.type == 'simple_decoder':
                decoder = tf_example_decoder.TfExampleDecoder(
                    regenerate_source_id=decoder_cfg.regenerate_source_id)
            elif params.decoder.type == 'label_map_decoder':
                decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
                    label_map=decoder_cfg.label_map,
                    regenerate_source_id=decoder_cfg.regenerate_source_id)
            else:
                raise ValueError('Unknown decoder type: {}!'.format(
                    params.decoder.type))

        parser = centernet_input.CenterNetParser(
            output_height=self.task_config.model.input_size[0],
            output_width=self.task_config.model.input_size[1],
            max_num_instances=self.task_config.model.max_num_instances,
            bgr_ordering=params.parser.bgr_ordering,
            channel_means=params.parser.channel_means,
            channel_stds=params.parser.channel_stds,
            aug_rand_hflip=params.parser.aug_rand_hflip,
            aug_scale_min=params.parser.aug_scale_min,
            aug_scale_max=params.parser.aug_scale_max,
            aug_rand_hue=params.parser.aug_rand_hue,
            aug_rand_brightness=params.parser.aug_rand_brightness,
            aug_rand_contrast=params.parser.aug_rand_contrast,
            aug_rand_saturation=params.parser.aug_rand_saturation,
            odapi_augmentation=params.parser.odapi_augmentation,
            dtype=params.dtype)

        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder.decode,
                                          parser_fn=parser.parse_fn(
                                              params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
예제 #7
0
    def test_result_shape(self, image_height, image_width, num_instances,
                          regenerate_source_id):
        decoder = tf_example_decoder.TfExampleDecoder(
            include_mask=True, regenerate_source_id=regenerate_source_id)

        image = _encode_image(np.uint8(
            np.random.rand(image_height, image_width, 3) * 255),
                              fmt='JPEG')
        if num_instances == 0:
            xmins = []
            xmaxs = []
            ymins = []
            ymaxs = []
            labels = []
            areas = []
            is_crowds = []
            masks = []
        else:
            xmins = list(np.random.rand(num_instances))
            xmaxs = list(np.random.rand(num_instances))
            ymins = list(np.random.rand(num_instances))
            ymaxs = list(np.random.rand(num_instances))
            labels = list(np.random.randint(100, size=num_instances))
            areas = [
                (xmax - xmin) * (ymax - ymin) * image_height * image_width
                for xmin, xmax, ymin, ymax in zip(xmins, xmaxs, ymins, ymaxs)
            ]
            is_crowds = [0] * num_instances
            masks = []
            for _ in range(num_instances):
                mask = _encode_image(np.uint8(
                    np.random.rand(image_height, image_width) * 255),
                                     fmt='PNG')
                masks.append(mask)
        serialized_example = tf.train.Example(features=tf.train.Features(
            feature={
                'image/encoded': (tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=[image]))),
                'image/source_id': (tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=[DUMP_SOURCE_ID]))),
                'image/height': (tf.train.Feature(
                    int64_list=tf.train.Int64List(value=[image_height]))),
                'image/width': (tf.train.Feature(int64_list=tf.train.Int64List(
                    value=[image_width]))),
                'image/object/bbox/xmin': (tf.train.Feature(
                    float_list=tf.train.FloatList(value=xmins))),
                'image/object/bbox/xmax': (tf.train.Feature(
                    float_list=tf.train.FloatList(value=xmaxs))),
                'image/object/bbox/ymin': (tf.train.Feature(
                    float_list=tf.train.FloatList(value=ymins))),
                'image/object/bbox/ymax': (tf.train.Feature(
                    float_list=tf.train.FloatList(value=ymaxs))),
                'image/object/class/label': (tf.train.Feature(
                    int64_list=tf.train.Int64List(value=labels))),
                'image/object/is_crowd': (tf.train.Feature(
                    int64_list=tf.train.Int64List(value=is_crowds))),
                'image/object/area': (tf.train.Feature(
                    float_list=tf.train.FloatList(value=areas))),
                'image/object/mask': (tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=masks))),
            })).SerializeToString()
        decoded_tensors = decoder.decode(
            tf.convert_to_tensor(value=serialized_example))

        results = tf.nest.map_structure(lambda x: x.numpy(), decoded_tensors)

        self.assertAllEqual((image_height, image_width, 3),
                            results['image'].shape)
        if not regenerate_source_id:
            self.assertEqual(DUMP_SOURCE_ID, results['source_id'])
        self.assertEqual(image_height, results['height'])
        self.assertEqual(image_width, results['width'])
        self.assertAllEqual((num_instances, ),
                            results['groundtruth_classes'].shape)
        self.assertAllEqual((num_instances, ),
                            results['groundtruth_is_crowd'].shape)
        self.assertAllEqual((num_instances, ),
                            results['groundtruth_area'].shape)
        self.assertAllEqual((num_instances, 4),
                            results['groundtruth_boxes'].shape)
        self.assertAllEqual((num_instances, image_height, image_width),
                            results['groundtruth_instance_masks'].shape)
        self.assertAllEqual((num_instances, ),
                            results['groundtruth_instance_masks_png'].shape)
예제 #8
0
    def test_handling_missing_fields(self):
        decoder = tf_example_decoder.TfExampleDecoder(include_mask=True)

        image_content = [[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
                         [[0, 0, 0], [255, 255, 255], [255, 255, 255],
                          [0, 0, 0]],
                         [[0, 0, 0], [255, 255, 255], [255, 255, 255],
                          [0, 0, 0]],
                         [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]
        image = _encode_image(np.uint8(image_content), fmt='PNG')
        image_height = 4
        image_width = 4
        num_instances = 2
        xmins = [0, 0.25]
        xmaxs = [0.5, 1.0]
        ymins = [0, 0]
        ymaxs = [0.5, 1.0]
        labels = [3, 1]
        mask_content = [[[255, 255, 0, 0], [255, 255, 0, 0], [0, 0, 0, 0],
                         [0, 0, 0, 0]],
                        [[0, 255, 255, 255], [0, 255, 255, 255],
                         [0, 255, 255, 255], [0, 255, 255, 255]]]
        masks = [
            _encode_image(np.uint8(m), fmt='PNG') for m in list(mask_content)
        ]
        serialized_example = tf.train.Example(features=tf.train.Features(
            feature={
                'image/encoded': (tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=[image]))),
                'image/source_id': (tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=[DUMP_SOURCE_ID]))),
                'image/height': (tf.train.Feature(
                    int64_list=tf.train.Int64List(value=[image_height]))),
                'image/width': (tf.train.Feature(int64_list=tf.train.Int64List(
                    value=[image_width]))),
                'image/object/bbox/xmin': (tf.train.Feature(
                    float_list=tf.train.FloatList(value=xmins))),
                'image/object/bbox/xmax': (tf.train.Feature(
                    float_list=tf.train.FloatList(value=xmaxs))),
                'image/object/bbox/ymin': (tf.train.Feature(
                    float_list=tf.train.FloatList(value=ymins))),
                'image/object/bbox/ymax': (tf.train.Feature(
                    float_list=tf.train.FloatList(value=ymaxs))),
                'image/object/class/label': (tf.train.Feature(
                    int64_list=tf.train.Int64List(value=labels))),
                'image/object/mask': (tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=masks))),
            })).SerializeToString()
        decoded_tensors = decoder.decode(
            tf.convert_to_tensor(serialized_example))
        results = tf.nest.map_structure(lambda x: x.numpy(), decoded_tensors)

        self.assertAllEqual((image_height, image_width, 3),
                            results['image'].shape)
        self.assertAllEqual(image_content, results['image'])
        self.assertEqual(DUMP_SOURCE_ID, results['source_id'])
        self.assertEqual(image_height, results['height'])
        self.assertEqual(image_width, results['width'])
        self.assertAllEqual((num_instances, ),
                            results['groundtruth_classes'].shape)
        self.assertAllEqual((num_instances, ),
                            results['groundtruth_is_crowd'].shape)
        self.assertAllEqual((num_instances, ),
                            results['groundtruth_area'].shape)
        self.assertAllEqual((num_instances, 4),
                            results['groundtruth_boxes'].shape)
        self.assertAllEqual((num_instances, image_height, image_width),
                            results['groundtruth_instance_masks'].shape)
        self.assertAllEqual((num_instances, ),
                            results['groundtruth_instance_masks_png'].shape)
        self.assertAllEqual([3, 1], results['groundtruth_classes'])
        self.assertAllEqual([False, False], results['groundtruth_is_crowd'])
        self.assertNDArrayNear([
            0.25 * image_height * image_width,
            0.75 * image_height * image_width
        ], results['groundtruth_area'], 1e-4)
        self.assertNDArrayNear([[0, 0, 0.5, 0.5], [0, 0.25, 1.0, 1.0]],
                               results['groundtruth_boxes'], 1e-4)
        self.assertNDArrayNear(mask_content,
                               results['groundtruth_instance_masks'], 1e-4)
        self.assertAllEqual(masks, results['groundtruth_instance_masks_png'])
예제 #9
0
    def testRetinanetInputReader(self, output_size, skip_crowd_during_training,
                                 use_autoaugment, is_training):

        batch_size = 2
        min_level = 3
        max_level = 7
        num_scales = 3
        aspect_ratios = [0.5, 1.0, 2.0]
        anchor_size = 3
        max_num_instances = 100

        params = cfg.DataConfig(
            input_path='/placer/prod/home/snaggletooth/test/data/coco/val*',
            global_batch_size=batch_size,
            is_training=is_training)

        decoder = tf_example_decoder.TfExampleDecoder()
        parser = retinanet_input.Parser(
            output_size=output_size,
            min_level=min_level,
            max_level=max_level,
            num_scales=num_scales,
            aspect_ratios=aspect_ratios,
            anchor_size=anchor_size,
            skip_crowd_during_training=skip_crowd_during_training,
            use_autoaugment=use_autoaugment,
            max_num_instances=max_num_instances,
            dtype='bfloat16')

        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder.decode,
                                          parser_fn=parser.parse_fn(
                                              params.is_training))

        dataset = reader.read()

        iterator = iter(dataset)
        image, labels = next(iterator)
        np_image = image.numpy()
        np_labels = tf.nest.map_structure(lambda x: x.numpy(), labels)

        # Checks image shape.
        self.assertEqual(list(np_image.shape),
                         [batch_size, output_size[0], output_size[1], 3])
        # Checks keys in labels.
        if is_training:
            self.assertCountEqual(np_labels.keys(), [
                'cls_targets', 'box_targets', 'anchor_boxes', 'cls_weights',
                'box_weights', 'image_info'
            ])
        else:
            self.assertCountEqual(np_labels.keys(), [
                'cls_targets', 'box_targets', 'anchor_boxes', 'cls_weights',
                'box_weights', 'groundtruths', 'image_info'
            ])
        # Checks shapes of `image_info` and `anchor_boxes`.
        self.assertEqual(np_labels['image_info'].shape, (batch_size, 4, 2))
        n_anchors = 0
        for level in range(min_level, max_level + 1):
            stride = 2**level
            output_size_l = [output_size[0] / stride, output_size[1] / stride]
            anchors_per_location = num_scales * len(aspect_ratios)
            self.assertEqual(list(np_labels['anchor_boxes'][level].shape), [
                batch_size, output_size_l[0], output_size_l[1],
                4 * anchors_per_location
            ])
            n_anchors += output_size_l[0] * output_size_l[
                1] * anchors_per_location
        # Checks shapes of training objectives.
        self.assertEqual(np_labels['cls_weights'].shape,
                         (batch_size, n_anchors))
        for level in range(min_level, max_level + 1):
            stride = 2**level
            output_size_l = [output_size[0] / stride, output_size[1] / stride]
            anchors_per_location = num_scales * len(aspect_ratios)
            self.assertEqual(list(np_labels['cls_targets'][level].shape), [
                batch_size, output_size_l[0], output_size_l[1],
                anchors_per_location
            ])
            self.assertEqual(list(np_labels['box_targets'][level].shape), [
                batch_size, output_size_l[0], output_size_l[1],
                4 * anchors_per_location
            ])
        # Checks shape of groundtruths for eval.
        if not is_training:
            self.assertEqual(np_labels['groundtruths']['source_id'].shape,
                             (batch_size, ))
            self.assertEqual(np_labels['groundtruths']['classes'].shape,
                             (batch_size, max_num_instances))
            self.assertEqual(np_labels['groundtruths']['boxes'].shape,
                             (batch_size, max_num_instances, 4))
            self.assertEqual(np_labels['groundtruths']['areas'].shape,
                             (batch_size, max_num_instances))
            self.assertEqual(np_labels['groundtruths']['is_crowds'].shape,
                             (batch_size, max_num_instances))
예제 #10
0
    def testMaskRCNNInputReader(self, output_size, skip_crowd_during_training,
                                include_mask, is_training):
        min_level = 3
        max_level = 7
        num_scales = 3
        aspect_ratios = [1.0, 2.0, 0.5]
        max_num_instances = 100
        batch_size = 2
        mask_crop_size = 112
        anchor_size = 4.0

        params = cfg.DataConfig(
            input_path='/placer/prod/home/snaggletooth/test/data/coco/val*',
            global_batch_size=batch_size,
            is_training=is_training)

        parser = maskrcnn_input.Parser(
            output_size=output_size,
            min_level=min_level,
            max_level=max_level,
            num_scales=num_scales,
            aspect_ratios=aspect_ratios,
            anchor_size=anchor_size,
            rpn_match_threshold=0.7,
            rpn_unmatched_threshold=0.3,
            rpn_batch_size_per_im=256,
            rpn_fg_fraction=0.5,
            aug_rand_hflip=True,
            aug_scale_min=0.8,
            aug_scale_max=1.2,
            skip_crowd_during_training=skip_crowd_during_training,
            max_num_instances=max_num_instances,
            include_mask=include_mask,
            mask_crop_size=mask_crop_size,
            dtype='bfloat16')

        decoder = tf_example_decoder.TfExampleDecoder(
            include_mask=include_mask)
        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder.decode,
                                          parser_fn=parser.parse_fn(
                                              params.is_training))

        dataset = reader.read()
        iterator = iter(dataset)

        images, labels = next(iterator)

        np_images = images.numpy()
        np_labels = tf.nest.map_structure(lambda x: x.numpy(), labels)

        if is_training:
            self.assertAllEqual(
                np_images.shape,
                [batch_size, output_size[0], output_size[1], 3])
            self.assertAllEqual(np_labels['image_info'].shape,
                                [batch_size, 4, 2])
            self.assertAllEqual(np_labels['gt_boxes'].shape,
                                [batch_size, max_num_instances, 4])
            self.assertAllEqual(np_labels['gt_classes'].shape,
                                [batch_size, max_num_instances])
            if include_mask:
                self.assertAllEqual(np_labels['gt_masks'].shape, [
                    batch_size, max_num_instances, mask_crop_size,
                    mask_crop_size
                ])
            for level in range(min_level, max_level + 1):
                stride = 2**level
                output_size_l = [
                    output_size[0] / stride, output_size[1] / stride
                ]
                anchors_per_location = num_scales * len(aspect_ratios)
                self.assertAllEqual(
                    np_labels['rpn_score_targets'][level].shape, [
                        batch_size, output_size_l[0], output_size_l[1],
                        anchors_per_location
                    ])
                self.assertAllEqual(
                    np_labels['rpn_box_targets'][level].shape, [
                        batch_size, output_size_l[0], output_size_l[1],
                        4 * anchors_per_location
                    ])
                self.assertAllEqual(np_labels['anchor_boxes'][level].shape, [
                    batch_size, output_size_l[0], output_size_l[1],
                    4 * anchors_per_location
                ])
        else:
            self.assertAllEqual(
                np_images.shape,
                [batch_size, output_size[0], output_size[1], 3])
            self.assertAllEqual(np_labels['image_info'].shape,
                                [batch_size, 4, 2])