def _allgather_buffer(layer_id, trainable_params, group, use_calc_stream, task_flow, sync_wait=False): for param in trainable_params[layer_id]: if param.status == "all": param.use_count += 1 continue with paddle.amp.auto_cast(enable=False): full_param = _all_gather(param.fw_storage, group, use_calc_stream=use_calc_stream) if sync_wait: with paddle.amp.auto_cast(enable=False): dist.wait(tensor=full_param, group=group, use_calc_stream=use_calc_stream) core.VarBase(full_param._slice( 0, param._numel()))._share_buffer_to(param) param.value().get_tensor()._set_dims(param.shape) param.fw_storage._clear() param.fw_storage = None param.status = "all" param.use_count += 1 task_flow.full_param[param.name] = full_param return task_flow
def get_all_parameters(self): assert len(self._trainable_params.keys()) > 0 current_layer_params = self._layer.parameters(include_sublayers=True) trainable_params = list( filter(lambda x: x.trainable, current_layer_params)) for param in trainable_params: if param.use_count > 0: continue assert hasattr(param, "fw_storage" ), "Find {} don't have fw_storage attribute".format( param.name) full_param = _all_gather(param.fw_storage, self._group, use_calc_stream=True) dist.wait(tensor=full_param, group=self._group, use_calc_stream=True) core.VarBase(full_param._slice( 0, param._numel()))._share_buffer_to(param) param.value().get_tensor()._set_dims(param.shape) param.fw_storage._clear() param.fw_storage = None param.status = "all" param.use_count += 1 self._optim._parameter_list = self._ori_parameter_list self._optim._param_groups = self._ori_param_groups
def reduce(*_): # Skip gradient reduction, do not change status information if self._grad_reduced[index]: assert param.grad is not None, "Parameter gradient cannot be None" # Change reduce information self._grad_reduced[index] = False grad_storage = self._grad_storages[param.dtype][dst_rank] grad_storage.params_checked_in += 1 if grad_storage.all_checked_in: assert grad_storage.buffer is not None # Normalize all ranks grad_storage if not self._accumulate_grads: grad_storage.buffer.scale_( scale=self._world_size_scaling) # Clearing up the grad_storage buffer def cleanup(): if dst_rank != self._rank: for p in grad_storage._params: p.clear_gradient(False) p._gradient_set_empty(False) grad_storage.buffer.value().get_tensor()._clear( ) elif self._offload: grad_storage.to(device=self._offload_device) for p in grad_storage._params: self._sharding_optimizers[ 0]._offload_acc_grad( p.name, p.grad.cast(dtype=Type.fp32.value)) p.clear_gradient(False) p._gradient_set_empty(False) grad_storage._device = self._default_device grad_storage.buffer.value().get_tensor()._clear( ) # Reduce the bucket grad_storage.sent = True self._tasks_flow.append( Taskflow( task=dist.reduce( tensor=grad_storage.buffer, dst=grad_storage.destination, group=self._group, use_calc_stream=True), callback=cleanup)) # Multi stream operation will be supported later dist.wait( tensor=grad_storage.buffer, group=self._group, use_calc_stream=True) # Clear the task flow and trigger callback to clear the redundant gradient self._clear_task_flow()
def _sync_buffers(self): for buffer in self._layer.buffers(include_sublayers=True): dist.broadcast(buffer, self._global_root_rank, self._group, use_calc_stream=True) # Multi stream operation will be supported later dist.wait(tensor=buffer, group=self._group, use_calc_stream=True)
def _sync_params_and_buffers(self): """ Sync all model states for all ranks """ for p in self._layer.parameters(): dist.broadcast(p, src=self._global_root_rank, group=self._group, use_calc_stream=True) # Multi stream operation will be supported later dist.wait(tensor=p, group=self._group, use_calc_stream=True)
def __sync_buffers(self): """ Sync all the param buffers from all ranks (exp: batch norm statistics). """ for buffer in self._layer.buffers(include_sublayers=True): dist.broadcast( buffer, self._global_root_rank, self._group, use_calc_stream=True) # Multi stream operation will be supported later dist.wait(tensor=buffer, group=self._group, use_calc_stream=True)
def _broadcast_params(self): """Broadcast the parameters of the current rank to each rank""" assert self._default_device == "gpu", "Only supported gpu" # Exchange all the shards with the other ranks for dtype_per_rank in self.param_storages.values(): for dst_rank, internal_storage in dtype_per_rank.items(): dist.broadcast( tensor=internal_storage.buffer, src=dst_rank, group=self.group, use_calc_stream=True) # Multi stream operation will be supported later dist.wait( tensor=internal_storage.buffer, group=self.group, use_calc_stream=True)
def reduce(*_): if param.name in self._task_flow.full_grad.keys(): full_grad = self._task_flow.full_grad[param.name] with paddle.amp.auto_cast(enable=False): if not self._accumulate_grads: full_grad.scale_(scale=self._world_size_scaling) # Only support sync allreduce current rank's layer now dist.all_reduce(tensor=full_grad, group=self._group, use_calc_stream=True) dist.wait(tensor=full_grad, group=self._group, use_calc_stream=True) start, end = self._param2buffer[param.name][self._rank] if not self._accumulate_grads or param.bw_storage is None: param.bw_storage = core.VarBase( full_grad._slice(start, end)).detach().clone() else: param.bw_storage.add_( core.VarBase(full_grad._slice( start, end)).detach().clone()) param.clear_gradient(False) param._gradient_set_empty(False) tmp_var = self._task_flow.full_grad.pop(param.name) tmp_var._clear() if param.name in self._task_flow.full_param.keys(): if param.status == "all": param.use_count = 0 param._clear() start, end = self._param2buffer[param.name][self._rank] with paddle.amp.auto_cast(enable=False): param.fw_storage = core.VarBase( self._task_flow.full_param[param.name]._slice( start, end), param.name + "@slice").detach().clone() param.status = "part" tmp_var = self._task_flow.full_param.pop(param.name) tmp_var._clear()
def reduce(*_): # Skip gradient reduction, do not change status information if self._grad_reduced[index]: assert param.grad is not None, "Parameter gradient cannot be None" # Change reduce information self._grad_reduced[index] = False if not self._accumulate_grads: param.grad.scale_(scale=self._world_size_scaling) param._reset_grad_inplace_version(True) # Clear the gradient that does not belong to the current rank through the callback function def cleanup(): if dst_rank != self._rank: param.clear_gradient(False) elif self._offload: self._sharding_optimizers[0]._offload_acc_grad( param.name, param.grad.cast(dtype=Type.fp32.value).cpu()) param.clear_gradient(False) # Synchronize the reduce parameter gradient self._tasks_flow.append( Taskflow( task=dist.reduce( tensor=param.grad, dst=dst_rank, group=self._group, use_calc_stream=True), callback=cleanup)) # Multi stream operation will be supported later dist.wait( tensor=param.grad, group=self._group, use_calc_stream=True) # Clear the task flow and trigger callback to clear the redundant gradient self._clear_task_flow()