예제 #1
0
    def step_fused_lamb(self, closure=None):
        """
        Not supporting closure.
        """
        # First compute norm for all group so we know if there is overflow
        grads_groups_flat = []
        grads_groups = []
        norm_groups = []
        for i, group in enumerate(self.fp16_groups):
            grads_groups.append([p.grad for p in group])
            grads_groups_flat.append(_flatten_dense_tensors(grads_groups[i]))
            norm_groups.append(
                get_weight_norm(grads_groups_flat[i], mpu=self.mpu))

        self.overflow = self.overflow_checker.check_using_norm(norm_groups)
        prev_scale = self.cur_scale

        if self.overflow:
            self._update_scale(self.overflow)
            if self.verbose:
                print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
                      "scale: {}, reducing to {}".format(
                          prev_scale, self.cur_scale))
            return self.overflow

        combined_scale = self.unscale_and_clip_grads(norm_groups,
                                                     apply_scale=False)
        self.optimizer.step(grads=grads_groups,
                            output_params=self.fp16_groups,
                            scale=combined_scale)

        return self.overflow
예제 #2
0
    def step_fused_adam(self, closure=None):
        """
        Not supporting closure.
        """
        # First compute norm for all group so we know if there is overflow
        grads_groups_flat = []
        norm_groups = []
        for i, group in enumerate(self.fp16_groups):
            grads_groups_flat.append(
                _flatten_dense_tensors([
                    torch.zeros(p.size(),
                                dtype=p.dtype,
                                device=p.device) if p.grad is None else p.grad
                    for p in group
                ]))
            norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu))

        self.overflow = self.overflow_checker.check_using_norm(norm_groups)
        prev_scale = self.cur_scale
        self._update_scale(self.overflow)

        if self.overflow:
            if self.verbose:
                logger.info("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
                            "scale: {}, reducing to {}".format(
                                prev_scale,
                                self.cur_scale))
            return self.overflow
        combined_scale = self.unscale_and_clip_grads(grads_groups_flat,
                                                     norm_groups,
                                                     apply_scale=False)
        # norm is in fact norm*cur_scale
        self.optimizer.step(grads=[[g] for g in grads_groups_flat],
                            output_params=[[p] for p in self.fp16_groups_flat],
                            scale=combined_scale,
                            grad_norms=norm_groups)
        # TODO: we probably don't need this? just to be safe
        for i in range(len(norm_groups)):
            updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
                                                      self.fp16_groups[i])
            for p, q in zip(self.fp16_groups[i], updated_params):
                p.data = q.data
        return self.overflow
예제 #3
0
def dump_weight_norms(tag, param_groups, micro_step, global_step):
    norm_groups = []
    for i, group in enumerate(param_groups):
        norm_groups.append(get_weight_norm(group))
    print("\n {} weight_norms: micro_step={}, global_step={}, norms={}".format(
        tag, micro_step, global_step, norm_groups))