def step(self, closure=None):
        """
        Not supporting closure.
        """

        if self.fused_lamb_legacy:
            return self.step_fused_lamb()

        self.overflow = self.overflow_checker.check()
        prev_scale = self.cur_scale

        self._update_scale(self.overflow)
        if self.overflow:
            if self.verbose:
                logger.info(
                    "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
                    "scale: {}, reducing to {}".format(prev_scale,
                                                       self.cur_scale))
            return self.overflow

        norm_groups = []
        for i, group in enumerate(self.fp16_groups):
            grads_for_norm, _ = split_params_grads_into_shared_and_expert_params(
                group)
            norm_group_value = 0.0
            if len(grads_for_norm) > 0:
                norm_group_value = get_weight_norm(grads_for_norm,
                                                   mpu=self.mpu)
            norm_groups.append(norm_group_value)

            # copying gradients to fp32 to wor  k with fp32 parameters
            for fp32_param, fp16_param in zip(self.fp32_groups[i],
                                              self.fp16_groups[i]):
                if fp16_param.grad is None:
                    fp32_param.grad = torch.zeros(fp16_param.size(),
                                                  dtype=fp32_param.dtype,
                                                  device=fp32_param.device)
                else:
                    fp32_param.grad = fp16_param.grad.to(fp32_param.dtype)

        self._global_grad_norm = get_global_norm(norm_list=norm_groups)
        self.unscale_and_clip_grads(self._global_grad_norm)

        self.optimizer.step()

        for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
            for idx, (fp32_param,
                      fp16_param) in enumerate(zip(fp32_group, fp16_group)):

                #remove the fp32 grad
                fp32_param.grad = None

                #copy data from fp32 to fp16
                fp16_param.data.copy_(fp32_param.data)

        return self.overflow
    def step_fused_adam(self, closure=None):
        """
        Not supporting closure.
        """

        # First compute norm for all group so we know if there is overflow
        grads_groups_flat = []
        norm_groups = []
        for i, group in enumerate(self.fp16_groups):
            grads_groups_flat.append(
                _flatten_dense_tensors([
                    torch.zeros(p.size(), dtype=p.dtype, device=p.device)
                    if p.grad is None else p.grad for p in group
                ]))
            norm_groups.append(
                get_weight_norm(grads_groups_flat[i], mpu=self.mpu))

        self.overflow = self.overflow_checker.check_using_norm(norm_groups)
        prev_scale = self.cur_scale
        self._update_scale(self.overflow)

        if self.overflow:
            if self.verbose:
                logger.info(
                    "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
                    "scale: {}, reducing to {}".format(prev_scale,
                                                       self.cur_scale))
            return self.overflow

        scaled_grad_norm = get_global_norm(norm_list=norm_groups)

        combined_scale = self.unscale_and_clip_grads(grads_groups_flat,
                                                     scaled_grad_norm,
                                                     apply_scale=False)

        # Stash unscaled gradient norm
        self._global_grad_norm = scaled_grad_norm / self.cur_scale

        # norm is in fact norm*cur_scale
        self.optimizer.step(grads=[[g] for g in grads_groups_flat],
                            output_params=[[p] for p in self.fp16_groups_flat],
                            scale=combined_scale,
                            grad_norms=norm_groups)
        # TODO: we probably don't need this? just to be safe
        for i in range(len(norm_groups)):
            updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
                                                      self.fp16_groups[i])
            for p, q in zip(self.fp16_groups[i], updated_params):
                p.data = q.data
        return self.overflow
    def step_fused_lamb(self, closure=None):
        """
        Not supporting closure.
        """
        # First compute norm for all group so we know if there is overflow
        grads_groups_flat = []
        grads_groups = []
        norm_groups = []
        expert_norm_groups = []
        for i, group in enumerate(self.fp16_groups):
            grads = [
                torch.zeros(p.size(), dtype=p.dtype, device=p.device)
                if p.grad is None else p.grad for p in group
            ]
            grads_groups.append(grads)
            grads_groups_flat.append(_flatten_dense_tensors(grads))
            grads_for_norm, expert_grads_for_norm = split_params_grads_into_shared_and_expert_params(
                group)
            norm_group_value = 0.0
            if len(grads_for_norm) > 0:
                norm_group_value = get_weight_norm(
                    _flatten_dense_tensors(grads_for_norm), mpu=self.mpu)
            norm_groups.append(norm_group_value)
            expert_norm_group_value = 0.0
            if len(expert_grads_for_norm) > 0:
                expert_norm_group_value = get_weight_norm(
                    _flatten_dense_tensors(expert_grads_for_norm),
                    mpu=self.mpu)
            expert_norm_groups.append(expert_norm_group_value)

        self.overflow = self.overflow_checker.check_using_norm(
            norm_groups + expert_norm_groups)
        prev_scale = self.cur_scale

        self._update_scale(self.overflow)
        if self.overflow:
            if self.verbose:
                logger.info(
                    "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
                    "scale: {}, reducing to {}".format(prev_scale,
                                                       self.cur_scale))
            return self.overflow

        self._global_grad_norm = get_global_norm(norm_list=norm_groups)
        combined_scale = self.unscale_and_clip_grads(self._global_grad_norm,
                                                     apply_scale=False)
        self.optimizer.step(grads=grads_groups,
                            output_params=self.fp16_groups,
                            scale=combined_scale)

        for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
            for idx, (fp32_param,
                      fp16_param) in enumerate(zip(fp32_group, fp16_group)):

                #remove the fp32 grad
                fp32_param.grad = None

                #copy data from fp32 to fp16
                fp16_param.data.copy_(fp32_param.data)

        return self.overflow
Exemple #4
0
    def step(self, closure=None):
        """
        Not supporting closure.
        """

        if self.fused_adam_legacy:
            return self.step_fused_adam()

        COMPUTE_NORM = "compute_norm"
        OVERFLOW_CHECK = 'overflow_check'
        OVERFLOW_TIMERS = [COMPUTE_NORM, OVERFLOW_CHECK]
        UNSCALE_AND_CLIP = 'unscale_and_clip'
        BASIC_STEP = 'basic_step'
        UPDATE_FP16 = 'update_fp16'
        STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16]

        # First determine if there is overflow.
        self.start_timers([OVERFLOW_CHECK])
        fp16_params = []
        for i, group in enumerate(self.fp16_groups):
            fp16_params.extend([p for p in group if p.grad is not None])
        self.overflow = self.overflow_checker.has_overflow(fp16_params)
        self.stop_timers([OVERFLOW_CHECK])
        prev_scale = self.cur_scale
        self._update_scale(self.overflow)
        if self.overflow:
            if self.verbose:
                log_dist(
                    "Overflow detected. Skipping step. Attempted loss "
                    f"scale: {prev_scale}, reducing to {self.cur_scale}",
                    ranks=[0])
            # Clear gradients
            for i, group in enumerate(self.fp16_groups):
                for p in group:
                    p.grad = None

            self.log_timers(OVERFLOW_TIMERS)
            return self.overflow

        grads_groups_flat = []
        for i, group in enumerate(self.fp16_groups):
            data_type = self.fp32_groups_flat[i].dtype

            grads_groups_flat.append(
                _flatten_dense_tensors([
                    torch.zeros(p.size(),
                                dtype=data_type,
                                device=p.device)
                    if p.grad is None else p.grad.to(data_type) for p in group
                ]))

            for p in group:
                p.grad = None

            self.fp32_groups_flat[i].grad = grads_groups_flat[i]

        self.start_timers([COMPUTE_NORM])

        all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu)
        #all_groups_norm_old = all_groups_norm
        # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce
        if self.using_pipeline:
            pg = self.deepspeed.mpu.get_data_parallel_group()
        else:
            pg = groups._get_data_parallel_group()
        scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg))
        scaled_norm_tensor = torch.tensor(scaled_norm,
                                          device=self.fp32_groups_flat[i].device,
                                          dtype=torch.float)
        dist.all_reduce(scaled_norm_tensor, group=pg)
        all_groups_norm = scaled_norm_tensor.item()
        #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {torch.distributed.get_rank()}")

        self.stop_timers([COMPUTE_NORM])

        self._global_grad_norm = get_global_norm(norm_list=[all_groups_norm])

        self.start_timers([UNSCALE_AND_CLIP])
        self.unscale_and_clip_grads(grads_groups_flat, self._global_grad_norm)
        self.stop_timers([UNSCALE_AND_CLIP])

        self.start_timers([BASIC_STEP])
        self.optimizer.step()
        self.stop_timers([BASIC_STEP])

        #get rid of the fp32 gradients. Not needed anymore
        for group in self.fp32_groups_flat:
            group.grad = None

        self.start_timers([UPDATE_FP16])

        for i in range(len(self.fp16_groups)):
            updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i],
                                                      self.fp16_groups[i])
            for p, q in zip(self.fp16_groups[i], updated_params):
                p.data.copy_(q.data)

        self.stop_timers([UPDATE_FP16])

        self.log_timers(STEP_TIMERS)

        return self.overflow
    def step(self, closure=None):
        """
        Not supporting closure.
        """

        if self.fused_adam_legacy:
            return self.step_fused_adam()

        COMPUTE_NORM = "compute_norm"
        OVERFLOW_CHECK = 'overflow_check'
        OVERFLOW_TIMERS = [COMPUTE_NORM, OVERFLOW_CHECK]
        UNSCALE_AND_CLIP = 'unscale_and_clip'
        BASIC_STEP = 'basic_step'
        UPDATE_FP16 = 'update_fp16'
        STEP_TIMERS = OVERFLOW_TIMERS + [
            UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16
        ]

        # First determine if there is overflow.
        self.start_timers([OVERFLOW_CHECK])
        fp16_params = []
        for i, group in enumerate(self.fp16_groups):
            fp16_params.extend([p for p in group if p.grad is not None])
        self.overflow = self.overflow_checker.has_overflow(fp16_params)
        self.stop_timers([OVERFLOW_CHECK])
        prev_scale = self.cur_scale
        self._update_scale(self.overflow)
        if self.overflow:
            if self.verbose:
                log_dist(
                    "Overflow detected. Skipping step. Attempted loss "
                    f"scale: {prev_scale}, reducing to {self.cur_scale}",
                    ranks=[0])
            # Clear gradients
            for i, group in enumerate(self.fp16_groups):
                for p in group:
                    p.grad = None

            self.log_timers(OVERFLOW_TIMERS)
            return self.overflow

        grads_groups_flat = []
        for i, group in enumerate(self.fp16_groups):
            data_type = self.fp32_groups_flat[i].dtype

            grads_groups_flat.append(
                _flatten_dense_tensors([
                    torch.zeros(p.size(), dtype=data_type, device=p.device)
                    if p.grad is None else p.grad.to(data_type) for p in group
                ]))

            for p in group:
                p.grad = None

            self.fp32_groups_flat[i].grad = grads_groups_flat[i]

        self.start_timers([COMPUTE_NORM])

        all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu)

        self.stop_timers([COMPUTE_NORM])

        if self.has_moe_layers:
            all_groups_norm = self._get_norm_with_moe_layers(all_groups_norm)

        scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm])

        # Stash unscaled gradient norm
        self._global_grad_norm = scaled_global_grad_norm / self.cur_scale

        self.start_timers([UNSCALE_AND_CLIP])
        self.unscale_and_clip_grads(grads_groups_flat, scaled_global_grad_norm)
        self.stop_timers([UNSCALE_AND_CLIP])

        self.start_timers([BASIC_STEP])
        self.optimizer.step()
        self.stop_timers([BASIC_STEP])

        #get rid of the fp32 gradients. Not needed anymore
        for group in self.fp32_groups_flat:
            group.grad = None

        self.start_timers([UPDATE_FP16])

        for i in range(len(self.fp16_groups)):
            updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i],
                                                      self.fp16_groups[i])
            for p, q in zip(self.fp16_groups[i], updated_params):
                p.data.copy_(q.data)

        self.stop_timers([UPDATE_FP16])

        self.log_timers(STEP_TIMERS)

        self.step_count += 1

        return self.overflow