repeat_count=1, file_buffer_size=32, file_parallelism=1, bucket_upper_bound=[10], bucket_batch_limit=[2]) return inputs def _get_test_dataset(num: int) -> tf.data.Dataset: def to_map(i: int): return {'data': i} return tf.data.Dataset.range(num).map(to_map) TestDataset = base_input_generator.DefineTFDataInput('TestDataset', _get_test_dataset) class TestDatasetOverride(TestDataset): def GetPreprocessedInputBatch(self) -> py_utils.NestedMap: batch = super().GetPreprocessedInputBatch() assert isinstance(batch, py_utils.NestedMap) batch.data2 = batch.data * 2 + 1 return batch class InputTest(test_util.JaxTestCase): def test_lingvo_input(self): tmp = os.path.join(FLAGS.test_tmpdir, 'tmptest') batch_size = 2 num_batches = 10
class _TestDatasetClass: """A class that generates tf.data by its member function.""" def __init__(self, begin): self._begin = begin def DatasetFn(self, end=10): ds = tf.data.Dataset.from_tensor_slices(tf.range(self._begin, end)) return ds.map(lambda x: {'value': x}) # A class object which will be instantiated at importing the module. # It can be used in DefineTFDataInput(). _TestDatasetObject = _TestDatasetClass(begin=0) # InputGenerators for TFDataInputTest. _TestTFDataInput = base_input_generator.DefineTFDataInput( '_TestTFDataInput', _TestDatasetFn) _TestTFDataInputWithIgnoreArgs = base_input_generator.DefineTFDataInput( '_TestTFDataInputWithIgnoreArgs', _TestDatasetFn, ignore_args=('begin', )) _TestTFDataInputWithMapArgs = base_input_generator.DefineTFDataInput( '_TestTFDataInputWithMapArgs', _TestDatasetFn, map_args={'end': 'num_samples'}) _TestTFDataInputWithoutDefault = base_input_generator.DefineTFDataInput( '_TestTFDataInputWithoutDefault', _TestDatasetFnWithoutDefault) _TestTFDataInputWithRepeat = base_input_generator.DefineTFDataInput( '_TestTFDataInputWithRepeat', _TestDatasetFnWithRepeat) _TestTFDataInputWithBoundMethod = base_input_generator.DefineTFDataInput( '_TestTFDataInputWithBoundMethod', _TestDatasetObject.DatasetFn) _TestTFDataInputV1 = base_input_generator.DefineTFDataInput( '_TestTFDataInputV1', _TestDatasetFnV1) _TestTFDataInputV2 = base_input_generator.DefineTFDataInput(