Exemplo n.º 1
0
    def __init__(self, dvariable, name):
        with ops.device(dvariable.device):
            original_layout = api.fetch_layout(dvariable)
        # Record original layout to allow restore.
        self._original_layout = original_layout
        self._dvariable = dvariable

        def pack(tensors, layout):
            with ops.device(dvariable.device):
                return api.pack(tensors, layout)

        host_layout = layout_lib.Layout(original_layout.sharding_specs,
                                        original_layout.mesh.host_mesh())

        def get_host_dvariable():
            # Copy to host mesh if needed.
            if original_layout.mesh.device_type().upper() != 'CPU':
                with ops.device(dvariable.device):
                    host_dvariable = DVariable(
                        api.pack(api.unpack(dvariable.read_value()),
                                 host_layout))
            else:
                host_dvariable = dvariable
            return (math_ops.cast(host_dvariable, dtypes.bfloat16)
                    if self.should_cast(host_dvariable) else host_dvariable)

        num_local_devices = original_layout.mesh.num_local_devices()
        super(_DVariableSaveable, self).__init__(
            None,
            [
                DSaveSpec(
                    tensor=get_host_dvariable,
                    slice_spec=pack([''] * num_local_devices,
                                    layout_lib.Layout.replicated(
                                        original_layout.mesh.host_mesh(),
                                        rank=0)),
                    name=pack([name] * num_local_devices,
                              layout_lib.Layout.replicated(
                                  original_layout.mesh.host_mesh(), rank=0)),
                    global_shape=dvariable.shape,
                    # Layout is attached as attribute, no need to put it as a
                    # Tensor on DTensorDevice.
                    layout=host_layout.to_string(),
                    dtype=dtypes.bfloat16
                    if self.should_cast(dvariable) else dvariable.dtype,
                    device=dvariable.device)
            ],
            name)
Exemplo n.º 2
0
def sharded_prefix(
    mesh: layout_lib.Mesh,
    prefix: List[str],
    tensor_names: List[str],
    shape_and_slices: List[str],
    tensors: List[ops.Tensor],
):
    """Generates all sharded prefix in distributed Save.

  DTensor SaveV2 SPMD would generate multiple SaveV2 ops on saving devices,
  and it is desired to not save with same shard_prefix so that content will
  not be overwritten.

  (prefix, tensor_names, tensors(with layouts)) and saving mesh collectively
  defines a unique set of shard prefix that is generated for all the Save ops.
  Usually, (prefix, tensor_names, shape_and_slices, tensors) should match what
  is used in save.

  Args:
    mesh: The mesh that is used in save op. Usually a CPU mesh, and matches what
      is used in Save op.
    prefix: The prefix of saving files.
    tensor_names: a list of tensor names used in save op.
    shape_and_slices: a list of shape and slice specification used in save op.
      The only supported value is "" as we don't support distributed saving with
      slices yet.
    tensors: a list of tensors used in save op. The order should match
      tensor_names.

  Returns:
    A one d string tensor that represents all shard_prefix generated.
  """
    layout_str = array_ops.stack(
        [api.fetch_layout(tensor).to_string() for tensor in tensors], axis=0)
    layouts = api.pack([layout_str] * mesh.num_local_devices(),
                       layout_lib.Layout.replicated(mesh, rank=1))

    mesh_str_tensor = api.pack([mesh.to_string()] * mesh.num_local_devices(),
                               layout_lib.Layout.replicated(mesh, rank=0))
    return gen_dtensor_ops.d_tensor_sharded_prefix(prefix,
                                                   tensor_names,
                                                   shape_and_slices,
                                                   mesh_str_tensor,
                                                   layouts=layouts,
                                                   tensors=tensors)
Exemplo n.º 3
0
    def __init__(self, initial_value, *args, dtype=None, **kwargs):
        """Overrides tf.Variable to fix VarHandleOp placements."""
        # Variables by default use the current device scope for placement. This
        # wrapper has them follow the initial value's placement instead (which will
        # be the DTensor device if the initial value has a layout).
        if callable(initial_value):
            initial_value = initial_value()

        initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
        variable_device = initial_value.device
        self._save_as_bf16 = False
        # TODO(b/159035705): The following code enables variable creation inside
        # a tf.function. However, it requires a global dtensor device.
        # if not variable_device and not tf.executing_eagerly():
        #   try:
        #     initial_value.op.get_attr("_layout")
        #   except ValueError:
        #     pass
        #   else:
        #     # The initial value is a DTensor, but because the DTensor device is
        #     # only active during eager execution at the moment we need to
        #     # translate that into a placement for the eager VarHandleOp.
        #     variable_device = _dtensor_device().name
        with ops.device(variable_device):
            # If initial tensor assigned to DVariable is DTensor, record the layout of
            # the resource so that this can be queried.
            self.layout = None
            if context.executing_eagerly():
                try:
                    self.layout = api.fetch_layout(initial_value)
                except (errors.InvalidArgumentError, errors.NotFoundError):
                    # For Non-DTensor tensors, fetch layout results in expected
                    # InvalidArgument or NotFoundError depending on whether the API
                    # is called within DTensor device scope or not.
                    self.layout = None
                    pass
            mesh = self.layout.mesh if self.layout else None
            with api.run_on(mesh) if mesh else contextlib.nullcontext():
                super(DVariable, self).__init__(initial_value,
                                                *args,
                                                dtype=dtype,
                                                **kwargs)
Exemplo n.º 4
0
def name_based_restore(
    mesh: layout_lib.Mesh,
    checkpoint_prefix: str,
    name_tensor_dict: Dict[str, Union[ops.Tensor, tf_variables.Variable]],
):
    """Restores from checkpoint_prefix to name based DTensors.

  It is required to have already-initialized DTensor variables that have same
  shape/dtype for the tensors being restored.

  Also, we currently only support a named based restore on a single mesh.

  Args:
    mesh: The single mesh that all Tensors would be restored to.
    checkpoint_prefix : The prefix of checkpoint to be restored.
    name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The
      DTensor shape/dtype must match the tensors being saved/restored for now.

  Returns:
    A dictionary of name to its restored DTensor value.
  """
    if not context.executing_eagerly():
        raise ValueError('name based restore must run eagerly.')

    ordered_name_tensor_dict = name_tensor_dict
    if not isinstance(name_tensor_dict, collections.OrderedDict):
        ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict)

    # Make sure that all tensors are on CPU mesh for now.
    # This might not be a hard limitation in the future.
    for name, tensor in ordered_name_tensor_dict.items():
        try:
            if api.fetch_layout(tensor).mesh.device_type().upper() != 'CPU':
                raise ValueError(
                    'Restoring a non CPU Tensor is not supported currently. Offending '
                    'tensor name : {tensor_name}'.format(tensor_name=name))
        except errors_impl.OpError as op_error:
            raise ValueError(
                'Saving/Restoring tensor must be a DTensor') from op_error

    # Now that we have all tensors on CPU mesh, do a DTensorRestoreV2.
    checkpoint_prefix = api.pack([checkpoint_prefix] *
                                 mesh.num_local_devices(),
                                 layout_lib.Layout.replicated(mesh.host_mesh(),
                                                              rank=0))
    # Explicitly pack to mesh to avoid implicit small constant extraction, which
    # does not work larger restores that has lots of names.
    tensor_names = api.pack([list(ordered_name_tensor_dict.keys())] *
                            mesh.num_local_devices(),
                            layout_lib.Layout.replicated(mesh.host_mesh(),
                                                         rank=1))
    shape_and_slices = api.pack([[''] * len(ordered_name_tensor_dict)] *
                                mesh.num_local_devices(),
                                layout_lib.Layout.replicated(mesh.host_mesh(),
                                                             rank=1))
    # A list of TensorShape representing all shapes for the input tensors.
    input_shapes = [
        tensor.shape for tensor in ordered_name_tensor_dict.values()
    ]
    input_layouts = [
        api.fetch_layout(tensor).to_string()
        for tensor in ordered_name_tensor_dict.values()
    ]

    with ops.device(api.device_name()):
        restored_cpu_tensors = gen_dtensor_ops.d_tensor_restore_v2(
            prefix=checkpoint_prefix,
            tensor_names=tensor_names,
            shape_and_slices=shape_and_slices,
            input_shapes=input_shapes,
            input_layouts=input_layouts,
            dtypes=[
                tensor.dtype for tensor in ordered_name_tensor_dict.values()
            ])

    return collections.OrderedDict(
        zip(ordered_name_tensor_dict.keys(), restored_cpu_tensors))
Exemplo n.º 5
0
    def __init__(self, initial_value, *args, dtype=None, **kwargs):
        """Overrides tf.Variable to fix VarHandleOp placements."""
        # Variables by default use the current device scope for placement. This
        # wrapper has them follow the initial value's placement instead (which will
        # be the DTensor device if the initial value has a layout).

        # Pop layout from kwargs since keras make_variable may pass a 'layout'
        # keyword argument. We need to pop it because we are passing kwargs to
        # super class constructor.
        layout = kwargs.pop('layout', None)
        shape = kwargs.get('shape', None)

        if callable(initial_value):
            unwrapped = initial_value
            if issubclass(type(initial_value), functools.partial):
                unwrapped = initial_value.func

            # If wrapped is a CheckpointInitialValueCallable, this means that
            # we are creating a Variable during a checkpoint restore.
            # Thus the restore will happen now through this callable
            # and we will create the DVariable with the restored dtensor.
            if issubclass(type(unwrapped),
                          trackable.CheckpointInitialValueCallable):
                if not shape or not layout:
                    raise ValueError(
                        'Expected shape and layout to be not None.')

                # CheckpointInitialValueCallable will call an eager tf.RestoreV2,
                # which does not have any shape information or layout information
                # attached. Thus we will do two things to have them correctly specified:
                #
                # The default layout scope allows us to correctly specify the output
                # layout of the tf.RestoreV2 that will be called
                #
                # Passing shard_info with the correct shape allows the tf.RestoreV2
                # ShapeInference to extract the shape.
                initial_value = api.call_with_layout(
                    initial_value,
                    layout,
                    shard_info=trackable.ShardInfo(shape=shape,
                                                   offset=[0] * len(shape)))
            else:
                initial_value = initial_value()

        # When the initial value came from a Checkpoint restoration, fetch tensor.
        if isinstance(initial_value, trackable.CheckpointInitialValue):
            initial_value = initial_value.wrapped_value

        initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
        variable_device = initial_value.device
        self._save_as_bf16 = False
        # TODO(b/159035705): The following code enables variable creation inside
        # a tf.function. However, it requires a global dtensor device.
        # if not variable_device and not tf.executing_eagerly():
        #   try:
        #     initial_value.op.get_attr("_layout")
        #   except ValueError:
        #     pass
        #   else:
        #     # The initial value is a DTensor, but because the DTensor device is
        #     # only active during eager execution at the moment we need to
        #     # translate that into a placement for the eager VarHandleOp.
        #     variable_device = _dtensor_device().name
        with ops.device(variable_device):
            # If initial tensor assigned to DVariable is DTensor, record the layout of
            # the resource so that this can be queried.
            self.layout = None
            if context.executing_eagerly():
                try:
                    self.layout = api.fetch_layout(initial_value)
                except (errors.InvalidArgumentError, errors.NotFoundError):
                    # For Non-DTensor tensors, fetch layout results in expected
                    # InvalidArgument or NotFoundError depending on whether the API
                    # is called within DTensor device scope or not.
                    self.layout = None
                    pass
            mesh = self.layout.mesh if self.layout else None
            with api.run_on(mesh) if mesh else contextlib.nullcontext():
                super(DVariable, self).__init__(initial_value,
                                                *args,
                                                dtype=dtype,
                                                **kwargs)