def BMUF(self, itr, cp): step_flag = (itr != 0 and itr % cp == 0) if step_flag: for group in self.param_groups: lr = group['lr'] for p in group['params']: param_state = self.state[p] old_data = param_state['anchor_model'] if 'global_momentum_buffer' not in param_state: buf = param_state['global_momentum_buffer'] = torch.clone(p.data).detach() buf.sub_(old_data) buf.div_(-lr) else: buf = param_state['global_momentum_buffer'] buf.mul_(self.gmf).sub_(1/lr, p.data).add_(1/lr, old_data) old_data.add_(-lr, buf) old_data.div_(self.size) communicate(self.comm_buf, dist.all_reduce) for group in self.param_groups: for p in group['params']: param_state = self.state[p] old_data = param_state['anchor_model'] p.data.copy_(old_data)
def average(self, weight=0, tau_eff=0): if weight == 0: weight = self.ratio if tau_eff == 0: if self.mu != 0: tau_eff_cuda = torch.tensor(self.local_steps * self.ratio).cuda() else: tau_eff_cuda = torch.tensor(self.local_normalizing_vec * self.ratio).cuda() dist.all_reduce(tau_eff_cuda, op=dist.ReduceOp.SUM) tau_eff = tau_eff_cuda.item() param_list = [] for group in self.param_groups: for p in group['params']: param_state = self.state[p] scale = tau_eff / self.local_normalizing_vec param_state['cum_grad'].mul_(weight * scale) param_list.append(param_state['cum_grad']) communicate(param_list, dist.all_reduce) for group in self.param_groups: lr = group['lr'] for p in group['params']: param_state = self.state[p] if self.gmf != 0: if 'global_momentum_buffer' not in param_state: buf = param_state[ 'global_momentum_buffer'] = torch.clone( param_state['cum_grad']).detach() buf.div_(lr) else: buf = param_state['global_momentum_buffer'] buf.mul_(self.gmf).add_(1 / lr, param_state['cum_grad']) param_state['old_init'].sub_(lr, buf) else: param_state['old_init'].sub_(param_state['cum_grad']) p.data.copy_(param_state['old_init']) param_state['cum_grad'].zero_() # Reinitialize momentum buffer if 'momentum_buffer' in param_state: param_state['momentum_buffer'].zero_() self.local_counter = 0 self.local_normalizing_vec = 0 self.local_steps = 0
def elastic_average(self, itr, cp): step_flag = (itr != 0 and itr % cp == 0) if step_flag: beta = 1/self.size - self.alpha - self.alpha**2/(1-self.alpha) for group in self.param_groups: for p in group['params']: param_state = self.state[p] buf = param_state['anchor_model'] p.data.mul_(1-self.alpha).add_(self.alpha, buf) buf.mul_(beta).add_(self.alpha/(1-self.alpha), p.data) communicate(self.comm_buf, dist.all_reduce)
def average(self): param_list = [] for group in self.param_groups: for p in group['params']: p.data.mul_(self.ratio) param_list.append(p.data) communicate(param_list, dist.all_reduce) for group in self.param_groups: for p in group['params']: param_state = self.state[p] param_state['old_init'] = torch.clone(p.data).detach() # Reinitialize momentum buffer if 'momentum_buffer' in param_state: param_state['momentum_buffer'].zero_()
def average(self): step_flag = (self.itr != 0 and self.itr % self.cp == 0) self.itr += 1 if step_flag: if self.gmf == 0: # simple average param_list = [] for group in self.param_groups: for p in group['params']: p.data.div_(self.size) param_list.append(p.data) communicate(param_list, dist.all_reduce) else: # simple average + global momentum for group in self.param_groups: lr = group['lr'] for p in group['params']: param_state = self.state[p] old_data = param_state['anchor_model'] if 'global_momentum_buffer' not in param_state: buf = param_state[ 'global_momentum_buffer'] = torch.clone( p.data).detach() buf.sub_(old_data) buf.div_(-lr) else: buf = param_state['global_momentum_buffer'] buf.mul_(self.gmf).sub_(1 / lr, p.data).add_( 1 / lr, old_data) old_data.add_(-lr, buf) old_data.div_(self.size) communicate(self.comm_buf, dist.all_reduce) for group in self.param_groups: for p in group['params']: param_state = self.state[p] old_data = param_state['anchor_model'] p.data.copy_(old_data)
def _async_all_reduce_(buff, buf_ready, comm_finish): while True: buf_ready.wait() communicate(buff, dist.all_reduce) buf_ready.clear() comm_finish.set()