示例#1
0
 def testNestedOutputs(self):
   ds = Dataset.zip((Dataset.range(4), Dataset.zip((Dataset.range(4),
                                                    Dataset.range(4)))))
   total = 0
   # The Iterator will return a nested structure of Tensor objects.
   # Some funkiness to compare against simple integers.
   for (i, x) in enumerate(datasets.Iterator(ds)):
     want = (i, (i, i))
     got = (x[0].numpy(), (x[1][0].numpy(), x[1][1].numpy()))
     self.assertEqual(got, want)
     total += 1
   self.assertEqual(4, total)
示例#2
0
    def testSaveRestoreMultipleIterator(self):
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
        dataset = Dataset.from_tensor_slices(
            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
        dataset = dataset.map(math_ops.square).batch(2)
        iterator_1 = datasets.Iterator(dataset)
        iterator_2 = datasets.Iterator(dataset)
        dataset_2 = Dataset.range(10)
        iterator_3 = datasets.Iterator(dataset_2)

        checkpoint = checkpointable_utils.Checkpoint(iterator_1=iterator_1,
                                                     iterator_2=iterator_2,
                                                     iterator_3=iterator_3)
        self.assertAllEqual([1, 4], iterator_1.get_next().numpy())
        self.assertEqual(0, iterator_3.get_next().numpy())
        self.assertEqual(1, iterator_3.get_next().numpy())
        self.assertEqual(2, iterator_3.get_next().numpy())

        save_path = checkpoint.save(checkpoint_prefix)
        self.assertAllEqual([1, 4], iterator_2.get_next().numpy())
        self.assertAllEqual([9, 16], iterator_2.get_next().numpy())
        self.assertEqual(3, iterator_3.get_next().numpy())
        checkpoint.restore(save_path)
        self.assertAllEqual([9, 16], iterator_1.get_next().numpy())
        self.assertAllEqual([1, 4], iterator_2.get_next().numpy())
        self.assertEqual(3, iterator_3.get_next().numpy())
示例#3
0
  def testMapAndFilter(self):
    def even(x):
      return math_ops.equal(math_ops.mod(x, 2), 0)

    it = datasets.Iterator(Dataset.range(8).map(math_ops.square).filter(even))
    got = [x.numpy() for x in it]
    self.assertAllEqual([0, 4, 16, 36], got)
示例#4
0
  def testSaveRestoreMultipleIterator(self):
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
    dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
    dataset = dataset.map(math_ops.square).batch(2)
    iterator_1 = datasets.Iterator(dataset)
    iterator_2 = datasets.Iterator(dataset)
    dataset_2 = Dataset.range(10)
    iterator_3 = datasets.Iterator(dataset_2)

    checkpoint = checkpointable_utils.Checkpoint(
        iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
    self.assertAllEqual([1, 4], iterator_1.get_next().numpy())
    self.assertEqual(0, iterator_3.get_next().numpy())
    self.assertEqual(1, iterator_3.get_next().numpy())
    self.assertEqual(2, iterator_3.get_next().numpy())

    save_path = checkpoint.save(checkpoint_prefix)
    self.assertAllEqual([1, 4], iterator_2.get_next().numpy())
    self.assertAllEqual([9, 16], iterator_2.get_next().numpy())
    self.assertEqual(3, iterator_3.get_next().numpy())
    checkpoint.restore(save_path)
    self.assertAllEqual([9, 16], iterator_1.get_next().numpy())
    self.assertAllEqual([1, 4], iterator_2.get_next().numpy())
    self.assertEqual(3, iterator_3.get_next().numpy())
示例#5
0
  def testOverrideThreadPool(self):

    def get_thread_id(_):
      # Python creates a dummy thread object to represent the current
      # thread when called from an "alien" thread (such as a
      # `PrivateThreadPool` thread in this case). It does not include
      # the TensorFlow-given display name, but it has a unique
      # identifier that maps one-to-one with the underlying OS thread.
      return np.array(threading.current_thread().ident).astype(np.int64)

    for num_threads in [1, 2, 4, 8, 16]:

      dataset = (
          Dataset.range(1000).map(
              lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
              num_parallel_calls=32).apply(unique.unique()))

      dataset = threadpool.override_threadpool(
          dataset,
          threadpool.PrivateThreadPool(
              num_threads, display_name='private_thread_pool_%d' % num_threads))

      thread_ids = []
      for next_element in datasets.Iterator(dataset):
        thread_ids.append(next_element)
      self.assertEqual(len(thread_ids), len(set(thread_ids)))
      self.assertGreater(len(thread_ids), 0)
      # NOTE(mrry): We don't control the thread pool scheduling, and
      # so cannot guarantee that all of the threads in the pool will
      # perform work.
      self.assertLessEqual(len(thread_ids), num_threads)
示例#6
0
    def testOverrideThreadPool(self):
        def get_thread_id(_):
            # Python creates a dummy thread object to represent the current
            # thread when called from an "alien" thread (such as a
            # `PrivateThreadPool` thread in this case). It does not include
            # the TensorFlow-given display name, but it has a unique
            # identifier that maps one-to-one with the underlying OS thread.
            return np.array(threading.current_thread().ident).astype(np.int64)

        for num_threads in [1, 2, 4, 8, 16]:

            dataset = (Dataset.range(1000).map(
                lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
                num_parallel_calls=32).apply(unique.unique()))

            dataset = threadpool.override_threadpool(
                dataset,
                threadpool.PrivateThreadPool(
                    num_threads,
                    display_name='private_thread_pool_%d' % num_threads))

            thread_ids = []
            for next_element in datasets.Iterator(dataset):
                thread_ids.append(next_element)
            self.assertEqual(len(thread_ids), len(set(thread_ids)))
            self.assertGreater(len(thread_ids), 0)
            # NOTE(mrry): We don't control the thread pool scheduling, and
            # so cannot guarantee that all of the threads in the pool will
            # perform work.
            self.assertLessEqual(len(thread_ids), num_threads)
示例#7
0
 def testGetNextOneShotIterator(self):
   iterator = Dataset.range(4).make_one_shot_iterator()
   self.assertEqual(0, iterator.get_next().numpy())
   self.assertEqual(1, iterator.get_next().numpy())
   self.assertEqual(2, iterator.get_next().numpy())
   self.assertEqual(3, iterator.get_next().numpy())
   with self.assertRaises(errors.OutOfRangeError):
     iterator.get_next()
示例#8
0
 def testGetNext(self):
   iterator = datasets.Iterator(Dataset.range(4))
   self.assertEqual(0, iterator.get_next().numpy())
   self.assertEqual(1, iterator.get_next().numpy())
   self.assertEqual(2, iterator.get_next().numpy())
   self.assertEqual(3, iterator.get_next().numpy())
   with self.assertRaises(errors.OutOfRangeError):
     iterator.get_next()
示例#9
0
 def testGetNext(self):
     iterator = datasets.Iterator(Dataset.range(4))
     self.assertEqual(0, iterator.get_next().numpy())
     self.assertEqual(1, iterator.get_next().numpy())
     self.assertEqual(2, iterator.get_next().numpy())
     self.assertEqual(3, iterator.get_next().numpy())
     with self.assertRaises(errors.OutOfRangeError):
         iterator.get_next()
示例#10
0
 def testGetNextOneShotIterator(self):
     iterator = Dataset.range(4).make_one_shot_iterator()
     self.assertEqual(0, iterator.get_next().numpy())
     self.assertEqual(1, iterator.get_next().numpy())
     self.assertEqual(2, iterator.get_next().numpy())
     self.assertEqual(3, iterator.get_next().numpy())
     with self.assertRaises(errors.OutOfRangeError):
         iterator.get_next()
示例#11
0
    def testPyFunc(self):
        def my_map(inp):
            return [[x + 1 for x in inp]]

        ds = Dataset.range(4).map(
            lambda x: script_ops.py_func(my_map, [[x]], dtypes.int64))
        got = [x.numpy() for x in datasets.Iterator(ds)]
        self.assertAllEqual([[1], [2], [3], [4]], got)
示例#12
0
  def testPyFunc(self):

    def my_map(inp):
      return [[x + 1 for x in inp]]

    ds = Dataset.range(4).map(
        lambda x: script_ops.py_func(my_map, [[x]], dtypes.int64))
    got = [x.numpy() for x in datasets.Iterator(ds)]
    self.assertAllEqual([[1], [2], [3], [4]], got)
示例#13
0
  def testMultipleIteratorsOnTheSameDataset(self):
    ds = Dataset.range(4)
    it1 = datasets.Iterator(ds)
    it2 = datasets.Iterator(ds)
    got = [x.numpy() for x in it1]
    self.assertAllEqual([0, 1, 2, 3], got)

    got = [x.numpy() for x in it2]
    self.assertAllEqual([0, 1, 2, 3], got)
示例#14
0
 def input_fn(self, mode: ModeKeys):
     return {
         ModeKeys.TRAIN:
         lambda: self.train_ds.repeat(self.params.num_epoch).padded_batch(
             self.params.batch_size, padded_shapes=([None], [], [], [])),
         ModeKeys.EVAL:
         lambda: self.eval_ds.padded_batch(
             self.params.batch_size, padded_shapes=([None], [], [], [])),
         ModeKeys.INFER:
         lambda: Dataset.range(1)
     }[mode]().make_one_shot_iterator().get_next(), None
示例#15
0
 def testRestoreInReconstructedIterator(self):
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
     dataset = Dataset.range(10)
     for i in range(5):
         iterator = datasets.Iterator(dataset)
         checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
         checkpoint.restore(saver.latest_checkpoint(checkpoint_directory))
         for j in range(2):
             self.assertEqual(i * 2 + j, iterator.get_next().numpy())
         checkpoint.save(file_prefix=checkpoint_prefix)
示例#16
0
 def testRestoreInReconstructedIterator(self):
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
   dataset = Dataset.range(10)
   for i in range(5):
     iterator = datasets.Iterator(dataset)
     checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
     checkpoint.restore(checkpoint_management.latest_checkpoint(
         checkpoint_directory))
     for j in range(2):
       self.assertEqual(i * 2 + j, iterator.get_next().numpy())
     checkpoint.save(file_prefix=checkpoint_prefix)
示例#17
0
  def testRestoreExhaustedIterator(self):
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
    dataset = Dataset.range(3)
    iterator = datasets.Iterator(dataset)

    checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
    self.assertEqual(0, iterator.get_next().numpy())
    self.assertEqual(1, iterator.get_next().numpy())
    save_path = checkpoint.save(checkpoint_prefix)
    self.assertEqual(2, iterator.get_next().numpy())
    checkpoint.restore(save_path)
    self.assertEqual(2, iterator.get_next().numpy())
示例#18
0
    def testRestoreExhaustedIterator(self):
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
        dataset = Dataset.range(3)
        iterator = datasets.Iterator(dataset)

        checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
        self.assertEqual(0, iterator.get_next().numpy())
        self.assertEqual(1, iterator.get_next().numpy())
        save_path = checkpoint.save(checkpoint_prefix)
        self.assertEqual(2, iterator.get_next().numpy())
        checkpoint.restore(save_path)
        self.assertEqual(2, iterator.get_next().numpy())
示例#19
0
def datasets_interleave(datasets, block_length=None, cycle_length=None):
    datasets = tuple(datasets)
    if cycle_length is not None:
        return datasets_concatenate([
            datasets_interleave(datasets[i:i + cycle_length],
                                block_length=block_length)
            for i in range(0, len(datasets), cycle_length)
        ])

    choices = Dataset.range(len(datasets))

    if block_length is not None:
        if not is_listing(block_length):
            block_length = tuple(block_length for _ in range(len(datasets)))
        choices = datasets_concatenate([
            Dataset.from_tensors(tf.convert_to_tensor(
                i, dtype=tf.int64)).repeat(block_length[i])
            for i in range(len(datasets))
        ])

    return tf.data.experimental.choose_from_datasets(datasets,
                                                     choices.cache().repeat())
示例#20
0
 def testBasic(self):
   got = []
   for t in datasets.Iterator(Dataset.range(4)):
     got.append(t.numpy())
   self.assertAllEqual([0, 1, 2, 3], got)
示例#21
0
 def testBasicOneShotIterator(self):
   got = []
   for t in Dataset.range(4).make_one_shot_iterator():
     got.append(t.numpy())
   self.assertAllEqual([0, 1, 2, 3], got)
示例#22
0
 def testBasicImplicitIterator(self):
   got = []
   for t in Dataset.range(4):
     got.append(t.numpy())
   self.assertAllEqual([0, 1, 2, 3], got)
示例#23
0
 def testBasicImplicitIterator(self):
     got = []
     for t in Dataset.range(4):
         got.append(t.numpy())
     self.assertAllEqual([0, 1, 2, 3], got)
示例#24
0
 def testBasicOneShotIterator(self):
     got = []
     for t in Dataset.range(4).make_one_shot_iterator():
         got.append(t.numpy())
     self.assertAllEqual([0, 1, 2, 3], got)