def test_parser(self, output_size, dtype, is_training, aug_name, is_multilabel, decode_jpeg_only, image_format): serialized_example = tfexample_utils.create_classification_example( output_size[0], output_size[1], image_format, is_multilabel) if aug_name == 'randaug': aug_type = common.Augmentation( type=aug_name, randaug=common.RandAugment(magnitude=10)) elif aug_name == 'autoaug': aug_type = common.Augmentation( type=aug_name, autoaug=common.AutoAugment(augmentation_name='test')) else: aug_type = None decoder = classification_input.Decoder(image_field_key=IMAGE_FIELD_KEY, label_field_key=LABEL_FIELD_KEY, is_multilabel=is_multilabel) parser = classification_input.Parser(output_size=output_size[:2], num_classes=10, image_field_key=IMAGE_FIELD_KEY, label_field_key=LABEL_FIELD_KEY, is_multilabel=is_multilabel, decode_jpeg_only=decode_jpeg_only, aug_rand_hflip=False, aug_type=aug_type, dtype=dtype) decoded_tensors = decoder.decode(serialized_example) image, label = parser.parse_fn(is_training)(decoded_tensors) self.assertAllEqual(image.numpy().shape, output_size) if not is_multilabel: self.assertAllEqual(label, 0) else: self.assertAllEqual(label.numpy().shape, [10]) if dtype == 'float32': self.assertAllEqual(image.dtype, tf.float32) elif dtype == 'float16': self.assertAllEqual(image.dtype, tf.float16) elif dtype == 'bfloat16': self.assertAllEqual(image.dtype, tf.bfloat16)
def build_inputs( self, params: base_cfg.DataConfig, input_context: Optional[tf.distribute.InputContext] = None ) -> tf.data.Dataset: """Builds classification input.""" num_classes = self.task_config.model.num_classes input_size = self.task_config.model.input_size image_field_key = self.task_config.train_data.image_field_key label_field_key = self.task_config.train_data.label_field_key is_multilabel = self.task_config.train_data.is_multilabel if params.tfds_name: if params.tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP: decoder = tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP[ params.tfds_name]() else: raise ValueError('TFDS {} is not supported'.format(params.tfds_name)) else: decoder = classification_input.Decoder( image_field_key=image_field_key, label_field_key=label_field_key, is_multilabel=is_multilabel) parser = classification_input.Parser( output_size=input_size[:2], num_classes=num_classes, image_field_key=image_field_key, label_field_key=label_field_key, decode_jpeg_only=params.decode_jpeg_only, aug_rand_hflip=params.aug_rand_hflip, aug_type=params.aug_type, is_multilabel=is_multilabel, dtype=params.dtype) reader = input_reader_factory.input_reader_generator( params, dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), decoder_fn=decoder.decode, parser_fn=parser.parse_fn(params.is_training)) dataset = reader.read(input_context=input_context) return dataset