def __init__(self, dvariable, name): with ops.device(dvariable.device): original_layout = api.fetch_layout(dvariable) # Record original layout to allow restore. self._original_layout = original_layout self._dvariable = dvariable def pack(tensors, layout): with ops.device(dvariable.device): return api.pack(tensors, layout) host_layout = layout_lib.Layout(original_layout.sharding_specs, original_layout.mesh.host_mesh()) def get_host_dvariable(): # Copy to host mesh if needed. if original_layout.mesh.device_type().upper() != 'CPU': with ops.device(dvariable.device): host_dvariable = DVariable( api.pack(api.unpack(dvariable.read_value()), host_layout)) else: host_dvariable = dvariable return (math_ops.cast(host_dvariable, dtypes.bfloat16) if self.should_cast(host_dvariable) else host_dvariable) num_local_devices = original_layout.mesh.num_local_devices() super(_DVariableSaveable, self).__init__( None, [ DSaveSpec( tensor=get_host_dvariable, slice_spec=pack([''] * num_local_devices, layout_lib.Layout.replicated( original_layout.mesh.host_mesh(), rank=0)), name=pack([name] * num_local_devices, layout_lib.Layout.replicated( original_layout.mesh.host_mesh(), rank=0)), global_shape=dvariable.shape, # Layout is attached as attribute, no need to put it as a # Tensor on DTensorDevice. layout=host_layout.to_string(), dtype=dtypes.bfloat16 if self.should_cast(dvariable) else dvariable.dtype, device=dvariable.device) ], name)
def sharded_prefix( mesh: layout_lib.Mesh, prefix: List[str], tensor_names: List[str], shape_and_slices: List[str], tensors: List[ops.Tensor], ): """Generates all sharded prefix in distributed Save. DTensor SaveV2 SPMD would generate multiple SaveV2 ops on saving devices, and it is desired to not save with same shard_prefix so that content will not be overwritten. (prefix, tensor_names, tensors(with layouts)) and saving mesh collectively defines a unique set of shard prefix that is generated for all the Save ops. Usually, (prefix, tensor_names, shape_and_slices, tensors) should match what is used in save. Args: mesh: The mesh that is used in save op. Usually a CPU mesh, and matches what is used in Save op. prefix: The prefix of saving files. tensor_names: a list of tensor names used in save op. shape_and_slices: a list of shape and slice specification used in save op. The only supported value is "" as we don't support distributed saving with slices yet. tensors: a list of tensors used in save op. The order should match tensor_names. Returns: A one d string tensor that represents all shard_prefix generated. """ layout_str = array_ops.stack( [api.fetch_layout(tensor).to_string() for tensor in tensors], axis=0) layouts = api.pack([layout_str] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh, rank=1)) mesh_str_tensor = api.pack([mesh.to_string()] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh, rank=0)) return gen_dtensor_ops.d_tensor_sharded_prefix(prefix, tensor_names, shape_and_slices, mesh_str_tensor, layouts=layouts, tensors=tensors)
def __init__(self, initial_value, *args, dtype=None, **kwargs): """Overrides tf.Variable to fix VarHandleOp placements.""" # Variables by default use the current device scope for placement. This # wrapper has them follow the initial value's placement instead (which will # be the DTensor device if the initial value has a layout). if callable(initial_value): initial_value = initial_value() initial_value = ops.convert_to_tensor(initial_value, dtype=dtype) variable_device = initial_value.device self._save_as_bf16 = False # TODO(b/159035705): The following code enables variable creation inside # a tf.function. However, it requires a global dtensor device. # if not variable_device and not tf.executing_eagerly(): # try: # initial_value.op.get_attr("_layout") # except ValueError: # pass # else: # # The initial value is a DTensor, but because the DTensor device is # # only active during eager execution at the moment we need to # # translate that into a placement for the eager VarHandleOp. # variable_device = _dtensor_device().name with ops.device(variable_device): # If initial tensor assigned to DVariable is DTensor, record the layout of # the resource so that this can be queried. self.layout = None if context.executing_eagerly(): try: self.layout = api.fetch_layout(initial_value) except (errors.InvalidArgumentError, errors.NotFoundError): # For Non-DTensor tensors, fetch layout results in expected # InvalidArgument or NotFoundError depending on whether the API # is called within DTensor device scope or not. self.layout = None pass mesh = self.layout.mesh if self.layout else None with api.run_on(mesh) if mesh else contextlib.nullcontext(): super(DVariable, self).__init__(initial_value, *args, dtype=dtype, **kwargs)
def name_based_restore( mesh: layout_lib.Mesh, checkpoint_prefix: str, name_tensor_dict: Dict[str, Union[ops.Tensor, tf_variables.Variable]], ): """Restores from checkpoint_prefix to name based DTensors. It is required to have already-initialized DTensor variables that have same shape/dtype for the tensors being restored. Also, we currently only support a named based restore on a single mesh. Args: mesh: The single mesh that all Tensors would be restored to. checkpoint_prefix : The prefix of checkpoint to be restored. name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The DTensor shape/dtype must match the tensors being saved/restored for now. Returns: A dictionary of name to its restored DTensor value. """ if not context.executing_eagerly(): raise ValueError('name based restore must run eagerly.') ordered_name_tensor_dict = name_tensor_dict if not isinstance(name_tensor_dict, collections.OrderedDict): ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict) # Make sure that all tensors are on CPU mesh for now. # This might not be a hard limitation in the future. for name, tensor in ordered_name_tensor_dict.items(): try: if api.fetch_layout(tensor).mesh.device_type().upper() != 'CPU': raise ValueError( 'Restoring a non CPU Tensor is not supported currently. Offending ' 'tensor name : {tensor_name}'.format(tensor_name=name)) except errors_impl.OpError as op_error: raise ValueError( 'Saving/Restoring tensor must be a DTensor') from op_error # Now that we have all tensors on CPU mesh, do a DTensorRestoreV2. checkpoint_prefix = api.pack([checkpoint_prefix] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh.host_mesh(), rank=0)) # Explicitly pack to mesh to avoid implicit small constant extraction, which # does not work larger restores that has lots of names. tensor_names = api.pack([list(ordered_name_tensor_dict.keys())] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh.host_mesh(), rank=1)) shape_and_slices = api.pack([[''] * len(ordered_name_tensor_dict)] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh.host_mesh(), rank=1)) # A list of TensorShape representing all shapes for the input tensors. input_shapes = [ tensor.shape for tensor in ordered_name_tensor_dict.values() ] input_layouts = [ api.fetch_layout(tensor).to_string() for tensor in ordered_name_tensor_dict.values() ] with ops.device(api.device_name()): restored_cpu_tensors = gen_dtensor_ops.d_tensor_restore_v2( prefix=checkpoint_prefix, tensor_names=tensor_names, shape_and_slices=shape_and_slices, input_shapes=input_shapes, input_layouts=input_layouts, dtypes=[ tensor.dtype for tensor in ordered_name_tensor_dict.values() ]) return collections.OrderedDict( zip(ordered_name_tensor_dict.keys(), restored_cpu_tensors))
def __init__(self, initial_value, *args, dtype=None, **kwargs): """Overrides tf.Variable to fix VarHandleOp placements.""" # Variables by default use the current device scope for placement. This # wrapper has them follow the initial value's placement instead (which will # be the DTensor device if the initial value has a layout). # Pop layout from kwargs since keras make_variable may pass a 'layout' # keyword argument. We need to pop it because we are passing kwargs to # super class constructor. layout = kwargs.pop('layout', None) shape = kwargs.get('shape', None) if callable(initial_value): unwrapped = initial_value if issubclass(type(initial_value), functools.partial): unwrapped = initial_value.func # If wrapped is a CheckpointInitialValueCallable, this means that # we are creating a Variable during a checkpoint restore. # Thus the restore will happen now through this callable # and we will create the DVariable with the restored dtensor. if issubclass(type(unwrapped), trackable.CheckpointInitialValueCallable): if not shape or not layout: raise ValueError( 'Expected shape and layout to be not None.') # CheckpointInitialValueCallable will call an eager tf.RestoreV2, # which does not have any shape information or layout information # attached. Thus we will do two things to have them correctly specified: # # The default layout scope allows us to correctly specify the output # layout of the tf.RestoreV2 that will be called # # Passing shard_info with the correct shape allows the tf.RestoreV2 # ShapeInference to extract the shape. initial_value = api.call_with_layout( initial_value, layout, shard_info=trackable.ShardInfo(shape=shape, offset=[0] * len(shape))) else: initial_value = initial_value() # When the initial value came from a Checkpoint restoration, fetch tensor. if isinstance(initial_value, trackable.CheckpointInitialValue): initial_value = initial_value.wrapped_value initial_value = ops.convert_to_tensor(initial_value, dtype=dtype) variable_device = initial_value.device self._save_as_bf16 = False # TODO(b/159035705): The following code enables variable creation inside # a tf.function. However, it requires a global dtensor device. # if not variable_device and not tf.executing_eagerly(): # try: # initial_value.op.get_attr("_layout") # except ValueError: # pass # else: # # The initial value is a DTensor, but because the DTensor device is # # only active during eager execution at the moment we need to # # translate that into a placement for the eager VarHandleOp. # variable_device = _dtensor_device().name with ops.device(variable_device): # If initial tensor assigned to DVariable is DTensor, record the layout of # the resource so that this can be queried. self.layout = None if context.executing_eagerly(): try: self.layout = api.fetch_layout(initial_value) except (errors.InvalidArgumentError, errors.NotFoundError): # For Non-DTensor tensors, fetch layout results in expected # InvalidArgument or NotFoundError depending on whether the API # is called within DTensor device scope or not. self.layout = None pass mesh = self.layout.mesh if self.layout else None with api.run_on(mesh) if mesh else contextlib.nullcontext(): super(DVariable, self).__init__(initial_value, *args, dtype=dtype, **kwargs)