def test_parser(self, output_size, dtype, is_training, aug_name,
                    is_multilabel, decode_jpeg_only, image_format):

        serialized_example = tfexample_utils.create_classification_example(
            output_size[0], output_size[1], image_format, is_multilabel)

        if aug_name == 'randaug':
            aug_type = common.Augmentation(
                type=aug_name, randaug=common.RandAugment(magnitude=10))
        elif aug_name == 'autoaug':
            aug_type = common.Augmentation(
                type=aug_name,
                autoaug=common.AutoAugment(augmentation_name='test'))
        else:
            aug_type = None

        decoder = classification_input.Decoder(image_field_key=IMAGE_FIELD_KEY,
                                               label_field_key=LABEL_FIELD_KEY,
                                               is_multilabel=is_multilabel)
        parser = classification_input.Parser(output_size=output_size[:2],
                                             num_classes=10,
                                             image_field_key=IMAGE_FIELD_KEY,
                                             label_field_key=LABEL_FIELD_KEY,
                                             is_multilabel=is_multilabel,
                                             decode_jpeg_only=decode_jpeg_only,
                                             aug_rand_hflip=False,
                                             aug_type=aug_type,
                                             dtype=dtype)

        decoded_tensors = decoder.decode(serialized_example)
        image, label = parser.parse_fn(is_training)(decoded_tensors)

        self.assertAllEqual(image.numpy().shape, output_size)

        if not is_multilabel:
            self.assertAllEqual(label, 0)
        else:
            self.assertAllEqual(label.numpy().shape, [10])

        if dtype == 'float32':
            self.assertAllEqual(image.dtype, tf.float32)
        elif dtype == 'float16':
            self.assertAllEqual(image.dtype, tf.float16)
        elif dtype == 'bfloat16':
            self.assertAllEqual(image.dtype, tf.bfloat16)
  def build_inputs(
      self,
      params: base_cfg.DataConfig,
      input_context: Optional[tf.distribute.InputContext] = None
  ) -> tf.data.Dataset:
    """Builds classification input."""

    num_classes = self.task_config.model.num_classes
    input_size = self.task_config.model.input_size
    image_field_key = self.task_config.train_data.image_field_key
    label_field_key = self.task_config.train_data.label_field_key
    is_multilabel = self.task_config.train_data.is_multilabel

    if params.tfds_name:
      if params.tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP:
        decoder = tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP[
            params.tfds_name]()
      else:
        raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
    else:
      decoder = classification_input.Decoder(
          image_field_key=image_field_key, label_field_key=label_field_key,
          is_multilabel=is_multilabel)

    parser = classification_input.Parser(
        output_size=input_size[:2],
        num_classes=num_classes,
        image_field_key=image_field_key,
        label_field_key=label_field_key,
        decode_jpeg_only=params.decode_jpeg_only,
        aug_rand_hflip=params.aug_rand_hflip,
        aug_type=params.aug_type,
        is_multilabel=is_multilabel,
        dtype=params.dtype)

    reader = input_reader_factory.input_reader_generator(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)

    return dataset