コード例 #1
0
    def build_inputs(self, params, is_training=True, input_context=None):
        params.is_training = False
        decoder = tfds_coco_decoder.MSCOCODecoder()
        parser = yolo_input.Parser(
            image_w=params.parser.image_w,
            image_h=params.parser.image_h,
            num_classes=self._task.task_config.model.num_classes,
            fixed_size=params.parser.fixed_size,
            jitter_im=params.parser.jitter_im,
            jitter_boxes=params.parser.jitter_boxes,
            net_down_scale=params.parser.net_down_scale,
            min_process_size=params.parser.min_process_size,
            max_process_size=params.parser.max_process_size,
            max_num_instances=params.parser.max_num_instances,
            random_flip=params.parser.random_flip,
            pct_rand=params.parser.pct_rand,
            seed=params.parser.seed,
            anchors=self._task.task_config.model.boxes)

        if is_training:
            post_process_fn = parser.postprocess_fn()
        else:
            post_process_fn = None

        reader = input_reader.InputReader(
            params,
            dataset_fn=tf.data.TFRecordDataset,
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(is_training),
            postprocess_fn=post_process_fn)

        dataset = reader.read(input_context=input_context)
        return dataset
コード例 #2
0
    def build_inputs(self, params, input_context=None):
        """Builds classification input."""

        num_classes = self.task_config.model.num_classes
        input_size = self.task_config.model.input_size

        if params.tfds_name:
            if params.tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP:
                decoder = tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP[
                    params.tfds_name]()
            else:
                raise ValueError('TFDS {} is not supported'.format(
                    params.tfds_name))
        else:
            decoder = classification_input.Decoder()

        parser = classification_input.Parser(output_size=input_size[:2],
                                             num_classes=num_classes,
                                             aug_policy=params.aug_policy,
                                             dtype=params.dtype)

        reader = input_reader.InputReader(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
コード例 #3
0
ファイル: basnet.py プロジェクト: PurdueCAM2Project/tf-models
    def build_inputs(
            self,
            params: exp_cfg.DataConfig,
            input_context: Optional[tf.distribute.InputContext] = None):
        """Builds BASNet input."""

        ignore_label = self.task_config.losses.ignore_label

        decoder = segmentation_input.Decoder()
        parser = segmentation_input.Parser(
            output_size=params.output_size,
            crop_size=params.crop_size,
            ignore_label=ignore_label,
            aug_rand_hflip=params.aug_rand_hflip,
            dtype=params.dtype)

        reader = input_reader.InputReader(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
コード例 #4
0
    def testSegmentationInputReader(self, input_size, num_classes,
                                    num_channels):
        params = cfg.DataConfig(input_path=self._data_path,
                                global_batch_size=2,
                                is_training=False)

        decoder = segmentation_input_3d.Decoder()
        parser = segmentation_input_3d.Parser(input_size=input_size,
                                              num_classes=num_classes,
                                              num_channels=num_channels)

        reader = input_reader.InputReader(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn('tfrecord'),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))

        dataset = reader.read()
        iterator = iter(dataset)
        image, labels = next(iterator)

        # Checks image shape.
        self.assertEqual(
            list(image.numpy().shape),
            [2, input_size[0], input_size[1], input_size[2], num_channels])
        self.assertEqual(
            list(labels.numpy().shape),
            [2, input_size[0], input_size[1], input_size[2], num_classes])
コード例 #5
0
ファイル: coco.py プロジェクト: tensorflow/models
 def load(self, input_context: Optional[tf.distribute.InputContext] = None):
   """Returns a tf.dataset.Dataset."""
   reader = input_reader.InputReader(
       params=self._params,
       decoder_fn=None,
       transform_and_batch_fn=self._transform_and_batch_fn)
   return reader.read(input_context)
コード例 #6
0
  def build_inputs(self, params, input_context=None):
    """Builds classification input."""

    num_classes = self.task_config.model.num_classes
    input_size = self.task_config.model.input_size

    if params.tfds_name is not None:
      decoder = cli.Decoder()
    else:
      decoder = classification_input.Decoder()

    parser = classification_input.Parser(
        output_size=input_size[:2],
        num_classes=num_classes,
        aug_rand_saturation=params.parser.aug_rand or
        params.parser.aug_rand_saturation,
        aug_rand_brightness=params.parser.aug_rand or
        params.parser.aug_rand_brightness,
        aug_rand_zoom=params.parser.aug_rand or params.parser.aug_rand_zoom,
        aug_rand_rotate=params.parser.aug_rand or params.parser.aug_rand_rotate,
        aug_rand_hue=params.parser.aug_rand or params.parser.aug_rand_hue,
        aug_rand_aspect=params.parser.aug_rand or params.parser.aug_rand_aspect,
        scale=params.parser.scale,
        seed=params.parser.seed,
        dtype=params.dtype)

    reader = input_reader.InputReader(
        params,
        dataset_fn=tf.data.TFRecordDataset,
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)
    return dataset
コード例 #7
0
ファイル: simclr.py プロジェクト: kia-ctw/models
    def build_inputs(self, params, input_context=None):
        input_size = self.task_config.model.input_size

        if params.tfds_name:
            decoder = simclr_input.TFDSDecoder(params.decoder.decode_label)
        else:
            decoder = simclr_input.Decoder(params.decoder.decode_label)

        parser = simclr_input.Parser(
            output_size=input_size[:2],
            aug_rand_crop=params.parser.aug_rand_crop,
            aug_rand_hflip=params.parser.aug_rand_hflip,
            aug_color_distort=params.parser.aug_color_distort,
            aug_color_jitter_strength=params.parser.aug_color_jitter_strength,
            aug_color_jitter_impl=params.parser.aug_color_jitter_impl,
            aug_rand_blur=params.parser.aug_rand_blur,
            parse_label=params.parser.parse_label,
            test_crop=params.parser.test_crop,
            mode=params.parser.mode,
            dtype=params.dtype)

        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder.decode,
                                          parser_fn=parser.parse_fn(
                                              params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
コード例 #8
0
    def test_parser(self, output_size, dtype, is_training):

        params = cfg.DataConfig(input_path='imagenet-2012-tfrecord/train*',
                                global_batch_size=2,
                                is_training=True,
                                examples_consume=4)

        decoder = classification_input.Decoder()
        parser = classification_input.Parser(output_size=output_size[:2],
                                             num_classes=1001,
                                             aug_rand_hflip=False,
                                             dtype=dtype)

        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder.decode,
                                          parser_fn=parser.parse_fn(
                                              params.is_training))

        dataset = reader.read()

        images, labels = next(iter(dataset))

        self.assertAllEqual(images.numpy().shape,
                            [params.global_batch_size] + output_size)
        self.assertAllEqual(labels.numpy().shape, [params.global_batch_size])

        if dtype == 'float32':
            self.assertAllEqual(images.dtype, tf.float32)
        elif dtype == 'float16':
            self.assertAllEqual(images.dtype, tf.float16)
        elif dtype == 'bfloat16':
            self.assertAllEqual(images.dtype, tf.bfloat16)
コード例 #9
0
    def build_inputs(self, params, input_context=None):
        """Builds classification input."""

        ignore_label = self.task_config.losses.ignore_label

        if params.tfds_name:
            if params.tfds_name in tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP:
                decoder = tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP[
                    params.tfds_name]()
            else:
                raise ValueError('TFDS {} is not supported'.format(
                    params.tfds_name))
        else:
            decoder = segmentation_input.Decoder()

        parser = segmentation_input.Parser(
            output_size=params.output_size,
            train_on_crops=params.train_on_crops,
            ignore_label=ignore_label,
            resize_eval_groundtruth=params.resize_eval_groundtruth,
            groundtruth_padded_size=params.groundtruth_padded_size,
            aug_scale_min=params.aug_scale_min,
            aug_scale_max=params.aug_scale_max,
            aug_rand_hflip=params.aug_rand_hflip,
            dtype=params.dtype)

        reader = input_reader.InputReader(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
コード例 #10
0
    def build_inputs(self, params, input_context=None):
        """Builds classification input."""

        input_size = self.task_config.model.input_size
        ignore_label = self.task_config.losses.ignore_label

        decoder = segmentation_input.Decoder()
        parser = segmentation_input.Parser(
            output_size=input_size[:2],
            ignore_label=ignore_label,
            resize_eval_groundtruth=params.resize_eval_groundtruth,
            groundtruth_padded_size=params.groundtruth_padded_size,
            aug_scale_min=params.aug_scale_min,
            aug_scale_max=params.aug_scale_max,
            dtype=params.dtype)

        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder.decode,
                                          parser_fn=parser.parse_fn(
                                              params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
  def build_inputs(self, params, input_context=None):
    """Builds classification input."""

    ignore_label = self.task_config.losses.ignore_label

    decoder = segmentation_input.Decoder()
    parser = segmentation_input.Parser(
        output_size=params.output_size,
        train_on_crops=params.train_on_crops,
        ignore_label=ignore_label,
        resize_eval_groundtruth=params.resize_eval_groundtruth,
        groundtruth_padded_size=params.groundtruth_padded_size,
        aug_scale_min=params.aug_scale_min,
        aug_scale_max=params.aug_scale_max,
        aug_rand_hflip=params.aug_rand_hflip,
        dtype=params.dtype)

    reader = input_reader.InputReader(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)

    return dataset
コード例 #12
0
    def build_inputs(self, params: yt8m_cfg.DataConfig, input_context=None):
        """Builds input.

    Args:
      params: configuration for input data
      input_context: indicates information about the compute replicas and input
        pipelines

    Returns:
      dataset: dataset fetched from reader
    """

        decoder = yt8m_input.Decoder(input_params=params)
        decoder_fn = decoder.decode
        parser = yt8m_input.Parser(input_params=params)
        parser_fn = parser.parse_fn(params.is_training)
        postprocess = yt8m_input.PostBatchProcessor(input_params=params)
        postprocess_fn = postprocess.post_fn
        transform_batch = yt8m_input.TransformBatcher(input_params=params)
        batch_fn = transform_batch.batch_fn

        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder_fn,
                                          parser_fn=parser_fn,
                                          postprocess_fn=postprocess_fn,
                                          transform_and_batch_fn=batch_fn)

        dataset = reader.read(input_context=input_context)

        return dataset
コード例 #13
0
def test_yolo_input():
    with tf.device('/CPU:0'):
        params = DataConfig(is_training=True)
        num_boxes = 9

        decoder = tfds_coco_decoder.MSCOCODecoder()

        #anchors = box_rd.read(k = num_boxes, image_width = params.parser.image_w, input_context=None)
        anchors = [[12.0, 19.0], [31.0, 46.0], [96.0, 54.0], [46.0, 114.0],
                   [133.0, 127.0], [79.0, 225.0], [301.0, 150.0],
                   [172.0, 286.0], [348.0, 340.0]]
        # write the boxes to a file

        parser = YOLO_Detection_Input.Parser(
            image_w=params.parser.image_w,
            fixed_size=params.parser.fixed_size,
            jitter_im=params.parser.jitter_im,
            jitter_boxes=params.parser.jitter_boxes,
            min_level=params.parser.min_level,
            max_level=params.parser.max_level,
            min_process_size=params.parser.min_process_size,
            max_process_size=params.parser.max_process_size,
            max_num_instances=params.parser.max_num_instances,
            random_flip=params.parser.random_flip,
            pct_rand=params.parser.pct_rand,
            seed=params.parser.seed,
            anchors=anchors)

        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder.decode,
                                          parser_fn=parser.parse_fn(
                                              params.is_training))
        dataset = reader.read(input_context=None)
    return dataset
コード例 #14
0
 def load(self, input_context: Optional[tf.distribute.InputContext] = None):
   """Returns a tf.dataset.Dataset."""
   reader = input_reader.InputReader(
       params=self._params,
       decoder_fn=self._decode,
       transform_and_batch_fn=self._bucketize_and_batch
       if self._params.is_training else self._inference_padded_batch)
   return reader.read(input_context)
コード例 #15
0
 def load(self, input_context: Optional[tf.distribute.InputContext] = None):
     """Returns a tf.dataset.Dataset."""
     if input_context:
         self._num_replicas_in_sync = input_context.num_replicas_in_sync
     reader = input_reader.InputReader(params=self._params,
                                       decoder_fn=self._decode,
                                       parser_fn=self._parse)
     return reader.read(input_context)
コード例 #16
0
 def load(
     self,
     input_context: Optional[tf.distribute.InputContext] = None
 ) -> tf.data.Dataset:
   """Returns a tf.dataset.Dataset."""
   reader = input_reader.InputReader(
       params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
   return reader.read(input_context)
コード例 #17
0
 def load(self, input_context: Optional[tf.distribute.InputContext] = None):
     """Returns a tf.dataset.Dataset."""
     reader = input_reader.InputReader(
         dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
         decoder_fn=self._decode if self._params.input_path else None,
         params=self._params,
         postprocess_fn=self._bert_preprocess)
     return reader.read(input_context)
コード例 #18
0
 def load(self, input_context: Optional[tf.distribute.InputContext] = None):
     """Returns a tf.dataset.Dataset."""
     reader = input_reader.InputReader(
         dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
         params=self._params,
         decoder_fn=self._decode,
         parser_fn=self._parse)
     return reader.read(input_context)
コード例 #19
0
 def load(self, input_context: Optional[tf.distribute.InputContext] = None):
     """Returns a tf.dataset.Dataset."""
     reader = input_reader.InputReader(
         params=self._params,
         # Skip `decoder_fn` for tfds input.
         decoder_fn=self._decode if self._params.input_path else None,
         dataset_fn=tf.data.TFRecordDataset,
         postprocess_fn=self._bert_preprocess)
     return reader.read(input_context)
コード例 #20
0
ファイル: yolo.py プロジェクト: ric-rhee/test-githubactions
  def build_inputs(self, params, input_context=None):
    """Build input dataset."""
    decoder = tfds_coco_decoder.MSCOCODecoder()
    """
    decoder_cfg = params.decoder.get()
    if params.decoder.type == 'simple_decoder':
        decoder = tf_example_decoder.TfExampleDecoder(
            regenerate_source_id=decoder_cfg.regenerate_source_id)
    elif params.decoder.type == 'label_map_decoder':
        decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
            label_map=decoder_cfg.label_map,
            regenerate_source_id=decoder_cfg.regenerate_source_id)
    else:
        raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))
    """

    model = self.task_config.model

    masks, path_scales, xy_scales = self._get_masks()
    anchors = self._get_boxes(gen_boxes=params.is_training)

    print(masks, path_scales, xy_scales)
    parser = yolo_input.Parser(
        image_w=params.parser.image_w,
        image_h=params.parser.image_h,
        num_classes=model.num_classes,
        min_level=model.min_level,
        max_level=model.max_level,
        fixed_size=params.parser.fixed_size,
        jitter_im=params.parser.jitter_im,
        jitter_boxes=params.parser.jitter_boxes,
        masks=masks,
        letter_box=params.parser.letter_box,
        cutmix=params.parser.cutmix,
        use_tie_breaker=params.parser.use_tie_breaker,
        min_process_size=params.parser.min_process_size,
        max_process_size=params.parser.max_process_size,
        max_num_instances=params.parser.max_num_instances,
        random_flip=params.parser.random_flip,
        pct_rand=params.parser.pct_rand,
        seed=params.parser.seed,
        aug_rand_saturation=params.parser.aug_rand_saturation,
        aug_rand_brightness=params.parser.aug_rand_brightness,
        aug_rand_zoom=params.parser.aug_rand_zoom,
        aug_rand_hue=params.parser.aug_rand_hue,
        anchors=anchors,
        dtype=params.dtype)

    reader = input_reader.InputReader(
        params,
        dataset_fn=tf.data.TFRecordDataset,
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training),
        postprocess_fn=parser.postprocess_fn(params.is_training))
    dataset = reader.read(input_context=input_context)
    return dataset
コード例 #21
0
  def load(
      self,
      input_context = None
  ):
    """Returns a tf.dataset.Dataset."""
    if self._params.input_path == "test":
      return test_dataset(self._params.seq_length)

    reader = input_reader.InputReader(
        params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
    return reader.read(input_context)
コード例 #22
0
    def build_inputs(self, params, input_context=None):
        """Build input dataset."""
        decoder_cfg = params.decoder.get()
        if params.decoder.type == 'simple_decoder':
            decoder = tf_example_decoder.TfExampleDecoder(
                regenerate_source_id=decoder_cfg.regenerate_source_id)
        elif params.decoder.type == 'label_map_decoder':
            decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
                label_map=decoder_cfg.label_map,
                regenerate_source_id=decoder_cfg.regenerate_source_id)
        else:
            raise ValueError('Unknown decoder type: {}!'.format(
                params.decoder.type))
        decoder_cfg = params.decoder.get()
        if params.decoder.type == 'simple_decoder':
            decoder = tf_example_decoder.TfExampleDecoder(
                regenerate_source_id=decoder_cfg.regenerate_source_id)
        elif params.decoder.type == 'label_map_decoder':
            decoder = tf_example_decoder.TfExampleDecoderLabelMap(
                label_map=decoder_cfg.label_map,
                regenerate_source_id=decoder_cfg.regenerate_source_id)
        else:
            raise ValueError('Unknown decoder type: {}!'.format(
                params.decoder.type))
        parser = retinanet_input.Parser(
            output_size=self.task_config.model.input_size[:2],
            min_level=self.task_config.model.min_level,
            max_level=self.task_config.model.max_level,
            num_scales=self.task_config.model.anchor.num_scales,
            aspect_ratios=self.task_config.model.anchor.aspect_ratios,
            anchor_size=self.task_config.model.anchor.anchor_size,
            dtype=params.dtype,
            match_threshold=params.parser.match_threshold,
            unmatched_threshold=params.parser.unmatched_threshold,
            aug_rand_hflip=params.parser.aug_rand_hflip,
            aug_scale_min=params.parser.aug_scale_min,
            aug_scale_max=params.parser.aug_scale_max,
            skip_crowd_during_training=params.parser.
            skip_crowd_during_training,
            max_num_instances=params.parser.max_num_instances)

        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder.decode,
                                          parser_fn=parser.parse_fn(
                                              params.is_training))
        dataset = reader.read(input_context=input_context)

        return dataset
コード例 #23
0
    def create_input_reader(self, params):
        decoder = yt8m_input.Decoder(input_params=params)
        decoder_fn = decoder.decode
        parser = yt8m_input.Parser(input_params=params)
        parser_fn = parser.parse_fn(params.is_training)
        postprocess = yt8m_input.PostBatchProcessor(input_params=params)
        postprocess_fn = postprocess.post_fn
        transform_batch = yt8m_input.TransformBatcher(input_params=params)
        batch_fn = transform_batch.batch_fn

        return input_reader.InputReader(params,
                                        dataset_fn=tf.data.TFRecordDataset,
                                        decoder_fn=decoder_fn,
                                        parser_fn=parser_fn,
                                        postprocess_fn=postprocess_fn,
                                        transform_and_batch_fn=batch_fn)
  def build_inputs(self, params: exp_cfg.DataConfig, input_context=None):
    """Builds classification input."""

    parser = video_input.Parser(input_params=params)
    postprocess_fn = video_input.PostBatchProcessor(params)

    reader = input_reader.InputReader(
        params,
        dataset_fn=self._get_dataset_fn(params),
        decoder_fn=self._get_decoder_fn(params),
        parser_fn=parser.parse_fn(params.is_training),
        postprocess_fn=postprocess_fn)

    dataset = reader.read(input_context=input_context)

    return dataset
コード例 #25
0
    def build_inputs(self, params, input_context=None):
        """Build input dataset."""
        decoder_cfg = params.decoder.get()
        if params.decoder.type == 'simple_decoder':
            decoder = tf_example_decoder.TfExampleDecoder(
                include_mask=self._task_config.model.include_mask,
                regenerate_source_id=decoder_cfg.regenerate_source_id,
                mask_binarize_threshold=decoder_cfg.mask_binarize_threshold)
        elif params.decoder.type == 'label_map_decoder':
            decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
                label_map=decoder_cfg.label_map,
                include_mask=self._task_config.model.include_mask,
                regenerate_source_id=decoder_cfg.regenerate_source_id,
                mask_binarize_threshold=decoder_cfg.mask_binarize_threshold)
        else:
            raise ValueError('Unknown decoder type: {}!'.format(
                params.decoder.type))

        parser = maskrcnn_input.Parser(
            output_size=self.task_config.model.input_size[:2],
            min_level=self.task_config.model.min_level,
            max_level=self.task_config.model.max_level,
            num_scales=self.task_config.model.anchor.num_scales,
            aspect_ratios=self.task_config.model.anchor.aspect_ratios,
            anchor_size=self.task_config.model.anchor.anchor_size,
            dtype=params.dtype,
            rpn_match_threshold=params.parser.rpn_match_threshold,
            rpn_unmatched_threshold=params.parser.rpn_unmatched_threshold,
            rpn_batch_size_per_im=params.parser.rpn_batch_size_per_im,
            rpn_fg_fraction=params.parser.rpn_fg_fraction,
            aug_rand_hflip=params.parser.aug_rand_hflip,
            aug_scale_min=params.parser.aug_scale_min,
            aug_scale_max=params.parser.aug_scale_max,
            skip_crowd_during_training=params.parser.
            skip_crowd_during_training,
            max_num_instances=params.parser.max_num_instances,
            include_mask=self._task_config.model.include_mask,
            mask_crop_size=params.parser.mask_crop_size)

        reader = input_reader.InputReader(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))
        dataset = reader.read(input_context=input_context)

        return dataset
コード例 #26
0
ファイル: centernet.py プロジェクト: vishalbelsare/models
    def build_inputs(
            self,
            params: exp_cfg.DataConfig,
            input_context: Optional[tf.distribute.InputContext] = None):
        """Build input dataset."""
        if params.tfds_name:
            decoder = tfds_factory.get_detection_decoder(params.tfds_name)
        else:
            decoder_cfg = params.decoder.get()
            if params.decoder.type == 'simple_decoder':
                decoder = tf_example_decoder.TfExampleDecoder(
                    regenerate_source_id=decoder_cfg.regenerate_source_id)
            elif params.decoder.type == 'label_map_decoder':
                decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
                    label_map=decoder_cfg.label_map,
                    regenerate_source_id=decoder_cfg.regenerate_source_id)
            else:
                raise ValueError('Unknown decoder type: {}!'.format(
                    params.decoder.type))

        parser = centernet_input.CenterNetParser(
            output_height=self.task_config.model.input_size[0],
            output_width=self.task_config.model.input_size[1],
            max_num_instances=self.task_config.model.max_num_instances,
            bgr_ordering=params.parser.bgr_ordering,
            channel_means=params.parser.channel_means,
            channel_stds=params.parser.channel_stds,
            aug_rand_hflip=params.parser.aug_rand_hflip,
            aug_scale_min=params.parser.aug_scale_min,
            aug_scale_max=params.parser.aug_scale_max,
            aug_rand_hue=params.parser.aug_rand_hue,
            aug_rand_brightness=params.parser.aug_rand_brightness,
            aug_rand_contrast=params.parser.aug_rand_contrast,
            aug_rand_saturation=params.parser.aug_rand_saturation,
            odapi_augmentation=params.parser.odapi_augmentation,
            dtype=params.dtype)

        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder.decode,
                                          parser_fn=parser.parse_fn(
                                              params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
コード例 #27
0
def input_reader_generator(params: cfg.DataConfig,
                           **kwargs) -> core_input_reader.InputReader:
  """Instantiates an input reader class according to the params.

  Args:
    params: A config_definitions.DataConfig object.
    **kwargs: Additional arguments passed to input reader initialization.

  Returns:
    An InputReader object.

  """
  if params.is_training and params.get('pseudo_label_data', False):
    return vision_input_reader.CombinationDatasetInputReader(
        params,
        pseudo_label_dataset_fn=dataset_fn_util.pick_dataset_fn(
            params.pseudo_label_data.file_type),
        **kwargs)
  else:
    return core_input_reader.InputReader(params, **kwargs)
コード例 #28
0
    def build_inputs(self, params, input_context=None):
        """Builds classification input."""

        num_classes = self.task_config.model.num_classes
        input_size = self.task_config.model.input_size

        decoder = classification_input.Decoder()
        parser = classification_input.Parser(output_size=input_size[:2],
                                             num_classes=num_classes,
                                             dtype=params.dtype)

        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder.decode,
                                          parser_fn=parser.parse_fn(
                                              params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
コード例 #29
0
ファイル: simclr.py プロジェクト: kia-ctw/models
    def build_inputs(self, params, input_context=None):
        input_size = self.task_config.model.input_size

        if params.tfds_name:
            decoder = simclr_input.TFDSDecoder(params.decoder.decode_label)
        else:
            decoder = simclr_input.Decoder(params.decoder.decode_label)
        parser = simclr_input.Parser(output_size=input_size[:2],
                                     parse_label=params.parser.parse_label,
                                     test_crop=params.parser.test_crop,
                                     mode=params.parser.mode,
                                     dtype=params.dtype)

        reader = input_reader.InputReader(params,
                                          dataset_fn=tf.data.TFRecordDataset,
                                          decoder_fn=decoder.decode,
                                          parser_fn=parser.parse_fn(
                                              params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
コード例 #30
0
    def load(self, input_context: Optional[tf.distribute.InputContext] = None):
        """Returns a tf.dataset.Dataset."""
        decoder_fn = None
        # Only decode for TFRecords.
        if self._params.input_path:
            decoder_fn = self._decode

        def _identity(
                dataset,
                input_context: Optional[tf.distribute.InputContext] = None):
            del input_context
            return dataset

        transform_and_batch_fn = _identity
        if self._params.transform_and_batch:
            transform_and_batch_fn = self._tokenize_bucketize_and_batch

        reader = input_reader.InputReader(
            params=self._params,
            decoder_fn=decoder_fn,
            transform_and_batch_fn=transform_and_batch_fn)
        return reader.read(input_context)