Exemple #1
0
    def __init__(self,
                 params,
                 deepspeed=None,
                 lr=1e-3,
                 freeze_step=100000,
                 bias_correction=True,
                 betas=(0.9, 0.999),
                 eps=1e-8,
                 eps_inside_sqrt=False,
                 weight_decay=0.,
                 max_grad_norm=0.,
                 amsgrad=False,
                 cuda_aware=False,
                 comm_backend_name='nccl'):

        if amsgrad:
            raise RuntimeError(
                '1-bit Adam does not support the AMSGrad variant.')

        defaults = dict(lr=lr,
                        bias_correction=bias_correction,
                        betas=betas,
                        eps=eps,
                        weight_decay=weight_decay,
                        max_grad_norm=max_grad_norm)

        super(OnebitAdam, self).__init__(params, defaults)
        self.eps_mode = 0 if eps_inside_sqrt else 1
        assert (dist.is_initialized())

        self.comm_time = 0.0
        self.step_time = 0.0
        self.ave_step = 1
        self.bk_time = 0.0

        self.deepspeed = deepspeed
        self.adam_freeze_key = False
        self.initialize = False
        self.freeze_step = freeze_step
        self.cuda_aware = cuda_aware
        self.using_pipeline = False

        self.comm_backend_name = comm_backend_name

        # Empty initializer. Set handle based on the comm backend as follows.
        self.comm_backend_handle = None

        if self.comm_backend_name == 'nccl':
            TORCH_MAJOR = int(torch.__version__.split('.')[0])
            TORCH_MINOR = int(torch.__version__.split('.')[1])
            assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
            assert dist.is_initialized(
            ) == True, "Please initialize the torch distributed backend."
            from deepspeed.runtime.comm.nccl import NcclBackend
            self.using_pipeline = hasattr(
                self.deepspeed, 'pipeline_enable_backward_allreduce')
            self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)

        elif self.comm_backend_name == 'mpi':
            from deepspeed.runtime.comm.mpi import MpiBackend
            self.comm_backend_handle = MpiBackend(cuda_aware)

        self.size = self.comm_backend_handle.size

        self.divider = int(self.size * 8 / np.gcd(self.size, 8))
Exemple #2
0
class OnebitAdam(torch.optim.Optimizer):
    """Implements the 1-bit Adam algorithm. Currently GPU-only.
    For usage example please see https://www.deepspeed.ai/tutorials/onebit-adam/
    For technical details please read https://arxiv.org/abs/2102.02888

    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups.
        lr (float, optional): learning rate. (default: 1e-3)
        freeze_step (int, optional): Number of steps for warmup (uncompressed)
            stage before we start using compressed communication. (default 100000)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square. (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability. (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
            algorithm from the paper `On the Convergence of Adam and Beyond`_
            (default: False) NOT SUPPORTED in 1-bit Adam!
        eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
            adds eps to the bias-corrected second moment estimate before
            evaluating square root instead of adding it to the square root of
            second moment estimate as in the original paper. (default: False)
        cuda_aware (boolean, required): Set True if the underlying MPI implementation
            supports CUDA-Aware communication. (default: False)
        comm_backend_name (string, optional): Set to 'mpi' if needed. (default: 'nccl')
    .. _Adam\: A Method for Stochastic Optimization:
        https://arxiv.org/abs/1412.6980
    .. _On the Convergence of Adam and Beyond:
        https://openreview.net/forum?id=ryQu7f-RZ
    """
    def __init__(self,
                 params,
                 deepspeed=None,
                 lr=1e-3,
                 freeze_step=100000,
                 bias_correction=True,
                 betas=(0.9, 0.999),
                 eps=1e-8,
                 eps_inside_sqrt=False,
                 weight_decay=0.,
                 max_grad_norm=0.,
                 amsgrad=False,
                 cuda_aware=False,
                 comm_backend_name='nccl'):

        if amsgrad:
            raise RuntimeError(
                '1-bit Adam does not support the AMSGrad variant.')

        defaults = dict(lr=lr,
                        bias_correction=bias_correction,
                        betas=betas,
                        eps=eps,
                        weight_decay=weight_decay,
                        max_grad_norm=max_grad_norm)

        super(OnebitAdam, self).__init__(params, defaults)
        self.eps_mode = 0 if eps_inside_sqrt else 1
        assert (dist.is_initialized())

        self.comm_time = 0.0
        self.step_time = 0.0
        self.ave_step = 1
        self.bk_time = 0.0

        self.deepspeed = deepspeed
        self.adam_freeze_key = False
        self.initialize = False
        self.freeze_step = freeze_step
        self.cuda_aware = cuda_aware
        self.using_pipeline = False

        self.comm_backend_name = comm_backend_name

        # Empty initializer. Set handle based on the comm backend as follows.
        self.comm_backend_handle = None

        if self.comm_backend_name == 'nccl':
            TORCH_MAJOR = int(torch.__version__.split('.')[0])
            TORCH_MINOR = int(torch.__version__.split('.')[1])
            assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
            assert dist.is_initialized(
            ) == True, "Please initialize the torch distributed backend."
            from deepspeed.runtime.comm.nccl import NcclBackend
            self.using_pipeline = hasattr(
                self.deepspeed, 'pipeline_enable_backward_allreduce')
            self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)

        elif self.comm_backend_name == 'mpi':
            from deepspeed.runtime.comm.mpi import MpiBackend
            self.comm_backend_handle = MpiBackend(cuda_aware)

        self.size = self.comm_backend_handle.size

        self.divider = int(self.size * 8 / np.gcd(self.size, 8))

    def step(self, closure=None, grads=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
            grads (list of tensors, optional): weight gradient to use for the
                optimizer update. If gradients have type torch.half, parameters
                are expected to be in type torch.float. (default: None)
            output params (list of tensors, optional): A reduced precision copy
                of the updated weights written out in addition to the regular
                updated weights. Have to be of same type as gradients. (default: None)
            scale (float, optional): factor to divide gradient tensor values
                by before applying to weights. (default: 1)
        """
        loss = None
        if closure is not None:
            loss = closure()

        gather_time = 0
        allgather_time = 0
        all_time = 0

        if self.adam_freeze_key is False:
            v_diff_buffer = 0.0

        if grads is None:
            grads_group = [None] * len(self.param_groups)
        # backward compatibility
        # assuming a list/generator of parameter means single group
        elif isinstance(grads, types.GeneratorType):
            grads_group = [grads]
        elif type(grads[0]) != list:
            grads_group = [grads]
        else:
            grads_group = grads

        for group, grads_this_group in zip(self.param_groups, grads_group):
            if grads_this_group is None:
                grads_this_group = [None] * len(group['params'])

            bias_correction = 1 if group['bias_correction'] else 0

            for p, grad in zip(group['params'], grads_this_group):
                if p.grad is None and grad is None:
                    continue
                if grad is None:
                    grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        '1-bit Adam does not support sparse gradients')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                if not self.initialize or (self.adam_freeze_key and
                                           'worker_error' not in state.keys()):
                    state['tensor_size'] = torch.numel(p.data)
                    state['corrected_tensor_size'] = state['tensor_size']

                    if state['tensor_size'] % (self.size * self.divider) != 0:
                        state['corrected_tensor_size'] += (
                            (self.size * self.divider) -
                            (state['tensor_size'] %
                             (self.size * self.divider)))
                    state['server_chunk_size'] = state[
                        'corrected_tensor_size'] // self.size
                    torch.cuda.empty_cache()
                    state['worker_error'] = torch.zeros(
                        state['corrected_tensor_size'], device=p.device)
                    state['server_error'] = torch.zeros(
                        state['server_chunk_size'], device=p.device)
                    torch.cuda.empty_cache()
                    self.adam_freeze_key = True
                    if not self.initialize and torch.distributed.get_rank(
                    ) == 0:
                        print("Cupy Buffers Initialized Successfully.")

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                if self.adam_freeze_key is False:
                    exp_avg.mul_(beta1).add_(1 - beta1, grad)
                    exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                    grad = None
                    if self.initialize:
                        update = exp_avg / (exp_avg_sq.sqrt() + group['eps'])

                else:
                    if 'non_freeze' in group.keys(
                    ) and group['non_freeze'] is True:
                        dist.all_reduce(grad)
                        grad.mul_(1 / dist.get_world_size())
                        exp_avg.mul_(beta1).add_(1 - beta1, grad)
                        exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                        grad = None
                    else:
                        if self.initialize is True:
                            exp_avg.mul_(beta1).add_(1 - beta1, grad)
                        grad = None

                        if self.size > 1:
                            exp_avg.set_(
                                self.comm_backend_handle.compressed_allreduce(
                                    exp_avg, state['worker_error'],
                                    state['server_error'],
                                    self.deepspeed.local_rank))
                        # Because 1-bit compression cannot represent exact zero, it is required to
                        # provide a momentum mask for those params that have constant exact zeros in their
                        # momentums, otherwise the compression error would keep accumulating.
                        # For example, for BERT pre-training seq 128, bert.embeddings.position_embeddings.weight
                        # always have exact zeros in its momentum for row 129 to 512, because it only
                        # learns up to seq length 128 while the model supports up to 512 seq length.
                        # (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.)
                        if 'exp_avg_mask' in group:
                            if exp_avg.device != group['exp_avg_mask'].device:
                                group['exp_avg_mask'] = group[
                                    'exp_avg_mask'].to(device=exp_avg.device)
                            exp_avg.mul_(group['exp_avg_mask'])

                    if self.initialize:
                        update = exp_avg / (exp_avg_sq.sqrt() + group['eps'])

                if self.initialize:
                    if group['weight_decay'] > 0.0:
                        update += group['weight_decay'] * p.data
                    with torch.no_grad():
                        p.add_(-group['lr'] * update)

            if not self.initialize:
                print('Pop out errors', flush=True)
                state.pop('worker_error')
                state.pop('server_error')

        if not self.initialize:
            self.adam_freeze_key = False
            self.initialize = True
            print(
                f"Finished the initialization step at rank {torch.distributed.get_rank()}"
            )
            return loss

        if self.adam_freeze_key is False:
            if state['step'] >= self.freeze_step:
                print('OnebitAdam - starting compressed communication')
                self.adam_freeze_key = True
                if self.using_pipeline:
                    self.deepspeed.pipeline_enable_backward_allreduce = False
                else:
                    self.deepspeed.enable_backward_allreduce = False

        return loss

    def load_state_dict(self, state_dict):
        """
        Overrides load_state_dict() to add special handling when loading checkpoints
        """
        # Because at different stage exp_avg_mask may change (e.g.,
        # BERT pre-training seqlen 128 and 512 ), we don't use the exp_avg_mask
        # in checkpoints but always use the one user provided in training script.
        # (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.)
        # Thus here we keep the exp_avg_mask unchanged when loading checkpoint
        for i, group in enumerate(self.param_groups):
            if 'exp_avg_mask' in group:
                state_dict['param_groups'][i]['exp_avg_mask'] = group[
                    'exp_avg_mask']
            elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict[
                    'param_groups'][i]:
                state_dict['param_groups'][i].pop('exp_avg_mask')
        super().load_state_dict(state_dict)
        if self.state[self.param_groups[0]['params']
                      [0]]['step'] < self.freeze_step:
            if torch.distributed.get_rank() == 0:
                print(
                    "Checkpoint loaded and OnebitAdam warmup stage starts/continues."
                )
            if self.adam_freeze_key is True:
                self.adam_freeze_key = False
                if self.using_pipeline:
                    self.deepspeed.pipeline_enable_backward_allreduce = True
                else:
                    self.deepspeed.enable_backward_allreduce = True
        else:
            if torch.distributed.get_rank() == 0:
                print(
                    "Checkpoint loaded and OnebitAdam compression stage starts/continues."
                )
            if self.adam_freeze_key is False:
                self.adam_freeze_key = True
                if self.using_pipeline:
                    self.deepspeed.pipeline_enable_backward_allreduce = False
                else:
                    self.deepspeed.enable_backward_allreduce = False
        # We reset the compression errors when loading checkpoints for 3 reasons:
        # 1) The worker and server error at each GPU are distinct, so in current implementation
        # only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors.
        # If we want to save them correctly we need O(num_gpu*model_size) memory in order to
        # gather all the error, which is a very large memory requirement. It's possible to save
        # them in a distributed way, but it will make the checkpoint saving/loading much more complicated.
        # 2) Even if we are able to save the compression errors correctly, you need to have the
        # exact same number of GPUs in order to load them correctly.
        # 3) We verified on BERT pre-training that occasionally resetting the compression error
        # at checkpoint loading does not affect the convergence.
        # However, please avoid frequent checkpoint loading which could break the error
        # compensation mechanism thus affect the convergence.
        for group in self.param_groups:
            for p in group['params']:
                if 'worker_error' in self.state[p]:
                    self.state[p].pop('worker_error')
                if 'server_error' in self.state[p]:
                    self.state[p].pop('server_error')
Exemple #3
0
class OnebitLamb(torch.optim.Optimizer):
    """Implements the 1-bit Lamb algorithm. Currently GPU-only.
    For usage example please see https://www.deepspeed.ai/tutorials/onebit-lamb/
    For technical details please see our paper https://arxiv.org/abs/2104.06069.

    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups.
        lr (float, optional): learning rate. (default: 1e-3)
        freeze_step (int, optional): Number of steps for warmup (uncompressed)
            stage before we start using compressed communication. (default 100000)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square. (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability. (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        max_coeff(float, optional): maximum value of the lamb coefficient (default: 10.0)
        min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01)
        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
            algorithm from the paper `On the Convergence of Adam and Beyond`_
            (default: False) NOT SUPPORTED in 1-bit Lamb!
        eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
            adds eps to the bias-corrected second moment estimate before
            evaluating square root instead of adding it to the square root of
            second moment estimate as in the original paper. (default: False)
        cuda_aware (boolean, required): Set True if the underlying MPI implementation
            supports CUDA-Aware communication. (default: False)
        comm_backend_name (string, optional): Set to 'mpi' if needed. (default: 'nccl')
        coeff_beta (float, optional): coefficient used for computing
            running averages of lamb coefficient (default: 0.9) note that you may want to
            increase or decrease this beta depending on the freeze_step you choose, as
            1/(1 - coeff_beta) should be smaller than or equal to freeze_step
        factor_max (float, optional): maximum value of scaling factor to the frozen lamb
            coefficient during compression stage (default: 4.0)
        factor_min (float, optional): minimum value of scaling factor to the frozen lamb
            coefficient during compression stage (default: 0.5)
        factor_threshold (float, optional): threshold of how much the scaling factor can
            fluctuate between steps (default: 0.1)
    .. _Large Batch Optimization for Deep Learning\\: Training BERT in 76 minutes:
        https://arxiv.org/abs/1904.00962
    .. _Adam\\: A Method for Stochastic Optimization:
        https://arxiv.org/abs/1412.6980
    .. _On the Convergence of Adam and Beyond:
        https://openreview.net/forum?id=ryQu7f-RZ
    """
    def __init__(self,
                 params,
                 deepspeed=None,
                 lr=1e-3,
                 freeze_step=100000,
                 bias_correction=True,
                 betas=(0.9, 0.999),
                 eps=1e-8,
                 eps_inside_sqrt=False,
                 weight_decay=0.,
                 max_grad_norm=0.,
                 max_coeff=10.0,
                 min_coeff=0.01,
                 amsgrad=False,
                 cuda_aware=False,
                 comm_backend_name='nccl',
                 coeff_beta=0.9,
                 factor_max=4.0,
                 factor_min=0.5,
                 factor_threshold=0.1):

        if amsgrad:
            raise RuntimeError(
                '1-bit Lamb does not support the AMSGrad variant.')

        defaults = dict(lr=lr,
                        bias_correction=bias_correction,
                        betas=betas,
                        eps=eps,
                        weight_decay=weight_decay,
                        max_grad_norm=max_grad_norm,
                        max_coeff=max_coeff,
                        min_coeff=min_coeff)

        super(OnebitLamb, self).__init__(params, defaults)
        self.eps_mode = 0 if eps_inside_sqrt else 1
        assert (dist.is_initialized())

        self.deepspeed = deepspeed
        self.lamb_freeze_key = False
        self.initialize = False
        self.freeze_step = freeze_step
        self.cuda_aware = cuda_aware
        self.coeff_beta = coeff_beta
        self.factor_max = factor_max
        self.factor_min = factor_min
        self.factor_threshold = factor_threshold
        self.using_pipeline = False

        self.comm_backend_name = comm_backend_name

        # Empty initializer. Set handle based on the comm backend as follows.
        self.comm_backend_handle = None

        if self.comm_backend_name == 'nccl':
            TORCH_MAJOR = int(torch.__version__.split('.')[0])
            TORCH_MINOR = int(torch.__version__.split('.')[1])
            assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
            assert dist.is_initialized(
            ) == True, "Please initialize the torch distributed backend."
            from deepspeed.runtime.comm.nccl import NcclBackend
            self.using_pipeline = hasattr(
                self.deepspeed, 'pipeline_enable_backward_allreduce')
            self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)

        elif self.comm_backend_name == 'mpi':
            from deepspeed.runtime.comm.mpi import MpiBackend
            self.comm_backend_handle = MpiBackend(cuda_aware)

        self.size = self.comm_backend_handle.size

        self.divider = int(self.size * 8 / np.gcd(self.size, 8))

        self.exp_avg_flat = []
        self.dummy_exp_avg = {}
        self.corrected_tensor_sizes = []
        self.server_chunk_sizes = []
        self.worker_errors = []
        self.server_errors = []

        self.lamb_coeffs = []

    def step(self, closure=None, grads=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
            grads (list of tensors, optional): weight gradient to use for the
                optimizer update. If gradients have type torch.half, parameters
                are expected to be in type torch.float. (default: None)
        """
        loss = None
        if closure is not None:
            loss = closure()

        if grads is None:
            grads_group = [None] * len(self.param_groups)
        # backward compatibility
        # assuming a list/generator of parameter means single group
        elif isinstance(grads, types.GeneratorType):
            grads_group = [grads]
        elif type(grads[0]) != list:
            grads_group = [grads]
        else:
            grads_group = grads

        #remove the previous stats
        del self.lamb_coeffs[:]

        if self.lamb_freeze_key:
            exp_avg_last_step = []
            for group in self.param_groups:
                exp_avg_last_step.append([
                    self.state[p]['exp_avg'].detach().clone()
                    for p in group['params']
                ])
            if 'scaling_coeff' not in self.state[self.param_groups[0]['params']
                                                 [0]]:
                # Compute the scaling_coeff for each momentum at the end of warmup stage.
                # This is used to reduce compression error during compression stage.
                momentum_scales = []
                for group in self.param_groups:
                    momentum_scales.append([(
                        torch.norm(self.state[p]['exp_avg']) /
                        np.sqrt(torch.numel(self.state[p]['exp_avg']))).item()
                                            for p in group['params']])
                united_scale = sum([sum(x) for x in momentum_scales]) / sum(
                    [len(x) for x in momentum_scales])
                for i, group in enumerate(self.param_groups):
                    for j, p in enumerate(group['params']):
                        self.state[p][
                            'scaling_coeff'] = united_scale / momentum_scales[
                                i][j]

        for group, grads_this_group in zip(self.param_groups, grads_group):
            if grads_this_group is None:
                grads_this_group = [None] * len(group['params'])

            bias_correction = 1 if group['bias_correction'] else 0

            for p, grad in zip(group['params'], grads_this_group):
                if p.grad is None and grad is None:
                    continue
                if grad is None:
                    grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        '1-bit Lamb does not support sparse gradients')

                state = self.state[p]

                # State initialization
                if len(state) == 0 or (len(state) == 1
                                       and 'scaling_coeff' in state.keys()):
                    state['step'] = 0
                    state['lamb_coeff_freeze'] = 0.0
                    state['last_factor'] = 1.0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                    state['exp_avg_sq_fresh'] = torch.zeros_like(p.data)

                if not self.initialize:
                    self.lamb_freeze_key = True

                exp_avg, exp_avg_sq, exp_avg_sq_fresh = state[
                    'exp_avg'], state['exp_avg_sq'], state['exp_avg_sq_fresh']
                beta1, beta2 = group['betas']
                max_coeff = group['max_coeff']
                min_coeff = group['min_coeff']

                state['step'] += 1

                if self.lamb_freeze_key is False:
                    # warmup stage, baseline Lamb optimization
                    exp_avg.mul_(beta1).add_(1 - beta1, grad)
                    exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                    if state['step'] == self.freeze_step:
                        exp_avg_sq_fresh.data = exp_avg_sq.detach().clone()
                    grad = None
                    if self.initialize:
                        weight_norm = p.data.pow(2).sum().sqrt()
                        update = exp_avg / (exp_avg_sq.sqrt() + group['eps'])
                        if group['weight_decay'] > 0.0:
                            update += group['weight_decay'] * p.data
                        update_norm = update.pow(2).sum().sqrt()
                        lamb_coeff = 1.0
                        if weight_norm != 0 and update_norm != 0:
                            lamb_coeff = (weight_norm / update_norm).item()
                            if lamb_coeff > max_coeff:
                                lamb_coeff = max_coeff
                            if lamb_coeff < min_coeff:
                                lamb_coeff = min_coeff
                        if lamb_coeff != 1.0:
                            state[
                                'lamb_coeff_freeze'] = self.coeff_beta * state[
                                    'lamb_coeff_freeze'] + (
                                        1 - self.coeff_beta) * lamb_coeff
                        self.lamb_coeffs.append(lamb_coeff)
                        with torch.no_grad():
                            p.add_(-group['lr'] * lamb_coeff * update)
                else:
                    # compression stage, update each momentum locally, then
                    # communicate based on the compressed_allreduce below
                    if self.initialize:
                        exp_avg.mul_(beta1).add_(1 - beta1, grad)
                        exp_avg.mul_(self.state[p]['scaling_coeff'])
                    grad = None

        # init fused momentum
        if len(self.exp_avg_flat) == 0:
            momentum_groups = []
            tensor_size = 0
            for group in self.param_groups:
                for p in group['params']:
                    momentum_groups.append(self.state[p]['exp_avg'])
                    tensor_size += torch.numel(p.data)
            corrected_tensor_size = tensor_size
            if tensor_size % (self.size * self.divider) != 0:
                difference = ((self.size * self.divider) -
                              (tensor_size % (self.size * self.divider)))
                corrected_tensor_size += difference
                self.dummy_exp_avg[0] = torch.zeros(
                    difference, device=momentum_groups[0].data.device)
                momentum_groups.append(self.dummy_exp_avg[0])
            self.corrected_tensor_sizes.append(corrected_tensor_size)
            self.server_chunk_sizes.append(corrected_tensor_size // self.size)

            self.exp_avg_flat.append(
                _flatten_dense_tensors(
                    [p.detach().clone() for p in momentum_groups]))
            updated_params = _unflatten_dense_tensors(self.exp_avg_flat[0],
                                                      momentum_groups)
            for p, q in zip(momentum_groups, updated_params):
                p.data = q.data

        if self.initialize and len(self.worker_errors) == 0:
            torch.cuda.empty_cache()
            for i in range(len(self.exp_avg_flat)):
                self.worker_errors.append(
                    torch.zeros(self.corrected_tensor_sizes[i],
                                device=self.exp_avg_flat[i].device))
                self.server_errors.append(
                    torch.zeros(self.server_chunk_sizes[i],
                                device=self.exp_avg_flat[i].device))
            torch.cuda.empty_cache()

        if self.lamb_freeze_key:
            if self.size > 1:
                for i in range(len(self.exp_avg_flat)):
                    if not self.initialize:
                        torch.cuda.empty_cache()
                        self.worker_errors.append(
                            torch.zeros(self.corrected_tensor_sizes[i],
                                        device=self.exp_avg_flat[i].device))
                        self.server_errors.append(
                            torch.zeros(self.server_chunk_sizes[i],
                                        device=self.exp_avg_flat[i].device))
                        torch.cuda.empty_cache()
                        if dist.get_rank() == 0:
                            print("Cupy Buffers Initialized Successfully.")

                        self.comm_backend_handle.compressed_allreduce(
                            self.exp_avg_flat[i], self.worker_errors[0],
                            self.server_errors[0], self.deepspeed.local_rank)

                        if dist.get_rank() == 0:
                            print('Pop out errors', flush=True)
                        del self.worker_errors[:]
                        del self.server_errors[:]
                    else:
                        self.comm_backend_handle.compressed_allreduce(
                            self.exp_avg_flat[i], self.worker_errors[i],
                            self.server_errors[i], self.deepspeed.local_rank)

        if self.lamb_freeze_key and self.initialize:
            for i, group in enumerate(self.param_groups):
                bias_correction = 1 if group['bias_correction'] else 0

                for j, p in enumerate(group['params']):
                    state = self.state[p]
                    exp_avg, exp_avg_sq, exp_avg_sq_fresh = state[
                        'exp_avg'], state['exp_avg_sq'], state[
                            'exp_avg_sq_fresh']
                    beta1, beta2 = group['betas']
                    exp_avg.div_(self.state[p]['scaling_coeff'])
                    # Because 1-bit compression cannot represent exact zero, it is required to
                    # provide a momentum mask for those params that have constant exact zeros in their
                    # momentums, otherwise the compression error would keep accumulating.
                    # For example, for BERT pre-training seq 128, bert.embeddings.position_embeddings.weight
                    # always have exact zeros in its momentum for row 129 to 512, because it only
                    # learns up to seq length 128 while the model supports up to 512 seq length.
                    # (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py about how
                    # to add this exp_avg_mask for BERT pre-training.)
                    if 'exp_avg_mask' in group:
                        if exp_avg.device != group['exp_avg_mask'].device:
                            group['exp_avg_mask'] = group['exp_avg_mask'].to(
                                device=exp_avg.device)
                        exp_avg.mul_(group['exp_avg_mask'])

                    grad_reconstruct = (
                        (exp_avg - exp_avg_last_step[i][j] * beta1) /
                        (1 - beta1))
                    exp_avg_sq_fresh.mul_(beta2).addcmul_(
                        1 - beta2, grad_reconstruct, grad_reconstruct)
                    denom = exp_avg_sq.sqrt() + group['eps']
                    update_prelim = exp_avg / denom

                    if group['weight_decay'] > 0.0:
                        update = update_prelim + group['weight_decay'] * p.data
                    else:
                        update = update_prelim

                    lamb_coeff = 1.0
                    update_norm = update.pow(2).sum().sqrt()
                    denom_real = exp_avg_sq_fresh.sqrt() + group['eps']
                    factor = (denom / denom_real).max().item()
                    if group['weight_decay'] > 0.0:
                        update_ratio = min(1.0,
                                           (update_prelim.pow(2).sum().sqrt() /
                                            update_norm).item())
                        factor = factor * update_ratio + (1.0 - update_ratio)
                    if factor > self.factor_max:
                        factor = self.factor_max
                    if factor < self.factor_min:
                        factor = self.factor_min
                    if factor > state['last_factor'] * (1.0 +
                                                        self.factor_threshold):
                        factor = state['last_factor'] * (1.0 +
                                                         self.factor_threshold)
                    if factor < state['last_factor'] * (1.0 -
                                                        self.factor_threshold):
                        factor = state['last_factor'] * (1.0 -
                                                         self.factor_threshold)
                    state['last_factor'] = factor
                    lamb_coeff = state['lamb_coeff_freeze'] * factor
                    self.lamb_coeffs.append(lamb_coeff)
                    with torch.no_grad():
                        p.add_(-group['lr'] * lamb_coeff * update)
            del exp_avg_last_step[:]
            exp_avg_last_step = None

        if not self.initialize:
            self.lamb_freeze_key = False
            self.initialize = True
            print(
                f"Finished the initialization step at rank {dist.get_rank()}")
            return loss

        if self.lamb_freeze_key is False:
            if state['step'] >= self.freeze_step:
                print('OnebitLamb - starting compressed communication')
                self.lamb_freeze_key = True
                if self.using_pipeline:
                    self.deepspeed.pipeline_enable_backward_allreduce = False
                else:
                    self.deepspeed.enable_backward_allreduce = False

        return loss

    def load_state_dict(self, state_dict):
        """
        Overrides load_state_dict() to add special handling when loading checkpoints
        """
        # Because at different stage exp_avg_mask may change (e.g.,
        # BERT pre-training seqlen 128 and 512 ), we don't use the exp_avg_mask
        # in checkpoints but always use the one user provided in training script.
        # (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.)
        # Thus here we keep the exp_avg_mask unchanged when loading checkpoint
        for i, group in enumerate(self.param_groups):
            if 'exp_avg_mask' in group:
                state_dict['param_groups'][i]['exp_avg_mask'] = group[
                    'exp_avg_mask']
            elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict[
                    'param_groups'][i]:
                state_dict['param_groups'][i].pop('exp_avg_mask')
        super().load_state_dict(state_dict)
        # need to reset the fused momentum since loading states will break the linking
        del self.exp_avg_flat[:]
        self.dummy_exp_avg.clear()
        del self.corrected_tensor_sizes[:]
        del self.server_chunk_sizes[:]
        if self.state[self.param_groups[0]['params']
                      [0]]['step'] < self.freeze_step:
            if dist.get_rank() == 0:
                print(
                    "Checkpoint loaded and OnebitLamb warmup stage starts/continues."
                )
            if self.lamb_freeze_key is True:
                self.lamb_freeze_key = False
                if self.using_pipeline:
                    self.deepspeed.pipeline_enable_backward_allreduce = True
                else:
                    self.deepspeed.enable_backward_allreduce = True
            for group in self.param_groups:
                for p in group['params']:
                    self.state[p]['lamb_coeff_freeze'] = 0.0
                    self.state[p]['last_factor'] = 1.0
                    if 'scaling_coeff' in self.state[p]:
                        self.state[p].pop('scaling_coeff')
        else:
            if dist.get_rank() == 0:
                print(
                    "Checkpoint loaded and OnebitLamb compression stage starts/continues."
                )
            if self.lamb_freeze_key is False:
                self.lamb_freeze_key = True
                if self.using_pipeline:
                    self.deepspeed.pipeline_enable_backward_allreduce = False
                else:
                    self.deepspeed.enable_backward_allreduce = False
        # We reset the compression errors when loading checkpoints for 3 reasons:
        # 1) The worker and server error at each GPU are distinct, so in current implementation
        # only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors.
        # If we want to save them correctly we need O(num_gpu*model_size) memory in order to
        # gather all the error, which is a very large memory requirement. It's possible to save
        # them in a distributed way, but it will make the checkpoint saving/loading much more complicated.
        # 2) Even if we are able to save the compression errors correctly, you need to have the
        # exact same number of GPUs in order to load them correctly.
        # 3) We verified on BERT pre-training that occasionally resetting the compression error
        # at checkpoint loading does not affect the convergence.
        # However, please avoid frequent checkpoint loading which could break the error
        # compensation mechanism thus affect the convergence.
        del self.worker_errors[:]
        del self.server_errors[:]

    def get_lamb_coeffs(self):
        return self.lamb_coeffs
Exemple #4
0
class ZeroOneAdam(torch.optim.Optimizer):
    """Implements the 0/1 Adam algorithm. Currently GPU-only.
    For usage example please see https://www.deepspeed.ai/tutorials/zero-one-adam/
    For technical details please read https://arxiv.org/abs/2202.06009
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups.
        lr (float, optional): learning rate. (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square. (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability. (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        var_freeze_step (int, optional): The latest step to update the variance,
            using the notation from https://arxiv.org/abs/2202.06009, it denotes the
            max{i|i in T_v}. Note that this is different from the freeze step from the
            1-bit Adam. The var_freeze_step is usually the end of the learning rate warmup
            and thus does not require tuning. (default: 100000)
        var_update_scaler (int, optional): The interval to update the variance. Note that
            the update policy for variance follows an exponential rule, where var_update_scaler
            denotes the kappa in the 0/1 Adam paper. (default: 16)
        local_step_scaler (int, optional): The interval to scale the local steps interval
            according to the learning rate policy. (default: 32678)
        local_step_clipper (int, optional): The largest interval for local steps with
            learning rate policy. This corresponds to the variable H in the 0/1 Adam paper.
            (default: 16)
        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
            algorithm from the paper `On the Convergence of Adam and Beyond`_
            (default: False) NOT SUPPORTED in 0/1 Adam!
        eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
            adds eps to the bias-corrected second moment estimate before
            evaluating square root instead of adding it to the square root of
            second moment estimate as in the original paper. (default: False)
        cuda_aware (boolean, required): Set True if the underlying MPI implementation
            supports CUDA-Aware communication. (default: False)
        comm_backend_name (string, optional): Set to 'mpi' if needed. (default: 'nccl')
    .. _Adam\: A Method for Stochastic Optimization:
        https://arxiv.org/abs/1412.6980
    .. _On the Convergence of Adam and Beyond:
        https://openreview.net/forum?id=ryQu7f-RZ
    """
    def __init__(self,
                 params,
                 deepspeed=None,
                 lr=1e-3,
                 bias_correction=True,
                 betas=(0.9,
                        0.999),
                 eps=1e-8,
                 eps_inside_sqrt=False,
                 weight_decay=0.,
                 max_grad_norm=0.,
                 var_freeze_step=100000,
                 var_update_scaler=16,
                 local_step_scaler=32678,
                 local_step_clipper=16,
                 amsgrad=False,
                 cuda_aware=False,
                 comm_backend_name='nccl'):

        if amsgrad:
            raise RuntimeError('0/1 Adam does not support the AMSGrad variant.')

        defaults = dict(lr=lr,
                        bias_correction=bias_correction,
                        betas=betas,
                        eps=eps,
                        weight_decay=weight_decay,
                        max_grad_norm=max_grad_norm)

        super(ZeroOneAdam, self).__init__(params, defaults)
        self.eps_mode = 0 if eps_inside_sqrt else 1
        assert (dist.is_initialized())

        self.deepspeed = deepspeed
        self.initialize = False
        self.cuda_aware = cuda_aware
        self.using_pipeline = False

        self.var_freeze_step = var_freeze_step
        self.var_update_scaler = var_update_scaler
        self.local_step_scaler = local_step_scaler
        self.local_step_clipper = local_step_clipper
        self.freeze_key = False
        self.reinitial_error_buffer = False

        self.comm_backend_name = comm_backend_name

        # Empty initializer. Set handle based on the comm backend as follows.
        self.comm_backend_handle = None

        if self.comm_backend_name == 'nccl':
            TORCH_MAJOR = int(torch.__version__.split('.')[0])
            TORCH_MINOR = int(torch.__version__.split('.')[1])
            assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 0/1 Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
            assert dist.is_initialized() == True, "Please initialize the torch distributed backend."
            from deepspeed.runtime.comm.nccl import NcclBackend
            self.using_pipeline = hasattr(self.deepspeed,
                                          'pipeline_enable_backward_allreduce')
            self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)

        elif self.comm_backend_name == 'mpi':
            from deepspeed.runtime.comm.mpi import MpiBackend
            self.comm_backend_handle = MpiBackend(cuda_aware)

        self.size = self.comm_backend_handle.size

        self.divider = int(self.size * 8 / np.gcd(self.size, 8))

    def step(self, closure=None, grads=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
            grads (list of tensors, optional): weight gradient to use for the
                optimizer update. If gradients have type torch.half, parameters
                are expected to be in type torch.float. (default: None)
            output params (list of tensors, optional): A reduced precision copy
                of the updated weights written out in addition to the regular
                updated weights. Have to be of same type as gradients. (default: None)
            scale (float, optional): factor to divide gradient tensor values
                by before applying to weights. (default: 1)
        """
        loss = None
        if closure is not None:
            loss = closure()

        if grads is None:
            grads_group = [None] * len(self.param_groups)
        # backward compatibility
        # assuming a list/generator of parameter means single group
        elif isinstance(grads, types.GeneratorType):
            grads_group = [grads]
        elif type(grads[0]) != list:
            grads_group = [grads]
        else:
            grads_group = grads

        for group, grads_this_group in zip(self.param_groups, grads_group):
            if grads_this_group is None:
                grads_this_group = [None] * len(group['params'])

            bias_correction = 1 if group['bias_correction'] else 0

            for p, grad in zip(group['params'], grads_this_group):
                if p.grad is None and grad is None:
                    continue
                if grad is None:
                    grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('0/1 Adam does not support sparse gradients')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                if not self.initialize or 'worker_error' not in state.keys():
                    # Some scalars to help scale the variance update/local step policies
                    state['var_interval'] = 1
                    state['var_counter'] = 0
                    state['local_step_interval'] = 1
                    state['local_step_counter'] = 0
                    state['lrs'] = 0
                    state['tensor_size'] = torch.numel(p.data)
                    state['corrected_tensor_size'] = state['tensor_size']

                    if state['tensor_size'] % (self.size * self.divider) != 0:
                        state['corrected_tensor_size'] += ((self.size * self.divider) -
                                                           (state['tensor_size'] %
                                                            (self.size * self.divider)))
                    state['server_chunk_size'] = state[
                        'corrected_tensor_size'] // self.size
                    torch.cuda.empty_cache()
                    state['worker_error'] = torch.zeros(state['corrected_tensor_size'],
                                                        device=p.device)
                    state['server_error'] = torch.zeros(state['server_chunk_size'],
                                                        device=p.device)
                    # Accumulation of momentum, i.e., the u variable in the 0/1 Adam paper
                    state['momentum_accumulator'] = torch.zeros_like(p.data)
                    torch.cuda.empty_cache()
                    # self.freeze_key = True
                    if not self.initialize and torch.distributed.get_rank() == 0:
                        print("Cupy Buffers Initialized Successfully.")

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                comm_buffer = state['momentum_accumulator']
                beta1, beta2 = group['betas']

                state['step'] += 1

                if self.initialize:
                    if self.freeze_key is False:
                        if state['step'] % state['var_interval'] == 0:
                            exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                            exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                        else:
                            if self.size > 1:
                                with torch.no_grad():
                                    grad_onebit = self.comm_backend_handle.compressed_allreduce(
                                        grad,
                                        state['worker_error'],
                                        state['server_error'],
                                        self.deepspeed.local_rank)
                                    if 'exp_avg_mask' in group:
                                        if grad_onebit.device != group[
                                                'exp_avg_mask'].device:
                                            group['exp_avg_mask'] = group[
                                                'exp_avg_mask'].to(
                                                    device=grad_onebit.device)
                                        grad_onebit.mul_(group['exp_avg_mask'])
                                    exp_avg.mul_(beta1).add_(1 - beta1, grad_onebit)
                    else:
                        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                        state['lrs'] += group['lr']
                    grad = None

                if not self.initialize:
                    if self.size > 1:
                        comm_buffer.set_(
                            self.comm_backend_handle.compressed_allreduce(
                                comm_buffer,
                                state['worker_error'],
                                state['server_error'],
                                self.deepspeed.local_rank))
                        if 'exp_avg_mask' in group:
                            if comm_buffer.device != group['exp_avg_mask'].device:
                                group['exp_avg_mask'] = group['exp_avg_mask'].to(
                                    device=comm_buffer.device)
                            comm_buffer.mul_(group['exp_avg_mask'])

                if self.initialize:
                    update = exp_avg / (exp_avg_sq.sqrt() + group['eps'])
                    if group['weight_decay'] > 0.0:
                        update += group['weight_decay'] * p.data
                    with torch.no_grad():
                        p.data.add_(-group['lr'] * update)
                        if self.freeze_key is True:
                            comm_buffer.add_(-group['lr'] * update)
                    if state['step'] % state[
                            'local_step_interval'] == 0 and self.freeze_key:
                        with torch.no_grad():
                            p.data.add_(-1 * comm_buffer)
                            comm_buffer.mul_(exp_avg_sq.sqrt() + group['eps'])
                            if self.size > 1:
                                comm_buffer.copy_(
                                    self.comm_backend_handle.compressed_allreduce(
                                        comm_buffer,
                                        state['worker_error'],
                                        state['server_error'],
                                        self.deepspeed.local_rank))
                                if 'exp_avg_mask' in group:
                                    if comm_buffer.device != group['exp_avg_mask'].device:
                                        group['exp_avg_mask'] = group['exp_avg_mask'].to(
                                            device=comm_buffer.device)
                                    comm_buffer.mul_(group['exp_avg_mask'])
                            exp_avg.zero_().add_(comm_buffer / state['lrs'], alpha=-1)
                            p.data.add_(comm_buffer / (exp_avg_sq.sqrt() + group['eps']))
                            comm_buffer.zero_()

                            state['lrs'] = 0

                    # According to 0/1 Adam theory, a fixed variance would allow more accurate estimation of momentum
                    # However, in practice, we can also disable the manual freezing of variance, since the interval of
                    # updating variance will increase exponentially, so that it has negligible effect on the estimation.
                    if self.freeze_key is False:
                        if state['step'] % state['var_interval'] == 0:
                            state['var_counter'] += 1
                            if state['var_counter'] == self.var_update_scaler:
                                state['var_counter'] = 0
                                state['var_interval'] *= 2
                        if (state['step'] + 1) % state['var_interval'] == 0:
                            if self.using_pipeline:
                                self.deepspeed.pipeline_enable_backward_allreduce = True
                            else:
                                self.deepspeed.enable_backward_allreduce = True
                        else:
                            if self.using_pipeline:
                                self.deepspeed.pipeline_enable_backward_allreduce = False
                            else:
                                self.deepspeed.enable_backward_allreduce = False
                    else:
                        state['local_step_counter'] += 1
                        if state['local_step_counter'] == self.local_step_scaler:
                            state['local_step_counter'] = 0
                            state['local_step_interval'] = min(
                                self.local_step_clipper,
                                state['local_step_interval'] * 2)

            if not self.initialize:
                print('Pop out errors', flush=True)
                self.freeze_key = False
                state.pop('worker_error')
                state.pop('server_error')

        if not self.initialize:
            self.initialize = True
            print(
                f"Finished the initialization step at rank {torch.distributed.get_rank()}"
            )
            return loss

        if self.state[self.param_groups[0]['params'][0]]['step'] > self.var_freeze_step:
            self.freeze_key = True
            if self.using_pipeline:
                self.deepspeed.pipeline_enable_backward_allreduce = False
            else:
                self.deepspeed.enable_backward_allreduce = False

        if self.freeze_key is True and self.reinitial_error_buffer is False:
            # We need to reinitialize the error buffers when local step > 1 since
            # the errors will be logged for different metrics (gradient vs. accumulated momentum).
            for group in self.param_groups:
                for p in group['params']:
                    self.state[p]['worker_error'].zero_()
                    self.state[p]['server_error'].zero_()
            self.reinitial_error_buffer = True

        return loss

    def load_state_dict(self, state_dict):
        """
        Overrides load_state_dict() to add special handling when loading checkpoints
        """
        # Because at different stage exp_avg_mask may change (e.g.,
        # BERT pre-training seqlen 128 and 512 ), we don't use the exp_avg_mask
        # in checkpoints but always use the one user provided in training script.
        # (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.)
        # Thus here we keep the exp_avg_mask unchanged when loading checkpoint
        for i, group in enumerate(self.param_groups):
            if 'exp_avg_mask' in group:
                state_dict['param_groups'][i]['exp_avg_mask'] = group['exp_avg_mask']
            elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict[
                    'param_groups'][i]:
                state_dict['param_groups'][i].pop('exp_avg_mask')
        super().load_state_dict(state_dict)
        if self.state[self.param_groups[0]['params'][0]]['step'] < self.var_freeze_step:
            self.var_freeze_key = False
            if (self.state[self.param_groups[0]['params'][0]]['step'] + 1
                ) % self.state[self.param_groups[0]['params'][0]]['var_interval'] == 0:
                if self.using_pipeline:
                    self.deepspeed.pipeline_enable_backward_allreduce = True
                else:
                    self.deepspeed.enable_backward_allreduce = True
            else:
                if self.using_pipeline:
                    self.deepspeed.pipeline_enable_backward_allreduce = False
                else:
                    self.deepspeed.enable_backward_allreduce = False
        else:
            self.var_freeze_key = True
            if self.using_pipeline:
                self.deepspeed.pipeline_enable_backward_allreduce = False
            else:
                self.deepspeed.enable_backward_allreduce = False
        self.reinitial_error_buffer = False
        for group in self.param_groups:
            for p in group['params']:
                if 'worker_error' in self.state[p]:
                    self.state[p].pop('worker_error')
                if 'server_error' in self.state[p]:
                    self.state[p].pop('server_error')
                if 'momentum_accumulator' in self.state[p]:
                    self.state[p].pop('momentum_accumulator')
from mpi4py import MPI
import torch
import deepspeed.comm as dist
import numpy as np
import deepspeed

from deepspeed.runtime.comm.mpi import MpiBackend

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

deepspeed.init_distributed(dist_backend='nccl')

# Change cuda_aware to True to test out CUDA-Aware MPI communication
backend = MpiBackend(cuda_aware=False)

device = torch.device('cuda', rank % torch.cuda.device_count())


# A simulated compression function using deepspeed.comm
def torch_sim(a):
    a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
    scale = a.norm() / np.sqrt(a.numel())
    a_compressed = scale * a_sign
    a_sign = None
    worker_error = a - a_compressed
    dist.all_reduce(a_compressed)
    a_compressed.mul_(1 / dist.get_world_size())
    a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(
        2.0)