Exemplo n.º 1
0
    def testUseDefaultDevice(self, use_default_device):
        if not test_util.is_gpu_available():
            self.skipTest("No GPUs available.")

        weights = variables.Variable(initial_value=array_ops.zeros((1000,
                                                                    1000)))
        result = variables.Variable(initial_value=array_ops.zeros((1000,
                                                                   1000)))

        def scan_fn(state, sample):
            product = math_ops.matmul(sample, weights)
            result.assign_add(product)
            with ops.colocate_with(product):
                device = test_ops.device_placement_op()
            return state, device

        data = variables.Variable(initial_value=array_ops.zeros((1, 1000,
                                                                 1000)))
        dataset = dataset_ops.Dataset.from_tensor_slices(data)
        dataset = scan_ops._ScanDataset(dataset,
                                        np.int64(1),
                                        scan_fn,
                                        use_default_device=use_default_device)
        get_next = self.getNext(dataset)

        if use_default_device:
            self.assertIn(b"CPU:0", self.evaluate(get_next()))
        else:
            self.assertIn(b"GPU:0", self.evaluate(get_next()))
Exemplo n.º 2
0
def _general_purpose_scan(ds, init_state, body):
  """Variant of Dataset.scan with semantics of general-purpose computation."""
  # Datasets are typically intended for data preprocessing. However, in
  # autograph loops they usually appear as general-purpose computations (for
  # example, a custom training loop). These two use cases require significantly
  # different optimization policies, the most important of which is the device
  # placement. The flag override for use_default_device below instructs the
  # runtime to treat the computation as general-purpose, rather than data
  # preprocessing.
  # TODO(mdan): s/use_default_device/specialize_for_input_pipeline.
  # TODO(mdan): Don't use private symbols.
  return scan_ops._ScanDataset(ds, init_state, body, use_default_device=False)  # pylint:disable=protected-access