def test_get_data_format_and_filenames(self, fake_filename_prefix, ref_data_format): file_pattern_sstables, ref_file_paths = self._create_fake_files( fake_filename_prefix=fake_filename_prefix, fake_file_type=ref_data_format) data_format, filenames = tfdata.get_data_format_and_filenames( file_pattern_sstables) self.assertAllEqual(sorted(filenames), sorted(ref_file_paths)) self.assertEqual(data_format, ref_data_format) # Test the ability to strip the filetype prefix. data_format, filenames = tfdata.get_data_format_and_filenames( '{}:{}'.format(ref_data_format, file_pattern_sstables)) self.assertAllEqual(sorted(filenames), sorted(ref_file_paths)) self.assertEqual(data_format, ref_data_format) # Test the ability to use a comma-separated string data_format, filenames = tfdata.get_data_format_and_filenames( ','.join(ref_file_paths)) self.assertAllEqual(sorted(filenames), sorted(ref_file_paths)) self.assertEqual(data_format, ref_data_format)
def __init__(self, file_fraction=1.0, **parent_kwargs): """Create an instance. Args: file_fraction: If file_fraction < 1.0, choose first file_fraction percent of files, (rounded down to the nearest integer number of files). **parent_kwargs: All parent arguments. """ super(FractionalRecordInputGenerator, self).__init__(**parent_kwargs) if file_fraction < 1.0: data_format, filenames = tfdata.get_data_format_and_filenames( self._file_patterns) n = int(file_fraction * len(filenames)) filenames = filenames[:n] self._file_patterns = '{}:{}'.format(data_format, ','.join(filenames))
def parallel_read(file_patterns, parse_fn, shuffle_filenames=True, num_train_samples_per_task=4, num_val_samples_per_task=4, shuffle_buffer_size=50, filter_fn=None, interleave_cycle_length=None, mode=tf.estimator.ModeKeys.TRAIN): """Read and parse multiple examples per task, per train/val split. This pipeline does the following: 1. Shuffle & repeats filenames. 2. Open up and shuffle each file. 3. Outputs num_train + num_val examples from each shard 4. De-serialize the protos. Args: file_patterns: Comma-separated string of file patterns, where each individual file contains data for one task. parse_fn: Python function that takes as an argument a Dataset whose output is a string tensor and returns a Dataset whose output is a collection of parsed TFEXamples. shuffle_filenames: If True, filenames are shuffled so tasks are sampled at random. num_train_samples_per_task: How many examples to dequeue for training. num_val_samples_per_task: How many examples to dequeue for validation. shuffle_buffer_size: How many examples to shuffle within each file. filter_fn: A callable function which will get a valid tensorspec filled with tensors as input and returns a tf.bool as output. If true the dataset will keep the example or drop it otherwise. interleave_cycle_length: Integer, cycle length when interleaving examples from different task files. If 0 or None, cycle length will default to num_tasks. mode: (ModeKeys) Specifies if this is training, evaluation or prediction. Returns: Dataset whose output is that of `parse_fn` (e.g. features, labels). Raises: ValueError: File patterns do not exist. """ data_format, filenames = tfdata.get_data_format_and_filenames( file_patterns) # We shuffle filenames. Dequeues filenames. dataset = tf.data.Dataset.from_tensor_slices(filenames) num_tasks = len(filenames) # Shuffle returns a new permutation *per epoch*. Upon epoch completion, # shuffling is repeated. if shuffle_filenames: dataset = dataset.shuffle(buffer_size=num_tasks).repeat() else: dataset = dataset.repeat() # From `task_batch_size` tasks at a time, dequeue 2*N elements, where N is # num_samples_per_task. These form N training and N validation instances for # meta-learning. def filename_map_fn(x): """Builds a dataset for a file path. Args: x: path name of data Returns: A dataset of parsed tensors, filtered by filter_fn. """ dataset_ = tfdata.DATA_FORMAT[data_format](x) # It is important to have at least num_train_samples_per_task + # num_val_samples_per_task in the dataset in order to apply the batching # thereafter. effective_shuffle_buffer_size = max( shuffle_buffer_size, num_train_samples_per_task + num_val_samples_per_task) if mode == tf.estimator.ModeKeys.TRAIN: dataset_ = dataset_.shuffle( buffer_size=effective_shuffle_buffer_size).repeat() else: dataset_ = dataset_.repeat() dataset_ = dataset_.batch(batch_size=num_train_samples_per_task + num_val_samples_per_task, drop_remainder=True) dataset = parse_fn(dataset_) if filter_fn is not None: dataset = dataset.unbatch().filter(filter_fn).batch( batch_size=num_train_samples_per_task + num_val_samples_per_task, drop_remainder=True) return dataset # Sample 1 example-set from each task (file). Dequeues one task's worth # (a batch) of SARSTransitions. [num_samples_per_task * 2, ...] if not interleave_cycle_length: interleave_cycle_length = num_tasks dataset = dataset.interleave(filename_map_fn, cycle_length=interleave_cycle_length, block_length=1) return dataset
def test_get_data_format_and_filenames_raises(self): # This should raise since the data format cannot be inferred from the # file pattern. with self.assertRaises(ValueError): tfdata.get_data_format_and_filenames('a wrong file pattern')