def step(self, closure=None): """ Not supporting closure. """ if self.fused_lamb_legacy: return self.step_fused_lamb() self.overflow = self.overflow_checker.check() prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow norm_groups = [] for i, group in enumerate(self.fp16_groups): norm_groups.append(get_grad_norm(group, mpu=self.mpu)) # copying gradients to fp32 to wor k with fp32 parameters for fp32_param, fp16_param in zip(self.fp32_groups[i], self.fp16_groups[i]): if fp16_param.grad is None: fp32_param.grad = torch.zeros(fp16_param.size(), dtype=fp32_param.dtype, device=fp32_param.device) else: fp32_param.grad = fp16_param.grad.to(fp32_param.dtype) self.unscale_and_clip_grads(norm_groups) self.optimizer.step() for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups): for idx, (fp32_param, fp16_param) in enumerate(zip(fp32_group, fp16_group)): #remove the fp32 grad fp32_param.grad = None #copy data from fp32 to fp16 fp16_param.data.copy_(fp32_param.data) return self.overflow
def step(self, closure=None): """ Not supporting closure. """ if self.fused_adam_legacy: return self.step_fused_adam() COMPUTE_NORM = "compute_norm" OVERFLOW_CHECK = 'overflow_check' OVERFLOW_TIMERS = [COMPUTE_NORM, OVERFLOW_CHECK] UNSCALE_AND_CLIP = 'unscale_and_clip' BASIC_STEP = 'basic_step' UPDATE_FP16 = 'update_fp16' STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16] # First determine if there is overflow. self.start_timers([OVERFLOW_CHECK]) fp16_params = [] for i, group in enumerate(self.fp16_groups): fp16_params.extend([p for p in group if p.grad is not None]) self.overflow = self.overflow_checker.has_overflow(fp16_params) self.stop_timers([OVERFLOW_CHECK]) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: log_dist( "Overflow detected. Skipping step. Attempted loss " f"scale: {prev_scale}, reducing to {self.cur_scale}", ranks=[0]) # Clear gradients for i, group in enumerate(self.fp16_groups): for p in group: p.grad = None self.log_timers(OVERFLOW_TIMERS) return self.overflow grads_groups_flat = [] for i, group in enumerate(self.fp16_groups): data_type = self.fp32_groups_flat[i].dtype grads_groups_flat.append( _flatten_dense_tensors([ torch.zeros(p.size(), dtype=data_type, device=p.device) if p.grad is None else p.grad.to(data_type) for p in group ])) for p in group: p.grad = None self.fp32_groups_flat[i].grad = grads_groups_flat[i] self.start_timers([COMPUTE_NORM]) all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu) self.stop_timers([COMPUTE_NORM]) self.start_timers([UNSCALE_AND_CLIP]) self.unscale_and_clip_grads(grads_groups_flat, [all_groups_norm]) self.stop_timers([UNSCALE_AND_CLIP]) self.start_timers([BASIC_STEP]) self.optimizer.step() self.stop_timers([BASIC_STEP]) #get rid of the fp32 gradients. Not needed anymore for group in self.fp32_groups_flat: group.grad = None self.start_timers([UPDATE_FP16]) for i in range(len(self.fp16_groups)): updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data.copy_(q.data) self.stop_timers([UPDATE_FP16]) self.log_timers(STEP_TIMERS) return self.overflow
def step(self, closure=None): # First compute norm for all group so we know if there is overflow self.overflow = self.overflow_checker.check() prev_scale = self.loss_scale self._update_scale(self.overflow) if self.overflow: self.zero_grad() if self.verbose: logger.info("[deepspeed] OVERFLOW! Skipping step. Attempted loss " "scale: {}, reducing to {}".format( prev_scale, self.loss_scale)) return self.overflow norm_groups = [] local_sub_partitions_grad_groups = [] partition_id = dist.get_rank(group=self.dp_process_group) for i, group in enumerate(self.fp16_groups): #TODO RS: update get grad norm to support sub partitions norm_groups.append(get_grad_norm(group, mpu=self.mpu)) #RS: update free grads w.r.t. sub partitions #free gradients for all the parameters that are not updated by this process self.free_grad_in_param_list(self.params_not_local[i]) #create flat gradients for parameters updated by this process #tensor_list, first_offset, partition_size, dtype #single_grad_partition = self.get_flat_partition( # tensor_list=self.params_in_partition[i], # first_offset=self.first_offset[i], # partition_size=self.partition_size[i], # dtype=self.single_partition_of_fp32_groups[i].dtype #) #TODO RS: can we safely use dtype of the first sub-partition? i think so local_grad_sub_partitions = self.get_flat_sub_partitions( comm_tensor_list=self.params_in_rank_sub_partitions[i][partition_id], comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i] [partition_id], sub_partition_size=self.sub_partition_sizes[i], dtype=self.local_sub_partitions_of_fp32_groups[i][0].dtype, num_comm_intervals=self.num_comm_intervals_per_group[i], default_device=self.local_sub_partitions_of_fp32_groups[i][0].device) #RS: update all our local params with sub-partition grads #logger. info("self.local_sub_partitions_of_fp32_groups[i]={}, local_grad_sub_partitions={}".format(len(self.local_sub_partitions_of_fp32_groups[i]), len(local_grad_sub_partitions))) for idx, sub_partition_param in enumerate(self.local_sub_partitions_of_fp32_groups[i]): sub_partition_param.grad = local_grad_sub_partitions[idx] #self.single_partition_of_fp32_groups[i].grad = single_grad_partition #RS: update free grads for sub-partitions #release all the gradient since we have already created a necessary copy in dp_grad_partition self.free_grad_in_param_list( self.params_in_rank_sub_partitions[i][partition_id]) local_sub_partitions_grad_groups.append(local_grad_sub_partitions) #RS: update unscale/clip with sub partitions self.unscale_and_clip_grads(local_sub_partitions_grad_groups, norm_groups) self.optimizer.step() #RS: clear our sub partition grads #get rid of the fp32 gradients. Not needed anymore for group in self.local_sub_partitions_of_fp32_groups: for idx, sub_partition_param in enumerate(group): sub_partition_param.grad = None #group.grad = None #NOTE RS: removed norm_groups outer loop from original code, i don't think it's needed #RS: copy all sub-partition fp32 data to fp16 sub partitions # copy fp32 param data to fp16 partitions w.r.t. our local rank for fp16_all_sub_partitions, fp32_local_sub_partitions in zip(self.parallel_sub_partitioned_fp16_groups, self.local_sub_partitions_of_fp32_groups): for local_sub_partition_param_fp16, local_sub_partition_param_fp32 in zip(fp16_all_sub_partitions[partition_id], fp32_local_sub_partitions): local_sub_partition_param_fp16.data.copy_( local_sub_partition_param_fp32.data) #RS: all_gather/broadcast sub-partitions in separate comm calls #gather the updated weights from everyone for fp16_all_sub_partitions in self.parallel_comm_sub_partitioned_fp16_groups: for comm_id, sub_partitions in enumerate(fp16_all_sub_partitions): dist.all_gather(sub_partitions, sub_partitions[partition_id], group=self.dp_process_group) # TODO: we probably don't need this? just to be safe for i in range(len(norm_groups)): updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data return self.overflow
def step(self, closure=None): """ Not supporting closure. """ if self.fused_adam_legacy: return self.step_fused_adam() COMPUTE_NORM = "compute_norm" OVERFLOW_CHECK = 'overflow_check' OVERFLOW_TIMERS = [COMPUTE_NORM, OVERFLOW_CHECK] UNSCALE_AND_CLIP = 'unscale_and_clip' BASIC_STEP = 'basic_step' UPDATE_FP16 = 'update_fp16' STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16] # First determine if there is overflow. self.start_timers([OVERFLOW_CHECK]) fp16_params = [] for i, group in enumerate(self.fp16_groups): fp16_params.extend([p for p in group if p.grad is not None]) self.overflow = self.overflow_checker.has_overflow(fp16_params) self.stop_timers([OVERFLOW_CHECK]) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: log_dist( "Overflow detected. Skipping step. Attempted loss " f"scale: {prev_scale}, reducing to {self.cur_scale}", ranks=[0]) # Clear gradients for i, group in enumerate(self.fp16_groups): for p in group: p.grad = None self.log_timers(OVERFLOW_TIMERS) return self.overflow grads_groups_flat = [] for i, group in enumerate(self.fp16_groups): data_type = self.fp32_groups_flat[i].dtype grads_groups_flat.append( _flatten_dense_tensors([ torch.zeros(p.size(), dtype=data_type, device=p.device) if p.grad is None else p.grad.to(data_type) for p in group ])) for p in group: p.grad = None self.fp32_groups_flat[i].grad = grads_groups_flat[i] self.start_timers([COMPUTE_NORM]) all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu) #all_groups_norm_old = all_groups_norm # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce if self.using_pipeline: pg = self.deepspeed.mpu.get_data_parallel_group() else: pg = groups._get_data_parallel_group() scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg)) scaled_norm_tensor = torch.tensor(scaled_norm, device=self.fp32_groups_flat[i].device, dtype=torch.float) dist.all_reduce(scaled_norm_tensor, group=pg) all_groups_norm = scaled_norm_tensor.item() #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {torch.distributed.get_rank()}") self.stop_timers([COMPUTE_NORM]) self._global_grad_norm = get_global_norm(norm_list=[all_groups_norm]) self.start_timers([UNSCALE_AND_CLIP]) self.unscale_and_clip_grads(grads_groups_flat, self._global_grad_norm) self.stop_timers([UNSCALE_AND_CLIP]) self.start_timers([BASIC_STEP]) self.optimizer.step() self.stop_timers([BASIC_STEP]) #get rid of the fp32 gradients. Not needed anymore for group in self.fp32_groups_flat: group.grad = None self.start_timers([UPDATE_FP16]) for i in range(len(self.fp16_groups)): updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data.copy_(q.data) self.stop_timers([UPDATE_FP16]) self.log_timers(STEP_TIMERS) return self.overflow