Ejemplo n.º 1
0
    def DISABLED_test_mnist_training_tpu(self):
        # TODO(scottzhu): Enable TPU test once the dtensor_test rule is migrated
        # out of learning/brain
        tpu_util.dtensor_initialize_tpu_system()
        total_tpu_device_count = dtensor.num_global_devices("TPU")
        mesh_shape = [total_tpu_device_count]
        mesh = tpu_util.create_tpu_mesh(["batch"], mesh_shape, "tpu_mesh")

        # Needed by keras initializers.
        tf_utils.set_random_seed(1337)

        model = integration_test_utils.get_model_with_layout_map(
            integration_test_utils.get_all_replicated_layout_map(mesh))

        optimizer = optimizer_lib.Adam(learning_rate=0.001, mesh=mesh)
        optimizer.build(model.trainable_variables)

        train_losses = integration_test_utils.train_mnist_model_batch_sharded(
            model,
            optimizer,
            mesh,
            num_epochs=3,
            steps_per_epoch=100,
            global_batch_size=64,
        )
        # Make sure the losses are decreasing
        self.assertEqual(train_losses, sorted(train_losses, reverse=True))
Ejemplo n.º 2
0
def create_distributed_mesh(mesh_dims: List[Tuple[str, int]],
                            mesh_name: str = '',
                            num_global_devices: Optional[int] = None,
                            num_clients: Optional[int] = None,
                            client_id: Optional[int] = None,
                            device_type: str = 'CPU') -> layout.Mesh:
    """Creates a single- or multi-client mesh.

  For CPU and GPU meshes, users can choose to use fewer local devices than what
  is available. If any argument is missing, it will be extracted from
  environment variables. The default values for these environment variables
  create a mesh using all devices (common for unit tests).

  For TPU meshes, users should not specify any of the nullable arguments. The
  DTensor runtime will set these arguments automatically, using all TPU cores
  available in the entire cluster.

  Args:
    mesh_dims: A list of (dim_name, dim_size) tuples.
    mesh_name: Name of the created mesh. Defaults to ''.
    num_global_devices: Number of devices in the DTensor cluster. Defaults to
      the corresponding environment variable.
    num_clients: Number of clients in the DTensor cluster. Defaults to the
      corresponding environment variable, DTENSOR_NUM_CLIENTS.
    client_id: This client's ID. Defaults to the corresponding environment
      variable, DTENSOR_CLIENT_ID.
    device_type: Type of device to build the mesh for. Defaults to 'CPU'.

  Returns:
    A mesh created from specified or default arguments.
  """
    dim_names, shape = zip(*mesh_dims)

    if device_type.upper() in ['CPU', 'GPU']:
        # For CPU and GPU meshes, user-specified args take precedence over env vars.
        # This is particularly useful on single clients when users want to create
        # meshes that use fewer logical devices than what's available.

        if num_global_devices is None:
            num_global_devices = api.num_global_devices(device_type)
        if num_global_devices <= 0:
            raise ValueError(
                f'num_global_devices ({num_global_devices}) must be > 0')
        if num_global_devices != np.prod(shape):
            raise ValueError(
                f'num_global_devices ({num_global_devices}) must be '
                f'equal to total size of the mesh of shape {shape}')

        if num_clients is None:
            num_clients = api.num_clients()
        if num_clients <= 0:
            raise ValueError(f'num_clients ({num_clients}) must be > 0')

        if _in_multi_client_mode is None and num_clients > 1:
            raise ValueError(
                'Invalid multi-client topology, run dtensor.initialize_multi_client() first'
            )

        if client_id is None:
            client_id = api.client_id()
        if client_id < 0:
            raise ValueError(f'client_id ({client_id}) must be >= 0')
        if client_id >= num_clients:
            raise ValueError(
                f'client_id ({client_id}) must be < {num_clients}')

        if num_global_devices % num_clients != 0:
            raise ValueError(
                f'num_global_devices ({num_global_devices}) must be '
                f'divisible by num_clients ({num_clients})')
        num_local_devices = num_global_devices // num_clients

        # It's allowed to create a CPU or GPU mesh using fewer logical devices than
        # what's available. If so, just use the first N logical devices.
        num_available_devices = api.num_local_devices(device_type)
        if num_local_devices > num_available_devices:
            raise ValueError(
                f'Not enough devices; {num_local_devices} needed, '
                f'only {num_available_devices} available')
        local_devices = api.local_devices(device_type,
                                          client_id)[:num_local_devices]

        global_device_ids = np.arange(num_global_devices).reshape(shape)
        flattened = np.ravel(global_device_ids).tolist()
        start_idx = num_local_devices * client_id
        local_device_ids = flattened[start_idx:start_idx + num_local_devices]

        mesh = layout.Mesh(dim_names=dim_names,
                           global_device_ids=global_device_ids,
                           local_device_ids=local_device_ids,
                           local_devices=local_devices,
                           mesh_name=mesh_name)
        _print_context(num_global_devices, num_clients, client_id, device_type,
                       mesh)
        return mesh

    if device_type.upper() == 'TPU':
        # TPU meshes can only be configured through environment variables that
        # reflect the actual TPU topology. Do not let users specify custom args.
        if num_global_devices is not None:
            raise ValueError(
                f'Do not specify num_global_devices for {device_type.upper()} meshes. '
                'It will be filled in automatically from environmental variables.'
                'See api.py for the list of environmental variables for DTensor.'
            )
        if num_clients is not None:
            raise ValueError(
                f'Do not specify num_clients for {device_type.upper()} meshes. '
                'It will be filled in automatically from environmental variables.'
                'See api.py for the list of environmental variables for DTensor.'
            )
        if client_id is not None:
            raise ValueError(
                f'Do not specify client_id for {device_type.upper()} meshes. '
                'It will be filled in automatically from environmental variables.'
                'See api.py for the list of environmental variables for DTensor.'
            )
        mesh = tpu_util.create_tpu_mesh(dim_names, shape, mesh_name)
        _print_context(api.num_global_devices(device_type), api.num_clients(),
                       api.client_id(), device_type, mesh)
        return mesh

    raise ValueError(f'Device type {device_type} is not CPU, GPU or TPU')
Ejemplo n.º 3
0
def create_distributed_mesh(mesh_dims: List[Tuple[str, int]],
                            mesh_name: str = '',
                            local_devices: Optional[List[str]] = None,
                            device_type: Optional[str] = None) -> layout.Mesh:
    """Creates a distributed mesh.

  This is similar to `create_mesh`, but with a different set of arguments to
  create a mesh that spans evenly across a multi-client DTensor cluster.

  For CPU and GPU meshes, users can choose to use fewer local devices than what
  is available `local_devices`.

  For TPU, only meshes that uses all TPU cores is supported by the DTensor
  runtime.

  Args:
    mesh_dims: A list of (dim_name, dim_size) tuples.
    mesh_name: Name of the created mesh. Defaults to ''.
    local_devices: String representations of devices to use. This is the device
      part of tf.DeviceSpec, e.g. 'CPU:0'. Defaults to all available local
      logical devices.
    device_type: Type of device to build the mesh for. Defaults to 'CPU'.
      Supported values are 'CPU', 'GPU', 'TPU'.

  Returns:
    A mesh that spans evenly across all DTensor clients in the cluster.
  """
    dim_names, shape = zip(*mesh_dims)

    if device_type and device_type.upper() == 'TPU':
        # TODO(b/185940495): Allow multi-mesh and partial on TPU.
        # TPU meshes can only be configured through environment variables that
        # reflect the actual TPU topology. Do not let users specify custom args.
        if local_devices is not None:
            raise ValueError(
                f'Do not specify devices for {device_type.upper()} meshes. '
                f'Using a partial list of devices for {device_type.upper()} '
                f'is not supported.')

    device_specs, device_type = _make_device_specs(local_devices, device_type)

    if device_type.upper() in ['CPU', 'GPU']:
        # For CPU and GPU meshes, user-specified args take precedence over env vars.
        # This is particularly useful on single clients when users want to create
        # meshes that use fewer logical devices than what's available.

        if api.num_clients() > 1 and not multi_client_util.is_initialized():
            raise ValueError('Invalid multi-client topology, please run '
                             'dtensor.initialize_multi_client() first.')

        local_spec = tf_device.DeviceSpec(job=api.job_name(),
                                          replica=0,
                                          task=api.client_id())
        device_specs = [local_spec.make_merged_spec(d) for d in device_specs]

        # Assumes identical number of local devices per client.
        num_global_devices = len(device_specs) * api.num_clients()

        if np.prod(shape) != num_global_devices:
            raise ValueError(
                f'Global number of devices '
                f'({len(device_specs)} per client * {api.num_clients()} clients '
                f'= {num_global_devices}) must be '
                f'equal to total size of the mesh of shape {shape}')

        global_device_ids = np.arange(num_global_devices).reshape(shape)
        flattened = np.ravel(global_device_ids).tolist()
        start_idx = len(device_specs) * api.client_id()
        local_device_ids = flattened[start_idx:start_idx + len(device_specs)]

        mesh = layout.Mesh(dim_names=dim_names,
                           global_device_ids=global_device_ids,
                           local_device_ids=local_device_ids,
                           local_devices=device_specs,
                           mesh_name=mesh_name)
        _print_context(num_global_devices, api.num_clients(), api.client_id(),
                       device_type, mesh)
        return mesh

    if device_type.upper() == 'TPU':
        mesh = tpu_util.create_tpu_mesh(dim_names, shape, mesh_name)
        _print_context(api.num_global_devices(device_type), api.num_clients(),
                       api.client_id(), device_type, mesh)
        return mesh

    raise ValueError(f'Device type {device_type} is not CPU, GPU or TPU')