def testCurrentDeviceWithGlobalGraph(self): with ops.device("/cpu:0"): self.assertEqual(device_util.current(), "/device:CPU:0") with ops.device("/job:worker"): with ops.device("/cpu:0"): self.assertEqual(device_util.current(), "/job:worker/device:CPU:0") with ops.device("/cpu:0"): with ops.device("/gpu:0"): self.assertEqual(device_util.current(), "/device:GPU:0")
def get(self, device=None): """Returns the value for the current device or raises a ValueError.""" if device is None: tower_context = distribute_lib.get_tower_context() if tower_context: device = tower_context.device else: device = distribute_lib.get_update_device() if device is None: device = device_util.current() device = device_util.canonicalize(device) try: return self._index[device] except KeyError: raise ValueError("Device %s not found in %s (current device %s)" % (device, self._index.keys(), device_util.current()))
def device(self): """The device this tower is to be executed on, as a string.""" require_tower_context(self) return device_util.current()
def _get_cross_tower(self): device = device_util.canonicalize(device_util.current()) if device in self._index: return array_ops.identity(self._index[device]) return array_ops.identity(self._primary_var)
def _get_cross_tower(self): device = device_util.canonicalize(device_util.current()) if device in self._index: return self._index[device] return list(self._index.values())[0]
def testCurrentDeviceWithEager(self): with context.eager_mode(): with ops.device("/cpu:0"): self.assertEqual(device_util.current(), "/job:localhost/replica:0/task:0/device:CPU:0")
def testCurrentDeviceWithNonGlobalGraph(self): with ops.Graph().as_default(): with ops.device("/cpu:0"): self.assertEqual(device_util.current(), "/device:CPU:0")
def testCurrentDeviceWithEager(self): with context.eager_mode(): with ops.device("/cpu:0"): self.assertEqual( device_util.current(), "/job:localhost/replica:0/task:0/device:CPU:0")