示例#1
0
    def BMUF(self, itr, cp):
        step_flag = (itr != 0 and itr % cp == 0)
        if step_flag:

            for group in self.param_groups:
                lr = group['lr']
                for p in group['params']:
                    param_state = self.state[p]
                    old_data = param_state['anchor_model']

                    if 'global_momentum_buffer' not in param_state:
                        buf = param_state['global_momentum_buffer'] = torch.clone(p.data).detach()
                        buf.sub_(old_data)
                        buf.div_(-lr)
                    else:
                        buf = param_state['global_momentum_buffer']
                        buf.mul_(self.gmf).sub_(1/lr, p.data).add_(1/lr, old_data)

                    old_data.add_(-lr, buf)
                    old_data.div_(self.size)

            communicate(self.comm_buf, dist.all_reduce)
            for group in self.param_groups:
                for p in group['params']:
                    param_state = self.state[p]
                    old_data = param_state['anchor_model']
                    p.data.copy_(old_data)
示例#2
0
    def average(self, weight=0, tau_eff=0):
        if weight == 0:
            weight = self.ratio
        if tau_eff == 0:
            if self.mu != 0:
                tau_eff_cuda = torch.tensor(self.local_steps *
                                            self.ratio).cuda()
            else:
                tau_eff_cuda = torch.tensor(self.local_normalizing_vec *
                                            self.ratio).cuda()
            dist.all_reduce(tau_eff_cuda, op=dist.ReduceOp.SUM)
            tau_eff = tau_eff_cuda.item()

        param_list = []
        for group in self.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                scale = tau_eff / self.local_normalizing_vec
                param_state['cum_grad'].mul_(weight * scale)
                param_list.append(param_state['cum_grad'])

        communicate(param_list, dist.all_reduce)

        for group in self.param_groups:
            lr = group['lr']
            for p in group['params']:
                param_state = self.state[p]

                if self.gmf != 0:
                    if 'global_momentum_buffer' not in param_state:
                        buf = param_state[
                            'global_momentum_buffer'] = torch.clone(
                                param_state['cum_grad']).detach()
                        buf.div_(lr)
                    else:
                        buf = param_state['global_momentum_buffer']
                        buf.mul_(self.gmf).add_(1 / lr,
                                                param_state['cum_grad'])
                    param_state['old_init'].sub_(lr, buf)
                else:
                    param_state['old_init'].sub_(param_state['cum_grad'])

                p.data.copy_(param_state['old_init'])
                param_state['cum_grad'].zero_()

                # Reinitialize momentum buffer
                if 'momentum_buffer' in param_state:
                    param_state['momentum_buffer'].zero_()

        self.local_counter = 0
        self.local_normalizing_vec = 0
        self.local_steps = 0
示例#3
0
    def elastic_average(self, itr, cp):
        step_flag = (itr != 0 and itr % cp == 0)
        if step_flag:
            beta = 1/self.size - self.alpha - self.alpha**2/(1-self.alpha)
            for group in self.param_groups:
                for p in group['params']:
                    param_state = self.state[p]
                    buf = param_state['anchor_model']

                    p.data.mul_(1-self.alpha).add_(self.alpha, buf)
                    buf.mul_(beta).add_(self.alpha/(1-self.alpha), p.data)
                    
            communicate(self.comm_buf, dist.all_reduce)
示例#4
0
    def average(self):
        param_list = []
        for group in self.param_groups:
            for p in group['params']:
                p.data.mul_(self.ratio)
                param_list.append(p.data)

        communicate(param_list, dist.all_reduce)

        for group in self.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                param_state['old_init'] = torch.clone(p.data).detach()
                # Reinitialize momentum buffer
                if 'momentum_buffer' in param_state:
                    param_state['momentum_buffer'].zero_()
    def average(self):
        step_flag = (self.itr != 0 and self.itr % self.cp == 0)
        self.itr += 1
        if step_flag:
            if self.gmf == 0:
                # simple average
                param_list = []
                for group in self.param_groups:
                    for p in group['params']:
                        p.data.div_(self.size)
                        param_list.append(p.data)
                communicate(param_list, dist.all_reduce)
            else:
                # simple average + global momentum
                for group in self.param_groups:
                    lr = group['lr']
                    for p in group['params']:
                        param_state = self.state[p]
                        old_data = param_state['anchor_model']

                        if 'global_momentum_buffer' not in param_state:
                            buf = param_state[
                                'global_momentum_buffer'] = torch.clone(
                                    p.data).detach()
                            buf.sub_(old_data)
                            buf.div_(-lr)
                        else:
                            buf = param_state['global_momentum_buffer']
                            buf.mul_(self.gmf).sub_(1 / lr, p.data).add_(
                                1 / lr, old_data)

                        old_data.add_(-lr, buf)
                        old_data.div_(self.size)

                communicate(self.comm_buf, dist.all_reduce)
                for group in self.param_groups:
                    for p in group['params']:
                        param_state = self.state[p]
                        old_data = param_state['anchor_model']
                        p.data.copy_(old_data)
示例#6
0
 def _async_all_reduce_(buff, buf_ready, comm_finish):
     while True:
         buf_ready.wait()
         communicate(buff, dist.all_reduce)
         buf_ready.clear()
         comm_finish.set()