Example #1
0
    def test_parser(self, output_size, dtype, is_training):

        params = cfg.DataConfig(input_path='imagenet-2012-tfrecord/train*',
                                global_batch_size=2,
                                is_training=True,
                                examples_consume=4)

        decoder = classification_input.Decoder()
        parser = classification_input.Parser(output_size=output_size[:2],
                                             num_classes=1001,
                                             aug_rand_hflip=False,
                                             dtype=dtype)

        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder.decode,
                                          parser_fn=parser.parse_fn(
                                              params.is_training))

        dataset = reader.read()

        images, labels = next(iter(dataset))

        self.assertAllEqual(images.numpy().shape,
                            [params.global_batch_size] + output_size)
        self.assertAllEqual(labels.numpy().shape, [params.global_batch_size])

        if dtype == 'float32':
            self.assertAllEqual(images.dtype, tf.float32)
        elif dtype == 'float16':
            self.assertAllEqual(images.dtype, tf.float16)
        elif dtype == 'bfloat16':
            self.assertAllEqual(images.dtype, tf.bfloat16)
Example #2
0
    def build_inputs(self, params, input_context=None):
        """Builds classification input."""

        num_classes = self.task_config.model.num_classes
        input_size = self.task_config.model.input_size

        if params.tfds_name:
            if params.tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP:
                decoder = tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP[
                    params.tfds_name]()
            else:
                raise ValueError('TFDS {} is not supported'.format(
                    params.tfds_name))
        else:
            decoder = classification_input.Decoder()

        parser = classification_input.Parser(output_size=input_size[:2],
                                             num_classes=num_classes,
                                             aug_policy=params.aug_policy,
                                             dtype=params.dtype)

        reader = input_reader.InputReader(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
    def build_inputs(
        self,
        params: exp_cfg.DataConfig,
        input_context: Optional[tf.distribute.InputContext] = None
    ) -> tf.data.Dataset:
        """Builds classification input."""

        num_classes = self.task_config.model.num_classes
        input_size = self.task_config.model.input_size
        image_field_key = self.task_config.train_data.image_field_key
        label_field_key = self.task_config.train_data.label_field_key
        is_multilabel = self.task_config.train_data.is_multilabel

        if params.tfds_name:
            decoder = tfds_factory.get_classification_decoder(params.tfds_name)
        else:
            decoder = classification_input.Decoder(
                image_field_key=image_field_key,
                label_field_key=label_field_key,
                is_multilabel=is_multilabel)

        parser = classification_input.Parser(
            output_size=input_size[:2],
            num_classes=num_classes,
            image_field_key=image_field_key,
            label_field_key=label_field_key,
            decode_jpeg_only=params.decode_jpeg_only,
            aug_rand_hflip=params.aug_rand_hflip,
            aug_type=params.aug_type,
            color_jitter=params.color_jitter,
            random_erasing=params.random_erasing,
            is_multilabel=is_multilabel,
            dtype=params.dtype)

        postprocess_fn = None
        if params.mixup_and_cutmix:
            postprocess_fn = augment.MixupAndCutmix(
                mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
                cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
                prob=params.mixup_and_cutmix.prob,
                label_smoothing=params.mixup_and_cutmix.label_smoothing,
                num_classes=num_classes)

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training),
            postprocess_fn=postprocess_fn)

        dataset = reader.read(input_context=input_context)

        return dataset
    def build_inputs(
        self,
        params: exp_cfg.DataConfig,
        input_context: Optional[tf.distribute.InputContext] = None
    ) -> tf.data.Dataset:
        """Builds classification input."""

        num_classes = self.task_config.model.num_classes
        input_size = self.task_config.model.input_size
        image_field_key = self.task_config.train_data.image_field_key
        label_field_key = self.task_config.train_data.label_field_key
        is_multilabel = self.task_config.train_data.is_multilabel

        if params.tfds_name:
            if params.tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP:
                decoder = tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP[
                    params.tfds_name]()
            else:
                raise ValueError('TFDS {} is not supported'.format(
                    params.tfds_name))
        else:
            decoder = classification_input.Decoder(
                image_field_key=image_field_key,
                label_field_key=label_field_key,
                is_multilabel=is_multilabel)

        parser = classification_input.Parser(
            output_size=input_size[:2],
            num_classes=num_classes,
            image_field_key=image_field_key,
            label_field_key=label_field_key,
            aug_rand_hflip=params.aug_rand_hflip,
            aug_type=params.aug_type,
            is_multilabel=is_multilabel,
            dtype=params.dtype)

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
    def build_inputs(self, params, input_context=None):
        """Builds classification input."""

        num_classes = self.task_config.model.num_classes
        input_size = self.task_config.model.input_size

        decoder = classification_input.Decoder()
        parser = classification_input.Parser(output_size=input_size[:2],
                                             num_classes=num_classes,
                                             dtype=params.dtype)

        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder.decode,
                                          parser_fn=parser.parse_fn(
                                              params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
Example #6
0
    def test_decoder(self, image_height, image_width, num_instances):
        decoder = classification_input.Decoder()

        image = _encode_image(np.uint8(
            np.random.rand(image_height, image_width, 3) * 255),
                              fmt='JPEG')
        label = 2
        serialized_example = tf.train.Example(features=tf.train.Features(
            feature={
                'image/encoded': (tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=[image]))),
                'image/class/label': (tf.train.Feature(
                    int64_list=tf.train.Int64List(value=[label]))),
            })).SerializeToString()
        decoded_tensors = decoder.decode(
            tf.convert_to_tensor(serialized_example))

        results = tf.nest.map_structure(lambda x: x.numpy(), decoded_tensors)
        self.assertCountEqual(['image/encoded', 'image/class/label'],
                              results.keys())
        self.assertEqual(label, results['image/class/label'])
  def build_inputs(self, params, input_context=None):
    """Builds classification input."""

    num_classes = self.task_config.model.num_classes
    input_size = self.task_config.model.input_size
    image_field_key = self.task_config.train_data.image_field_key
    label_field_key = self.task_config.train_data.label_field_key
    is_multilabel = self.task_config.train_data.is_multilabel

    if params.tfds_name:
      decoder = tfds_factory.get_classification_decoder(params.tfds_name)
    else:
      decoder = classification_input_base.Decoder(
          image_field_key=image_field_key,
          label_field_key=label_field_key,
          is_multilabel=is_multilabel)

    parser = classification_input.Parser(
        output_size=input_size[:2],
        num_classes=num_classes,
        image_field_key=image_field_key,
        label_field_key=label_field_key,
        decode_jpeg_only=params.decode_jpeg_only,
        aug_rand_hflip=params.aug_rand_hflip,
        aug_type=params.aug_type,
        is_multilabel=is_multilabel,
        dtype=params.dtype)

    reader = input_reader_factory.input_reader_generator(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)
    return dataset