def step_fused_lamb(self, closure=None): """ Not supporting closure. """ # First compute norm for all group so we know if there is overflow grads_groups_flat = [] grads_groups = [] norm_groups = [] expert_norm_groups = [] for i, group in enumerate(self.fp16_groups): grads = [ torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group ] grads_groups.append(grads) grads_groups_flat.append(_flatten_dense_tensors(grads)) grads_for_norm, expert_grads_for_norm = split_params_grads_into_shared_and_expert_params( group) norm_group_value = 0.0 if len(grads_for_norm) > 0: norm_group_value = get_weight_norm( _flatten_dense_tensors(grads_for_norm), mpu=self.mpu) norm_groups.append(norm_group_value) expert_norm_group_value = 0.0 if len(expert_grads_for_norm) > 0: expert_norm_group_value = get_weight_norm( _flatten_dense_tensors(expert_grads_for_norm), mpu=self.mpu) expert_norm_groups.append(expert_norm_group_value) self.overflow = self.overflow_checker.check_using_norm( norm_groups + expert_norm_groups) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow combined_scale = self.unscale_and_clip_grads(norm_groups, apply_scale=False) self.optimizer.step(grads=grads_groups, output_params=self.fp16_groups, scale=combined_scale) for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups): for idx, (fp32_param, fp16_param) in enumerate(zip(fp32_group, fp16_group)): #remove the fp32 grad fp32_param.grad = None #copy data from fp32 to fp16 fp16_param.data.copy_(fp32_param.data) return self.overflow
def test(self, check_using_norm): groups._create_expert_and_data_parallel(2) param1 = torch.nn.Parameter(torch.Tensor([0])) param1.grad = torch.Tensor([1]) param2 = torch.nn.Parameter(torch.Tensor([0])) if dist.get_rank() == 0: param2.grad = torch.Tensor([1]) else: param2.grad = torch.Tensor([float("inf")]) param2.allreduce = False # param2 is now MoE parameter parameters = [param1, param2] if check_using_norm: grads_group_flat = [ _flatten_dense_tensors([p.grad for p in parameters]) ] norm = ds_utils.get_weight_norm(grads_group_flat) overflow_checker = ds_utils.CheckOverflow([parameters]) overflow = overflow_checker.check_using_norm([norm], reduce_overflow=False) else: overflow_checker = ds_utils.CheckOverflow([parameters]) overflow = overflow_checker.check() assert overflow
def step(self, closure=None): """ Not supporting closure. """ if self.fused_lamb_legacy: return self.step_fused_lamb() self.overflow = self.overflow_checker.check() prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow norm_groups = [] for i, group in enumerate(self.fp16_groups): grads_for_norm, _ = split_params_grads_into_shared_and_expert_params( group) norm_group_value = 0.0 if len(grads_for_norm) > 0: norm_group_value = get_weight_norm(grads_for_norm, mpu=self.mpu) norm_groups.append(norm_group_value) # copying gradients to fp32 to wor k with fp32 parameters for fp32_param, fp16_param in zip(self.fp32_groups[i], self.fp16_groups[i]): if fp16_param.grad is None: fp32_param.grad = torch.zeros(fp16_param.size(), dtype=fp32_param.dtype, device=fp32_param.device) else: fp32_param.grad = fp16_param.grad.to(fp32_param.dtype) self._global_grad_norm = get_global_norm(norm_list=norm_groups) self.unscale_and_clip_grads(self._global_grad_norm) self.optimizer.step() for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups): for idx, (fp32_param, fp16_param) in enumerate(zip(fp32_group, fp16_group)): #remove the fp32 grad fp32_param.grad = None #copy data from fp32 to fp16 fp16_param.data.copy_(fp32_param.data) return self.overflow
def step_fused_adam(self, closure=None): """ Not supporting closure. """ # First compute norm for all group so we know if there is overflow grads_groups_flat = [] norm_groups = [] for i, group in enumerate(self.fp16_groups): grads_groups_flat.append( _flatten_dense_tensors([ torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group ])) norm_groups.append( get_weight_norm(grads_groups_flat[i], mpu=self.mpu)) self.overflow = self.overflow_checker.check_using_norm(norm_groups) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow scaled_grad_norm = get_global_norm(norm_list=norm_groups) combined_scale = self.unscale_and_clip_grads(grads_groups_flat, scaled_grad_norm, apply_scale=False) # Stash unscaled gradient norm self._global_grad_norm = scaled_grad_norm / self.cur_scale # norm is in fact norm*cur_scale self.optimizer.step(grads=[[g] for g in grads_groups_flat], output_params=[[p] for p in self.fp16_groups_flat], scale=combined_scale, grad_norms=norm_groups) # TODO: we probably don't need this? just to be safe for i in range(len(norm_groups)): updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data return self.overflow
def step_fused_lamb(self, closure=None): """ Not supporting closure. """ # First compute norm for all group so we know if there is overflow grads_groups_flat = [] grads_groups = [] norm_groups = [] for i, group in enumerate(self.fp16_groups): grads = [ torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group ] grads_groups.append(grads) grads_groups_flat.append(_flatten_dense_tensors(grads)) norm_groups.append( get_weight_norm(grads_groups_flat[i], mpu=self.mpu)) self.overflow = self.overflow_checker.check_using_norm(norm_groups) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] OVERFLOW! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow combined_scale = self.unscale_and_clip_grads(norm_groups, apply_scale=False) self.optimizer.step(grads=grads_groups, output_params=self.fp16_groups, scale=combined_scale) return self.overflow