Exemplo n.º 1
0
def shuffle_and_repeat(buffer_size, count=None, seed=None):
  """Shuffles and repeats a Dataset returning a new permutation for each epoch.

  `dataset.apply(tf.data.experimental.shuffle_and_repeat(buffer_size, count))`

  is equivalent to

  `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)`

  The difference is that the latter dataset is not serializable. So,
  if you need to checkpoint an input pipeline with reshuffling you must use
  this implementation.

  Args:
    buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
      maximum number elements that will be buffered when prefetching.
    count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
      number of times the dataset should be repeated. The default behavior
      (if `count` is `None` or `-1`) is for the dataset be repeated
      indefinitely.
    seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
      random seed that will be used to create the distribution. See
      `tf.compat.v1.set_random_seed` for behavior.

  Returns:
    A `Dataset` transformation function, which can be passed to
    `tf.data.Dataset.apply`.
  """
  return shuffle_ops.shuffle_and_repeat(buffer_size, count, seed)
Exemplo n.º 2
0
def shuffle_and_repeat(buffer_size, count=None, seed=None):
    """Shuffles and repeats a Dataset returning a new permutation for each epoch.

  `dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, count))`

  is equivalent to

  `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)`

  The difference is that the latter dataset is not serializable. So,
  if you need to checkpoint an input pipeline with reshuffling you must use
  this implementation.

  Args:
    buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
      maximum number elements that will be buffered when prefetching.
    count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
      number of times the dataset should be repeated. The default behavior
      (if `count` is `None` or `-1`) is for the dataset be repeated
      indefinitely.
    seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
      random seed that will be used to create the distribution. See
      `tf.set_random_seed` for behavior.

  Returns:
    A `Dataset` transformation function, which can be passed to
    `tf.data.Dataset.apply`.
  """
    return shuffle_ops.shuffle_and_repeat(buffer_size, count, seed)
Exemplo n.º 3
0
 def testLargeBufferSize(self):
     with ops.Graph().as_default() as g:
         ds = dataset_ops.Dataset.range(20).apply(
             shuffle_ops.shuffle_and_repeat(buffer_size=21))
         get_next_op = ds.make_one_shot_iterator().get_next()
         with self.session(graph=g) as sess:
             self.evaluate(get_next_op)
 def testLargeBufferSize(self):
   with ops.Graph().as_default() as g:
     ds = dataset_ops.Dataset.range(20).apply(
         shuffle_ops.shuffle_and_repeat(buffer_size=21))
     get_next_op = ds.make_one_shot_iterator().get_next()
     with self.session(graph=g) as sess:
       self.evaluate(get_next_op)
Exemplo n.º 5
0
def _maybe_shuffle_and_repeat(
    dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed):
  """Optionally shuffle and repeat dataset, as requested."""
  if num_epochs != 1 and shuffle:
    # Use shuffle_and_repeat for perf
    return dataset.apply(
        shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs,
                                       shuffle_seed))
  elif shuffle:
    return dataset.shuffle(shuffle_buffer_size, shuffle_seed)
  elif num_epochs != 1:
    return dataset.repeat(num_epochs)
  return dataset
Exemplo n.º 6
0
def _maybe_shuffle_and_repeat(
    dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed):
  """Optionally shuffle and repeat dataset, as requested."""
  if num_epochs != 1 and shuffle:
    # Use shuffle_and_repeat for perf
    return dataset.apply(
        shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs,
                                       shuffle_seed))
  elif shuffle:
    return dataset.shuffle(shuffle_buffer_size, shuffle_seed)
  elif num_epochs != 1:
    return dataset.repeat(num_epochs)
  return dataset
Exemplo n.º 7
0
 def testVeryLargeBufferSize(self):
     num_epochs = 1000 * 1000
     # Each element being shuffled and repeated has shape (100,). This will OOM
     # or timeout if we actually load everything into the buffer.
     ds = dataset_ops.Dataset.range(500).batch(100).apply(
         shuffle_ops.shuffle_and_repeat(buffer_size=5 * num_epochs,
                                        count=num_epochs))
     # Verify two epochs worth of output.
     output = self._gen_outputs(lambda: ds, 2 * 5, verify_exhausted=False)
     for i in range(2):
         sorted_epoch = sorted(output[i * 5:(i + 1) * 5],
                               key=lambda batch: batch[0])
         self.assertAllEqual(sorted_epoch, np.arange(500).reshape([5, 100]))
Exemplo n.º 8
0
    def testRerandomizeOnReplicate(self):
        random_seed.set_random_seed(None)
        # When no seeds are fixed, each instantiation of the dataset should
        # produce elements in a different order.
        num_epochs = 2
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements).apply(
            shuffle_ops.shuffle_and_repeat(buffer_size=num_elements,
                                           count=num_epochs))

        shuffle_1 = self.getDatasetOutput(ds)
        ds = self.graphRoundTrip(ds)
        shuffle_2 = self.getDatasetOutput(ds)

        self.assertCountEqual(shuffle_1, shuffle_2)
        self.assertNotEqual(shuffle_1, shuffle_2)
Exemplo n.º 9
0
 def _build_ds(self, seed):
     return dataset_ops.Dataset.range(20).apply(
         shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed))
Exemplo n.º 10
0
 def _build_ds(self, seed, count=5, num_elements=20):
     return dataset_ops.Dataset.range(num_elements).apply(
         shuffle_ops.shuffle_and_repeat(buffer_size=5,
                                        count=count,
                                        seed=seed))
Exemplo n.º 11
0
 def testLargeBufferSize(self):
     ds = dataset_ops.Dataset.range(20).apply(
         shuffle_ops.shuffle_and_repeat(buffer_size=21))
     get_next = self.getNext(ds)
     self.evaluate(get_next())
 def _build_ds(self, seed, count=5, num_elements=20):
   return dataset_ops.Dataset.range(num_elements).apply(
       shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed))
Exemplo n.º 13
0
def make_batched_features_dataset_multi_task(  file_pattern,
                                    batch_size,
                                    features,
                                    reader=core_readers.TFRecordDataset,
                                    label_key=None,
                                    weight_key=None,
                                    reader_args=None,
                                    num_epochs=None,
                                    shuffle=True,
                                    shuffle_buffer_size=10000,
                                    shuffle_seed=None,
                                    prefetch_buffer_size=optimization.AUTOTUNE,
                                    reader_num_threads=32,
                                    parser_num_threads=32,
                                    sloppy_ordering=True,
                                    drop_final_batch=False):

    """Returns a `Dataset` of feature dictionaries from `Example` protos.
    Returns:
    A dataset of `dict` elements, (or a tuple of `dict` elements and label).
    Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects.
    """
    if shuffle_seed is None:
        shuffle_seed = int(time.time())

    filenames = list(gfile.Glob(file_pattern))
    dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
    if shuffle:
        dataset = dataset.shuffle(len(filenames), shuffle_seed)

    # Read `Example` records from files as tensor objects.
    if reader_args is None:
        reader_args = []

    # Read files sequentially (if reader_num_threads=1) or in parallel
    dataset = dataset.apply(
      interleave_ops.parallel_interleave(
          lambda filename: reader(filename, *reader_args),
          cycle_length=reader_num_threads,
          block_length=200,
          sloppy=sloppy_ordering))

    # Extract values if the `Example` tensors are stored as key-value tuples.
    if dataset_ops.get_legacy_output_types(dataset) == (
          dtypes.string, dtypes.string):
        dataset = dataset_ops.MapDataset(
          dataset, lambda _, v: v, use_inter_op_parallelism=True)

    # Apply dataset repeat and shuffle transformations.
    dataset = dataset.apply(
        shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs,
                                       shuffle_seed))

    dataset = dataset.batch(
      batch_size, drop_remainder=drop_final_batch or num_epochs is None)

    # Parse `Example` tensors to a dictionary of `Feature` tensors.
    dataset = dataset.apply(
      parsing_ops.parse_example_dataset(
          features, num_parallel_calls=parser_num_threads))

        
    if weight_key:
        #assert label_key
        #assert label_key != weight_key
        #assert label_key in features
        assert weight_key in features
        if label_key:
            if label_key not in features:
                raise ValueError(
                    "The 'label_key' provided (%r) must be one of the 'features' keys."% label_key)
        assert label_key != weight_key
        
        
        dataset = dataset.map(lambda x: (x, tuple([x.pop(label_key)]*5),x.pop(weight_key)))
        #w = dataset.map(lambda x,y : x.pop(weight_key))
        
    else:
        if label_key:
            if label_key not in features:
                raise ValueError(
                    "The `label_key` provided (%r) must be one of the `features` keys." % label_key)
        dataset = dataset.map(lambda x: (x, tuple([x.pop(label_key)]*5)))
    dataset = dataset.prefetch(prefetch_buffer_size)
    
    if not weight_key:
        return dataset
    else:
        return dataset