Beispiel #1
0
    def testInitializableIterator(self):
        with context.graph_mode():
            devices = ["/device:CPU:0"]
            # Using random input since that is only allowed with initializable
            # iterator.
            dataset = dataset_ops.Dataset.from_tensor_slices(
                random_ops.random_uniform((10, )))

            per_replica_dataset = values.PerReplicaDataset(dataset, devices)
            iterator = per_replica_dataset.make_initializable_iterator()

            self.evaluate(iterator.initializer)
            next_element = iterator.get_next()
            for _ in range(10):
                self.evaluate(next_element)

            # Should fail after the input is finished.
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

            # After re-initializing the iterator, should be able to iterate again.
            self.evaluate(iterator.initializer)
            for _ in range(10):
                self.evaluate(next_element)
 def distribute_dataset(self, dataset_fn):
     return values.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                     [self._device],
                                     self._prefetch_on_device)
Beispiel #3
0
 def _distribute_dataset(self, dataset_fn):
     """Distributes the dataset to each local GPU."""
     # TODO(yuefengz): shard the dataset.
     return values.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                     self._devices, True)
Beispiel #4
0
 def distribute_dataset(self, dataset_fn):
     """Distributes the dataset to each local GPU."""
     return values.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                     self._compute_devices, True)
 def make_dataset_iterator(self, dataset):
     distributed_dataset = values.PerReplicaDataset(dataset, [self._device])
     # TODO(priyag): Return distribution strategy specific InputIterator
     return distributed_dataset.make_initializable_iterator()
Beispiel #6
0
 def _distribute_dataset(self, dataset_fn):
     return values.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                     [self._device])