def input_fn(is_training, data_dir, batch_size, num_epochs=1, dtype=tf.float32, datasets_num_private_threads=None, parse_record_fn=parse_record, input_context=None, drop_remainder=False): """Input function which provides batches for train or eval. Args: is_training: A boolean denoting whether the input is for training. data_dir: The directory containing the input data. batch_size: The number of samples per batch. num_epochs: The number of epochs to repeat the dataset. dtype: Data type to use for images/features datasets_num_private_threads: Number of private threads for tf.data. parse_record_fn: Function to use for parsing the records. input_context: A `tf.distribute.InputContext` object passed in by `tf.distribute.Strategy`. drop_remainder: A boolean indicates whether to drop the remainder of the batches. If True, the batch dimension will be static. Returns: A dataset that can be used for iteration. """ filenames = get_filenames(is_training, data_dir) dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES) if input_context: tf.compat.v1.logging.info( 'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d' % (input_context.input_pipeline_id, input_context.num_input_pipelines)) dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) return resnet_run_loop.process_record_dataset( dataset=dataset, is_training=is_training, batch_size=batch_size, shuffle_buffer=NUM_IMAGES['train'], parse_record_fn=parse_record_fn, num_epochs=num_epochs, dtype=dtype, datasets_num_private_threads=datasets_num_private_threads, drop_remainder=drop_remainder)
def input_fn(is_training, data_dir, batch_size, num_epochs=1, dtype=tf.float32, datasets_num_private_threads=None, parse_record_fn=parse_record, input_context=None, drop_remainder=False, tf_data_experimental_slack=False): """Input function which provides batches for train or eval. Args: is_training: A boolean denoting whether the input is for training. data_dir: The directory containing the input data. batch_size: The number of samples per batch. num_epochs: The number of epochs to repeat the dataset. dtype: Data type to use for images/features datasets_num_private_threads: Number of private threads for tf.data. parse_record_fn: Function to use for parsing the records. input_context: A `tf.distribute.InputContext` object passed in by `tf.distribute.Strategy`. drop_remainder: A boolean indicates whether to drop the remainder of the batches. If True, the batch dimension will be static. tf_data_experimental_slack: Whether to enable tf.data's `experimental_slack` option. Returns: A dataset that can be used for iteration. """ filenames = get_filenames(is_training, data_dir) dataset = tf.data.Dataset.from_tensor_slices(filenames) if input_context: tf.compat.v1.logging.info( 'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d' % (input_context.input_pipeline_id, input_context.num_input_pipelines)) dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) if is_training: # Shuffle the input files dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES) # Convert to individual records. # cycle_length = 10 means that up to 10 files will be read and deserialized in # parallel. You may want to increase this number if you have a large number of # CPU cores. dataset = dataset.interleave( tf.data.TFRecordDataset, cycle_length=10, num_parallel_calls=tf.data.experimental.AUTOTUNE) return resnet_run_loop.process_record_dataset( dataset=dataset, is_training=is_training, batch_size=batch_size, shuffle_buffer=_SHUFFLE_BUFFER, parse_record_fn=parse_record_fn, num_epochs=num_epochs, dtype=dtype, datasets_num_private_threads=datasets_num_private_threads, drop_remainder=drop_remainder, tf_data_experimental_slack=tf_data_experimental_slack, )