コード例 #1
0
    def build_inputs(self, params, input_context=None):
        """Builds classification input."""

        input_size = self.task_config.model.input_size
        ignore_label = self.task_config.losses.ignore_label

        decoder = segmentation_input.Decoder()
        parser = segmentation_input.Parser(
            output_size=input_size[:2],
            ignore_label=ignore_label,
            resize_eval_groundtruth=params.resize_eval_groundtruth,
            groundtruth_padded_size=params.groundtruth_padded_size,
            aug_scale_min=params.aug_scale_min,
            aug_scale_max=params.aug_scale_max,
            dtype=params.dtype)

        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder.decode,
                                          parser_fn=parser.parse_fn(
                                              params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
コード例 #2
0
    def build_inputs(
            self,
            params: cfg.DataConfig,
            input_context: Optional[tf.distribute.InputContext] = None):
        """Builds classification input."""
        ignore_label = self.task_config.losses.ignore_label

        if params.tfds_name:
            decoder = tfds_factory.get_segmentation_decoder(params.tfds_name)
        else:
            decoder = segmentation_input.Decoder()

        parser = ClassMappingParser(
            output_size=params.output_size,
            crop_size=params.crop_size,
            ignore_label=ignore_label,
            resize_eval_groundtruth=params.resize_eval_groundtruth,
            groundtruth_padded_size=params.groundtruth_padded_size,
            aug_scale_min=params.aug_scale_min,
            aug_scale_max=params.aug_scale_max,
            aug_rand_hflip=params.aug_rand_hflip,
            dtype=params.dtype)

        parser.max_class = self.task_config.model.num_classes - 1

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
コード例 #3
0
    def build_inputs(self, params, input_context=None):
        """Builds classification input."""

        ignore_label = self.task_config.losses.ignore_label

        if params.tfds_name:
            if params.tfds_name in tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP:
                decoder = tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP[
                    params.tfds_name]()
            else:
                raise ValueError('TFDS {} is not supported'.format(
                    params.tfds_name))
        else:
            decoder = segmentation_input.Decoder()

        parser = segmentation_input.Parser(
            output_size=params.output_size,
            train_on_crops=params.train_on_crops,
            ignore_label=ignore_label,
            resize_eval_groundtruth=params.resize_eval_groundtruth,
            groundtruth_padded_size=params.groundtruth_padded_size,
            aug_scale_min=params.aug_scale_min,
            aug_scale_max=params.aug_scale_max,
            aug_rand_hflip=params.aug_rand_hflip,
            dtype=params.dtype)

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
  def build_inputs(self, params, input_context=None):
    """Builds classification input."""

    ignore_label = self.task_config.losses.ignore_label

    decoder = segmentation_input.Decoder()
    parser = segmentation_input.Parser(
        output_size=params.output_size,
        train_on_crops=params.train_on_crops,
        ignore_label=ignore_label,
        resize_eval_groundtruth=params.resize_eval_groundtruth,
        groundtruth_padded_size=params.groundtruth_padded_size,
        aug_scale_min=params.aug_scale_min,
        aug_scale_max=params.aug_scale_max,
        aug_rand_hflip=params.aug_rand_hflip,
        dtype=params.dtype)

    reader = input_reader.InputReader(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)

    return dataset
コード例 #5
0
    def build_inputs(
            self,
            params: exp_cfg.DataConfig,
            input_context: Optional[tf.distribute.InputContext] = None):
        """Builds classification input."""

        ignore_label = self.task_config.losses.ignore_label

        if params.tfds_name:
            if params.tfds_name in tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP:
                decoder = tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP[
                    params.tfds_name]()
            else:
                raise ValueError('TFDS {} is not supported'.format(
                    params.tfds_name))
        else:
            decoder = segmentation_input.Decoder()

        parser = segmentation_input.Parser(
            output_size=params.output_size,
            crop_size=params.crop_size,
            ignore_label=ignore_label,
            resize_eval_groundtruth=params.resize_eval_groundtruth,
            groundtruth_padded_size=params.groundtruth_padded_size,
            aug_scale_min=params.aug_scale_min,
            aug_scale_max=params.aug_scale_max,
            aug_rand_hflip=params.aug_rand_hflip,
            aug_policy=params.aug_policy,
            randaug_magnitude=params.randaug_magnitude,
            randaug_available_ops=params.randaug_available_ops,
            preserve_aspect_ratio=params.preserve_aspect_ratio,
            rotate_min=params.rotate_min,
            rotate_max=params.rotate_max,
            bright_min=params.bright_min,
            bright_max=params.bright_max,
            dtype=params.dtype)

        reader = input_reader_factory.input_reader_generator(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
コード例 #6
0
  def build_inputs(self,
                   params: exp_cfg.DataConfig,
                   input_context: Optional[tf.distribute.InputContext] = None):
    """Builds BASNet input."""

    ignore_label = self.task_config.losses.ignore_label

    decoder = segmentation_input.Decoder()
    parser = segmentation_input.Parser(
        output_size=params.output_size,
        crop_size=params.crop_size,
        ignore_label=ignore_label,
        aug_rand_hflip=params.aug_rand_hflip,
        dtype=params.dtype)

    reader = input_reader.InputReader(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)

    return dataset