def dump_gradient_norms(tag, param_groups, micro_step, global_step): norm_groups = [] for i, group in enumerate(param_groups): norm_groups.append(get_grad_norm(group)) print( "\n {} gradient_norms: micro_step={}, global_step={}, norms={}".format( tag, micro_step, global_step, norm_groups))
def step(self, closure=None): """ Not supporting closure. """ if self.fused_adam_legacy: return self.step_fused_adam() # First compute norm for all group so we know if there is overflow grads_groups_flat = [] norm_groups = [] for i, group in enumerate(self.fp16_groups): data_type = self.fp32_groups_flat[i].dtype grads_groups_flat.append( _flatten_dense_tensors([ torch.zeros(p.size(), dtype=data_type, device=p.device) if p.grad is None else p.grad.to(data_type) for p in group ])) self.fp32_groups_flat[i].grad = grads_groups_flat[i] norm_groups.append( get_grad_norm(self.fp32_groups_flat, mpu=self.mpu)) self.overflow = self.overflow_checker.check_using_norm(norm_groups) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] OVERFLOW! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow self.unscale_and_clip_grads(grads_groups_flat, norm_groups) self.optimizer.step() #get rid of the fp32 gradients. Not needed anymore for group in self.fp32_groups_flat: group.grad = None for i in range(len(norm_groups)): updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data.copy_(q.data) return self.overflow
def step(self, closure=None): """ Not supporting closure. """ if self.fused_lamb_legacy: return self.step_fused_lamb() self.overflow = self.overflow_checker.check() prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] OVERFLOW! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow norm_groups = [] for i, group in enumerate(self.fp16_groups): norm_groups.append(get_grad_norm(group, mpu=self.mpu)) # copying gradients to fp32 to work with fp32 parameters for fp32_param, fp16_param in zip(self.fp32_groups[i], self.fp16_groups[i]): if fp16_param.grad is None: fp32_param.grad = torch.zeros(fp16_param.size(), dtype=fp32_param.dtype, device=fp32_param.device) else: fp32_param.grad = fp16_param.grad.to(fp32_param.dtype) self.unscale_and_clip_grads(norm_groups) self.optimizer.step() for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups): for fp32_param, fp16_param in zip(fp32_group, fp16_group): #remove the fp32 grad fp32_param.grad = None #copy data from fp32 to fp16 fp16_param.data.copy_(fp32_param.data) return self.overflow
def step(self, closure=None): # First compute norm for all group so we know if there is overflow self.overflow = self.overflow_checker.check() prev_scale = self.loss_scale self._update_scale(self.overflow) if self.overflow: self.zero_grad() if self.verbose: print("[deepspeed] OVERFLOW! Skipping step. Attempted loss " "scale: {}, reducing to {}".format( prev_scale, self.loss_scale)) return self.overflow norm_groups = [] local_sub_partitions_grad_groups = [] partition_id = dist.get_rank(group=self.dp_process_group) for i, group in enumerate(self.fp16_groups): #TODO RS: update get grad norm to support sub partitions norm_groups.append(get_grad_norm(group, mpu=self.mpu)) #RS: update free grads w.r.t. sub partitions #free gradients for all the parameters that are not updated by this process self.free_grad_in_param_list(self.params_not_local[i]) #create flat gradients for parameters updated by this process #tensor_list, first_offset, partition_size, dtype #single_grad_partition = self.get_flat_partition( # tensor_list=self.params_in_partition[i], # first_offset=self.first_offset[i], # partition_size=self.partition_size[i], # dtype=self.single_partition_of_fp32_groups[i].dtype #) #TODO RS: can we safely use dtype of the first sub-partition? i think so local_grad_sub_partitions = self.get_flat_sub_partitions( comm_tensor_list=self.params_in_rank_sub_partitions[i] [partition_id], comm_param_offsets=self. params_in_rank_sub_partitions_offsets[i][partition_id], sub_partition_size=self.sub_partition_sizes[i], dtype=self.local_sub_partitions_of_fp32_groups[i][0].dtype, num_comm_intervals=self.num_comm_intervals_per_group[i], default_device=self.local_sub_partitions_of_fp32_groups[i] [0].device) #RS: update all our local params with sub-partition grads #print("self.local_sub_partitions_of_fp32_groups[i]={}, local_grad_sub_partitions={}".format(len(self.local_sub_partitions_of_fp32_groups[i]), len(local_grad_sub_partitions))) for idx, sub_partition_param in enumerate( self.local_sub_partitions_of_fp32_groups[i]): sub_partition_param.grad = local_grad_sub_partitions[idx] #self.single_partition_of_fp32_groups[i].grad = single_grad_partition #RS: update free grads for sub-partitions #release all the gradient since we have already created a necessary copy in dp_grad_partition self.free_grad_in_param_list( self.params_in_rank_sub_partitions[i][partition_id]) local_sub_partitions_grad_groups.append(local_grad_sub_partitions) #RS: update unscale/clip with sub partitions self.unscale_and_clip_grads(local_sub_partitions_grad_groups, norm_groups) self.optimizer.step() #RS: clear our sub partition grads #get rid of the fp32 gradients. Not needed anymore for group in self.local_sub_partitions_of_fp32_groups: for idx, sub_partition_param in enumerate(group): sub_partition_param.grad = None #group.grad = None #NOTE RS: removed norm_groups outer loop from original code, i don't think it's needed #RS: copy all sub-partition fp32 data to fp16 sub partitions # copy fp32 param data to fp16 partitions w.r.t. our local rank for fp16_all_sub_partitions, fp32_local_sub_partitions in zip( self.parallel_sub_partitioned_fp16_groups, self.local_sub_partitions_of_fp32_groups): for local_sub_partition_param_fp16, local_sub_partition_param_fp32 in zip( fp16_all_sub_partitions[partition_id], fp32_local_sub_partitions): local_sub_partition_param_fp16.data.copy_( local_sub_partition_param_fp32.data) #RS: all_gather/broadcast sub-partitions in separate comm calls #gather the updated weights from everyone for fp16_all_sub_partitions in self.parallel_comm_sub_partitioned_fp16_groups: for comm_id, sub_partitions in enumerate(fp16_all_sub_partitions): dist.all_gather(sub_partitions, sub_partitions[partition_id], group=self.dp_process_group) # TODO: we probably don't need this? just to be safe for i in range(len(norm_groups)): updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data return self.overflow
def step(self, closure=None): """ Not supporting closure. """ # First compute norm for all group so we know if there is overflow self.overflow = self.overflow_checker.check() prev_scale = self.loss_scale self._update_scale(self.overflow) if self.overflow: self.zero_grad() if self.verbose: print("[deepspeed] OVERFLOW! Skipping step. Attempted loss " "scale: {}, reducing to {}".format( prev_scale, self.loss_scale)) return self.overflow norm_groups = [] single_partition_grad_groups = [] partition_id = dist.get_rank(group=self.dp_process_group) for i, group in enumerate(self.fp16_groups): norm_groups.append(get_grad_norm(group, mpu=self.mpu)) #free gradients for all the parameters that are not updated by this process self.free_grad_in_param_list(self.params_not_in_partition[i]) #create a flat gradients for parameters updated by this process single_grad_partition = self.get_flat_partition( self.params_in_partition[i], self.first_offset[i], self.partition_size[i], dtype=self.single_partition_of_fp32_groups[i].dtype) self.single_partition_of_fp32_groups[ i].grad = single_grad_partition #release all the gradient since we have already created a necessary copy in dp_grad_partition self.free_grad_in_param_list(self.params_in_partition[i]) single_partition_grad_groups.append(single_grad_partition) self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups) self.optimizer.step() #get rid of the fp32 gradients. Not needed anymore for group in self.single_partition_of_fp32_groups: group.grad = None for fp16_partitions, fp32_partition in zip( self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): fp16_partitions[partition_id].data.copy_(fp32_partition.data) dp_world_size = dist.get_world_size(group=self.dp_process_group) #gather the updated weights from everyone for _, partitioned_params in enumerate( self.parallel_partitioned_fp16_groups): if self.all_gather_partitions: # controllable memory-time tradeoff num_shards = max( 1, partitioned_params[partition_id].numel() * dp_world_size // self.allgather_size) shard_size = partitioned_params[partition_id].numel( ) // num_shards num_elements = shard_size for shard_id in range(num_shards + 1): if shard_id == num_shards: if shard_size * num_shards >= partitioned_params[ partition_id].numel(): break else: num_elements = partitioned_params[ partition_id].numel() - shard_id * shard_size shard_list = [] for dp_id in range(dp_world_size): curr_shard = partitioned_params[dp_id].narrow( 0, shard_id * shard_size, num_elements) shard_list.append(curr_shard) dist.all_gather(shard_list, shard_list[partition_id], group=self.dp_process_group) else: #this should require less memory but should be faster for src, partitioned_param in enumerate(partitioned_params): global_src = _get_global_rank(self.dp_process_group, src) dist.broadcast(partitioned_param, global_src, group=self.dp_process_group) # TODO: we probably don't need this? just to be safe for i in range(len(norm_groups)): updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data return self.overflow
def step(self, closure=None): """ Not supporting closure. """ if self.fused_adam_legacy: return self.step_fused_adam() COMPUTE_NORM = "compute_norm" OVERFLOW_CHECK = 'overflow_check' OVERFLOW_TIMERS = [COMPUTE_NORM, OVERFLOW_CHECK] UNSCALE_AND_CLIP = 'unscale_and_clip' BASIC_STEP = 'basic_step' UPDATE_FP16 = 'update_fp16' STEP_TIMERS = OVERFLOW_TIMERS + [ UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16 ] # First compute norm for all group so we know if there is overflow grads_groups_flat = [] for i, group in enumerate(self.fp16_groups): data_type = self.fp32_groups_flat[i].dtype grads_groups_flat.append( _flatten_dense_tensors([ torch.zeros(p.size(), dtype=data_type, device=p.device) if p.grad is None else p.grad.to(data_type) for p in group ])) self.fp32_groups_flat[i].grad = grads_groups_flat[i] self.start_timers([COMPUTE_NORM]) all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu) self.stop_timers([COMPUTE_NORM]) self.start_timers([OVERFLOW_CHECK]) self.overflow = self.overflow_checker.check_using_norm( [all_groups_norm]) self.stop_timers([OVERFLOW_CHECK]) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: print("[deepspeed] OVERFLOW! Skipping step. Attempted loss " "scale: {}, reducing to {}".format( prev_scale, self.cur_scale)) self.log_timers(OVERFLOW_TIMERS) return self.overflow self.start_timers([UNSCALE_AND_CLIP]) self.unscale_and_clip_grads(grads_groups_flat, [all_groups_norm]) self.stop_timers([UNSCALE_AND_CLIP]) self.start_timers([BASIC_STEP]) self.optimizer.step() self.stop_timers([BASIC_STEP]) #get rid of the fp32 gradients. Not needed anymore for group in self.fp32_groups_flat: group.grad = None self.start_timers([UPDATE_FP16]) for i in range(len(self.fp16_groups)): updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data.copy_(q.data) self.stop_timers([UPDATE_FP16]) self.log_timers(STEP_TIMERS) return self.overflow