def testUseDefaultDevice(self, use_default_device): if not test_util.is_gpu_available(): self.skipTest("No GPUs available.") weights = variables.Variable(initial_value=array_ops.zeros((1000, 1000))) result = variables.Variable(initial_value=array_ops.zeros((1000, 1000))) def scan_fn(state, sample): product = math_ops.matmul(sample, weights) result.assign_add(product) with ops.colocate_with(product): device = test_ops.device_placement_op() return state, device data = variables.Variable(initial_value=array_ops.zeros((1, 1000, 1000))) dataset = dataset_ops.Dataset.from_tensor_slices(data) dataset = scan_ops._ScanDataset(dataset, np.int64(1), scan_fn, use_default_device=use_default_device) get_next = self.getNext(dataset) if use_default_device: self.assertIn(b"CPU:0", self.evaluate(get_next())) else: self.assertIn(b"GPU:0", self.evaluate(get_next()))
def _general_purpose_scan(ds, init_state, body): """Variant of Dataset.scan with semantics of general-purpose computation.""" # Datasets are typically intended for data preprocessing. However, in # autograph loops they usually appear as general-purpose computations (for # example, a custom training loop). These two use cases require significantly # different optimization policies, the most important of which is the device # placement. The flag override for use_default_device below instructs the # runtime to treat the computation as general-purpose, rather than data # preprocessing. # TODO(mdan): s/use_default_device/specialize_for_input_pipeline. # TODO(mdan): Don't use private symbols. return scan_ops._ScanDataset(ds, init_state, body, use_default_device=False) # pylint:disable=protected-access