예제 #1
0
  def test_inputs_using_generic_text_dataset_preprocess_fn(self):

    gin.bind_parameter(
        'generic_text_dataset_preprocess_fn.spm_path', _spm_path())
    gin.bind_parameter(
        'generic_text_dataset_preprocess_fn.text_preprocess_fns',
        [lambda ds, training: t5_processors.squad(ds)])

    # Just make sure this doesn't throw.
    def data_streams():
      return tf_inputs.data_streams(
          'squad', data_dir=_TESTDATA, input_name='inputs',
          target_name='targets',
          bare_preprocess_fn=tf_inputs.generic_text_dataset_preprocess_fn,
          shuffle_buffer_size=1)

    n_devices = 3

    squad_inputs = inputs.batcher(
        data_streams=data_streams,
        max_eval_length=512,
        buckets=([513,], [n_devices, n_devices])
    )

    eval_stream = squad_inputs.eval_stream(n_devices)
    inps, tgts, _ = next(eval_stream)

    # We can only assert that the batch dim gets divided by n_devices.
    self.assertEqual(inps.shape[0] % n_devices, 0)
    self.assertEqual(tgts.shape[0] % n_devices, 0)
예제 #2
0
    def test_inputs_using_generic_text_dataset_preprocess_fn(self):

        gin.bind_parameter('generic_text_dataset_preprocess_fn.spm_path',
                           _spm_path())
        gin.bind_parameter('generic_text_dataset_preprocess_fn.copy_plaintext',
                           True)
        gin.bind_parameter(
            'generic_text_dataset_preprocess_fn.text_preprocess_fn',
            t5_processors.squad)

        # Just make sure this doesn't throw.
        def data_streams():
            return tf_inputs.data_streams('squad',
                                          data_dir=_TESTDATA,
                                          input_name='inputs',
                                          target_name='targets',
                                          bare_preprocess_fn=tf_inputs.
                                          generic_text_dataset_preprocess_fn)

        squad_inputs = inputs.batcher(
            data_streams=data_streams,
            batch_size_per_device=2,
            eval_batch_size=2,
            max_eval_length=50,
        )

        n_devices = 3
        train_stream = squad_inputs.train_stream(n_devices)
        inps, tgts = next(train_stream)

        # We can only assert that the batch dim gets divided by n_devices.
        self.assertEqual(inps.shape[0] % n_devices, 0)
        self.assertEqual(tgts.shape[0] % n_devices, 0)
예제 #3
0
파일: mnist_test.py 프로젝트: yangliuy/trax
def _mnist_dataset():
    """Loads (and caches) the standard MNIST data set."""
    streams = tf_inputs.data_streams('mnist')
    return inputs.batcher(streams,
                          variable_shapes=False,
                          batch_size_per_device=256,
                          eval_batch_size=256)
예제 #4
0
def dataset_from_trax_tfds(dataset_name='mnist',
                           variable_shapes=False,
                           tfds_dir=None,
                           batch_size_per_device=256,
                           eval_batch_size=256,
                           **kwargs):
    kwargs = {k: v for k, v in kwargs.items() if v is not None}
    if len(kwargs):
        logger.warn(
            'dataset_from_trax_tfds: ignoring arguments {}'.format(kwargs))
    streams = tf_inputs.data_streams(dataset_name, data_dir=tfds_dir)
    return inputs.batcher(streams,
                          variable_shapes=variable_shapes,
                          batch_size_per_device=batch_size_per_device,
                          eval_batch_size=eval_batch_size)