Пример #1
0
                                               repeat_count=1,
                                               file_buffer_size=32,
                                               file_parallelism=1,
                                               bucket_upper_bound=[10],
                                               bucket_batch_limit=[2])
        return inputs


def _get_test_dataset(num: int) -> tf.data.Dataset:
    def to_map(i: int):
        return {'data': i}

    return tf.data.Dataset.range(num).map(to_map)


TestDataset = base_input_generator.DefineTFDataInput('TestDataset',
                                                     _get_test_dataset)


class TestDatasetOverride(TestDataset):
    def GetPreprocessedInputBatch(self) -> py_utils.NestedMap:
        batch = super().GetPreprocessedInputBatch()
        assert isinstance(batch, py_utils.NestedMap)
        batch.data2 = batch.data * 2 + 1
        return batch


class InputTest(test_util.JaxTestCase):
    def test_lingvo_input(self):
        tmp = os.path.join(FLAGS.test_tmpdir, 'tmptest')
        batch_size = 2
        num_batches = 10
Пример #2
0
class _TestDatasetClass:
    """A class that generates tf.data by its member function."""
    def __init__(self, begin):
        self._begin = begin

    def DatasetFn(self, end=10):
        ds = tf.data.Dataset.from_tensor_slices(tf.range(self._begin, end))
        return ds.map(lambda x: {'value': x})


# A class object which will be instantiated at importing the module.
# It can be used in DefineTFDataInput().
_TestDatasetObject = _TestDatasetClass(begin=0)

# InputGenerators for TFDataInputTest.
_TestTFDataInput = base_input_generator.DefineTFDataInput(
    '_TestTFDataInput', _TestDatasetFn)
_TestTFDataInputWithIgnoreArgs = base_input_generator.DefineTFDataInput(
    '_TestTFDataInputWithIgnoreArgs', _TestDatasetFn, ignore_args=('begin', ))
_TestTFDataInputWithMapArgs = base_input_generator.DefineTFDataInput(
    '_TestTFDataInputWithMapArgs',
    _TestDatasetFn,
    map_args={'end': 'num_samples'})
_TestTFDataInputWithoutDefault = base_input_generator.DefineTFDataInput(
    '_TestTFDataInputWithoutDefault', _TestDatasetFnWithoutDefault)
_TestTFDataInputWithRepeat = base_input_generator.DefineTFDataInput(
    '_TestTFDataInputWithRepeat', _TestDatasetFnWithRepeat)
_TestTFDataInputWithBoundMethod = base_input_generator.DefineTFDataInput(
    '_TestTFDataInputWithBoundMethod', _TestDatasetObject.DatasetFn)
_TestTFDataInputV1 = base_input_generator.DefineTFDataInput(
    '_TestTFDataInputV1', _TestDatasetFnV1)
_TestTFDataInputV2 = base_input_generator.DefineTFDataInput(