def sharded_prefix( mesh: layout_lib.Mesh, prefix: List[str], tensor_names: List[str], shape_and_slices: List[str], tensors: List[ops.Tensor], ): """Generates all sharded prefix in distributed Save. DTensor SaveV2 SPMD would generate multiple SaveV2 ops on saving devices, and it is desired to not save with same shard_prefix so that content will not be overwritten. (prefix, tensor_names, tensors(with layouts)) and saving mesh collectively defines a unique set of shard prefix that is generated for all the Save ops. Usually, (prefix, tensor_names, shape_and_slices, tensors) should match what is used in save. Args: mesh: The mesh that is used in save op. Usually a CPU mesh, and matches what is used in Save op. prefix: The prefix of saving files. tensor_names: a list of tensor names used in save op. shape_and_slices: a list of shape and slice specification used in save op. The only supported value is "" as we don't support distributed saving with slices yet. tensors: a list of tensors used in save op. The order should match tensor_names. Returns: A one d string tensor that represents all shard_prefix generated. """ layout_str = array_ops.stack( [api.fetch_layout(tensor).to_string() for tensor in tensors], axis=0) layouts = api.pack([layout_str] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh, rank=1)) mesh_str_tensor = api.pack([mesh.to_string()] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh, rank=0)) return gen_dtensor_ops.d_tensor_sharded_prefix(prefix, tensor_names, shape_and_slices, mesh_str_tensor, layouts=layouts, tensors=tensors)
def name_based_save(mesh: layout_lib.Mesh, checkpoint_prefix: Union[str, ops.Tensor], name_tensor_dict: Dict[str, Union[ops.Tensor, tf_variables.Variable]]): """Saves name based Tensor into a Checkpoint. The function prepares the input dictionary to the format of a `sharded_save`, so that it can take advantage of DTensor SPMD based distributed save. Same as restore, the function only supports saving on the single mesh. Args: mesh: The single mesh that all Tensors would be restored to. checkpoint_prefix : The prefix of checkpoint to be restored. name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The DTensor shape/dtype must match the tensors being saved/restored for now. """ if not context.executing_eagerly(): raise ValueError('name based save must run eagerly.') ordered_name_tensor_dict = name_tensor_dict if not isinstance(name_tensor_dict, collections.OrderedDict): ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict) # Current _dtensor_device() in api.py is the correct way of specifying # DTensor device singletons. The API itself will be eventually be moved to # a public API and provides global singleton in DTensor context. # For now, we just use the current `internal` API and aim at migrating in # one shot later. # TODO(hthu): Provide _dtensor_device() singleton as a public API. # pylint: disable=protected-access checkpoint_prefix = api.pack([checkpoint_prefix] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh.host_mesh(), rank=0)) tensor_names = api.pack([list(ordered_name_tensor_dict.keys())] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh.host_mesh(), rank=1)) sharded_save(mesh, file_prefix=checkpoint_prefix, tensor_names=tensor_names, shape_and_slices=[''] * len(ordered_name_tensor_dict), tensors=list(ordered_name_tensor_dict.values()))
def get_host_dvariable(): # Copy to host mesh if needed. if original_layout.mesh.device_type().upper() != 'CPU': with ops.device(dvariable.device): host_dvariable = DVariable( api.pack(api.unpack(dvariable.read_value()), host_layout)) else: host_dvariable = dvariable return (math_ops.cast(host_dvariable, dtypes.bfloat16) if self.should_cast(host_dvariable) else host_dvariable)
def barrier(mesh: layout.Mesh, barrier_name: Optional[str] = None): """Runs a barrier on the mesh. Upon returning from the barrier, all operations run before the barrier would have completed across all clients. Currently we allocate a fully sharded tensor with mesh shape and run an all_reduce on it. Example: A barrier can be used before application exit to ensure completion of pending ops. ```python x = [1, 2, 3] x = dtensor.relayout(x, dtensor.Layout.batch_sharded(mesh, 'batch', 1)) dtensor.barrier(mesh) # At this point all devices on all clients in the mesh have completed # operations before the barrier. Therefore it is OK to tear down the clients. sys.exit() ``` Args: mesh: The mesh to run the barrier on. barrier_name: The name of the barrier. mainly used for logging purpose. """ if barrier_name is None: barrier_name = '(barrier)' logging.info('entering barrier before op: %s', barrier_name) # Make sure all ops are consumed before running the sync. context.async_wait() # Reduction on a fully sharded tensor requires all devices to participate # and serves as a barrier on the mesh. component = array_ops.reshape(1.0, [1] * len(mesh.shape())) ones = api.pack([component] * mesh.num_local_devices(), layout.Layout(mesh.dim_names, mesh)) mesh_size = math_ops.reduce_sum(ones) if mesh_size != mesh.size: raise ValueError( 'Global barrier produced wrong mesh size : {0} while mesh has actual' 'size : {1}'.format(mesh_size, mesh.size)) # TODO(hthu): This isn't strictly needed but might cause confusing behaviors # from users. Consider dropping this if there is a `big` performance hit. context.async_wait() logging.info('finished running barrier across all clients after ' 'op: %s', barrier_name)
def save( self, file_prefix: str, options: Optional[checkpoint_options.CheckpointOptions] = None ) -> Optional[ops.Operation]: """Saves the saveable objects to a checkpoint with `file_prefix`. Also query the generated shards from the distributed DTensor SaveV2 ops and do a MergeV2 on those. Each op here is backed by a global_barrier to avoid racing from multiple clients. Args: file_prefix: A string or scalar string Tensor containing the prefix to save under. options: Optional `CheckpointOptions` object. This is unused in DTensor. Returns: An `Operation`, or None when executing eagerly. """ if options is not None and options.experimental_io_device is not None: raise ValueError( "Specified experimental_io_device in DTensor checkpoint is not supported." ) del options tensor_names = [] tensors = [] tensor_slices = [] for saveable in self._saveable_objects: for spec in saveable.specs: tensor = spec.tensor # A tensor value of `None` indicates that this SaveableObject gets # recorded in the object graph, but that no value is saved in the # checkpoint. if tensor is not None: if api.device_name() != spec.device: # Some small tensors are placed on CPU0 from save manager and # broadcasted to DTensor mesh, e,g., SaveCounter. tensor = api.pack( [tensor] * self._mesh.host_mesh().num_local_devices(), layout.Layout.replicated(self._mesh.host_mesh(), rank=tensor.shape.rank)) tensor_names.append(spec.name) tensors.append(tensor) tensor_slices.append(spec.slice_spec) return save_restore.sharded_save(self._mesh, file_prefix, tensor_names, tensor_slices, tensors)
def get_next(self): """Returns the next element. Returns: A possibly nested structure of values matching `tf.data.Iterator.element_spec`. Raises: `tf.errors.OutOfRangeError`: if the end of the underlying iterators has been reached. RuntimeError: if any of the underlying iterators do not return the expected number of items. """ # Create the data structure to store the individual elements of the current # batch. We store a list per element in the flattened dataset batch, and # each list should contain as many tensors as there local devices. curr_batch_elems = [[] for _ in range(len(self._flattened_layouts))] for _, iterator in self._iterators: for _ in range(self._num_local_devices_per_replica): element = iterator.get_next() # Separate the dataset elements based on the structure of the dataset. flattened_element = nest.flatten(element) for idx, batch in enumerate(flattened_element): curr_batch_elems[idx].append(batch) flattened_output = [] for batch_elems, layout in zip(curr_batch_elems, self._flattened_layouts): expected_num_elems = layout.mesh.num_local_devices() actual_num_elems = len(batch_elems) if actual_num_elems != expected_num_elems: raise RuntimeError( 'Expected to pack %d elements in batch but got %d' % (expected_num_elems, actual_num_elems)) flattened_output.append(api.pack(batch_elems, layout)) return nest.pack_sequence_as(self._layouts, flattened_output)
def restore(self, save_path, options=None): """Restore a training checkpoint with host mesh placement.""" options = options or checkpoint_options.CheckpointOptions() if save_path is None: return util.InitializationOnlyStatus(self._graph_view, ops.uid()) reader = py_checkpoint_reader.NewCheckpointReader(save_path) graph_building = not context.executing_eagerly() if graph_building: dtype_map = None else: dtype_map = reader.get_variable_to_dtype_map() try: object_graph_string = reader.get_tensor( base.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError: # The object graph proto does not exist in this checkpoint. Try the # name-based compatibility mode. restore_coordinator = util._NameBasedRestoreCoordinator( # pylint: disable=protected-access save_path=save_path, dtype_map=dtype_map) if not graph_building: for existing_trackable in self._graph_view.list_objects(): # pylint: disable=protected-access existing_trackable._maybe_initialize_trackable() existing_trackable._name_based_restores.add( restore_coordinator) existing_trackable._name_based_attribute_restore( restore_coordinator) # pylint: enable=protected-access return util.NameBasedSaverStatus(restore_coordinator, graph_view=self._graph_view) if graph_building: if self._file_prefix_placeholder is None: # DTensor change: provide a hint for mesh broadcasting to put the input # onto the host mesh. self._file_prefix_placeholder = api.pack( [constant_op.constant("model")] * self._mesh.num_local_devices(), layout.Layout.replicated(self._mesh.host_mesh(), rank=0)) file_prefix_tensor = self._file_prefix_placeholder file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} else: # DTensor change: provide a hint for mesh broadcasting to put the input # onto the host mesh. file_prefix_tensor = api.pack([constant_op.constant(save_path)] * self._mesh.num_local_devices(), layout.Layout.replicated( self._mesh.host_mesh(), rank=0)) file_prefix_feed_dict = None object_graph_proto = ( trackable_object_graph_pb2.TrackableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) # DTensor Change: Hook the proper DSaver in restore. checkpoint = _DCheckpointRestoreCoordinator( mesh=self._mesh, object_graph_proto=object_graph_proto, save_path=save_path, save_path_tensor=file_prefix_tensor, reader=reader, restore_op_cache=self._restore_op_cache, graph_view=self._graph_view, options=options, saveables_cache=self._saveables_cache) base.CheckpointPosition(checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root) # Attached dependencies are not attached to the root, so should be restored # separately. if self._graph_view.attached_dependencies: for ref in self._graph_view.attached_dependencies: if ref.name == "root": # Root dependency is automatically added to attached dependencies -- # this can be ignored since it maps back to the root object. continue proto_id = None # Find proto ID of attached dependency (if it is in the proto). for proto_ref in object_graph_proto.nodes[0].children: if proto_ref.local_name == ref.name: proto_id = proto_ref.node_id break if proto_id in checkpoint.object_by_proto_id: # Object has already been restored. This can happen when there's an # indirect connection from the attached object to the root. continue base.CheckpointPosition(checkpoint=checkpoint, proto_id=proto_id).restore(ref.ref) load_status = util.CheckpointLoadStatus( checkpoint, graph_view=self._graph_view, feed_dict=file_prefix_feed_dict) return load_status
def name_based_restore( mesh: layout_lib.Mesh, checkpoint_prefix: str, name_tensor_dict: Dict[str, Union[ops.Tensor, tf_variables.Variable]], ): """Restores from checkpoint_prefix to name based DTensors. It is required to have already-initialized DTensor variables that have same shape/dtype for the tensors being restored. Also, we currently only support a named based restore on a single mesh. Args: mesh: The single mesh that all Tensors would be restored to. checkpoint_prefix : The prefix of checkpoint to be restored. name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The DTensor shape/dtype must match the tensors being saved/restored for now. Returns: A dictionary of name to its restored DTensor value. """ if not context.executing_eagerly(): raise ValueError('name based restore must run eagerly.') ordered_name_tensor_dict = name_tensor_dict if not isinstance(name_tensor_dict, collections.OrderedDict): ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict) # Make sure that all tensors are on CPU mesh for now. # This might not be a hard limitation in the future. for name, tensor in ordered_name_tensor_dict.items(): try: if api.fetch_layout(tensor).mesh.device_type().upper() != 'CPU': raise ValueError( 'Restoring a non CPU Tensor is not supported currently. Offending ' 'tensor name : {tensor_name}'.format(tensor_name=name)) except errors_impl.OpError as op_error: raise ValueError( 'Saving/Restoring tensor must be a DTensor') from op_error # Now that we have all tensors on CPU mesh, do a DTensorRestoreV2. checkpoint_prefix = api.pack([checkpoint_prefix] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh.host_mesh(), rank=0)) # Explicitly pack to mesh to avoid implicit small constant extraction, which # does not work larger restores that has lots of names. tensor_names = api.pack([list(ordered_name_tensor_dict.keys())] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh.host_mesh(), rank=1)) shape_and_slices = api.pack([[''] * len(ordered_name_tensor_dict)] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh.host_mesh(), rank=1)) # A list of TensorShape representing all shapes for the input tensors. input_shapes = [ tensor.shape for tensor in ordered_name_tensor_dict.values() ] input_layouts = [ api.fetch_layout(tensor).to_string() for tensor in ordered_name_tensor_dict.values() ] with ops.device(api.device_name()): restored_cpu_tensors = gen_dtensor_ops.d_tensor_restore_v2( prefix=checkpoint_prefix, tensor_names=tensor_names, shape_and_slices=shape_and_slices, input_shapes=input_shapes, input_layouts=input_layouts, dtypes=[ tensor.dtype for tensor in ordered_name_tensor_dict.values() ]) return collections.OrderedDict( zip(ordered_name_tensor_dict.keys(), restored_cpu_tensors))
def pack(tensors, layout): with ops.device(dvariable.device): return api.pack(tensors, layout)