def dump_gradient_norms(tag, param_groups, micro_step, global_step):
    norm_groups = []
    for i, group in enumerate(param_groups):
        norm_groups.append(get_grad_norm(group))
    print(
        "\n {} gradient_norms: micro_step={}, global_step={}, norms={}".format(
            tag, micro_step, global_step, norm_groups))
Exemple #2
0
    def step(self, closure=None):
        """
        Not supporting closure.
        """

        if self.fused_adam_legacy:
            return self.step_fused_adam()

        # First compute norm for all group so we know if there is overflow
        grads_groups_flat = []
        norm_groups = []

        for i, group in enumerate(self.fp16_groups):
            data_type = self.fp32_groups_flat[i].dtype

            grads_groups_flat.append(
                _flatten_dense_tensors([
                    torch.zeros(p.size(), dtype=data_type, device=p.device)
                    if p.grad is None else p.grad.to(data_type) for p in group
                ]))

            self.fp32_groups_flat[i].grad = grads_groups_flat[i]

            norm_groups.append(
                get_grad_norm(self.fp32_groups_flat, mpu=self.mpu))

        self.overflow = self.overflow_checker.check_using_norm(norm_groups)
        prev_scale = self.cur_scale
        self._update_scale(self.overflow)

        if self.overflow:
            if self.verbose:
                logger.info(
                    "[deepspeed] OVERFLOW! Skipping step. Attempted loss "
                    "scale: {}, reducing to {}".format(prev_scale,
                                                       self.cur_scale))
            return self.overflow

        self.unscale_and_clip_grads(grads_groups_flat, norm_groups)

        self.optimizer.step()

        #get rid of the fp32 gradients. Not needed anymore
        for group in self.fp32_groups_flat:
            group.grad = None

        for i in range(len(norm_groups)):
            updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i],
                                                      self.fp16_groups[i])
            for p, q in zip(self.fp16_groups[i], updated_params):
                p.data.copy_(q.data)

        return self.overflow
    def step(self, closure=None):
        """
        Not supporting closure.
        """
        if self.fused_lamb_legacy:
            return self.step_fused_lamb()

        self.overflow = self.overflow_checker.check()
        prev_scale = self.cur_scale

        self._update_scale(self.overflow)
        if self.overflow:
            if self.verbose:
                logger.info(
                    "[deepspeed] OVERFLOW! Skipping step. Attempted loss "
                    "scale: {}, reducing to {}".format(prev_scale,
                                                       self.cur_scale))
            return self.overflow

        norm_groups = []
        for i, group in enumerate(self.fp16_groups):
            norm_groups.append(get_grad_norm(group, mpu=self.mpu))

            # copying gradients to fp32 to work with fp32 parameters
            for fp32_param, fp16_param in zip(self.fp32_groups[i],
                                              self.fp16_groups[i]):
                if fp16_param.grad is None:
                    fp32_param.grad = torch.zeros(fp16_param.size(),
                                                  dtype=fp32_param.dtype,
                                                  device=fp32_param.device)
                else:
                    fp32_param.grad = fp16_param.grad.to(fp32_param.dtype)

        self.unscale_and_clip_grads(norm_groups)

        self.optimizer.step()

        for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
            for fp32_param, fp16_param in zip(fp32_group, fp16_group):

                #remove the fp32 grad
                fp32_param.grad = None

                #copy data from fp32 to fp16
                fp16_param.data.copy_(fp32_param.data)

        return self.overflow
    def step(self, closure=None):
        # First compute norm for all group so we know if there is overflow

        self.overflow = self.overflow_checker.check()

        prev_scale = self.loss_scale
        self._update_scale(self.overflow)
        if self.overflow:
            self.zero_grad()
            if self.verbose:
                print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
                      "scale: {}, reducing to {}".format(
                          prev_scale, self.loss_scale))
            return self.overflow

        norm_groups = []
        local_sub_partitions_grad_groups = []

        partition_id = dist.get_rank(group=self.dp_process_group)
        for i, group in enumerate(self.fp16_groups):

            #TODO RS: update get grad norm to support sub partitions
            norm_groups.append(get_grad_norm(group, mpu=self.mpu))

            #RS: update free grads w.r.t. sub partitions
            #free gradients for all the parameters that are not updated by this process
            self.free_grad_in_param_list(self.params_not_local[i])

            #create flat gradients for parameters updated by this process
            #tensor_list, first_offset, partition_size, dtype
            #single_grad_partition = self.get_flat_partition(
            #    tensor_list=self.params_in_partition[i],
            #    first_offset=self.first_offset[i],
            #    partition_size=self.partition_size[i],
            #    dtype=self.single_partition_of_fp32_groups[i].dtype
            #)

            #TODO RS: can we safely use dtype of the first sub-partition? i think so
            local_grad_sub_partitions = self.get_flat_sub_partitions(
                comm_tensor_list=self.params_in_rank_sub_partitions[i]
                [partition_id],
                comm_param_offsets=self.
                params_in_rank_sub_partitions_offsets[i][partition_id],
                sub_partition_size=self.sub_partition_sizes[i],
                dtype=self.local_sub_partitions_of_fp32_groups[i][0].dtype,
                num_comm_intervals=self.num_comm_intervals_per_group[i],
                default_device=self.local_sub_partitions_of_fp32_groups[i]
                [0].device)

            #RS: update all our local params with sub-partition grads
            #print("self.local_sub_partitions_of_fp32_groups[i]={}, local_grad_sub_partitions={}".format(len(self.local_sub_partitions_of_fp32_groups[i]), len(local_grad_sub_partitions)))
            for idx, sub_partition_param in enumerate(
                    self.local_sub_partitions_of_fp32_groups[i]):
                sub_partition_param.grad = local_grad_sub_partitions[idx]
            #self.single_partition_of_fp32_groups[i].grad = single_grad_partition

            #RS: update free grads for sub-partitions
            #release all the gradient since we have already created a necessary copy in dp_grad_partition
            self.free_grad_in_param_list(
                self.params_in_rank_sub_partitions[i][partition_id])

            local_sub_partitions_grad_groups.append(local_grad_sub_partitions)

        #RS: update unscale/clip with sub partitions
        self.unscale_and_clip_grads(local_sub_partitions_grad_groups,
                                    norm_groups)

        self.optimizer.step()

        #RS: clear our sub partition grads
        #get rid of the fp32 gradients. Not needed anymore
        for group in self.local_sub_partitions_of_fp32_groups:
            for idx, sub_partition_param in enumerate(group):
                sub_partition_param.grad = None
            #group.grad = None

        #NOTE RS: removed norm_groups outer loop from original code, i don't think it's needed
        #RS: copy all sub-partition fp32 data to fp16 sub partitions
        # copy fp32 param data to fp16 partitions w.r.t. our local rank
        for fp16_all_sub_partitions, fp32_local_sub_partitions in zip(
                self.parallel_sub_partitioned_fp16_groups,
                self.local_sub_partitions_of_fp32_groups):
            for local_sub_partition_param_fp16, local_sub_partition_param_fp32 in zip(
                    fp16_all_sub_partitions[partition_id],
                    fp32_local_sub_partitions):
                local_sub_partition_param_fp16.data.copy_(
                    local_sub_partition_param_fp32.data)

        #RS: all_gather/broadcast sub-partitions in separate comm calls
        #gather the updated weights from everyone
        for fp16_all_sub_partitions in self.parallel_comm_sub_partitioned_fp16_groups:
            for comm_id, sub_partitions in enumerate(fp16_all_sub_partitions):
                dist.all_gather(sub_partitions,
                                sub_partitions[partition_id],
                                group=self.dp_process_group)

        # TODO: we probably don't need this? just to be safe
        for i in range(len(norm_groups)):
            updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
                                                      self.fp16_groups[i])
            for p, q in zip(self.fp16_groups[i], updated_params):
                p.data = q.data

        return self.overflow
Exemple #5
0
    def step(self, closure=None):
        """
        Not supporting closure.
        """
        # First compute norm for all group so we know if there is overflow

        self.overflow = self.overflow_checker.check()

        prev_scale = self.loss_scale
        self._update_scale(self.overflow)
        if self.overflow:
            self.zero_grad()
            if self.verbose:
                print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
                      "scale: {}, reducing to {}".format(
                          prev_scale, self.loss_scale))
            return self.overflow

        norm_groups = []
        single_partition_grad_groups = []

        partition_id = dist.get_rank(group=self.dp_process_group)
        for i, group in enumerate(self.fp16_groups):

            norm_groups.append(get_grad_norm(group, mpu=self.mpu))

            #free gradients for all the parameters that are not updated by this process
            self.free_grad_in_param_list(self.params_not_in_partition[i])

            #create a flat gradients for parameters updated by this process
            single_grad_partition = self.get_flat_partition(
                self.params_in_partition[i],
                self.first_offset[i],
                self.partition_size[i],
                dtype=self.single_partition_of_fp32_groups[i].dtype)

            self.single_partition_of_fp32_groups[
                i].grad = single_grad_partition

            #release all the gradient since we have already created a necessary copy in dp_grad_partition
            self.free_grad_in_param_list(self.params_in_partition[i])

            single_partition_grad_groups.append(single_grad_partition)

        self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)

        self.optimizer.step()

        #get rid of the fp32 gradients. Not needed anymore
        for group in self.single_partition_of_fp32_groups:
            group.grad = None

        for fp16_partitions, fp32_partition in zip(
                self.parallel_partitioned_fp16_groups,
                self.single_partition_of_fp32_groups):
            fp16_partitions[partition_id].data.copy_(fp32_partition.data)

        dp_world_size = dist.get_world_size(group=self.dp_process_group)
        #gather the updated weights from everyone
        for _, partitioned_params in enumerate(
                self.parallel_partitioned_fp16_groups):
            if self.all_gather_partitions:
                # controllable memory-time tradeoff
                num_shards = max(
                    1, partitioned_params[partition_id].numel() *
                    dp_world_size // self.allgather_size)
                shard_size = partitioned_params[partition_id].numel(
                ) // num_shards
                num_elements = shard_size
                for shard_id in range(num_shards + 1):
                    if shard_id == num_shards:
                        if shard_size * num_shards >= partitioned_params[
                                partition_id].numel():
                            break
                        else:
                            num_elements = partitioned_params[
                                partition_id].numel() - shard_id * shard_size
                    shard_list = []
                    for dp_id in range(dp_world_size):
                        curr_shard = partitioned_params[dp_id].narrow(
                            0, shard_id * shard_size, num_elements)
                        shard_list.append(curr_shard)
                    dist.all_gather(shard_list,
                                    shard_list[partition_id],
                                    group=self.dp_process_group)
            else:
                #this should require less memory but should be faster
                for src, partitioned_param in enumerate(partitioned_params):
                    global_src = _get_global_rank(self.dp_process_group, src)
                    dist.broadcast(partitioned_param,
                                   global_src,
                                   group=self.dp_process_group)

        # TODO: we probably don't need this? just to be safe
        for i in range(len(norm_groups)):
            updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
                                                      self.fp16_groups[i])
            for p, q in zip(self.fp16_groups[i], updated_params):
                p.data = q.data

        return self.overflow
Exemple #6
0
    def step(self, closure=None):
        """
        Not supporting closure.
        """

        if self.fused_adam_legacy:
            return self.step_fused_adam()

        COMPUTE_NORM = "compute_norm"
        OVERFLOW_CHECK = 'overflow_check'
        OVERFLOW_TIMERS = [COMPUTE_NORM, OVERFLOW_CHECK]
        UNSCALE_AND_CLIP = 'unscale_and_clip'
        BASIC_STEP = 'basic_step'
        UPDATE_FP16 = 'update_fp16'
        STEP_TIMERS = OVERFLOW_TIMERS + [
            UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16
        ]

        # First compute norm for all group so we know if there is overflow
        grads_groups_flat = []

        for i, group in enumerate(self.fp16_groups):
            data_type = self.fp32_groups_flat[i].dtype

            grads_groups_flat.append(
                _flatten_dense_tensors([
                    torch.zeros(p.size(), dtype=data_type, device=p.device)
                    if p.grad is None else p.grad.to(data_type) for p in group
                ]))

            self.fp32_groups_flat[i].grad = grads_groups_flat[i]

        self.start_timers([COMPUTE_NORM])
        all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu)
        self.stop_timers([COMPUTE_NORM])

        self.start_timers([OVERFLOW_CHECK])
        self.overflow = self.overflow_checker.check_using_norm(
            [all_groups_norm])
        self.stop_timers([OVERFLOW_CHECK])

        prev_scale = self.cur_scale
        self._update_scale(self.overflow)

        if self.overflow:
            if self.verbose:
                print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
                      "scale: {}, reducing to {}".format(
                          prev_scale, self.cur_scale))
            self.log_timers(OVERFLOW_TIMERS)
            return self.overflow

        self.start_timers([UNSCALE_AND_CLIP])
        self.unscale_and_clip_grads(grads_groups_flat, [all_groups_norm])
        self.stop_timers([UNSCALE_AND_CLIP])

        self.start_timers([BASIC_STEP])
        self.optimizer.step()
        self.stop_timers([BASIC_STEP])

        #get rid of the fp32 gradients. Not needed anymore
        for group in self.fp32_groups_flat:
            group.grad = None

        self.start_timers([UPDATE_FP16])
        for i in range(len(self.fp16_groups)):
            updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i],
                                                      self.fp16_groups[i])
            for p, q in zip(self.fp16_groups[i], updated_params):
                p.data.copy_(q.data)
        self.stop_timers([UPDATE_FP16])

        self.log_timers(STEP_TIMERS)

        return self.overflow