Ejemplo n.º 1
0
class PaddedBatchCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                                parameterized.TestCase):
    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        def build_dataset(seq_lens):
            return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
                lambda x: array_ops.fill([x], x)).padded_batch(
                    batch_size=4, padded_shapes=[-1])

        seq_lens = np.random.randint(1, 20, size=(32, )).astype(np.int32)
        verify_fn(self, lambda: build_dataset(seq_lens), num_outputs=8)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def testNonDefaultPadding(self, verify_fn):
        def build_dataset(seq_lens):
            def fill_tuple(x):
                filled = array_ops.fill([x], x)
                return (filled, string_ops.as_string(filled))

            padded_shape = [-1]
            return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
                fill_tuple).padded_batch(batch_size=4,
                                         padded_shapes=(padded_shape,
                                                        padded_shape),
                                         padding_values=(-1, '<end>'))

        seq_lens = np.random.randint(1, 20, size=(32, )).astype(np.int32)
        verify_fn(self, lambda: build_dataset(seq_lens), num_outputs=8)
Ejemplo n.º 2
0
class FromTensorSlicesCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                                     parameterized.TestCase):
    def _build_tensor_slices_dataset(self, components):
        return dataset_ops.Dataset.from_tensor_slices(components)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        # Equal length components
        components = (np.tile(np.array([[1], [2], [3], [4]]), 20),
                      np.tile(np.array([[12], [13], [14], [15]]),
                              22), np.array([37.0, 38.0, 39.0, 40.0]))

        verify_fn(self,
                  lambda: self._build_tensor_slices_dataset(components),
                  num_outputs=4)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def testDict(self, verify_fn):
        dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}

        verify_fn(self,
                  lambda: self._build_tensor_slices_dataset(dict_components),
                  num_outputs=3)
class ParallelInterleaveCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                                       parameterized.TestCase):
    def setUp(self):
        super(ParallelInterleaveCheckpointTest, self).setUp()
        self.input_values = np.array([2, 3], dtype=np.int64)
        self.num_repeats = 2
        self.num_outputs = np.sum(self.input_values) * 2

    def _build_ds(self, cycle_length, block_length, sloppy=False):
        return (dataset_ops.Dataset.from_tensor_slices(
            self.input_values).repeat(self.num_repeats).apply(
                interleave_ops.parallel_interleave(
                    lambda x: dataset_ops.Dataset.range(10 * x, 11 * x),
                    cycle_length, block_length, sloppy)))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            checkpoint_test_base.default_test_combinations(),
            combinations.combine(cycle_length=[1, 2], block_length=[1, 3])))
    def test(self, verify_fn, cycle_length, block_length):
        verify_fn(self, lambda: self._build_ds(cycle_length, block_length),
                  self.num_outputs)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(cycle_length=[1, 2], block_length=[1, 3])))
    def testWithSloppy(self, cycle_length, block_length):
        break_points = self.gen_break_points(self.num_outputs, 10)
        expected_outputs = np.repeat(
            np.concatenate(
                [np.arange(10 * x, 11 * x) for x in self.input_values]),
            self.num_repeats).tolist()

        actual = self.gen_outputs(
            lambda: self._build_ds(cycle_length, block_length, True),
            break_points, self.num_outputs)
        self.assertSequenceEqual(sorted(actual), expected_outputs)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def testSparse(self, verify_fn):
        def _map_fn(i):
            return sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]],
                                                   values=(i * [1, -1]),
                                                   dense_shape=[2, 2])

        def _interleave_fn(x):
            return dataset_ops.Dataset.from_tensor_slices(
                sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))

        def _build_dataset():
            return dataset_ops.Dataset.range(10).map(_map_fn).apply(
                interleave_ops.parallel_interleave(_interleave_fn, 1))

        verify_fn(self, _build_dataset, num_outputs=20)
Ejemplo n.º 4
0
class InterleaveDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                                      parameterized.TestCase):

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          checkpoint_test_base.default_test_combinations(),
          combinations.combine(
              cycle_length=2,
              block_length=[1, 3],
              num_parallel_calls=[None, 1, 2])))
  def test(self, verify_fn, cycle_length, block_length, num_parallel_calls):

    num_repeats = 2
    input_values = np.array([2, 3], dtype=np.int64)

    def _build_dataset():
      return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
          num_repeats).interleave(
              lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
              cycle_length, block_length, num_parallel_calls)

    num_outputs = np.sum(input_values) * num_repeats
    verify_fn(self, _build_dataset, num_outputs)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         checkpoint_test_base.default_test_combinations(),
                         combinations.combine(num_parallel_calls=[None, 2])))
  def testNested(self, verify_fn, num_parallel_calls):

    def build_ds():

      inner_ds = dataset_ops.Dataset.from_tensor_slices(range(10))
      ds = dataset_ops.Dataset.from_tensors(inner_ds).repeat(10)
      return ds.interleave(
          lambda x: x, cycle_length=5, num_parallel_calls=num_parallel_calls)

    verify_fn(self, build_ds, num_outputs=100)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         checkpoint_test_base.default_test_combinations()))
  def testSparse(self, verify_fn):

    def _map_fn(i):
      return sparse_tensor.SparseTensorValue(
          indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])

    def _interleave_fn(x):
      return dataset_ops.Dataset.from_tensor_slices(
          sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))

    def _build_dataset():
      return dataset_ops.Dataset.range(10).map(_map_fn).interleave(
          _interleave_fn, cycle_length=1)

    verify_fn(self, _build_dataset, num_outputs=20)
Ejemplo n.º 5
0
class FromSparseTensorSlicesCheckpointTest(
        checkpoint_test_base.CheckpointTestBase, parameterized.TestCase):
    def _build_sparse_tensor_slice_dataset(self, slices):
        # pylint: disable=g-complex-comprehension
        indices = np.array([[i, j] for i in range(len(slices))
                            for j in range(len(slices[i]))],
                           dtype=np.int64)
        values = np.array([val for s in slices for val in s], dtype=np.float64)
        # pylint: enable=g-complex-comprehension
        dense_shape = np.array(
            [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64)
        sparse_components = sparse_tensor.SparseTensor(indices, values,
                                                       dense_shape)
        return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components)

    @combinations.generate(
        combinations.times(test_base.v1_only_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]

        verify_fn(self,
                  lambda: self._build_sparse_tensor_slice_dataset(slices),
                  num_outputs=9,
                  sparse_tensors=True)
class MatchingFilesDatasetCheckpointTest(
        checkpoint_test_base.CheckpointTestBase, parameterized.TestCase):
    def _build_iterator_graph(self, test_patterns):
        return matching_files.MatchingFilesDataset(test_patterns)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        tmp_dir = tempfile.mkdtemp()
        width = 16
        depth = 8
        for i in range(width):
            for j in range(depth):
                new_base = os.path.join(
                    tmp_dir, str(i), *[str(dir_name) for dir_name in range(j)])
                if not os.path.exists(new_base):
                    os.makedirs(new_base)
                child_files = ['a.py', 'b.pyc'
                               ] if j < depth - 1 else ['c.txt', 'd.log']
                for f in child_files:
                    filename = os.path.join(new_base, f)
                    open(filename, 'w').close()

        patterns = [
            os.path.join(tmp_dir, os.path.join(*['**' for _ in range(depth)]),
                         suffix) for suffix in ['*.txt', '*.log']
        ]

        num_outputs = width * len(patterns)
        verify_fn(self, lambda: self._build_iterator_graph(patterns),
                  num_outputs)

        shutil.rmtree(tmp_dir, ignore_errors=True)
Ejemplo n.º 7
0
class BatchCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                          parameterized.TestCase):
    def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2):
        components = (np.arange(tensor_slice_len), np.array([[1, 2, 3]]) *
                      np.arange(tensor_slice_len)[:, np.newaxis],
                      np.array(multiplier) * np.arange(tensor_slice_len))

        return dataset_ops.Dataset.from_tensor_slices(components).batch(
            batch_size)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        tensor_slice_len = 8
        batch_size = 2
        num_outputs = tensor_slice_len // batch_size
        verify_fn(
            self,
            lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
            num_outputs)

    def _sparse(self, i):
        return sparse_tensor.SparseTensorValue(indices=[[0]],
                                               values=(i * [1]),
                                               dense_shape=[1])

    def _build_dataset_sparse(self, batch_size=5):
        return dataset_ops.Dataset.range(10).map(
            self._sparse).batch(batch_size)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def testSparse(self, verify_fn):
        verify_fn(self, self._build_dataset_sparse, num_outputs=2)

    def _build_dataset_nested_sparse(self):
        return dataset_ops.Dataset.range(10).map(
            self._sparse).batch(5).batch(2)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def testNestedSparse(self, verify_fn):
        verify_fn(self, self._build_dataset_nested_sparse, num_outputs=1)
Ejemplo n.º 8
0
class FilterCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                           parameterized.TestCase):
    def _build_filter_range_graph(self, div):
        return dataset_ops.Dataset.range(100).filter(
            lambda x: math_ops.not_equal(math_ops.mod(x, div), 2))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        div = 3
        num_outputs = sum(x % 3 != 2 for x in range(100))
        verify_fn(self, lambda: self._build_filter_range_graph(div),
                  num_outputs)

    def _build_filter_dict_graph(self):
        return dataset_ops.Dataset.range(10).map(lambda x: {
            "foo": x * 2,
            "bar": x**2
        }).filter(lambda d: math_ops.equal(d["bar"] % 2, 0)).map(
            lambda d: d["foo"] + d["bar"])

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def testDict(self, verify_fn):
        num_outputs = sum((x**2) % 2 == 0 for x in range(10))
        verify_fn(self, self._build_filter_dict_graph, num_outputs)

    def _build_sparse_filter(self):
        def _map_fn(i):
            return sparse_tensor.SparseTensor(indices=[[0, 0]],
                                              values=(i * [1]),
                                              dense_shape=[1, 1]), i

        def _filter_fn(_, i):
            return math_ops.equal(i % 2, 0)

        return dataset_ops.Dataset.range(10).map(_map_fn).filter(
            _filter_fn).map(lambda x, i: x)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def testSparse(self, verify_fn):
        verify_fn(self, self._build_sparse_filter, num_outputs=5)
Ejemplo n.º 9
0
class SaveCheckpointTest(IOTest, checkpoint_test_base.CheckpointTestBase):
    def _build_ds(self):
        dataset = dataset_ops.Dataset.range(42)
        return io._SaveDataset(dataset=dataset,
                               path=self._save_dir,
                               shard_func=None,
                               compression=None)

    # This tests checkpointing for the _SaveDataset, which is internally
    # consumed in the save() function. The purpose of this test is to
    # thoroughly test the checkpointing functionality of the internal dataset.
    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        verify_fn(self, self._build_ds, num_outputs=42)

    @combinations.generate(test_base.eager_only_combinations())
    def testSaveCheckpointingAPI(self):
        dataset = dataset_ops.Dataset.range(40)
        checkpoint_args = {
            "directory": self._checkpoint_prefix,
            "max_to_keep": 50
        }
        io.save(dataset, self._save_dir, checkpoint_args=checkpoint_args)
        num_checkpoint_files = len(list(os.listdir(self._checkpoint_prefix)))
        # By default, we checkpoint every increment. Each checkpoint writes a
        # file containing the data and a file containing the index. There is
        # also an overall checkpoint file. Thus, we expect (2 * 40) + 1 files.
        self.assertEqual(81, num_checkpoint_files)

    @combinations.generate(test_base.eager_only_combinations())
    def testSaveCheckpointingAPICustomCheckpointInterval(self):
        dataset = dataset_ops.Dataset.range(40)
        step_counter = variables.Variable(0, trainable=False)
        checkpoint_args = {
            "checkpoint_interval": 5,
            "step_counter": step_counter,
            "directory": self._checkpoint_prefix,
            "max_to_keep": 10,
        }
        io.save(dataset, self._save_dir, checkpoint_args=checkpoint_args)
        num_checkpoint_files = len(list(os.listdir(self._checkpoint_prefix)))
        # We expect (2 * 8) + 1 files.
        self.assertEqual(17, num_checkpoint_files)

    @combinations.generate(test_base.eager_only_combinations())
    def testSaveCheckpointingAPIIncorrectArgs(self):
        dataset = dataset_ops.Dataset.range(42)
        checkpoint_args = {
            "directory": self._checkpoint_prefix,
            "incorrect_arg": "incorrect_arg"
        }
        with self.assertRaises(TypeError):
            io.save(dataset, self._save_dir, checkpoint_args=checkpoint_args)
Ejemplo n.º 10
0
class AssertCardinalityCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                                      parameterized.TestCase):
    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        def build_dataset(num_elements):
            return dataset_ops.Dataset.range(num_elements).apply(
                cardinality.assert_cardinality(num_elements))

        verify_fn(self, lambda: build_dataset(200), num_outputs=200)
Ejemplo n.º 11
0
class LoadCheckpointTest(IOTest, checkpoint_test_base.CheckpointTestBase):
    def _build_ds(self):
        return io.load(self._save_dir)

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        dataset = dataset_ops.Dataset.range(42)
        io.save(dataset, self._save_dir)
        verify_fn(self, self._build_ds, num_outputs=42)
Ejemplo n.º 12
0
class PrefetchCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                             parameterized.TestCase):
    def build_dataset(self, seed=10):
        return dataset_ops.Dataset.range(100).prefetch(10).shuffle(
            buffer_size=10, seed=seed, reshuffle_each_iteration=False)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        verify_fn(self, self.build_dataset, num_outputs=100)
class ShuffleAndRepeatCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                                     parameterized.TestCase):
    def _build_ds(self, seed):
        return dataset_ops.Dataset.range(20).apply(
            shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        verify_fn(self, lambda: self._build_ds(10), num_outputs=100)
Ejemplo n.º 14
0
class UniqueCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                           parameterized.TestCase):
    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        def build_dataset(num_elements, unique_elem_range):
            return dataset_ops.Dataset.range(num_elements).map(
                lambda x: x % unique_elem_range).unique()

        verify_fn(self, lambda: build_dataset(200, 100), num_outputs=100)
Ejemplo n.º 15
0
class WindowCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                           parameterized.TestCase):
    def _build_dataset(self):
        dataset = dataset_ops.Dataset.range(42).window(6).interleave(
            lambda x: x, cycle_length=2, num_parallel_calls=2)
        return dataset

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        verify_fn(self, self._build_dataset, num_outputs=42)
Ejemplo n.º 16
0
class RepeatDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                                  parameterized.TestCase):
    def _build_repeat_dataset(self, count, take_count=3):
        components = (np.arange(10), )
        return dataset_ops.Dataset.from_tensor_slices(components).take(
            take_count).repeat(count)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def testFiniteRepeat(self, verify_fn):
        count = 10
        verify_fn(self,
                  lambda: self._build_repeat_dataset(count),
                  num_outputs=(3 * count))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def testEmptyRepeat(self, verify_fn):
        verify_fn(self, lambda: self._build_repeat_dataset(0), num_outputs=0)

    @combinations.generate(test_base.default_test_combinations())
    def testInfiniteRepeat(self):
        self.verify_unused_iterator(lambda: self._build_repeat_dataset(-1),
                                    10,
                                    verify_exhausted=False)
        self.verify_multiple_breaks(lambda: self._build_repeat_dataset(-1),
                                    20,
                                    verify_exhausted=False)
        self.verify_reset_restored_iterator(
            lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def testInfiniteEmptyRepeat(self, verify_fn):
        verify_fn(self,
                  lambda: self._build_repeat_dataset(-1, 0),
                  num_outputs=0)
Ejemplo n.º 17
0
class RangeCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                          parameterized.TestCase):
    def _build_range_dataset(self, start, stop):
        return dataset_ops.Dataset.range(start, stop)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        start = 2
        stop = 10
        verify_fn(self, lambda: self._build_range_dataset(start, stop),
                  stop - start)
Ejemplo n.º 18
0
class FromTensorsCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                                parameterized.TestCase):
    def _build_tensor_dataset(self, variable_array):
        components = (variable_array, np.array([1, 2, 3]), np.array(37.0))

        return dataset_ops.Dataset.from_tensors(components)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        arr = np.array(1)
        verify_fn(self, lambda: self._build_tensor_dataset(arr), num_outputs=1)
Ejemplo n.º 19
0
class RebatchDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                                   parameterized.TestCase):
    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        def build_dataset(num_elements, batch_size):
            return distribute._RebatchDataset(
                dataset_ops.Dataset.range(num_elements).batch(
                    2 * batch_size, drop_remainder=True),
                batch_sizes=[batch_size, batch_size])

        verify_fn(self, lambda: build_dataset(64, 8), num_outputs=8)
Ejemplo n.º 20
0
class TakeWhileCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                              parameterized.TestCase):
    def _build_dataset(self, num_elements, upper_bound):
        return dataset_ops.Dataset.range(num_elements).take_while(
            predicate=lambda x: x < upper_bound)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            checkpoint_test_base.default_test_combinations(),
            combinations.combine(num_elements=[10, 23], upper_bound=[10, 23])))
    def test(self, verify_fn, num_elements, upper_bound):
        verify_fn(self, lambda: self._build_dataset(num_elements, upper_bound),
                  min(num_elements, upper_bound))
Ejemplo n.º 21
0
class LoadCheckpointTest(IOTest, checkpoint_test_base.CheckpointTestBase):

  def _build_ds(self):
    return io.load(self._save_dir)

  @combinations.generate(
      combinations.times(test_base.eager_only_combinations(),
                         checkpoint_test_base.default_test_combinations()))
  def test(self, verify_fn):
    self.skipTest(
        "TODO(jsimsa): Re-enable once snapshot reader supports serialization.")
    dataset = dataset_ops.Dataset.range(42)
    io.save(dataset, self._save_dir)
    verify_fn(self, self._build_ds, num_outputs=42)
Ejemplo n.º 22
0
class SkipDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                                parameterized.TestCase):
    def _build_skip_dataset(self, count):
        components = (np.arange(10), )
        return dataset_ops.Dataset.from_tensor_slices(components).skip(count)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            checkpoint_test_base.default_test_combinations(),
            combinations.combine(count=[5], num_outputs=[5]) +
            combinations.combine(count=[20, 10, -1], num_outputs=[0]) +
            combinations.combine(count=[0], num_outputs=[10])))
    def test(self, verify_fn, count, num_outputs):
        verify_fn(self, lambda: self._build_skip_dataset(count), num_outputs)
Ejemplo n.º 23
0
class ScanCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                         parameterized.TestCase):
    def _build_dataset(self, num_elements):
        dataset = dataset_ops.Dataset.from_tensors(1).repeat(num_elements)
        return dataset.scan(initial_state=[0, 1],
                            scan_func=lambda a, _: ([a[1], a[0] + a[1]], a[1]))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        num_outputs = 5
        verify_fn(self,
                  lambda: self._build_dataset(num_outputs),
                  num_outputs=num_outputs)
Ejemplo n.º 24
0
class FromListCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                             parameterized.TestCase):
    def _build_list_dataset(self, elements):
        return from_list.from_list(elements)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        # Equal length elements
        elements = [
            np.tile(np.array([[1], [2], [3], [4]]), 20),
            np.tile(np.array([[12], [13], [14], [15]]), 22),
            np.array([37, 38, 39, 40])
        ]
        verify_fn(self,
                  lambda: self._build_list_dataset(elements),
                  num_outputs=3)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def testDict(self, verify_fn):
        dict_elements = [{
            "foo": 1,
            "bar": 4.0
        }, {
            "foo": 2,
            "bar": 5.0
        }, {
            "foo": 3,
            "bar": 6.0
        }]
        verify_fn(self,
                  lambda: self._build_list_dataset(dict_elements),
                  num_outputs=3)
Ejemplo n.º 25
0
class DenseToSparseBatchCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                                       parameterized.TestCase):
    def _build_dataset(self, components):
        return dataset_ops.Dataset.from_tensor_slices(components).map(
            lambda x: array_ops.fill([x], x)).apply(
                batching.dense_to_sparse_batch(4, [12]))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        components = np.random.randint(5, size=(40, )).astype(np.int32)

        num_outputs = len(components) // 4
        verify_fn(self, lambda: self._build_dataset(components), num_outputs)
class GroupByReducerCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                                   parameterized.TestCase):
    def _build_dataset(self, components):
        reducer = grouping.Reducer(init_func=lambda _: np.int64(0),
                                   reduce_func=lambda x, y: x + y,
                                   finalize_func=lambda x: x)

        return dataset_ops.Dataset.from_tensor_slices(components).apply(
            grouping.group_by_reducer(lambda x: x % 5, reducer))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64)
        verify_fn(self, lambda: self._build_dataset(components), num_outputs=5)
Ejemplo n.º 27
0
class ShardCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                          parameterized.TestCase):
    def _build_dataset(self, num_elements, num_shards, index):
        return dataset_ops.Dataset.range(num_elements).shard(num_shards, index)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            checkpoint_test_base.default_test_combinations(),
            combinations.combine(elems=[10, 100],
                                 num_shards=[2, 5],
                                 index=[0, 1])))
    def test(self, verify_fn, elems, num_shards, index):
        verify_fn(self,
                  lambda: self._build_dataset(elems, num_shards, index),
                  num_outputs=elems // num_shards)
class FixedLengthRecordDatasetCheckpointTest(
        FixedLengthRecordDatasetTestBase,
        checkpoint_test_base.CheckpointTestBase, parameterized.TestCase):
    def _build_dataset(self, num_epochs, compression_type=None):
        filenames = self._createFiles()
        return readers.FixedLengthRecordDataset(
            filenames, self._record_bytes, self._header_bytes,
            self._footer_bytes).repeat(num_epochs)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        num_epochs = 5
        num_outputs = num_epochs * self._num_files * self._num_records
        verify_fn(self, lambda: self._build_dataset(num_epochs), num_outputs)
Ejemplo n.º 29
0
class SampleFromDatasetsCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                                       parameterized.TestCase):
    def _build_dataset(self, probs, num_samples):
        dataset = dataset_ops.Dataset.sample_from_datasets([
            dataset_ops.Dataset.from_tensors(i).repeat(None)
            for i in range(len(probs))
        ],
                                                           probs,
                                                           seed=1813)
        return dataset.take(num_samples)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        verify_fn(self,
                  lambda: self._build_dataset([0.5, 0.5], 100),
                  num_outputs=100)
Ejemplo n.º 30
0
class IgnoreErrorsCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                                 parameterized.TestCase):
    def _build_ds(self):
        components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)

        dataset = dataset_ops.Dataset.from_tensor_slices(components)
        dataset = dataset.map(lambda x: array_ops.check_numerics(x, "message"))
        dataset = dataset.apply(error_ops.ignore_errors())
        options = options_lib.Options()
        options.experimental_external_state_policy = (
            options_lib.ExternalStatePolicy.IGNORE)
        return dataset.with_options(options)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        verify_fn(self, self._build_ds, num_outputs=4)