Ejemplo n.º 1
0
def _mnist_dataset():
    """Loads (and caches) the standard MNIST data set."""
    streams = tf_inputs.data_streams('mnist')
    return inputs.batcher(streams,
                          variable_shapes=False,
                          batch_size_per_device=256,
                          eval_batch_size=256)
Ejemplo n.º 2
0
  def test_inputs_using_generic_text_dataset_preprocess_fn(self):
    gin.bind_parameter('generic_text_dataset_preprocess_fn.spm_path',
                       _spm_path())
    gin.bind_parameter('generic_text_dataset_preprocess_fn.text_preprocess_fns',
                       [lambda ds, training: t5_processors.squad(ds)])

    # Just make sure this doesn't throw.
    def data_streams():
      return tf_inputs.data_streams(
          'squad',
          data_dir=_TESTDATA,
          input_name='inputs',
          target_name='targets',
          bare_preprocess_fn=tf_inputs.generic_text_dataset_preprocess_fn,
          shuffle_buffer_size=1)

    n_devices = 3

    squad_inputs = inputs.batcher(
        data_streams=data_streams,
        max_eval_length=512,
        buckets=([
            513,
        ], [n_devices, n_devices]))

    eval_stream = squad_inputs.eval_stream(n_devices)
    inps, tgts, _ = next(eval_stream)

    # We can only assert that the batch dim gets divided by n_devices.
    self.assertEqual(inps.shape[0] % n_devices, 0)
    self.assertEqual(tgts.shape[0] % n_devices, 0)
Ejemplo n.º 3
0
def _mnist_brightness_dataset():
  """Loads (and caches) a MNIST mean brightness data set."""
  def preprocess_stream(stream):
    def new_stream():
      for (image, _) in stream():
        yield (image, (image / 255).mean()[None])
    return new_stream

  streams = tuple(map(preprocess_stream, tf_inputs.data_streams('mnist')))
  return inputs.batcher(streams, variable_shapes=False,
                        batch_size_per_device=256,
                        eval_batch_size=256)