예제 #1
0
    def step_fused_lamb(self, closure=None):
        """
        Not supporting closure.
        """
        # First compute norm for all group so we know if there is overflow
        grads_groups_flat = []
        grads_groups = []
        norm_groups = []
        expert_norm_groups = []
        for i, group in enumerate(self.fp16_groups):
            grads = [
                torch.zeros(p.size(), dtype=p.dtype, device=p.device)
                if p.grad is None else p.grad for p in group
            ]
            grads_groups.append(grads)
            grads_groups_flat.append(_flatten_dense_tensors(grads))
            grads_for_norm, expert_grads_for_norm = split_params_grads_into_shared_and_expert_params(
                group)
            norm_group_value = 0.0
            if len(grads_for_norm) > 0:
                norm_group_value = get_weight_norm(
                    _flatten_dense_tensors(grads_for_norm), mpu=self.mpu)
            norm_groups.append(norm_group_value)
            expert_norm_group_value = 0.0
            if len(expert_grads_for_norm) > 0:
                expert_norm_group_value = get_weight_norm(
                    _flatten_dense_tensors(expert_grads_for_norm),
                    mpu=self.mpu)
            expert_norm_groups.append(expert_norm_group_value)

        self.overflow = self.overflow_checker.check_using_norm(
            norm_groups + expert_norm_groups)
        prev_scale = self.cur_scale

        self._update_scale(self.overflow)
        if self.overflow:
            if self.verbose:
                logger.info(
                    "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
                    "scale: {}, reducing to {}".format(prev_scale,
                                                       self.cur_scale))
            return self.overflow

        combined_scale = self.unscale_and_clip_grads(norm_groups,
                                                     apply_scale=False)
        self.optimizer.step(grads=grads_groups,
                            output_params=self.fp16_groups,
                            scale=combined_scale)

        for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
            for idx, (fp32_param,
                      fp16_param) in enumerate(zip(fp32_group, fp16_group)):

                #remove the fp32 grad
                fp32_param.grad = None

                #copy data from fp32 to fp16
                fp16_param.data.copy_(fp32_param.data)

        return self.overflow
예제 #2
0
    def test(self, check_using_norm):
        groups._create_expert_and_data_parallel(2)

        param1 = torch.nn.Parameter(torch.Tensor([0]))
        param1.grad = torch.Tensor([1])
        param2 = torch.nn.Parameter(torch.Tensor([0]))
        if dist.get_rank() == 0:
            param2.grad = torch.Tensor([1])
        else:
            param2.grad = torch.Tensor([float("inf")])
        param2.allreduce = False
        # param2 is now MoE parameter
        parameters = [param1, param2]
        if check_using_norm:
            grads_group_flat = [
                _flatten_dense_tensors([p.grad for p in parameters])
            ]
            norm = ds_utils.get_weight_norm(grads_group_flat)
            overflow_checker = ds_utils.CheckOverflow([parameters])
            overflow = overflow_checker.check_using_norm([norm],
                                                         reduce_overflow=False)
        else:
            overflow_checker = ds_utils.CheckOverflow([parameters])
            overflow = overflow_checker.check()
        assert overflow
예제 #3
0
    def step(self, closure=None):
        """
        Not supporting closure.
        """

        if self.fused_lamb_legacy:
            return self.step_fused_lamb()

        self.overflow = self.overflow_checker.check()
        prev_scale = self.cur_scale

        self._update_scale(self.overflow)
        if self.overflow:
            if self.verbose:
                logger.info(
                    "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
                    "scale: {}, reducing to {}".format(prev_scale,
                                                       self.cur_scale))
            return self.overflow

        norm_groups = []
        for i, group in enumerate(self.fp16_groups):
            grads_for_norm, _ = split_params_grads_into_shared_and_expert_params(
                group)
            norm_group_value = 0.0
            if len(grads_for_norm) > 0:
                norm_group_value = get_weight_norm(grads_for_norm,
                                                   mpu=self.mpu)
            norm_groups.append(norm_group_value)

            # copying gradients to fp32 to wor  k with fp32 parameters
            for fp32_param, fp16_param in zip(self.fp32_groups[i],
                                              self.fp16_groups[i]):
                if fp16_param.grad is None:
                    fp32_param.grad = torch.zeros(fp16_param.size(),
                                                  dtype=fp32_param.dtype,
                                                  device=fp32_param.device)
                else:
                    fp32_param.grad = fp16_param.grad.to(fp32_param.dtype)

        self._global_grad_norm = get_global_norm(norm_list=norm_groups)
        self.unscale_and_clip_grads(self._global_grad_norm)

        self.optimizer.step()

        for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
            for idx, (fp32_param,
                      fp16_param) in enumerate(zip(fp32_group, fp16_group)):

                #remove the fp32 grad
                fp32_param.grad = None

                #copy data from fp32 to fp16
                fp16_param.data.copy_(fp32_param.data)

        return self.overflow
예제 #4
0
    def step_fused_adam(self, closure=None):
        """
        Not supporting closure.
        """

        # First compute norm for all group so we know if there is overflow
        grads_groups_flat = []
        norm_groups = []
        for i, group in enumerate(self.fp16_groups):
            grads_groups_flat.append(
                _flatten_dense_tensors([
                    torch.zeros(p.size(), dtype=p.dtype, device=p.device)
                    if p.grad is None else p.grad for p in group
                ]))
            norm_groups.append(
                get_weight_norm(grads_groups_flat[i], mpu=self.mpu))

        self.overflow = self.overflow_checker.check_using_norm(norm_groups)
        prev_scale = self.cur_scale
        self._update_scale(self.overflow)

        if self.overflow:
            if self.verbose:
                logger.info(
                    "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
                    "scale: {}, reducing to {}".format(prev_scale,
                                                       self.cur_scale))
            return self.overflow

        scaled_grad_norm = get_global_norm(norm_list=norm_groups)

        combined_scale = self.unscale_and_clip_grads(grads_groups_flat,
                                                     scaled_grad_norm,
                                                     apply_scale=False)

        # Stash unscaled gradient norm
        self._global_grad_norm = scaled_grad_norm / self.cur_scale

        # norm is in fact norm*cur_scale
        self.optimizer.step(grads=[[g] for g in grads_groups_flat],
                            output_params=[[p] for p in self.fp16_groups_flat],
                            scale=combined_scale,
                            grad_norms=norm_groups)
        # TODO: we probably don't need this? just to be safe
        for i in range(len(norm_groups)):
            updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
                                                      self.fp16_groups[i])
            for p, q in zip(self.fp16_groups[i], updated_params):
                p.data = q.data
        return self.overflow
예제 #5
0
    def step_fused_lamb(self, closure=None):
        """
        Not supporting closure.
        """
        # First compute norm for all group so we know if there is overflow
        grads_groups_flat = []
        grads_groups = []
        norm_groups = []
        for i, group in enumerate(self.fp16_groups):
            grads = [
                torch.zeros(p.size(), dtype=p.dtype, device=p.device)
                if p.grad is None else p.grad for p in group
            ]
            grads_groups.append(grads)
            grads_groups_flat.append(_flatten_dense_tensors(grads))
            norm_groups.append(
                get_weight_norm(grads_groups_flat[i], mpu=self.mpu))

        self.overflow = self.overflow_checker.check_using_norm(norm_groups)
        prev_scale = self.cur_scale

        self._update_scale(self.overflow)
        if self.overflow:
            if self.verbose:
                logger.info(
                    "[deepspeed] OVERFLOW! Skipping step. Attempted loss "
                    "scale: {}, reducing to {}".format(prev_scale,
                                                       self.cur_scale))
            return self.overflow

        combined_scale = self.unscale_and_clip_grads(norm_groups,
                                                     apply_scale=False)
        self.optimizer.step(grads=grads_groups,
                            output_params=self.fp16_groups,
                            scale=combined_scale)

        return self.overflow