def testNestedOutputs(self): ds = Dataset.zip((Dataset.range(4), Dataset.zip((Dataset.range(4), Dataset.range(4))))) total = 0 # The Iterator will return a nested structure of Tensor objects. # Some funkiness to compare against simple integers. for (i, x) in enumerate(datasets.Iterator(ds)): want = (i, (i, i)) got = (x[0].numpy(), (x[1][0].numpy(), x[1][1].numpy())) self.assertEqual(got, want) total += 1 self.assertEqual(4, total)
def testSaveRestoreMultipleIterator(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') dataset = Dataset.from_tensor_slices( [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) dataset = dataset.map(math_ops.square).batch(2) iterator_1 = datasets.Iterator(dataset) iterator_2 = datasets.Iterator(dataset) dataset_2 = Dataset.range(10) iterator_3 = datasets.Iterator(dataset_2) checkpoint = checkpointable_utils.Checkpoint(iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) self.assertAllEqual([1, 4], iterator_1.get_next().numpy()) self.assertEqual(0, iterator_3.get_next().numpy()) self.assertEqual(1, iterator_3.get_next().numpy()) self.assertEqual(2, iterator_3.get_next().numpy()) save_path = checkpoint.save(checkpoint_prefix) self.assertAllEqual([1, 4], iterator_2.get_next().numpy()) self.assertAllEqual([9, 16], iterator_2.get_next().numpy()) self.assertEqual(3, iterator_3.get_next().numpy()) checkpoint.restore(save_path) self.assertAllEqual([9, 16], iterator_1.get_next().numpy()) self.assertAllEqual([1, 4], iterator_2.get_next().numpy()) self.assertEqual(3, iterator_3.get_next().numpy())
def testMapAndFilter(self): def even(x): return math_ops.equal(math_ops.mod(x, 2), 0) it = datasets.Iterator(Dataset.range(8).map(math_ops.square).filter(even)) got = [x.numpy() for x in it] self.assertAllEqual([0, 4, 16, 36], got)
def testSaveRestoreMultipleIterator(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) dataset = dataset.map(math_ops.square).batch(2) iterator_1 = datasets.Iterator(dataset) iterator_2 = datasets.Iterator(dataset) dataset_2 = Dataset.range(10) iterator_3 = datasets.Iterator(dataset_2) checkpoint = checkpointable_utils.Checkpoint( iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) self.assertAllEqual([1, 4], iterator_1.get_next().numpy()) self.assertEqual(0, iterator_3.get_next().numpy()) self.assertEqual(1, iterator_3.get_next().numpy()) self.assertEqual(2, iterator_3.get_next().numpy()) save_path = checkpoint.save(checkpoint_prefix) self.assertAllEqual([1, 4], iterator_2.get_next().numpy()) self.assertAllEqual([9, 16], iterator_2.get_next().numpy()) self.assertEqual(3, iterator_3.get_next().numpy()) checkpoint.restore(save_path) self.assertAllEqual([9, 16], iterator_1.get_next().numpy()) self.assertAllEqual([1, 4], iterator_2.get_next().numpy()) self.assertEqual(3, iterator_3.get_next().numpy())
def testOverrideThreadPool(self): def get_thread_id(_): # Python creates a dummy thread object to represent the current # thread when called from an "alien" thread (such as a # `PrivateThreadPool` thread in this case). It does not include # the TensorFlow-given display name, but it has a unique # identifier that maps one-to-one with the underlying OS thread. return np.array(threading.current_thread().ident).astype(np.int64) for num_threads in [1, 2, 4, 8, 16]: dataset = ( Dataset.range(1000).map( lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), num_parallel_calls=32).apply(unique.unique())) dataset = threadpool.override_threadpool( dataset, threadpool.PrivateThreadPool( num_threads, display_name='private_thread_pool_%d' % num_threads)) thread_ids = [] for next_element in datasets.Iterator(dataset): thread_ids.append(next_element) self.assertEqual(len(thread_ids), len(set(thread_ids))) self.assertGreater(len(thread_ids), 0) # NOTE(mrry): We don't control the thread pool scheduling, and # so cannot guarantee that all of the threads in the pool will # perform work. self.assertLessEqual(len(thread_ids), num_threads)
def testOverrideThreadPool(self): def get_thread_id(_): # Python creates a dummy thread object to represent the current # thread when called from an "alien" thread (such as a # `PrivateThreadPool` thread in this case). It does not include # the TensorFlow-given display name, but it has a unique # identifier that maps one-to-one with the underlying OS thread. return np.array(threading.current_thread().ident).astype(np.int64) for num_threads in [1, 2, 4, 8, 16]: dataset = (Dataset.range(1000).map( lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), num_parallel_calls=32).apply(unique.unique())) dataset = threadpool.override_threadpool( dataset, threadpool.PrivateThreadPool( num_threads, display_name='private_thread_pool_%d' % num_threads)) thread_ids = [] for next_element in datasets.Iterator(dataset): thread_ids.append(next_element) self.assertEqual(len(thread_ids), len(set(thread_ids))) self.assertGreater(len(thread_ids), 0) # NOTE(mrry): We don't control the thread pool scheduling, and # so cannot guarantee that all of the threads in the pool will # perform work. self.assertLessEqual(len(thread_ids), num_threads)
def testGetNextOneShotIterator(self): iterator = Dataset.range(4).make_one_shot_iterator() self.assertEqual(0, iterator.get_next().numpy()) self.assertEqual(1, iterator.get_next().numpy()) self.assertEqual(2, iterator.get_next().numpy()) self.assertEqual(3, iterator.get_next().numpy()) with self.assertRaises(errors.OutOfRangeError): iterator.get_next()
def testGetNext(self): iterator = datasets.Iterator(Dataset.range(4)) self.assertEqual(0, iterator.get_next().numpy()) self.assertEqual(1, iterator.get_next().numpy()) self.assertEqual(2, iterator.get_next().numpy()) self.assertEqual(3, iterator.get_next().numpy()) with self.assertRaises(errors.OutOfRangeError): iterator.get_next()
def testPyFunc(self): def my_map(inp): return [[x + 1 for x in inp]] ds = Dataset.range(4).map( lambda x: script_ops.py_func(my_map, [[x]], dtypes.int64)) got = [x.numpy() for x in datasets.Iterator(ds)] self.assertAllEqual([[1], [2], [3], [4]], got)
def testMultipleIteratorsOnTheSameDataset(self): ds = Dataset.range(4) it1 = datasets.Iterator(ds) it2 = datasets.Iterator(ds) got = [x.numpy() for x in it1] self.assertAllEqual([0, 1, 2, 3], got) got = [x.numpy() for x in it2] self.assertAllEqual([0, 1, 2, 3], got)
def input_fn(self, mode: ModeKeys): return { ModeKeys.TRAIN: lambda: self.train_ds.repeat(self.params.num_epoch).padded_batch( self.params.batch_size, padded_shapes=([None], [], [], [])), ModeKeys.EVAL: lambda: self.eval_ds.padded_batch( self.params.batch_size, padded_shapes=([None], [], [], [])), ModeKeys.INFER: lambda: Dataset.range(1) }[mode]().make_one_shot_iterator().get_next(), None
def testRestoreInReconstructedIterator(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') dataset = Dataset.range(10) for i in range(5): iterator = datasets.Iterator(dataset) checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) checkpoint.restore(saver.latest_checkpoint(checkpoint_directory)) for j in range(2): self.assertEqual(i * 2 + j, iterator.get_next().numpy()) checkpoint.save(file_prefix=checkpoint_prefix)
def testRestoreInReconstructedIterator(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') dataset = Dataset.range(10) for i in range(5): iterator = datasets.Iterator(dataset) checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) checkpoint.restore(checkpoint_management.latest_checkpoint( checkpoint_directory)) for j in range(2): self.assertEqual(i * 2 + j, iterator.get_next().numpy()) checkpoint.save(file_prefix=checkpoint_prefix)
def testRestoreExhaustedIterator(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') dataset = Dataset.range(3) iterator = datasets.Iterator(dataset) checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) self.assertEqual(0, iterator.get_next().numpy()) self.assertEqual(1, iterator.get_next().numpy()) save_path = checkpoint.save(checkpoint_prefix) self.assertEqual(2, iterator.get_next().numpy()) checkpoint.restore(save_path) self.assertEqual(2, iterator.get_next().numpy())
def datasets_interleave(datasets, block_length=None, cycle_length=None): datasets = tuple(datasets) if cycle_length is not None: return datasets_concatenate([ datasets_interleave(datasets[i:i + cycle_length], block_length=block_length) for i in range(0, len(datasets), cycle_length) ]) choices = Dataset.range(len(datasets)) if block_length is not None: if not is_listing(block_length): block_length = tuple(block_length for _ in range(len(datasets))) choices = datasets_concatenate([ Dataset.from_tensors(tf.convert_to_tensor( i, dtype=tf.int64)).repeat(block_length[i]) for i in range(len(datasets)) ]) return tf.data.experimental.choose_from_datasets(datasets, choices.cache().repeat())
def testBasic(self): got = [] for t in datasets.Iterator(Dataset.range(4)): got.append(t.numpy()) self.assertAllEqual([0, 1, 2, 3], got)
def testBasicOneShotIterator(self): got = [] for t in Dataset.range(4).make_one_shot_iterator(): got.append(t.numpy()) self.assertAllEqual([0, 1, 2, 3], got)
def testBasicImplicitIterator(self): got = [] for t in Dataset.range(4): got.append(t.numpy()) self.assertAllEqual([0, 1, 2, 3], got)