class PaddedBatchCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): def build_dataset(seq_lens): return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( lambda x: array_ops.fill([x], x)).padded_batch( batch_size=4, padded_shapes=[-1]) seq_lens = np.random.randint(1, 20, size=(32, )).astype(np.int32) verify_fn(self, lambda: build_dataset(seq_lens), num_outputs=8) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def testNonDefaultPadding(self, verify_fn): def build_dataset(seq_lens): def fill_tuple(x): filled = array_ops.fill([x], x) return (filled, string_ops.as_string(filled)) padded_shape = [-1] return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( fill_tuple).padded_batch(batch_size=4, padded_shapes=(padded_shape, padded_shape), padding_values=(-1, '<end>')) seq_lens = np.random.randint(1, 20, size=(32, )).astype(np.int32) verify_fn(self, lambda: build_dataset(seq_lens), num_outputs=8)
class FromTensorSlicesCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_tensor_slices_dataset(self, components): return dataset_ops.Dataset.from_tensor_slices(components) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): # Equal length components components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(np.array([[12], [13], [14], [15]]), 22), np.array([37.0, 38.0, 39.0, 40.0])) verify_fn(self, lambda: self._build_tensor_slices_dataset(components), num_outputs=4) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def testDict(self, verify_fn): dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} verify_fn(self, lambda: self._build_tensor_slices_dataset(dict_components), num_outputs=3)
class ParallelInterleaveCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def setUp(self): super(ParallelInterleaveCheckpointTest, self).setUp() self.input_values = np.array([2, 3], dtype=np.int64) self.num_repeats = 2 self.num_outputs = np.sum(self.input_values) * 2 def _build_ds(self, cycle_length, block_length, sloppy=False): return (dataset_ops.Dataset.from_tensor_slices( self.input_values).repeat(self.num_repeats).apply( interleave_ops.parallel_interleave( lambda x: dataset_ops.Dataset.range(10 * x, 11 * x), cycle_length, block_length, sloppy))) @combinations.generate( combinations.times( test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations(), combinations.combine(cycle_length=[1, 2], block_length=[1, 3]))) def test(self, verify_fn, cycle_length, block_length): verify_fn(self, lambda: self._build_ds(cycle_length, block_length), self.num_outputs) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(cycle_length=[1, 2], block_length=[1, 3]))) def testWithSloppy(self, cycle_length, block_length): break_points = self.gen_break_points(self.num_outputs, 10) expected_outputs = np.repeat( np.concatenate( [np.arange(10 * x, 11 * x) for x in self.input_values]), self.num_repeats).tolist() actual = self.gen_outputs( lambda: self._build_ds(cycle_length, block_length, True), break_points, self.num_outputs) self.assertSequenceEqual(sorted(actual), expected_outputs) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def testSparse(self, verify_fn): def _map_fn(i): return sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) def _interleave_fn(x): return dataset_ops.Dataset.from_tensor_slices( sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) def _build_dataset(): return dataset_ops.Dataset.range(10).map(_map_fn).apply( interleave_ops.parallel_interleave(_interleave_fn, 1)) verify_fn(self, _build_dataset, num_outputs=20)
class InterleaveDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations(), combinations.combine( cycle_length=2, block_length=[1, 3], num_parallel_calls=[None, 1, 2]))) def test(self, verify_fn, cycle_length, block_length, num_parallel_calls): num_repeats = 2 input_values = np.array([2, 3], dtype=np.int64) def _build_dataset(): return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( num_repeats).interleave( lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), cycle_length, block_length, num_parallel_calls) num_outputs = np.sum(input_values) * num_repeats verify_fn(self, _build_dataset, num_outputs) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations(), combinations.combine(num_parallel_calls=[None, 2]))) def testNested(self, verify_fn, num_parallel_calls): def build_ds(): inner_ds = dataset_ops.Dataset.from_tensor_slices(range(10)) ds = dataset_ops.Dataset.from_tensors(inner_ds).repeat(10) return ds.interleave( lambda x: x, cycle_length=5, num_parallel_calls=num_parallel_calls) verify_fn(self, build_ds, num_outputs=100) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def testSparse(self, verify_fn): def _map_fn(i): return sparse_tensor.SparseTensorValue( indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) def _interleave_fn(x): return dataset_ops.Dataset.from_tensor_slices( sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) def _build_dataset(): return dataset_ops.Dataset.range(10).map(_map_fn).interleave( _interleave_fn, cycle_length=1) verify_fn(self, _build_dataset, num_outputs=20)
class FromSparseTensorSlicesCheckpointTest( checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_sparse_tensor_slice_dataset(self, slices): # pylint: disable=g-complex-comprehension indices = np.array([[i, j] for i in range(len(slices)) for j in range(len(slices[i]))], dtype=np.int64) values = np.array([val for s in slices for val in s], dtype=np.float64) # pylint: enable=g-complex-comprehension dense_shape = np.array( [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64) sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape) return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components) @combinations.generate( combinations.times(test_base.v1_only_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] verify_fn(self, lambda: self._build_sparse_tensor_slice_dataset(slices), num_outputs=9, sparse_tensors=True)
class MatchingFilesDatasetCheckpointTest( checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_iterator_graph(self, test_patterns): return matching_files.MatchingFilesDataset(test_patterns) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): tmp_dir = tempfile.mkdtemp() width = 16 depth = 8 for i in range(width): for j in range(depth): new_base = os.path.join( tmp_dir, str(i), *[str(dir_name) for dir_name in range(j)]) if not os.path.exists(new_base): os.makedirs(new_base) child_files = ['a.py', 'b.pyc' ] if j < depth - 1 else ['c.txt', 'd.log'] for f in child_files: filename = os.path.join(new_base, f) open(filename, 'w').close() patterns = [ os.path.join(tmp_dir, os.path.join(*['**' for _ in range(depth)]), suffix) for suffix in ['*.txt', '*.log'] ] num_outputs = width * len(patterns) verify_fn(self, lambda: self._build_iterator_graph(patterns), num_outputs) shutil.rmtree(tmp_dir, ignore_errors=True)
class BatchCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): components = (np.arange(tensor_slice_len), np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], np.array(multiplier) * np.arange(tensor_slice_len)) return dataset_ops.Dataset.from_tensor_slices(components).batch( batch_size) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): tensor_slice_len = 8 batch_size = 2 num_outputs = tensor_slice_len // batch_size verify_fn( self, lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), num_outputs) def _sparse(self, i): return sparse_tensor.SparseTensorValue(indices=[[0]], values=(i * [1]), dense_shape=[1]) def _build_dataset_sparse(self, batch_size=5): return dataset_ops.Dataset.range(10).map( self._sparse).batch(batch_size) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def testSparse(self, verify_fn): verify_fn(self, self._build_dataset_sparse, num_outputs=2) def _build_dataset_nested_sparse(self): return dataset_ops.Dataset.range(10).map( self._sparse).batch(5).batch(2) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def testNestedSparse(self, verify_fn): verify_fn(self, self._build_dataset_nested_sparse, num_outputs=1)
class FilterCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_filter_range_graph(self, div): return dataset_ops.Dataset.range(100).filter( lambda x: math_ops.not_equal(math_ops.mod(x, div), 2)) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): div = 3 num_outputs = sum(x % 3 != 2 for x in range(100)) verify_fn(self, lambda: self._build_filter_range_graph(div), num_outputs) def _build_filter_dict_graph(self): return dataset_ops.Dataset.range(10).map(lambda x: { "foo": x * 2, "bar": x**2 }).filter(lambda d: math_ops.equal(d["bar"] % 2, 0)).map( lambda d: d["foo"] + d["bar"]) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def testDict(self, verify_fn): num_outputs = sum((x**2) % 2 == 0 for x in range(10)) verify_fn(self, self._build_filter_dict_graph, num_outputs) def _build_sparse_filter(self): def _map_fn(i): return sparse_tensor.SparseTensor(indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i def _filter_fn(_, i): return math_ops.equal(i % 2, 0) return dataset_ops.Dataset.range(10).map(_map_fn).filter( _filter_fn).map(lambda x, i: x) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def testSparse(self, verify_fn): verify_fn(self, self._build_sparse_filter, num_outputs=5)
class SaveCheckpointTest(IOTest, checkpoint_test_base.CheckpointTestBase): def _build_ds(self): dataset = dataset_ops.Dataset.range(42) return io._SaveDataset(dataset=dataset, path=self._save_dir, shard_func=None, compression=None) # This tests checkpointing for the _SaveDataset, which is internally # consumed in the save() function. The purpose of this test is to # thoroughly test the checkpointing functionality of the internal dataset. @combinations.generate( combinations.times(test_base.eager_only_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): verify_fn(self, self._build_ds, num_outputs=42) @combinations.generate(test_base.eager_only_combinations()) def testSaveCheckpointingAPI(self): dataset = dataset_ops.Dataset.range(40) checkpoint_args = { "directory": self._checkpoint_prefix, "max_to_keep": 50 } io.save(dataset, self._save_dir, checkpoint_args=checkpoint_args) num_checkpoint_files = len(list(os.listdir(self._checkpoint_prefix))) # By default, we checkpoint every increment. Each checkpoint writes a # file containing the data and a file containing the index. There is # also an overall checkpoint file. Thus, we expect (2 * 40) + 1 files. self.assertEqual(81, num_checkpoint_files) @combinations.generate(test_base.eager_only_combinations()) def testSaveCheckpointingAPICustomCheckpointInterval(self): dataset = dataset_ops.Dataset.range(40) step_counter = variables.Variable(0, trainable=False) checkpoint_args = { "checkpoint_interval": 5, "step_counter": step_counter, "directory": self._checkpoint_prefix, "max_to_keep": 10, } io.save(dataset, self._save_dir, checkpoint_args=checkpoint_args) num_checkpoint_files = len(list(os.listdir(self._checkpoint_prefix))) # We expect (2 * 8) + 1 files. self.assertEqual(17, num_checkpoint_files) @combinations.generate(test_base.eager_only_combinations()) def testSaveCheckpointingAPIIncorrectArgs(self): dataset = dataset_ops.Dataset.range(42) checkpoint_args = { "directory": self._checkpoint_prefix, "incorrect_arg": "incorrect_arg" } with self.assertRaises(TypeError): io.save(dataset, self._save_dir, checkpoint_args=checkpoint_args)
class AssertCardinalityCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): def build_dataset(num_elements): return dataset_ops.Dataset.range(num_elements).apply( cardinality.assert_cardinality(num_elements)) verify_fn(self, lambda: build_dataset(200), num_outputs=200)
class LoadCheckpointTest(IOTest, checkpoint_test_base.CheckpointTestBase): def _build_ds(self): return io.load(self._save_dir) @combinations.generate( combinations.times(test_base.eager_only_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): dataset = dataset_ops.Dataset.range(42) io.save(dataset, self._save_dir) verify_fn(self, self._build_ds, num_outputs=42)
class PrefetchCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def build_dataset(self, seed=10): return dataset_ops.Dataset.range(100).prefetch(10).shuffle( buffer_size=10, seed=seed, reshuffle_each_iteration=False) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): verify_fn(self, self.build_dataset, num_outputs=100)
class ShuffleAndRepeatCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_ds(self, seed): return dataset_ops.Dataset.range(20).apply( shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): verify_fn(self, lambda: self._build_ds(10), num_outputs=100)
class UniqueCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): def build_dataset(num_elements, unique_elem_range): return dataset_ops.Dataset.range(num_elements).map( lambda x: x % unique_elem_range).unique() verify_fn(self, lambda: build_dataset(200, 100), num_outputs=100)
class WindowCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_dataset(self): dataset = dataset_ops.Dataset.range(42).window(6).interleave( lambda x: x, cycle_length=2, num_parallel_calls=2) return dataset @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): verify_fn(self, self._build_dataset, num_outputs=42)
class RepeatDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_repeat_dataset(self, count, take_count=3): components = (np.arange(10), ) return dataset_ops.Dataset.from_tensor_slices(components).take( take_count).repeat(count) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def testFiniteRepeat(self, verify_fn): count = 10 verify_fn(self, lambda: self._build_repeat_dataset(count), num_outputs=(3 * count)) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def testEmptyRepeat(self, verify_fn): verify_fn(self, lambda: self._build_repeat_dataset(0), num_outputs=0) @combinations.generate(test_base.default_test_combinations()) def testInfiniteRepeat(self): self.verify_unused_iterator(lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False) self.verify_multiple_breaks(lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False) self.verify_reset_restored_iterator( lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def testInfiniteEmptyRepeat(self, verify_fn): verify_fn(self, lambda: self._build_repeat_dataset(-1, 0), num_outputs=0)
class RangeCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_range_dataset(self, start, stop): return dataset_ops.Dataset.range(start, stop) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): start = 2 stop = 10 verify_fn(self, lambda: self._build_range_dataset(start, stop), stop - start)
class FromTensorsCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_tensor_dataset(self, variable_array): components = (variable_array, np.array([1, 2, 3]), np.array(37.0)) return dataset_ops.Dataset.from_tensors(components) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): arr = np.array(1) verify_fn(self, lambda: self._build_tensor_dataset(arr), num_outputs=1)
class RebatchDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): def build_dataset(num_elements, batch_size): return distribute._RebatchDataset( dataset_ops.Dataset.range(num_elements).batch( 2 * batch_size, drop_remainder=True), batch_sizes=[batch_size, batch_size]) verify_fn(self, lambda: build_dataset(64, 8), num_outputs=8)
class TakeWhileCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_dataset(self, num_elements, upper_bound): return dataset_ops.Dataset.range(num_elements).take_while( predicate=lambda x: x < upper_bound) @combinations.generate( combinations.times( test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations(), combinations.combine(num_elements=[10, 23], upper_bound=[10, 23]))) def test(self, verify_fn, num_elements, upper_bound): verify_fn(self, lambda: self._build_dataset(num_elements, upper_bound), min(num_elements, upper_bound))
class LoadCheckpointTest(IOTest, checkpoint_test_base.CheckpointTestBase): def _build_ds(self): return io.load(self._save_dir) @combinations.generate( combinations.times(test_base.eager_only_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): self.skipTest( "TODO(jsimsa): Re-enable once snapshot reader supports serialization.") dataset = dataset_ops.Dataset.range(42) io.save(dataset, self._save_dir) verify_fn(self, self._build_ds, num_outputs=42)
class SkipDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_skip_dataset(self, count): components = (np.arange(10), ) return dataset_ops.Dataset.from_tensor_slices(components).skip(count) @combinations.generate( combinations.times( test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations(), combinations.combine(count=[5], num_outputs=[5]) + combinations.combine(count=[20, 10, -1], num_outputs=[0]) + combinations.combine(count=[0], num_outputs=[10]))) def test(self, verify_fn, count, num_outputs): verify_fn(self, lambda: self._build_skip_dataset(count), num_outputs)
class ScanCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_dataset(self, num_elements): dataset = dataset_ops.Dataset.from_tensors(1).repeat(num_elements) return dataset.scan(initial_state=[0, 1], scan_func=lambda a, _: ([a[1], a[0] + a[1]], a[1])) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): num_outputs = 5 verify_fn(self, lambda: self._build_dataset(num_outputs), num_outputs=num_outputs)
class FromListCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_list_dataset(self, elements): return from_list.from_list(elements) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): # Equal length elements elements = [ np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(np.array([[12], [13], [14], [15]]), 22), np.array([37, 38, 39, 40]) ] verify_fn(self, lambda: self._build_list_dataset(elements), num_outputs=3) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def testDict(self, verify_fn): dict_elements = [{ "foo": 1, "bar": 4.0 }, { "foo": 2, "bar": 5.0 }, { "foo": 3, "bar": 6.0 }] verify_fn(self, lambda: self._build_list_dataset(dict_elements), num_outputs=3)
class DenseToSparseBatchCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_dataset(self, components): return dataset_ops.Dataset.from_tensor_slices(components).map( lambda x: array_ops.fill([x], x)).apply( batching.dense_to_sparse_batch(4, [12])) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): components = np.random.randint(5, size=(40, )).astype(np.int32) num_outputs = len(components) // 4 verify_fn(self, lambda: self._build_dataset(components), num_outputs)
class GroupByReducerCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_dataset(self, components): reducer = grouping.Reducer(init_func=lambda _: np.int64(0), reduce_func=lambda x, y: x + y, finalize_func=lambda x: x) return dataset_ops.Dataset.from_tensor_slices(components).apply( grouping.group_by_reducer(lambda x: x % 5, reducer)) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64) verify_fn(self, lambda: self._build_dataset(components), num_outputs=5)
class ShardCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_dataset(self, num_elements, num_shards, index): return dataset_ops.Dataset.range(num_elements).shard(num_shards, index) @combinations.generate( combinations.times( test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations(), combinations.combine(elems=[10, 100], num_shards=[2, 5], index=[0, 1]))) def test(self, verify_fn, elems, num_shards, index): verify_fn(self, lambda: self._build_dataset(elems, num_shards, index), num_outputs=elems // num_shards)
class FixedLengthRecordDatasetCheckpointTest( FixedLengthRecordDatasetTestBase, checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_dataset(self, num_epochs, compression_type=None): filenames = self._createFiles() return readers.FixedLengthRecordDataset( filenames, self._record_bytes, self._header_bytes, self._footer_bytes).repeat(num_epochs) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): num_epochs = 5 num_outputs = num_epochs * self._num_files * self._num_records verify_fn(self, lambda: self._build_dataset(num_epochs), num_outputs)
class SampleFromDatasetsCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_dataset(self, probs, num_samples): dataset = dataset_ops.Dataset.sample_from_datasets([ dataset_ops.Dataset.from_tensors(i).repeat(None) for i in range(len(probs)) ], probs, seed=1813) return dataset.take(num_samples) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): verify_fn(self, lambda: self._build_dataset([0.5, 0.5], 100), num_outputs=100)
class IgnoreErrorsCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_ds(self): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) dataset = dataset_ops.Dataset.from_tensor_slices(components) dataset = dataset.map(lambda x: array_ops.check_numerics(x, "message")) dataset = dataset.apply(error_ops.ignore_errors()) options = options_lib.Options() options.experimental_external_state_policy = ( options_lib.ExternalStatePolicy.IGNORE) return dataset.with_options(options) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): verify_fn(self, self._build_ds, num_outputs=4)