def build_inputs(self, params, is_training=True, input_context=None): params.is_training = False decoder = tfds_coco_decoder.MSCOCODecoder() parser = yolo_input.Parser( image_w=params.parser.image_w, image_h=params.parser.image_h, num_classes=self._task.task_config.model.num_classes, fixed_size=params.parser.fixed_size, jitter_im=params.parser.jitter_im, jitter_boxes=params.parser.jitter_boxes, net_down_scale=params.parser.net_down_scale, min_process_size=params.parser.min_process_size, max_process_size=params.parser.max_process_size, max_num_instances=params.parser.max_num_instances, random_flip=params.parser.random_flip, pct_rand=params.parser.pct_rand, seed=params.parser.seed, anchors=self._task.task_config.model.boxes) if is_training: post_process_fn = parser.postprocess_fn() else: post_process_fn = None reader = input_reader.InputReader( params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder.decode, parser_fn=parser.parse_fn(is_training), postprocess_fn=post_process_fn) dataset = reader.read(input_context=input_context) return dataset
def build_inputs(self, params, input_context=None): """Builds classification input.""" num_classes = self.task_config.model.num_classes input_size = self.task_config.model.input_size if params.tfds_name: if params.tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP: decoder = tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP[ params.tfds_name]() else: raise ValueError('TFDS {} is not supported'.format( params.tfds_name)) else: decoder = classification_input.Decoder() parser = classification_input.Parser(output_size=input_size[:2], num_classes=num_classes, aug_policy=params.aug_policy, dtype=params.dtype) reader = input_reader.InputReader( params, dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), decoder_fn=decoder.decode, parser_fn=parser.parse_fn(params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def build_inputs( self, params: exp_cfg.DataConfig, input_context: Optional[tf.distribute.InputContext] = None): """Builds BASNet input.""" ignore_label = self.task_config.losses.ignore_label decoder = segmentation_input.Decoder() parser = segmentation_input.Parser( output_size=params.output_size, crop_size=params.crop_size, ignore_label=ignore_label, aug_rand_hflip=params.aug_rand_hflip, dtype=params.dtype) reader = input_reader.InputReader( params, dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), decoder_fn=decoder.decode, parser_fn=parser.parse_fn(params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def testSegmentationInputReader(self, input_size, num_classes, num_channels): params = cfg.DataConfig(input_path=self._data_path, global_batch_size=2, is_training=False) decoder = segmentation_input_3d.Decoder() parser = segmentation_input_3d.Parser(input_size=input_size, num_classes=num_classes, num_channels=num_channels) reader = input_reader.InputReader( params, dataset_fn=dataset_fn.pick_dataset_fn('tfrecord'), decoder_fn=decoder.decode, parser_fn=parser.parse_fn(params.is_training)) dataset = reader.read() iterator = iter(dataset) image, labels = next(iterator) # Checks image shape. self.assertEqual( list(image.numpy().shape), [2, input_size[0], input_size[1], input_size[2], num_channels]) self.assertEqual( list(labels.numpy().shape), [2, input_size[0], input_size[1], input_size[2], num_classes])
def load(self, input_context: Optional[tf.distribute.InputContext] = None): """Returns a tf.dataset.Dataset.""" reader = input_reader.InputReader( params=self._params, decoder_fn=None, transform_and_batch_fn=self._transform_and_batch_fn) return reader.read(input_context)
def build_inputs(self, params, input_context=None): """Builds classification input.""" num_classes = self.task_config.model.num_classes input_size = self.task_config.model.input_size if params.tfds_name is not None: decoder = cli.Decoder() else: decoder = classification_input.Decoder() parser = classification_input.Parser( output_size=input_size[:2], num_classes=num_classes, aug_rand_saturation=params.parser.aug_rand or params.parser.aug_rand_saturation, aug_rand_brightness=params.parser.aug_rand or params.parser.aug_rand_brightness, aug_rand_zoom=params.parser.aug_rand or params.parser.aug_rand_zoom, aug_rand_rotate=params.parser.aug_rand or params.parser.aug_rand_rotate, aug_rand_hue=params.parser.aug_rand or params.parser.aug_rand_hue, aug_rand_aspect=params.parser.aug_rand or params.parser.aug_rand_aspect, scale=params.parser.scale, seed=params.parser.seed, dtype=params.dtype) reader = input_reader.InputReader( params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder.decode, parser_fn=parser.parse_fn(params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def build_inputs(self, params, input_context=None): input_size = self.task_config.model.input_size if params.tfds_name: decoder = simclr_input.TFDSDecoder(params.decoder.decode_label) else: decoder = simclr_input.Decoder(params.decoder.decode_label) parser = simclr_input.Parser( output_size=input_size[:2], aug_rand_crop=params.parser.aug_rand_crop, aug_rand_hflip=params.parser.aug_rand_hflip, aug_color_distort=params.parser.aug_color_distort, aug_color_jitter_strength=params.parser.aug_color_jitter_strength, aug_color_jitter_impl=params.parser.aug_color_jitter_impl, aug_rand_blur=params.parser.aug_rand_blur, parse_label=params.parser.parse_label, test_crop=params.parser.test_crop, mode=params.parser.mode, dtype=params.dtype) reader = input_reader.InputReader(params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder.decode, parser_fn=parser.parse_fn( params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def test_parser(self, output_size, dtype, is_training): params = cfg.DataConfig(input_path='imagenet-2012-tfrecord/train*', global_batch_size=2, is_training=True, examples_consume=4) decoder = classification_input.Decoder() parser = classification_input.Parser(output_size=output_size[:2], num_classes=1001, aug_rand_hflip=False, dtype=dtype) reader = input_reader.InputReader(params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder.decode, parser_fn=parser.parse_fn( params.is_training)) dataset = reader.read() images, labels = next(iter(dataset)) self.assertAllEqual(images.numpy().shape, [params.global_batch_size] + output_size) self.assertAllEqual(labels.numpy().shape, [params.global_batch_size]) if dtype == 'float32': self.assertAllEqual(images.dtype, tf.float32) elif dtype == 'float16': self.assertAllEqual(images.dtype, tf.float16) elif dtype == 'bfloat16': self.assertAllEqual(images.dtype, tf.bfloat16)
def build_inputs(self, params, input_context=None): """Builds classification input.""" ignore_label = self.task_config.losses.ignore_label if params.tfds_name: if params.tfds_name in tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP: decoder = tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP[ params.tfds_name]() else: raise ValueError('TFDS {} is not supported'.format( params.tfds_name)) else: decoder = segmentation_input.Decoder() parser = segmentation_input.Parser( output_size=params.output_size, train_on_crops=params.train_on_crops, ignore_label=ignore_label, resize_eval_groundtruth=params.resize_eval_groundtruth, groundtruth_padded_size=params.groundtruth_padded_size, aug_scale_min=params.aug_scale_min, aug_scale_max=params.aug_scale_max, aug_rand_hflip=params.aug_rand_hflip, dtype=params.dtype) reader = input_reader.InputReader( params, dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), decoder_fn=decoder.decode, parser_fn=parser.parse_fn(params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def build_inputs(self, params, input_context=None): """Builds classification input.""" input_size = self.task_config.model.input_size ignore_label = self.task_config.losses.ignore_label decoder = segmentation_input.Decoder() parser = segmentation_input.Parser( output_size=input_size[:2], ignore_label=ignore_label, resize_eval_groundtruth=params.resize_eval_groundtruth, groundtruth_padded_size=params.groundtruth_padded_size, aug_scale_min=params.aug_scale_min, aug_scale_max=params.aug_scale_max, dtype=params.dtype) reader = input_reader.InputReader(params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder.decode, parser_fn=parser.parse_fn( params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def build_inputs(self, params, input_context=None): """Builds classification input.""" ignore_label = self.task_config.losses.ignore_label decoder = segmentation_input.Decoder() parser = segmentation_input.Parser( output_size=params.output_size, train_on_crops=params.train_on_crops, ignore_label=ignore_label, resize_eval_groundtruth=params.resize_eval_groundtruth, groundtruth_padded_size=params.groundtruth_padded_size, aug_scale_min=params.aug_scale_min, aug_scale_max=params.aug_scale_max, aug_rand_hflip=params.aug_rand_hflip, dtype=params.dtype) reader = input_reader.InputReader( params, dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), decoder_fn=decoder.decode, parser_fn=parser.parse_fn(params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def build_inputs(self, params: yt8m_cfg.DataConfig, input_context=None): """Builds input. Args: params: configuration for input data input_context: indicates information about the compute replicas and input pipelines Returns: dataset: dataset fetched from reader """ decoder = yt8m_input.Decoder(input_params=params) decoder_fn = decoder.decode parser = yt8m_input.Parser(input_params=params) parser_fn = parser.parse_fn(params.is_training) postprocess = yt8m_input.PostBatchProcessor(input_params=params) postprocess_fn = postprocess.post_fn transform_batch = yt8m_input.TransformBatcher(input_params=params) batch_fn = transform_batch.batch_fn reader = input_reader.InputReader(params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder_fn, parser_fn=parser_fn, postprocess_fn=postprocess_fn, transform_and_batch_fn=batch_fn) dataset = reader.read(input_context=input_context) return dataset
def test_yolo_input(): with tf.device('/CPU:0'): params = DataConfig(is_training=True) num_boxes = 9 decoder = tfds_coco_decoder.MSCOCODecoder() #anchors = box_rd.read(k = num_boxes, image_width = params.parser.image_w, input_context=None) anchors = [[12.0, 19.0], [31.0, 46.0], [96.0, 54.0], [46.0, 114.0], [133.0, 127.0], [79.0, 225.0], [301.0, 150.0], [172.0, 286.0], [348.0, 340.0]] # write the boxes to a file parser = YOLO_Detection_Input.Parser( image_w=params.parser.image_w, fixed_size=params.parser.fixed_size, jitter_im=params.parser.jitter_im, jitter_boxes=params.parser.jitter_boxes, min_level=params.parser.min_level, max_level=params.parser.max_level, min_process_size=params.parser.min_process_size, max_process_size=params.parser.max_process_size, max_num_instances=params.parser.max_num_instances, random_flip=params.parser.random_flip, pct_rand=params.parser.pct_rand, seed=params.parser.seed, anchors=anchors) reader = input_reader.InputReader(params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder.decode, parser_fn=parser.parse_fn( params.is_training)) dataset = reader.read(input_context=None) return dataset
def load(self, input_context: Optional[tf.distribute.InputContext] = None): """Returns a tf.dataset.Dataset.""" reader = input_reader.InputReader( params=self._params, decoder_fn=self._decode, transform_and_batch_fn=self._bucketize_and_batch if self._params.is_training else self._inference_padded_batch) return reader.read(input_context)
def load(self, input_context: Optional[tf.distribute.InputContext] = None): """Returns a tf.dataset.Dataset.""" if input_context: self._num_replicas_in_sync = input_context.num_replicas_in_sync reader = input_reader.InputReader(params=self._params, decoder_fn=self._decode, parser_fn=self._parse) return reader.read(input_context)
def load( self, input_context: Optional[tf.distribute.InputContext] = None ) -> tf.data.Dataset: """Returns a tf.dataset.Dataset.""" reader = input_reader.InputReader( params=self._params, decoder_fn=self._decode, parser_fn=self._parse) return reader.read(input_context)
def load(self, input_context: Optional[tf.distribute.InputContext] = None): """Returns a tf.dataset.Dataset.""" reader = input_reader.InputReader( dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type), decoder_fn=self._decode if self._params.input_path else None, params=self._params, postprocess_fn=self._bert_preprocess) return reader.read(input_context)
def load(self, input_context: Optional[tf.distribute.InputContext] = None): """Returns a tf.dataset.Dataset.""" reader = input_reader.InputReader( dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type), params=self._params, decoder_fn=self._decode, parser_fn=self._parse) return reader.read(input_context)
def load(self, input_context: Optional[tf.distribute.InputContext] = None): """Returns a tf.dataset.Dataset.""" reader = input_reader.InputReader( params=self._params, # Skip `decoder_fn` for tfds input. decoder_fn=self._decode if self._params.input_path else None, dataset_fn=tf.data.TFRecordDataset, postprocess_fn=self._bert_preprocess) return reader.read(input_context)
def build_inputs(self, params, input_context=None): """Build input dataset.""" decoder = tfds_coco_decoder.MSCOCODecoder() """ decoder_cfg = params.decoder.get() if params.decoder.type == 'simple_decoder': decoder = tf_example_decoder.TfExampleDecoder( regenerate_source_id=decoder_cfg.regenerate_source_id) elif params.decoder.type == 'label_map_decoder': decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap( label_map=decoder_cfg.label_map, regenerate_source_id=decoder_cfg.regenerate_source_id) else: raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type)) """ model = self.task_config.model masks, path_scales, xy_scales = self._get_masks() anchors = self._get_boxes(gen_boxes=params.is_training) print(masks, path_scales, xy_scales) parser = yolo_input.Parser( image_w=params.parser.image_w, image_h=params.parser.image_h, num_classes=model.num_classes, min_level=model.min_level, max_level=model.max_level, fixed_size=params.parser.fixed_size, jitter_im=params.parser.jitter_im, jitter_boxes=params.parser.jitter_boxes, masks=masks, letter_box=params.parser.letter_box, cutmix=params.parser.cutmix, use_tie_breaker=params.parser.use_tie_breaker, min_process_size=params.parser.min_process_size, max_process_size=params.parser.max_process_size, max_num_instances=params.parser.max_num_instances, random_flip=params.parser.random_flip, pct_rand=params.parser.pct_rand, seed=params.parser.seed, aug_rand_saturation=params.parser.aug_rand_saturation, aug_rand_brightness=params.parser.aug_rand_brightness, aug_rand_zoom=params.parser.aug_rand_zoom, aug_rand_hue=params.parser.aug_rand_hue, anchors=anchors, dtype=params.dtype) reader = input_reader.InputReader( params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder.decode, parser_fn=parser.parse_fn(params.is_training), postprocess_fn=parser.postprocess_fn(params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def load( self, input_context = None ): """Returns a tf.dataset.Dataset.""" if self._params.input_path == "test": return test_dataset(self._params.seq_length) reader = input_reader.InputReader( params=self._params, decoder_fn=self._decode, parser_fn=self._parse) return reader.read(input_context)
def build_inputs(self, params, input_context=None): """Build input dataset.""" decoder_cfg = params.decoder.get() if params.decoder.type == 'simple_decoder': decoder = tf_example_decoder.TfExampleDecoder( regenerate_source_id=decoder_cfg.regenerate_source_id) elif params.decoder.type == 'label_map_decoder': decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap( label_map=decoder_cfg.label_map, regenerate_source_id=decoder_cfg.regenerate_source_id) else: raise ValueError('Unknown decoder type: {}!'.format( params.decoder.type)) decoder_cfg = params.decoder.get() if params.decoder.type == 'simple_decoder': decoder = tf_example_decoder.TfExampleDecoder( regenerate_source_id=decoder_cfg.regenerate_source_id) elif params.decoder.type == 'label_map_decoder': decoder = tf_example_decoder.TfExampleDecoderLabelMap( label_map=decoder_cfg.label_map, regenerate_source_id=decoder_cfg.regenerate_source_id) else: raise ValueError('Unknown decoder type: {}!'.format( params.decoder.type)) parser = retinanet_input.Parser( output_size=self.task_config.model.input_size[:2], min_level=self.task_config.model.min_level, max_level=self.task_config.model.max_level, num_scales=self.task_config.model.anchor.num_scales, aspect_ratios=self.task_config.model.anchor.aspect_ratios, anchor_size=self.task_config.model.anchor.anchor_size, dtype=params.dtype, match_threshold=params.parser.match_threshold, unmatched_threshold=params.parser.unmatched_threshold, aug_rand_hflip=params.parser.aug_rand_hflip, aug_scale_min=params.parser.aug_scale_min, aug_scale_max=params.parser.aug_scale_max, skip_crowd_during_training=params.parser. skip_crowd_during_training, max_num_instances=params.parser.max_num_instances) reader = input_reader.InputReader(params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder.decode, parser_fn=parser.parse_fn( params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def create_input_reader(self, params): decoder = yt8m_input.Decoder(input_params=params) decoder_fn = decoder.decode parser = yt8m_input.Parser(input_params=params) parser_fn = parser.parse_fn(params.is_training) postprocess = yt8m_input.PostBatchProcessor(input_params=params) postprocess_fn = postprocess.post_fn transform_batch = yt8m_input.TransformBatcher(input_params=params) batch_fn = transform_batch.batch_fn return input_reader.InputReader(params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder_fn, parser_fn=parser_fn, postprocess_fn=postprocess_fn, transform_and_batch_fn=batch_fn)
def build_inputs(self, params: exp_cfg.DataConfig, input_context=None): """Builds classification input.""" parser = video_input.Parser(input_params=params) postprocess_fn = video_input.PostBatchProcessor(params) reader = input_reader.InputReader( params, dataset_fn=self._get_dataset_fn(params), decoder_fn=self._get_decoder_fn(params), parser_fn=parser.parse_fn(params.is_training), postprocess_fn=postprocess_fn) dataset = reader.read(input_context=input_context) return dataset
def build_inputs(self, params, input_context=None): """Build input dataset.""" decoder_cfg = params.decoder.get() if params.decoder.type == 'simple_decoder': decoder = tf_example_decoder.TfExampleDecoder( include_mask=self._task_config.model.include_mask, regenerate_source_id=decoder_cfg.regenerate_source_id, mask_binarize_threshold=decoder_cfg.mask_binarize_threshold) elif params.decoder.type == 'label_map_decoder': decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap( label_map=decoder_cfg.label_map, include_mask=self._task_config.model.include_mask, regenerate_source_id=decoder_cfg.regenerate_source_id, mask_binarize_threshold=decoder_cfg.mask_binarize_threshold) else: raise ValueError('Unknown decoder type: {}!'.format( params.decoder.type)) parser = maskrcnn_input.Parser( output_size=self.task_config.model.input_size[:2], min_level=self.task_config.model.min_level, max_level=self.task_config.model.max_level, num_scales=self.task_config.model.anchor.num_scales, aspect_ratios=self.task_config.model.anchor.aspect_ratios, anchor_size=self.task_config.model.anchor.anchor_size, dtype=params.dtype, rpn_match_threshold=params.parser.rpn_match_threshold, rpn_unmatched_threshold=params.parser.rpn_unmatched_threshold, rpn_batch_size_per_im=params.parser.rpn_batch_size_per_im, rpn_fg_fraction=params.parser.rpn_fg_fraction, aug_rand_hflip=params.parser.aug_rand_hflip, aug_scale_min=params.parser.aug_scale_min, aug_scale_max=params.parser.aug_scale_max, skip_crowd_during_training=params.parser. skip_crowd_during_training, max_num_instances=params.parser.max_num_instances, include_mask=self._task_config.model.include_mask, mask_crop_size=params.parser.mask_crop_size) reader = input_reader.InputReader( params, dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), decoder_fn=decoder.decode, parser_fn=parser.parse_fn(params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def build_inputs( self, params: exp_cfg.DataConfig, input_context: Optional[tf.distribute.InputContext] = None): """Build input dataset.""" if params.tfds_name: decoder = tfds_factory.get_detection_decoder(params.tfds_name) else: decoder_cfg = params.decoder.get() if params.decoder.type == 'simple_decoder': decoder = tf_example_decoder.TfExampleDecoder( regenerate_source_id=decoder_cfg.regenerate_source_id) elif params.decoder.type == 'label_map_decoder': decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap( label_map=decoder_cfg.label_map, regenerate_source_id=decoder_cfg.regenerate_source_id) else: raise ValueError('Unknown decoder type: {}!'.format( params.decoder.type)) parser = centernet_input.CenterNetParser( output_height=self.task_config.model.input_size[0], output_width=self.task_config.model.input_size[1], max_num_instances=self.task_config.model.max_num_instances, bgr_ordering=params.parser.bgr_ordering, channel_means=params.parser.channel_means, channel_stds=params.parser.channel_stds, aug_rand_hflip=params.parser.aug_rand_hflip, aug_scale_min=params.parser.aug_scale_min, aug_scale_max=params.parser.aug_scale_max, aug_rand_hue=params.parser.aug_rand_hue, aug_rand_brightness=params.parser.aug_rand_brightness, aug_rand_contrast=params.parser.aug_rand_contrast, aug_rand_saturation=params.parser.aug_rand_saturation, odapi_augmentation=params.parser.odapi_augmentation, dtype=params.dtype) reader = input_reader.InputReader(params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder.decode, parser_fn=parser.parse_fn( params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def input_reader_generator(params: cfg.DataConfig, **kwargs) -> core_input_reader.InputReader: """Instantiates an input reader class according to the params. Args: params: A config_definitions.DataConfig object. **kwargs: Additional arguments passed to input reader initialization. Returns: An InputReader object. """ if params.is_training and params.get('pseudo_label_data', False): return vision_input_reader.CombinationDatasetInputReader( params, pseudo_label_dataset_fn=dataset_fn_util.pick_dataset_fn( params.pseudo_label_data.file_type), **kwargs) else: return core_input_reader.InputReader(params, **kwargs)
def build_inputs(self, params, input_context=None): """Builds classification input.""" num_classes = self.task_config.model.num_classes input_size = self.task_config.model.input_size decoder = classification_input.Decoder() parser = classification_input.Parser(output_size=input_size[:2], num_classes=num_classes, dtype=params.dtype) reader = input_reader.InputReader(params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder.decode, parser_fn=parser.parse_fn( params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def build_inputs(self, params, input_context=None): input_size = self.task_config.model.input_size if params.tfds_name: decoder = simclr_input.TFDSDecoder(params.decoder.decode_label) else: decoder = simclr_input.Decoder(params.decoder.decode_label) parser = simclr_input.Parser(output_size=input_size[:2], parse_label=params.parser.parse_label, test_crop=params.parser.test_crop, mode=params.parser.mode, dtype=params.dtype) reader = input_reader.InputReader(params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder.decode, parser_fn=parser.parse_fn( params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def load(self, input_context: Optional[tf.distribute.InputContext] = None): """Returns a tf.dataset.Dataset.""" decoder_fn = None # Only decode for TFRecords. if self._params.input_path: decoder_fn = self._decode def _identity( dataset, input_context: Optional[tf.distribute.InputContext] = None): del input_context return dataset transform_and_batch_fn = _identity if self._params.transform_and_batch: transform_and_batch_fn = self._tokenize_bucketize_and_batch reader = input_reader.InputReader( params=self._params, decoder_fn=decoder_fn, transform_and_batch_fn=transform_and_batch_fn) return reader.read(input_context)