示例#1
0
 def test_classification_decoder(self, tfds_name):
     decoder = tfds_factory.get_classification_decoder(tfds_name)
     self.assertIsInstance(decoder, base_decoder.Decoder)
     decoded_tensor = decoder.decode(self._create_test_example())
     self.assertLen(decoded_tensor, 2)
     self.assertIn('image/encoded', decoded_tensor)
     self.assertIn('image/class/label', decoded_tensor)
    def build_inputs(
        self,
        params: exp_cfg.DataConfig,
        input_context: Optional[tf.distribute.InputContext] = None
    ) -> tf.data.Dataset:
        """Builds classification input."""

        num_classes = self.task_config.model.num_classes
        input_size = self.task_config.model.input_size
        image_field_key = self.task_config.train_data.image_field_key
        label_field_key = self.task_config.train_data.label_field_key
        is_multilabel = self.task_config.train_data.is_multilabel

        if params.tfds_name:
            decoder = tfds_factory.get_classification_decoder(params.tfds_name)
        else:
            decoder = classification_input.Decoder(
                image_field_key=image_field_key,
                label_field_key=label_field_key,
                is_multilabel=is_multilabel)

        parser = classification_input.Parser(
            output_size=input_size[:2],
            num_classes=num_classes,
            image_field_key=image_field_key,
            label_field_key=label_field_key,
            decode_jpeg_only=params.decode_jpeg_only,
            aug_rand_hflip=params.aug_rand_hflip,
            aug_type=params.aug_type,
            color_jitter=params.color_jitter,
            random_erasing=params.random_erasing,
            is_multilabel=is_multilabel,
            dtype=params.dtype)

        postprocess_fn = None
        if params.mixup_and_cutmix:
            postprocess_fn = augment.MixupAndCutmix(
                mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
                cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
                prob=params.mixup_and_cutmix.prob,
                label_smoothing=params.mixup_and_cutmix.label_smoothing,
                num_classes=num_classes)

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training),
            postprocess_fn=postprocess_fn)

        dataset = reader.read(input_context=input_context)

        return dataset
    def build_inputs(
        self,
        params: exp_cfg.DataConfig,
        input_context: Optional[tf.distribute.InputContext] = None
    ) -> tf.data.Dataset:
        """Builds classification input."""

        num_classes = self.task_config.model.num_classes
        input_size = self.task_config.model.input_size
        image_field_key = self.task_config.train_data.image_field_key
        label_field_key = self.task_config.train_data.label_field_key
        is_multilabel = self.task_config.train_data.is_multilabel

        if params.tfds_name:
            decoder = tfds_factory.get_classification_decoder(params.tfds_name)
        else:
            decoder = classification_input.Decoder(
                image_field_key=image_field_key,
                label_field_key=label_field_key,
                is_multilabel=is_multilabel)

        parser = classification_input.Parser(
            output_size=input_size[:2],
            num_classes=num_classes,
            image_field_key=image_field_key,
            label_field_key=label_field_key,
            decode_jpeg_only=params.decode_jpeg_only,
            aug_rand_hflip=params.aug_rand_hflip,
            aug_type=params.aug_type,
            is_multilabel=is_multilabel,
            dtype=params.dtype)

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
示例#4
0
 def test_doesnt_exit_classification_decoder(self, tfds_name):
     with self.assertRaises(ValueError):
         _ = tfds_factory.get_classification_decoder(tfds_name)