def step(self, closure=None): """ Not supporting closure. """ if self.fused_lamb_legacy: return self.step_fused_lamb() self.overflow = self.overflow_checker.check() prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow norm_groups = [] for i, group in enumerate(self.fp16_groups): grads_for_norm, _ = split_params_grads_into_shared_and_expert_params( group) norm_group_value = 0.0 if len(grads_for_norm) > 0: norm_group_value = get_weight_norm(grads_for_norm, mpu=self.mpu) norm_groups.append(norm_group_value) # copying gradients to fp32 to wor k with fp32 parameters for fp32_param, fp16_param in zip(self.fp32_groups[i], self.fp16_groups[i]): if fp16_param.grad is None: fp32_param.grad = torch.zeros(fp16_param.size(), dtype=fp32_param.dtype, device=fp32_param.device) else: fp32_param.grad = fp16_param.grad.to(fp32_param.dtype) self._global_grad_norm = get_global_norm(norm_list=norm_groups) self.unscale_and_clip_grads(self._global_grad_norm) self.optimizer.step() for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups): for idx, (fp32_param, fp16_param) in enumerate(zip(fp32_group, fp16_group)): #remove the fp32 grad fp32_param.grad = None #copy data from fp32 to fp16 fp16_param.data.copy_(fp32_param.data) return self.overflow
def step_fused_adam(self, closure=None): """ Not supporting closure. """ # First compute norm for all group so we know if there is overflow grads_groups_flat = [] norm_groups = [] for i, group in enumerate(self.fp16_groups): grads_groups_flat.append( _flatten_dense_tensors([ torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group ])) norm_groups.append( get_weight_norm(grads_groups_flat[i], mpu=self.mpu)) self.overflow = self.overflow_checker.check_using_norm(norm_groups) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow scaled_grad_norm = get_global_norm(norm_list=norm_groups) combined_scale = self.unscale_and_clip_grads(grads_groups_flat, scaled_grad_norm, apply_scale=False) # Stash unscaled gradient norm self._global_grad_norm = scaled_grad_norm / self.cur_scale # norm is in fact norm*cur_scale self.optimizer.step(grads=[[g] for g in grads_groups_flat], output_params=[[p] for p in self.fp16_groups_flat], scale=combined_scale, grad_norms=norm_groups) # TODO: we probably don't need this? just to be safe for i in range(len(norm_groups)): updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data return self.overflow
def step_fused_lamb(self, closure=None): """ Not supporting closure. """ # First compute norm for all group so we know if there is overflow grads_groups_flat = [] grads_groups = [] norm_groups = [] expert_norm_groups = [] for i, group in enumerate(self.fp16_groups): grads = [ torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group ] grads_groups.append(grads) grads_groups_flat.append(_flatten_dense_tensors(grads)) grads_for_norm, expert_grads_for_norm = split_params_grads_into_shared_and_expert_params( group) norm_group_value = 0.0 if len(grads_for_norm) > 0: norm_group_value = get_weight_norm( _flatten_dense_tensors(grads_for_norm), mpu=self.mpu) norm_groups.append(norm_group_value) expert_norm_group_value = 0.0 if len(expert_grads_for_norm) > 0: expert_norm_group_value = get_weight_norm( _flatten_dense_tensors(expert_grads_for_norm), mpu=self.mpu) expert_norm_groups.append(expert_norm_group_value) self.overflow = self.overflow_checker.check_using_norm( norm_groups + expert_norm_groups) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow self._global_grad_norm = get_global_norm(norm_list=norm_groups) combined_scale = self.unscale_and_clip_grads(self._global_grad_norm, apply_scale=False) self.optimizer.step(grads=grads_groups, output_params=self.fp16_groups, scale=combined_scale) for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups): for idx, (fp32_param, fp16_param) in enumerate(zip(fp32_group, fp16_group)): #remove the fp32 grad fp32_param.grad = None #copy data from fp32 to fp16 fp16_param.data.copy_(fp32_param.data) return self.overflow
def step(self, closure=None): """ Not supporting closure. """ if self.fused_adam_legacy: return self.step_fused_adam() COMPUTE_NORM = "compute_norm" OVERFLOW_CHECK = 'overflow_check' OVERFLOW_TIMERS = [COMPUTE_NORM, OVERFLOW_CHECK] UNSCALE_AND_CLIP = 'unscale_and_clip' BASIC_STEP = 'basic_step' UPDATE_FP16 = 'update_fp16' STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16] # First determine if there is overflow. self.start_timers([OVERFLOW_CHECK]) fp16_params = [] for i, group in enumerate(self.fp16_groups): fp16_params.extend([p for p in group if p.grad is not None]) self.overflow = self.overflow_checker.has_overflow(fp16_params) self.stop_timers([OVERFLOW_CHECK]) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: log_dist( "Overflow detected. Skipping step. Attempted loss " f"scale: {prev_scale}, reducing to {self.cur_scale}", ranks=[0]) # Clear gradients for i, group in enumerate(self.fp16_groups): for p in group: p.grad = None self.log_timers(OVERFLOW_TIMERS) return self.overflow grads_groups_flat = [] for i, group in enumerate(self.fp16_groups): data_type = self.fp32_groups_flat[i].dtype grads_groups_flat.append( _flatten_dense_tensors([ torch.zeros(p.size(), dtype=data_type, device=p.device) if p.grad is None else p.grad.to(data_type) for p in group ])) for p in group: p.grad = None self.fp32_groups_flat[i].grad = grads_groups_flat[i] self.start_timers([COMPUTE_NORM]) all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu) #all_groups_norm_old = all_groups_norm # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce if self.using_pipeline: pg = self.deepspeed.mpu.get_data_parallel_group() else: pg = groups._get_data_parallel_group() scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg)) scaled_norm_tensor = torch.tensor(scaled_norm, device=self.fp32_groups_flat[i].device, dtype=torch.float) dist.all_reduce(scaled_norm_tensor, group=pg) all_groups_norm = scaled_norm_tensor.item() #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {torch.distributed.get_rank()}") self.stop_timers([COMPUTE_NORM]) self._global_grad_norm = get_global_norm(norm_list=[all_groups_norm]) self.start_timers([UNSCALE_AND_CLIP]) self.unscale_and_clip_grads(grads_groups_flat, self._global_grad_norm) self.stop_timers([UNSCALE_AND_CLIP]) self.start_timers([BASIC_STEP]) self.optimizer.step() self.stop_timers([BASIC_STEP]) #get rid of the fp32 gradients. Not needed anymore for group in self.fp32_groups_flat: group.grad = None self.start_timers([UPDATE_FP16]) for i in range(len(self.fp16_groups)): updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data.copy_(q.data) self.stop_timers([UPDATE_FP16]) self.log_timers(STEP_TIMERS) return self.overflow
def step(self, closure=None): """ Not supporting closure. """ if self.fused_adam_legacy: return self.step_fused_adam() COMPUTE_NORM = "compute_norm" OVERFLOW_CHECK = 'overflow_check' OVERFLOW_TIMERS = [COMPUTE_NORM, OVERFLOW_CHECK] UNSCALE_AND_CLIP = 'unscale_and_clip' BASIC_STEP = 'basic_step' UPDATE_FP16 = 'update_fp16' STEP_TIMERS = OVERFLOW_TIMERS + [ UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16 ] # First determine if there is overflow. self.start_timers([OVERFLOW_CHECK]) fp16_params = [] for i, group in enumerate(self.fp16_groups): fp16_params.extend([p for p in group if p.grad is not None]) self.overflow = self.overflow_checker.has_overflow(fp16_params) self.stop_timers([OVERFLOW_CHECK]) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: log_dist( "Overflow detected. Skipping step. Attempted loss " f"scale: {prev_scale}, reducing to {self.cur_scale}", ranks=[0]) # Clear gradients for i, group in enumerate(self.fp16_groups): for p in group: p.grad = None self.log_timers(OVERFLOW_TIMERS) return self.overflow grads_groups_flat = [] for i, group in enumerate(self.fp16_groups): data_type = self.fp32_groups_flat[i].dtype grads_groups_flat.append( _flatten_dense_tensors([ torch.zeros(p.size(), dtype=data_type, device=p.device) if p.grad is None else p.grad.to(data_type) for p in group ])) for p in group: p.grad = None self.fp32_groups_flat[i].grad = grads_groups_flat[i] self.start_timers([COMPUTE_NORM]) all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu) self.stop_timers([COMPUTE_NORM]) if self.has_moe_layers: all_groups_norm = self._get_norm_with_moe_layers(all_groups_norm) scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm]) # Stash unscaled gradient norm self._global_grad_norm = scaled_global_grad_norm / self.cur_scale self.start_timers([UNSCALE_AND_CLIP]) self.unscale_and_clip_grads(grads_groups_flat, scaled_global_grad_norm) self.stop_timers([UNSCALE_AND_CLIP]) self.start_timers([BASIC_STEP]) self.optimizer.step() self.stop_timers([BASIC_STEP]) #get rid of the fp32 gradients. Not needed anymore for group in self.fp32_groups_flat: group.grad = None self.start_timers([UPDATE_FP16]) for i in range(len(self.fp16_groups)): updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data.copy_(q.data) self.stop_timers([UPDATE_FP16]) self.log_timers(STEP_TIMERS) self.step_count += 1 return self.overflow