Exemplo n.º 1
0
  def build_inputs(self,
                   params: exp_cfg.DataConfig,
                   input_context: Optional[tf.distribute.InputContext] = None):
    """Builds panoptic deeplab input."""
    decoder_cfg = params.decoder.get()

    if params.decoder.type == 'simple_decoder':
      decoder = panoptic_deeplab_input.TfExampleDecoder(
          regenerate_source_id=decoder_cfg.regenerate_source_id,
          panoptic_category_mask_key=decoder_cfg.panoptic_category_mask_key,
          panoptic_instance_mask_key=decoder_cfg.panoptic_instance_mask_key)
    else:
      raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))

    parser = panoptic_deeplab_input.Parser(
        output_size=self.task_config.model.input_size[:2],
        ignore_label=params.parser.ignore_label,
        resize_eval_groundtruth=params.parser.resize_eval_groundtruth,
        groundtruth_padded_size=params.parser.groundtruth_padded_size,
        aug_scale_min=params.parser.aug_scale_min,
        aug_scale_max=params.parser.aug_scale_max,
        aug_rand_hflip=params.parser.aug_rand_hflip,
        aug_type=params.parser.aug_type,
        sigma=params.parser.sigma,
        dtype=params.parser.dtype)

    reader = input_reader_factory.input_reader_generator(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)

    return dataset
Exemplo n.º 2
0
  def build_inputs(self,
                   params: cfg.DataConfig,
                   input_context: Optional[tf.distribute.InputContext] = None):
    """Builds classification input."""
    ignore_label = self.task_config.losses.ignore_label

    if params.tfds_name:
      decoder = tfds_factory.get_segmentation_decoder(params.tfds_name)
    else:
      decoder = segmentation_input.Decoder()

    parser = ClassMappingParser(
        output_size=params.output_size,
        crop_size=params.crop_size,
        ignore_label=ignore_label,
        resize_eval_groundtruth=params.resize_eval_groundtruth,
        groundtruth_padded_size=params.groundtruth_padded_size,
        aug_scale_min=params.aug_scale_min,
        aug_scale_max=params.aug_scale_max,
        aug_rand_hflip=params.aug_rand_hflip,
        dtype=params.dtype)

    parser.max_class = self.task_config.model.num_classes-1

    reader = input_reader_factory.input_reader_generator(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)

    return dataset
Exemplo n.º 3
0
    def build_inputs(
            self,
            params: exp_cfg.DataConfig,
            input_context: Optional[tf.distribute.InputContext] = None):
        """Builds classification input."""

        parser = video_input.Parser(input_params=params,
                                    image_key=params.image_field_key,
                                    label_key=params.label_field_key)
        postprocess_fn = video_input.PostBatchProcessor(params)
        if params.mixup_and_cutmix is not None:

            def mixup_and_cutmix(features, labels):
                augmenter = augment.MixupAndCutmix(
                    mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
                    cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
                    prob=params.mixup_and_cutmix.prob,
                    label_smoothing=params.mixup_and_cutmix.label_smoothing,
                    num_classes=self._get_num_classes())
                features['image'], labels = augmenter(features['image'],
                                                      labels)
                return features, labels

            postprocess_fn = mixup_and_cutmix

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=self._get_dataset_fn(params),
            decoder_fn=self._get_decoder_fn(params),
            parser_fn=parser.parse_fn(params.is_training),
            postprocess_fn=postprocess_fn)

        dataset = reader.read(input_context=input_context)

        return dataset
Exemplo n.º 4
0
    def build_inputs(
        self,
        params: exp_cfg.ExampleDataConfig,
        input_context: Optional[tf.distribute.InputContext] = None
    ) -> tf.data.Dataset:
        """Builds input.

    The input from this function is a tf.data.Dataset that has gone through
    pre-processing steps, such as augmentation, batching, shuffuling, etc.

    Args:
      params: The experiment config.
      input_context: An optional InputContext used by input reader.

    Returns:
      A tf.data.Dataset object.
    """

        num_classes = self.task_config.model.num_classes
        input_size = self.task_config.model.input_size
        decoder = example_input.Decoder()

        parser = example_input.Parser(output_size=input_size[:2],
                                      num_classes=num_classes)

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
Exemplo n.º 5
0
    def build_inputs(
        self,
        params: exp_cfg.DataConfig,
        input_context: Optional[tf.distribute.InputContext] = None,
        dataset_fn: Optional[dataset_fn_lib.PossibleDatasetType] = None
    ) -> tf.data.Dataset:
        """Build input dataset."""
        decoder_cfg = params.decoder.get()
        if params.decoder.type == 'simple_decoder':
            decoder = tf_example_decoder.TfExampleDecoder(
                include_mask=self._task_config.model.include_mask,
                regenerate_source_id=decoder_cfg.regenerate_source_id,
                mask_binarize_threshold=decoder_cfg.mask_binarize_threshold)
        elif params.decoder.type == 'label_map_decoder':
            decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
                label_map=decoder_cfg.label_map,
                include_mask=self._task_config.model.include_mask,
                regenerate_source_id=decoder_cfg.regenerate_source_id,
                mask_binarize_threshold=decoder_cfg.mask_binarize_threshold)
        else:
            raise ValueError('Unknown decoder type: {}!'.format(
                params.decoder.type))

        parser = maskrcnn_input.Parser(
            output_size=self.task_config.model.input_size[:2],
            min_level=self.task_config.model.min_level,
            max_level=self.task_config.model.max_level,
            num_scales=self.task_config.model.anchor.num_scales,
            aspect_ratios=self.task_config.model.anchor.aspect_ratios,
            anchor_size=self.task_config.model.anchor.anchor_size,
            dtype=params.dtype,
            rpn_match_threshold=params.parser.rpn_match_threshold,
            rpn_unmatched_threshold=params.parser.rpn_unmatched_threshold,
            rpn_batch_size_per_im=params.parser.rpn_batch_size_per_im,
            rpn_fg_fraction=params.parser.rpn_fg_fraction,
            aug_rand_hflip=params.parser.aug_rand_hflip,
            aug_scale_min=params.parser.aug_scale_min,
            aug_scale_max=params.parser.aug_scale_max,
            aug_type=params.parser.aug_type,
            skip_crowd_during_training=params.parser.
            skip_crowd_during_training,
            max_num_instances=params.parser.max_num_instances,
            include_mask=self._task_config.model.include_mask,
            mask_crop_size=params.parser.mask_crop_size)

        if not dataset_fn:
            dataset_fn = dataset_fn_lib.pick_dataset_fn(params.file_type)

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=dataset_fn,
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))
        dataset = reader.read(input_context=input_context)

        return dataset
    def build_inputs(
        self,
        params: exp_cfg.DataConfig,
        input_context: Optional[tf.distribute.InputContext] = None
    ) -> tf.data.Dataset:
        """Build input dataset."""
        decoder_cfg = params.decoder.get()
        if params.decoder.type == 'simple_decoder':
            decoder = panoptic_maskrcnn_input.TfExampleDecoder(
                regenerate_source_id=decoder_cfg.regenerate_source_id,
                mask_binarize_threshold=decoder_cfg.mask_binarize_threshold,
                include_panoptic_masks=decoder_cfg.include_panoptic_masks,
                panoptic_category_mask_key=decoder_cfg.
                panoptic_category_mask_key,
                panoptic_instance_mask_key=decoder_cfg.
                panoptic_instance_mask_key)
        else:
            raise ValueError('Unknown decoder type: {}!'.format(
                params.decoder.type))

        parser = panoptic_maskrcnn_input.Parser(
            output_size=self.task_config.model.input_size[:2],
            min_level=self.task_config.model.min_level,
            max_level=self.task_config.model.max_level,
            num_scales=self.task_config.model.anchor.num_scales,
            aspect_ratios=self.task_config.model.anchor.aspect_ratios,
            anchor_size=self.task_config.model.anchor.anchor_size,
            dtype=params.dtype,
            rpn_match_threshold=params.parser.rpn_match_threshold,
            rpn_unmatched_threshold=params.parser.rpn_unmatched_threshold,
            rpn_batch_size_per_im=params.parser.rpn_batch_size_per_im,
            rpn_fg_fraction=params.parser.rpn_fg_fraction,
            aug_rand_hflip=params.parser.aug_rand_hflip,
            aug_scale_min=params.parser.aug_scale_min,
            aug_scale_max=params.parser.aug_scale_max,
            skip_crowd_during_training=params.parser.
            skip_crowd_during_training,
            max_num_instances=params.parser.max_num_instances,
            mask_crop_size=params.parser.mask_crop_size,
            segmentation_resize_eval_groundtruth=params.parser.
            segmentation_resize_eval_groundtruth,
            segmentation_groundtruth_padded_size=params.parser.
            segmentation_groundtruth_padded_size,
            segmentation_ignore_label=params.parser.segmentation_ignore_label,
            panoptic_ignore_label=params.parser.panoptic_ignore_label,
            include_panoptic_masks=params.parser.include_panoptic_masks)

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))
        dataset = reader.read(input_context=input_context)

        return dataset
Exemplo n.º 7
0
  def build_inputs(
      self,
      params: exp_cfg.DataConfig,
      input_context: Optional[tf.distribute.InputContext] = None
  ) -> tf.data.Dataset:
    """Builds classification input."""

    num_classes = self.task_config.model.num_classes
    input_size = self.task_config.model.input_size
    image_field_key = self.task_config.train_data.image_field_key
    label_field_key = self.task_config.train_data.label_field_key
    is_multilabel = self.task_config.train_data.is_multilabel

    if params.tfds_name:
      decoder = tfds_factory.get_classification_decoder(params.tfds_name)
    else:
      decoder = classification_input.Decoder(
          image_field_key=image_field_key, label_field_key=label_field_key,
          is_multilabel=is_multilabel)

    parser = classification_input.Parser(
        output_size=input_size[:2],
        num_classes=num_classes,
        image_field_key=image_field_key,
        label_field_key=label_field_key,
        decode_jpeg_only=params.decode_jpeg_only,
        aug_rand_hflip=params.aug_rand_hflip,
        aug_crop=params.aug_crop,
        aug_type=params.aug_type,
        color_jitter=params.color_jitter,
        random_erasing=params.random_erasing,
        is_multilabel=is_multilabel,
        dtype=params.dtype)

    postprocess_fn = None
    if params.mixup_and_cutmix:
      postprocess_fn = augment.MixupAndCutmix(
          mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
          cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
          prob=params.mixup_and_cutmix.prob,
          label_smoothing=params.mixup_and_cutmix.label_smoothing,
          num_classes=num_classes)

    reader = input_reader_factory.input_reader_generator(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training),
        postprocess_fn=postprocess_fn)

    dataset = reader.read(input_context=input_context)

    return dataset
Exemplo n.º 8
0
    def build_inputs(
            self,
            params: exp_cfg.DataConfig,
            input_context: Optional[tf.distribute.InputContext] = None):
        """Build input dataset."""

        if params.tfds_name:
            decoder = tfds_factory.get_detection_decoder(params.tfds_name)
        else:
            decoder_cfg = params.decoder.get()
            if params.decoder.type == 'simple_decoder':
                decoder = tf_example_decoder.TfExampleDecoder(
                    regenerate_source_id=decoder_cfg.regenerate_source_id)
            elif params.decoder.type == 'label_map_decoder':
                decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
                    label_map=decoder_cfg.label_map,
                    regenerate_source_id=decoder_cfg.regenerate_source_id)
            else:
                raise ValueError('Unknown decoder type: {}!'.format(
                    params.decoder.type))

        parser = retinanet_input.Parser(
            output_size=self.task_config.model.input_size[:2],
            min_level=self.task_config.model.min_level,
            max_level=self.task_config.model.max_level,
            num_scales=self.task_config.model.anchor.num_scales,
            aspect_ratios=self.task_config.model.anchor.aspect_ratios,
            anchor_size=self.task_config.model.anchor.anchor_size,
            dtype=params.dtype,
            match_threshold=params.parser.match_threshold,
            unmatched_threshold=params.parser.unmatched_threshold,
            aug_type=params.parser.aug_type,
            aug_rand_hflip=params.parser.aug_rand_hflip,
            aug_scale_min=params.parser.aug_scale_min,
            aug_scale_max=params.parser.aug_scale_max,
            skip_crowd_during_training=params.parser.
            skip_crowd_during_training,
            max_num_instances=params.parser.max_num_instances)

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))
        dataset = reader.read(input_context=input_context)

        return dataset
  def build_inputs(self,
                   params: exp_cfg.DataConfig,
                   input_context: Optional[tf.distribute.InputContext] = None):
    """Builds classification input."""

    parser = video_input.Parser(
        input_params=params,
        image_key=params.image_field_key,
        label_key=params.label_field_key)
    postprocess_fn = video_input.PostBatchProcessor(params)

    reader = input_reader_factory.input_reader_generator(
        params,
        dataset_fn=self._get_dataset_fn(params),
        decoder_fn=self._get_decoder_fn(params),
        parser_fn=parser.parse_fn(params.is_training),
        postprocess_fn=postprocess_fn)

    dataset = reader.read(input_context=input_context)

    return dataset
Exemplo n.º 10
0
  def build_inputs(
      self,
      params: base_cfg.DataConfig,
      input_context: Optional[tf.distribute.InputContext] = None
  ) -> tf.data.Dataset:
    """Builds classification input."""

    num_classes = self.task_config.model.num_classes
    input_size = self.task_config.model.input_size
    image_field_key = self.task_config.train_data.image_field_key
    label_field_key = self.task_config.train_data.label_field_key
    is_multilabel = self.task_config.train_data.is_multilabel

    if params.tfds_name:
      raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
    else:
      decoder = classification_input.Decoder(
          image_field_key=image_field_key, label_field_key=label_field_key,
          is_multilabel=is_multilabel)

    parser = classification_input.Parser(
        output_size=input_size[:2],
        num_classes=num_classes,
        image_field_key=image_field_key,
        label_field_key=label_field_key,
        decode_jpeg_only=params.decode_jpeg_only,
        aug_rand_hflip=params.aug_rand_hflip,
        aug_type=params.aug_type,
        is_multilabel=is_multilabel,
        dtype=params.dtype)

    reader = input_reader_factory.input_reader_generator(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)

    return dataset
Exemplo n.º 11
0
    def build_inputs(self, params, input_context=None):
        """Builds classification input."""

        num_classes = self.task_config.model.num_classes
        input_size = self.task_config.model.input_size
        image_field_key = self.task_config.train_data.image_field_key
        label_field_key = self.task_config.train_data.label_field_key
        is_multilabel = self.task_config.train_data.is_multilabel

        if params.tfds_name:
            decoder = tfds_factory.get_classification_decoder(params.tfds_name)
        else:
            decoder = classification_input_base.Decoder(
                image_field_key=image_field_key,
                label_field_key=label_field_key,
                is_multilabel=is_multilabel)

        parser = classification_input.Parser(
            output_size=input_size[:2],
            num_classes=num_classes,
            image_field_key=image_field_key,
            label_field_key=label_field_key,
            decode_jpeg_only=params.decode_jpeg_only,
            aug_rand_hflip=params.aug_rand_hflip,
            aug_type=params.aug_type,
            is_multilabel=is_multilabel,
            dtype=params.dtype)

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))

        dataset = reader.read(input_context=input_context)
        return dataset
Exemplo n.º 12
0
    def build_inputs(
            self,
            params,
            input_context: Optional[tf.distribute.InputContext] = None):
        """Build input dataset."""
        if isinstance(params, coco.COCODataConfig):
            dataset = coco.COCODataLoader(params).load(input_context)
        else:
            if params.tfds_name:
                decoder = tfds_factory.get_detection_decoder(params.tfds_name)
            else:
                decoder_cfg = params.decoder.get()
                if params.decoder.type == 'simple_decoder':
                    decoder = tf_example_decoder.TfExampleDecoder(
                        regenerate_source_id=decoder_cfg.regenerate_source_id)
                elif params.decoder.type == 'label_map_decoder':
                    decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
                        label_map=decoder_cfg.label_map,
                        regenerate_source_id=decoder_cfg.regenerate_source_id)
                else:
                    raise ValueError('Unknown decoder type: {}!'.format(
                        params.decoder.type))

            parser = detr_input.Parser(
                class_offset=self._task_config.losses.class_offset,
                output_size=self._task_config.model.input_size[:2],
            )

            reader = input_reader_factory.input_reader_generator(
                params,
                dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
                decoder_fn=decoder.decode,
                parser_fn=parser.parse_fn(params.is_training))
            dataset = reader.read(input_context=input_context)

        return dataset