Exemplo n.º 1
0
    def testInitializableIterator(self):
        with context.graph_mode():
            devices = ["/device:CPU:0"]
            # Using random input since that is only allowed with initializable
            # iterator.
            dataset = dataset_ops.Dataset.from_tensor_slices(
                random_ops.random_uniform((10, )))

            per_device_dataset = values.PerDeviceDataset(
                dataset, devices, prefetch_on_device=False)
            iterator = per_device_dataset.make_initializable_iterator()

            self.evaluate(iterator.initializer)
            next_element = iterator.get_next()
            for _ in range(10):
                self.evaluate(next_element)

            # Should fail after the input is finished.
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

            # After re-initializing the iterator, should be able to iterate again.
            self.evaluate(iterator.initializer)
            for _ in range(10):
                self.evaluate(next_element)
Exemplo n.º 2
0
    def _test_iterator_with_prefetch(self, devices, dataset, expected_values):
        if not context.executing_eagerly():
            per_device_dataset = values.PerDeviceDataset(
                dataset, devices, prefetch_on_device=True)
            iterator = per_device_dataset.make_one_shot_iterator()

            # With prefetching, we cannot guarantee which input ends up on which
            # device, so we verify that the complete set seen on all devices is
            # correct, and equal numbers are distributed to each device.
            combined_actual = []
            combined_expected = []
            for expected_value in expected_values:
                next_element = iterator.get_next()
                combined_actual.extend(
                    self.evaluate([
                        values.select_device(d, next_element) for d in devices
                    ]))
                combined_expected.extend(expected_value)

            self.assertEqual(set(combined_expected), set(combined_actual))

            with self.assertRaises(errors.OutOfRangeError):
                next_element = iterator.get_next()
                self.evaluate(
                    [values.select_device(d, next_element) for d in devices])
Exemplo n.º 3
0
 def distribute_dataset(self, dataset_fn):
   if self._cluster_spec:
     return values.MultiWorkerDataset(
         partial(self._call_dataset_fn, dataset_fn), self._worker_devices,
         self._prefetch_on_device, self._auto_shard_dataset)
   else:
     return values.PerDeviceDataset(
         self._call_dataset_fn(dataset_fn), self._devices,
         self._prefetch_on_device)
Exemplo n.º 4
0
 def distribute_dataset(self, dataset_fn):
   if self._cluster_spec:
     return values.MultiWorkerDataset(
         partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
         self._prefetch_on_device)
   else:
     return values.PerDeviceDataset(
         self._call_dataset_fn(dataset_fn),
         self._devices,
         self._prefetch_on_device,
         source_device=device_util.resolve("/device:CPU:0"))
Exemplo n.º 5
0
  def _test_iterator_no_prefetch(self, devices, dataset, expected_values):
    per_device_dataset = values.PerDeviceDataset(
        dataset, devices, prefetch_on_device=False)
    iterator = per_device_dataset.make_one_shot_iterator()

    for expected_value in expected_values:
      next_element = iterator.get_next()
      actual = self.evaluate([
          values.select_device(d, next_element) for d in devices])
      self.assertEqual(expected_value, actual)

    with self.assertRaises(errors.OutOfRangeError):
      next_element = iterator.get_next()
      self.evaluate([
          values.select_device(d, next_element) for d in devices])
Exemplo n.º 6
0
    def _test_iterator_with_prefetch(self, devices, dataset, expected_values):
        if not context.executing_eagerly():
            per_device_dataset = values.PerDeviceDataset(
                dataset, devices, prefetch_on_device=True)
            iterator = per_device_dataset.make_initializable_iterator()
            self.evaluate([iterator.initializer])

            for expected_value in expected_values:
                next_element = iterator.get_next()
                computed_value = self.evaluate(
                    [values.select_device(d, next_element) for d in devices])
                self.assertEqual(expected_value, computed_value)

            with self.assertRaises(errors.OutOfRangeError):
                next_element = iterator.get_next()
                self.evaluate(
                    [values.select_device(d, next_element) for d in devices])
 def distribute_dataset(self, dataset_fn):
     """Distributes the dataset to each local GPU."""
     return values.PerDeviceDataset(self._call_dataset_fn(dataset_fn),
                                    self._compute_devices, True)
Exemplo n.º 8
0
 def distribute_dataset(self, dataset_fn):
     return values.PerDeviceDataset(self._call_dataset_fn(dataset_fn),
                                    self._devices, self._prefetch_on_device)
Exemplo n.º 9
0
 def distribute_dataset(self, dataset_fn):
     """Distributes the dataset to each local GPU."""
     # TODO(yuefengz): shard the dataset.
     return values.PerDeviceDataset(self._call_dataset_fn(dataset_fn),
                                    self._devices, True)
Exemplo n.º 10
0
 def distribute_dataset(self, dataset):
     per_device_dataset = values.PerDeviceDataset(dataset, self._devices,
                                                  self._prefetch_on_device)
     return per_device_dataset.make_one_shot_iterator()
Exemplo n.º 11
0
 def _replicated_input_fn():
     result = input_fn()
     if not isinstance(result, dataset_ops.Dataset):
         raise ValueError("input_fn should return an object of"
                          "type Dataset for it to be replicated")
     return values.PerDeviceDataset(result, devices, True)