Example #1
0
    def __init__(self, meshes: List[layout_lib.Mesh], is_async=True):
        """Create a new DTensorDevice which executes ops on `underlying_device`.

    Args:
      meshes: A list of `Mesh` objects indicating groups of devices to execute
        on. These may also be registered lazily.
      is_async: Indicates whether DTensor operations on this client will return
        immediately (with "non-ready" handles) or block until executed. This is
        on by default and is exposed as an option for ease of debugging.
    """
        if any(not isinstance(mesh, layout_lib.Mesh) for mesh in meshes):
            raise TypeError(
                "Expected a flat list of Mesh objects, got {}".format(meshes))
        global _next_device_number
        ctx = context.context()
        with _next_device_number_lock:
            self.name = "{}/device:CUSTOM:{}".format(ctx.host_address_space(),
                                                     _next_device_number)
            _next_device_number += 1
        device, device_info = _pywrap_dtensor_device.Allocate(self.name)
        context.register_custom_device(device, self.name, device_info)

        self._device_info = device_info
        self._current_output_layout = None
        self._current_default_mesh = None
        self._is_async = is_async
        self._meshes = set()
        self._mesh_lock = threading.Lock()
        for mesh in meshes:
            self._register_mesh(mesh)
Example #2
0
    def __init__(self, components):
        """Creates a device which executes operations in parallel on `components`.

    Args:
      components: A list of device names. Each operation executed on the
        returned device executes on these component devices.

    Returns:
      A string with the name of the newly created device.
    """
        global _next_device_number, _next_device_number_lock
        self.components = tuple(
            device_util.canonicalize(d) for d in components)
        if not self.components:
            raise ValueError("ParallelDevice requires at least one component.")
        ctx = context.context()
        with _next_device_number_lock:
            # TODO(allenl): Better names for parallel devices (right now "CUSTOM" is
            # special-cased).
            self._name = "{}/device:CUSTOM:{}".format(ctx.host_address_space(),
                                                      _next_device_number)
            _next_device_number += 1
        device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules(
            self._name, self.components)
        context.register_custom_device(device, self._name, device_info)
        self._device_ids = None
        self._device_scope = None
        self._saving_scope = None
        _all_parallel_devices[self._name] = self
 def testRegisterCustomDevice(self):
     device_name = '/job:localhost/replica:0/task:0/device:CUSTOM:0'
     device, device_info, arrived_flag, executed_flag = (
         custom_device_testutil.GetLoggingDeviceCapsules(device_name))
     context.register_custom_device(device, device_name, device_info)
     self.assertFalse(custom_device_testutil.FlagValue(arrived_flag))
     self.assertFalse(custom_device_testutil.FlagValue(executed_flag))
     with ops.device(device_name):
         x = constant_op.constant(1.)
         y = x * constant_op.constant(2.)
     self.assertTrue(custom_device_testutil.FlagValue(executed_flag))
     # There was no copy onto the device. Actually I'm not sure how to trigger
     # that from Python.
     self.assertFalse(custom_device_testutil.FlagValue(arrived_flag))
     with self.assertRaisesRegex(errors.InternalError, 'Trying to copy'):
         y.numpy()
Example #4
0
    def __init__(self, components):
        """Creates a device which executes operations in parallel on `components`.

    Args:
      components: A list of device names. Each operation executed on the
        returned device executes on these component devices.

    Returns:
      A string with the name of the newly created device.
    """
        global _next_device_number, _next_device_number_lock
        self.components = tuple(components)
        ctx = context.context()
        with _next_device_number_lock:
            # TODO(allenl): Better names for parallel devices (right now "CUSTOM" is
            # special-cased).
            self.name = "{}/device:CUSTOM:{}".format(ctx.host_address_space(),
                                                     _next_device_number)
            _next_device_number += 1
        device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules(
            self.name, self.components)
        context.register_custom_device(device, self.name, device_info)