def testInitializableIterator(self):
        with context.graph_mode():
            devices = ["/device:CPU:0"]
            # Using random input since that is only allowed with initializable
            # iterator.
            dataset = dataset_ops.Dataset.from_tensor_slices(
                random_ops.random_uniform((10, )))

            device_map = values.ReplicaDeviceMap(devices)
            input_workers = input_lib.InputWorkers(device_map)
            per_replica_dataset = input_lib.PerReplicaDataset(
                dataset, input_workers, 0)
            iterator = per_replica_dataset.make_initializable_iterator()

            self.evaluate(iterator.initializer)
            next_element = iterator.get_next_as_list()
            for _ in range(10):
                self.evaluate(next_element)

            # Should fail after the input is finished.
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

            # After re-initializing the iterator, should be able to iterate again.
            self.evaluate(iterator.initializer)
            for _ in range(10):
                self.evaluate(next_element)
Beispiel #2
0
 def _distribute_dataset(self, dataset_fn):
   """Distributes the dataset to each local GPU."""
   return input_lib.PerReplicaDataset(
       self._call_dataset_fn(dataset_fn),
       self._input_workers,
       0,
       prefetch_on_device=True)
Beispiel #3
0
 def _distribute_dataset(self, dataset_fn):
     """Distributes the dataset to each local GPU."""
     # TODO(yuefengz): shard the dataset.
     worker_index = 0
     return input_lib.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                        self._input_workers,
                                        worker_index,
                                        prefetch_on_device=True)
Beispiel #4
0
 def _distribute_dataset(self, dataset_fn):
     if self._local_mode:
         return input_lib.PerReplicaDataset(
             self._call_dataset_fn(dataset_fn), self._input_workers, 0)
     else:
         return input_lib.MultiWorkerDataset(
             functools.partial(self._call_dataset_fn, dataset_fn),
             self._input_workers,
             auto_shard=self._auto_shard_dataset)
    def _test_iterator(self, devices, dataset, expected_values):
        device_map = values.ReplicaDeviceMap(devices)
        input_workers = input_lib.InputWorkers(device_map)
        per_replica_dataset = input_lib.PerReplicaDataset(
            dataset, input_workers, 0)
        if context.executing_eagerly():
            iterator = per_replica_dataset.make_one_shot_iterator()
        else:
            iterator = per_replica_dataset.make_initializable_iterator()
            self.evaluate([iterator.initializer])

        for expected_value in expected_values:
            next_element = iterator.get_next_as_list()
            computed_value = self.evaluate(next_element)
            self.assertEqual(expected_value, computed_value)

        with self.assertRaises(errors.OutOfRangeError):
            next_element = iterator.get_next_as_list()
            self.evaluate(next_element)
 def _distribute_dataset(self, dataset_fn):
     return input_lib.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                        self._input_workers, 0)