def step_fused_lamb(self, closure=None): """ Not supporting closure. """ # First compute norm for all group so we know if there is overflow grads_groups_flat = [] grads_groups = [] norm_groups = [] for i, group in enumerate(self.fp16_groups): grads_groups.append([p.grad for p in group]) grads_groups_flat.append(_flatten_dense_tensors(grads_groups[i])) norm_groups.append( get_weight_norm(grads_groups_flat[i], mpu=self.mpu)) self.overflow = self.overflow_checker.check_using_norm(norm_groups) prev_scale = self.cur_scale if self.overflow: self._update_scale(self.overflow) if self.verbose: print("[deepspeed] OVERFLOW! Skipping step. Attempted loss " "scale: {}, reducing to {}".format( prev_scale, self.cur_scale)) return self.overflow combined_scale = self.unscale_and_clip_grads(norm_groups, apply_scale=False) self.optimizer.step(grads=grads_groups, output_params=self.fp16_groups, scale=combined_scale) return self.overflow
def step_fused_adam(self, closure=None): """ Not supporting closure. """ # First compute norm for all group so we know if there is overflow grads_groups_flat = [] norm_groups = [] for i, group in enumerate(self.fp16_groups): grads_groups_flat.append( _flatten_dense_tensors([ torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group ])) norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu)) self.overflow = self.overflow_checker.check_using_norm(norm_groups) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info("[deepspeed] OVERFLOW! Skipping step. Attempted loss " "scale: {}, reducing to {}".format( prev_scale, self.cur_scale)) return self.overflow combined_scale = self.unscale_and_clip_grads(grads_groups_flat, norm_groups, apply_scale=False) # norm is in fact norm*cur_scale self.optimizer.step(grads=[[g] for g in grads_groups_flat], output_params=[[p] for p in self.fp16_groups_flat], scale=combined_scale, grad_norms=norm_groups) # TODO: we probably don't need this? just to be safe for i in range(len(norm_groups)): updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data return self.overflow
def dump_weight_norms(tag, param_groups, micro_step, global_step): norm_groups = [] for i, group in enumerate(param_groups): norm_groups.append(get_weight_norm(group)) print("\n {} weight_norms: micro_step={}, global_step={}, norms={}".format( tag, micro_step, global_step, norm_groups))