예제 #1
0
    def testUseDefaultDevice(self, use_default_device):
        if not test_util.is_gpu_available():
            self.skipTest("No GPUs available.")

        weights = variables.Variable(initial_value=array_ops.zeros((1000,
                                                                    1000)))
        result = variables.Variable(initial_value=array_ops.zeros((1000,
                                                                   1000)))

        def scan_fn(state, sample):
            product = math_ops.matmul(sample, weights)
            result.assign_add(product)
            with ops.colocate_with(product):
                device = test_ops.device_placement_op()
            return state, device

        data = variables.Variable(initial_value=array_ops.zeros((1, 1000,
                                                                 1000)))
        dataset = dataset_ops.Dataset.from_tensor_slices(data)
        dataset = dataset_ops._ScanDataset(
            dataset,
            np.int64(1),
            scan_fn,
            use_default_device=use_default_device)
        get_next = self.getNext(dataset)

        if use_default_device:
            self.assertIn(b"CPU:0", self.evaluate(get_next()))
        else:
            self.assertIn(b"GPU:0", self.evaluate(get_next()))
예제 #2
0
def _general_purpose_scan(ds, init_state, body):
  """Variant of Dataset.scan with semantics of general-purpose computation."""
  # Datasets are typically intended for data preprocessing. However, in
  # autograph loops they usually appear as general-purpose computations (for
  # example, a custom training loop). These two use cases require significantly
  # different optimization policies, the most important of which is the device
  # placement. The flag override for use_default_device below instructs the
  # runtime to treat the computation as general-purpose, rather than data
  # preprocessing.
  # TODO(mdan): s/use_default_device/specialize_for_input_pipeline.
  # TODO(mdan): Don't use private symbols.
  # pylint:disable=protected-access
  return dataset_ops._ScanDataset(
      ds, init_state, body, use_default_device=False)