def __init__(self, meshes: List[layout_lib.Mesh], is_async=True): """Create a new DTensorDevice which executes ops on `underlying_device`. Args: meshes: A list of `Mesh` objects indicating groups of devices to execute on. These may also be registered lazily. is_async: Indicates whether DTensor operations on this client will return immediately (with "non-ready" handles) or block until executed. This is on by default and is exposed as an option for ease of debugging. """ if any(not isinstance(mesh, layout_lib.Mesh) for mesh in meshes): raise TypeError( "Expected a flat list of Mesh objects, got {}".format(meshes)) global _next_device_number ctx = context.context() with _next_device_number_lock: self.name = "{}/device:CUSTOM:{}".format(ctx.host_address_space(), _next_device_number) _next_device_number += 1 device, device_info = _pywrap_dtensor_device.Allocate(self.name) context.register_custom_device(device, self.name, device_info) self._device_info = device_info self._current_output_layout = None self._current_default_mesh = None self._is_async = is_async self._meshes = set() self._mesh_lock = threading.Lock() for mesh in meshes: self._register_mesh(mesh)
def __init__(self, components): """Creates a device which executes operations in parallel on `components`. Args: components: A list of device names. Each operation executed on the returned device executes on these component devices. Returns: A string with the name of the newly created device. """ global _next_device_number, _next_device_number_lock self.components = tuple( device_util.canonicalize(d) for d in components) if not self.components: raise ValueError("ParallelDevice requires at least one component.") ctx = context.context() with _next_device_number_lock: # TODO(allenl): Better names for parallel devices (right now "CUSTOM" is # special-cased). self._name = "{}/device:CUSTOM:{}".format(ctx.host_address_space(), _next_device_number) _next_device_number += 1 device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules( self._name, self.components) context.register_custom_device(device, self._name, device_info) self._device_ids = None self._device_scope = None self._saving_scope = None _all_parallel_devices[self._name] = self
def testRegisterCustomDevice(self): device_name = '/job:localhost/replica:0/task:0/device:CUSTOM:0' device, device_info, arrived_flag, executed_flag = ( custom_device_testutil.GetLoggingDeviceCapsules(device_name)) context.register_custom_device(device, device_name, device_info) self.assertFalse(custom_device_testutil.FlagValue(arrived_flag)) self.assertFalse(custom_device_testutil.FlagValue(executed_flag)) with ops.device(device_name): x = constant_op.constant(1.) y = x * constant_op.constant(2.) self.assertTrue(custom_device_testutil.FlagValue(executed_flag)) # There was no copy onto the device. Actually I'm not sure how to trigger # that from Python. self.assertFalse(custom_device_testutil.FlagValue(arrived_flag)) with self.assertRaisesRegex(errors.InternalError, 'Trying to copy'): y.numpy()
def __init__(self, components): """Creates a device which executes operations in parallel on `components`. Args: components: A list of device names. Each operation executed on the returned device executes on these component devices. Returns: A string with the name of the newly created device. """ global _next_device_number, _next_device_number_lock self.components = tuple(components) ctx = context.context() with _next_device_number_lock: # TODO(allenl): Better names for parallel devices (right now "CUSTOM" is # special-cased). self.name = "{}/device:CUSTOM:{}".format(ctx.host_address_space(), _next_device_number) _next_device_number += 1 device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules( self.name, self.components) context.register_custom_device(device, self.name, device_info)