示例#1
0
  def test_get_data_format_and_filenames(self, fake_filename_prefix,
                                         ref_data_format):
    file_pattern_sstables, ref_file_paths = self._create_fake_files(
        fake_filename_prefix=fake_filename_prefix,
        fake_file_type=ref_data_format)
    data_format, filenames = tfdata.get_data_format_and_filenames(
        file_pattern_sstables)
    self.assertAllEqual(sorted(filenames), sorted(ref_file_paths))
    self.assertEqual(data_format, ref_data_format)

    # Test the ability to strip the filetype prefix.
    data_format, filenames = tfdata.get_data_format_and_filenames(
        '{}:{}'.format(ref_data_format, file_pattern_sstables))
    self.assertAllEqual(sorted(filenames), sorted(ref_file_paths))
    self.assertEqual(data_format, ref_data_format)

    # Test the ability to use a comma-separated string
    data_format, filenames = tfdata.get_data_format_and_filenames(
        ','.join(ref_file_paths))
    self.assertAllEqual(sorted(filenames), sorted(ref_file_paths))
    self.assertEqual(data_format, ref_data_format)
示例#2
0
    def __init__(self, file_fraction=1.0, **parent_kwargs):
        """Create an instance.

    Args:
      file_fraction: If file_fraction < 1.0, choose first file_fraction percent
        of files, (rounded down to the nearest integer number of files).
      **parent_kwargs: All parent arguments.
    """
        super(FractionalRecordInputGenerator, self).__init__(**parent_kwargs)
        if file_fraction < 1.0:
            data_format, filenames = tfdata.get_data_format_and_filenames(
                self._file_patterns)
            n = int(file_fraction * len(filenames))
            filenames = filenames[:n]
            self._file_patterns = '{}:{}'.format(data_format,
                                                 ','.join(filenames))
示例#3
0
def parallel_read(file_patterns,
                  parse_fn,
                  shuffle_filenames=True,
                  num_train_samples_per_task=4,
                  num_val_samples_per_task=4,
                  shuffle_buffer_size=50,
                  filter_fn=None,
                  interleave_cycle_length=None,
                  mode=tf.estimator.ModeKeys.TRAIN):
    """Read and parse multiple examples per task, per train/val split.

  This pipeline does the following:
    1. Shuffle & repeats filenames.
    2. Open up and shuffle each file.
    3. Outputs num_train + num_val examples from each shard
    4. De-serialize the protos.

  Args:
    file_patterns: Comma-separated string of file patterns, where each
      individual file contains data for one task.
    parse_fn: Python function that takes as an argument a Dataset
      whose output is a string tensor and returns a Dataset
      whose output is a collection of parsed TFEXamples.
    shuffle_filenames: If True, filenames are shuffled so tasks are sampled at
      random.
    num_train_samples_per_task: How many examples to dequeue for training.
    num_val_samples_per_task: How many examples to dequeue for validation.
    shuffle_buffer_size: How many examples to shuffle within each file.
    filter_fn: A callable function which will get a valid tensorspec filled with
      tensors as input and returns a tf.bool as output. If true the dataset
      will keep the example or drop it otherwise.
    interleave_cycle_length: Integer, cycle length when interleaving examples
      from different task files. If 0 or None, cycle length will default
      to num_tasks.
    mode: (ModeKeys) Specifies if this is training, evaluation or prediction.

  Returns:
    Dataset whose output is that of `parse_fn` (e.g. features, labels).

  Raises:
    ValueError: File patterns do not exist.
  """
    data_format, filenames = tfdata.get_data_format_and_filenames(
        file_patterns)

    # We shuffle filenames. Dequeues filenames.
    dataset = tf.data.Dataset.from_tensor_slices(filenames)
    num_tasks = len(filenames)
    # Shuffle returns a new permutation *per epoch*. Upon epoch completion,
    # shuffling is repeated.
    if shuffle_filenames:
        dataset = dataset.shuffle(buffer_size=num_tasks).repeat()
    else:
        dataset = dataset.repeat()
    # From `task_batch_size` tasks at a time, dequeue 2*N elements, where N is
    # num_samples_per_task. These form N training and N validation instances for
    # meta-learning.
    def filename_map_fn(x):
        """Builds a dataset for a file path.

    Args:
      x: path name of data
    Returns:
      A dataset of parsed tensors, filtered by filter_fn.
    """
        dataset_ = tfdata.DATA_FORMAT[data_format](x)
        # It is important to have at least num_train_samples_per_task +
        # num_val_samples_per_task in the dataset in order to apply the batching
        # thereafter.
        effective_shuffle_buffer_size = max(
            shuffle_buffer_size,
            num_train_samples_per_task + num_val_samples_per_task)
        if mode == tf.estimator.ModeKeys.TRAIN:
            dataset_ = dataset_.shuffle(
                buffer_size=effective_shuffle_buffer_size).repeat()
        else:
            dataset_ = dataset_.repeat()
        dataset_ = dataset_.batch(batch_size=num_train_samples_per_task +
                                  num_val_samples_per_task,
                                  drop_remainder=True)
        dataset = parse_fn(dataset_)
        if filter_fn is not None:
            dataset = dataset.unbatch().filter(filter_fn).batch(
                batch_size=num_train_samples_per_task +
                num_val_samples_per_task,
                drop_remainder=True)
        return dataset

    # Sample 1 example-set from each task (file). Dequeues one task's worth
    # (a batch) of SARSTransitions. [num_samples_per_task * 2, ...]
    if not interleave_cycle_length:
        interleave_cycle_length = num_tasks
    dataset = dataset.interleave(filename_map_fn,
                                 cycle_length=interleave_cycle_length,
                                 block_length=1)
    return dataset
示例#4
0
 def test_get_data_format_and_filenames_raises(self):
   # This should raise since the data format cannot be inferred from the
   # file pattern.
   with self.assertRaises(ValueError):
     tfdata.get_data_format_and_filenames('a wrong file pattern')