def hook(*ignore): assert p not in self._handles assert not p.grad.requires_grad name = self._parameter_names.get(p) p_size = np.prod(p.size()) torch.cuda.synchronize() begin_time = time.time() if self._use_allgather and p_size > self._plan1: torch.cuda.synchronize() begin_mom_time = time.time() weight_decay = self._weight_decay #group['weight_decay'] momentum = self._momentum #group['momentum'] dampening = 0.0 #group['dampening'] d_p = p.grad.data d_p.div_(hvd.size()) if weight_decay != 0: d_p.add_(weight_decay, p.data) if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) buf.mul_(momentum).add_(d_p) else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(1 - dampening, d_p) #TODO if 'residue_buffer' not in param_state: rsd = param_state['residue_buffer'] = torch.zeros_like(p.data) rsd.add_(param_state['momentum_buffer']) if self._use_nesterov: rsd = rsd.add(momentum, d_p) else: rsd = param_state['residue_buffer'] rsd.add_(param_state['momentum_buffer']) if self._use_nesterov: rsd = rsd.add(momentum, d_p) torch.cuda.synchronize() self.mom_time += time.time() - begin_mom_time compressed_val = [] compressed_idx = [] torch.cuda.synchronize() begin_select_time = time.time() if 'flag' not in param_state: param_state['flag'] = 0 if 'interval' not in param_state: param_state['interval'] = 10 it = 0 sparsity = 0.0 if p_size > self._plan3: if param_state['flag'] == 1: compressed_val, compressed_idx, it, _, sparsity = \ select_bs_top(param_state['residue_buffer'], 0.001) param_state['flag'] = 0 else: compressed_val, compressed_idx, it, _, sparsity = \ select_bs_bottom(param_state['residue_buffer'], 0.001) param_state['flag'] = 1 elif p_size > self._plan2: if param_state['flag'] == 1: compressed_val, compressed_idx = \ select_trim_topk_mean(param_state['residue_buffer'], 0.001) param_state['flag'] = 0 else: compressed_val, compressed_idx = \ select_trim_lowk_mean(param_state['residue_buffer'], 0.001) param_state['flag'] = 1 else: if param_state['flag'] == 1: compressed_val, compressed_idx = \ select_topk_mean(param_state['residue_buffer'], 0.001) param_state['flag'] = 0 else: compressed_val, compressed_idx = \ select_lowk_mean(param_state['residue_buffer'], 0.001) param_state['flag'] = 1 assert(len(compressed_idx) > 0) torch.cuda.synchronize() end_select_time = time.time() self.select_time += end_select_time - begin_select_time #if param_state['interval'] == 10: # compressed_val, compressed_idx, it, param_state['mid_store'], sparsity = \ # select_top_k_thdv3(param_state['residue_buffer'], 0.001) # param_state['interval'] = 0 #else: # compressed_val, compressed_idx, sparsity = \ # select_top_k_fixthd(param_state['residue_buffer'], param_state['mid_store']) # param_state['interval'] += 1 #if hvd.rank() == 0: # print(name, p.size()) #if hvd.rank() == 0 and name == "features.27.weight": #if name == "features.27.weight": # torch.save(compressed_val, 'compressed_val' + str(local_rank())) # torch.save(compressed_idx, 'compressed_idx' + str(local_rank())) #if hvd.rank() == 0 and name == "features.27.weight": # self._it = it # self._mid = param_state['mid_store'] # self._sparsity = sparsity #tmp_t = torch.tensor([local_len], dtype=torch.long) # tmp_t = torch.tensor([local_len]) # print("len list, ", global_len_list) #local_len = torch.min(global_len_list) ##print("local_len, ", local_len) #compressed_val = compressed_val[0:local_len] #compressed_idx = compressed_idx[0:local_len] torch.cuda.synchronize() begin_mask_time = time.time() masks_size = self._masks[name].size() self._masks[name].zero_() self._masks[name] = self._masks[name].view(-1) self._masks[name][compressed_idx] = 1.0 self._masks[name] = 1.0 - self._masks[name] self._masks[name] = self._masks[name].view(masks_size) if self._debug: self._v_ref[name] = param_state['residue_buffer'] * (1.0 - self._masks[name]) allreduce_(self._v_ref[name], average = False) if hvd.size() == 1: p.grad.data = param_state['residue_buffer'] * (1.0 - self._masks[name]) param_state['residue_buffer'].mul_(self._masks[name]) param_state['momentum_buffer'].mul_(self._masks[name]) end_mask_time = time.time() self.mask_time += end_mask_time - begin_mask_time torch.cuda.synchronize() begin_pack_time = time.time() if hvd.size() > 1: if self._use_gpu: if p_size > self._plan3: compressed_msg= torch.cat((\ torch.tensor([len(compressed_idx)]).type(torch.cuda.LongTensor),\ compressed_idx)) handle = _allgather_async(compressed_msg, self._compressed_idx[name], name=name + "idx") self._handles[p] = handle handle = _allgather_async(torch.mean(compressed_val), self._compressed_val[name], name=name + "val") self._handles_val[p] = handle else: self._compressed_msg_size[name] = len(compressed_idx) handle = _allgather_async(compressed_idx, self._compressed_idx[name], \ name = name+"idx") self._handles[p] = handle handle = _allgather_async(torch.mean(compressed_val), \ self._compressed_val[name], name=name+"val") self._handles_val[p] = handle torch.cuda.synchronize() self.pack_time += time.time() - begin_pack_time else: torch.cuda.synchronize() begin_allreduce_time = time.time() p.grad.data.div_(hvd.size()) p.grad.data.add_(torch.mul(p.data, self._weight_decay)) param_state = self.state[p] if 'momentum_buffer' not in param_state: param_state['momentum_buffer'] = torch.zeros_like(p.data) if self._use_nesterov: param_state['momentum_buffer'] = torch.mul(torch.add(param_state['momentum_buffer'], p.grad.data), self._momentum) p.grad.data = param_state['momentum_buffer'] + p.grad.data else: param_state['momentum_buffer']= self._momentum * param_state['momentum_buffer'] + p.grad.data p.grad.data = param_state['momentum_buffer'] if hvd.size() > 1: handle = allreduce_async_(p.grad.data, average=False, name=name) self._handles[p] = handle torch.cuda.synchronize() self.allreduce_time += time.time() - begin_allreduce_time torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time
def step(self, closure=None): # local clipping # DGC for group in self.param_groups: for p in group['params']: p.grad.data.add_(torch.mul(p.data, self._weight_decay)) p.grad.data.div_(hvd.size()) torch.nn.utils.clip_grad_norm_(group['params'], 0.25 * hvd.size()**-0.5) #torch.nn.utils.clip_grad_norm(group['params'], 0.25) #weight_decay = group['weight_decay'] #momentum = group['momentum'] torch.cuda.synchronize() begin_time = time.time() dampening = 0.0 #gcoup['dampening'] for p in group['params']: assert p not in self._handles assert not p.grad.requires_grad name = self._parameter_names.get(p) p_size = np.prod(p.size()) if self._use_allgather and p_size > 1024: param_state = self.state[p] self._V[name].add_(p.grad.data) # fjr compress grad compressed_val = [] compressed_idx = [] torch.cuda.synchronize() begin_select_time = time.time() if 'interval' not in param_state: param_state['interval'] = 1 if param_state['interval'] == 0: compressed_val, compressed_idx, _, _, _ = \ select_bs_top(self._V[name], 0.001) param_state['interval'] = 1 else: compressed_val, compressed_idx, _, _, _ = \ select_bs_bottom(self._V[name], 0.001) param_state['interval'] = 0 #masks_size = self._masks[name].size() #self._masks[name].zero_() #self._masks[name] = self._masks[name].view(-1) #self._masks[name][compressed_idx] = 1.0 #self._masks[name] = 1.0 - self._masks[name] #self._masks[name] = self._masks[name].view(masks_size) torch.cuda.synchronize() end_select_time = time.time() self.select_time += end_select_time - begin_select_time if self._debug: self._v_ref[name] = self._V[name] * (1.0 - self._masks[name]) allreduce_(self._v_ref[name], average=False) torch.cuda.synchronize() begin_mask_time = time.time() V_size = self._masks[name].size() self._V[name] = self._V[name].view(-1) self._V[name][compressed_idx] = 0.0 self._V[name] = self._V[name].view(V_size) torch.cuda.synchronize() self.mask_time += time.time() - begin_mask_time begin_pack_time = time.time() self._compressed_msg_size[name] = len(compressed_idx) if self._use_gpu: compressed_msg = torch.cat([\ torch.tensor([len(compressed_idx)]).type('torch.cuda.LongTensor'),\ compressed_idx]) if p_size == 1500 * 10000: compressed_msg = torch.cat([\ torch.tensor([len(compressed_idx)]).type('torch.cuda.FloatTensor'),\ compressed_idx.type('torch.cuda.FloatTensor'), \ compressed_val]) handle = _allgather_async(compressed_msg, self._compressed_msg[name], name=name) self._handles[p] = handle else: handle = _allgather_async(compressed_msg, self._compressed_idx[name], name=name + "idx") self._handles[p] = handle handle = _allgather_async(torch.mean(compressed_val), \ self._compressed_val[name], name=name+"val") self._handles_val[p] = handle torch.cuda.synchronize() self.pack_time += time.time() - begin_pack_time else: handle = allreduce_async_(p.grad.data, average=True, name=name) self._handles[p] = handle torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time self.synchronize() return super(self.__class__, self).step(closure)
def step(self, closure=None): # local clipping # DGC for group in self.param_groups: for p in group['params']: p.grad.data.add_(torch.mul(p.data, self._weight_decay)) p.grad.data.div_(hvd.size()) torch.nn.utils.clip_grad_norm(group['params'], 0.25 * hvd.size()**-0.5) #torch.nn.utils.clip_grad_norm(group['params'], 0.25) #weight_decay = group['weight_decay'] #momentum = group['momentum'] torch.cuda.synchronize() begin_time = time.time() dampening = 0.0 #gcoup['dampening'] for p in group['params']: assert p not in self._handles assert not p.grad.requires_grad name = self._parameter_names.get(p) p_size = np.prod(p.size()) if self._use_allgather and p_size > 1024: if self._momentum != 0: if self._use_nesterov: self._U[name] = torch.mul( torch.add(self._U[name], p.grad.data), self._momentum) self._V[name] = self._V[name] + self._U[ name] + p.grad.data else: self._U[name] = self._momentum * self._U[ name] + p.grad.data self._V[name] = self._V[name] + self._U[name] else: self._V[name].add_(p.grad.data) compressed_val = [] compressed_idx = [] torch.cuda.synchronize() begin_select_time = time.time() #compressed_val, compressed_idx = select_top_k_thd(self._V[name], 0.001) if self._flag[name] == 0: compressed_val, compressed_idx = \ select_lowk_truncated_mean(self._V[name], 0.001) self._flag[name] = 1 else: compressed_val, compressed_idx = \ select_topk_truncated_mean(self._V[name], 0.001) self._flag[name] = 0 #compressed_val_low, compressed_idx_low = \ # select_lowk_truncated_mean(self._V[name], 0.001) #compressed_val_top, compressed_idx_top = \ # select_topk_truncated_mean(self._V[name], 0.001) #compressed_mean = 0 #if(-torch.mean(compressed_val_low) > torch.mean(compressed_val_top)): # compressed_val = compressed_val_low # compressed_idx = compressed_idx_low # compressed_mean = torch.mean(compressed_val_low) #else: # compressed_val = compressed_val_top # compressed_idx = compressed_idx_top # compressed_mean = torch.mean(compressed_val_top) torch.cuda.synchronize() end_select_time = time.time() self.select_time += end_select_time - begin_select_time if self._debug: masks_size = self._masks[name].size() self._masks[name].zero_() self._masks[name] = self._masks[name].view(-1) self._masks[name][compressed_idx] = 1.0 self._masks[name] = 1.0 - self._masks[name] self._masks[name] = self._masks[name].view(masks_size) self._v_ref[name] = self._V[name] * (1.0 - self._masks[name]) allreduce_(self._v_ref[name], average=False) #self._V[name].mul_(self._masks[name]) V_size = self._masks[name].size() self._V[name] = self._V[name].view(-1) self._V[name][compressed_idx] = 0.0 self._V[name] = self._V[name].view(V_size) if self._momentum != 0.0: self._U[name].mul_(self._masks[name]) torch.cuda.synchronize() begin_comm_time = time.time() self._compressed_msg_size[name] = len(compressed_idx) handle = _allgather_async(compressed_idx, self._compressed_idx[name], name=name + "idx") self._handles[p] = handle handle = _allgather_async(torch.mean(compressed_val), \ self._compressed_val[name], name=name+"val") self._handles_val[p] = handle else: if self._weight_decay != 0.0: p.grad.data.add_(torch.mul(p.data, self._weight_decay)) if self._momentum != 0: if self._use_nesterov: self._U[name] = torch.mul( torch.add(self._U[name], p.grad.data), self._momentum) self._V[name] = self._V[name] + self._U[ name] + p.grad.data else: self._U[name] = self._momentum * self._U[ name] + p.grad.data self._V[name] = self._V[name] + self._U[name] p.grad.data = self._V[name] #compressed_msg = torch.randn(100).cuda() #handle = _allgather_async(compressed_msg, self._compressed_msg[name], name=name) torch.cuda.synchronize() begin_comm_time = time.time() handle = allreduce_async_(p.grad.data, average=True, name=name) self._handles[p] = handle torch.cuda.synchronize() self.comm_time += time.time() - begin_comm_time torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time self.synchronize() return super(self.__class__, self).step(closure)
def hook(*ignore): assert p not in self._handles assert not p.grad.requires_grad name = self._parameter_names.get(p) p_size = np.prod(p.size()) torch.cuda.synchronize() begin_time = time.time() if self._use_allgather and p_size > 1024: weight_decay = self._weight_decay #group['weight_decay'] momentum = self._momentum #group['momentum'] dampening = 0.0 #group['dampening'] nesterov = False #group['nesterov'] d_p = p.grad.data if weight_decay != 0: d_p.add_(weight_decay, p.data) if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state[ 'momentum_buffer'] = torch.zeros_like(p.data) buf.mul_(momentum).add_(d_p) else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(1 - dampening, d_p) #TODO # if nesterov: # d_p = d_p.add(momentum, buf) # else: # d_p = buf if 'residue_buffer' not in param_state: rsd = param_state['residue_buffer'] = torch.zeros_like( p.data) rsd.add_(param_state['momentum_buffer']) else: rsd = param_state['residue_buffer'] rsd.add_(param_state['momentum_buffer']) compressed_val = [] compressed_idx = [] torch.cuda.synchronize() begin_select_time = time.time() if 'mid_store' not in param_state: param_state['mid_store'] = 0.0 if 'interval' not in param_state: param_state['interval'] = 10 it = 0 sparsity = 0.0 if param_state['interval'] == 10: compressed_val, compressed_idx, it, param_state['mid_store'], sparsity = \ select_top_k_thdv3(param_state['residue_buffer'], 0.001) param_state['interval'] = 0 else: compressed_val, compressed_idx, sparsity = \ select_top_k_fixthd(param_state['residue_buffer'], param_state['mid_store']) param_state['interval'] += 1 assert (len(compressed_idx) > 0) #if hvd.rank() == 0: # print(name, p.size()) #if hvd.rank() == 0 and name == "features.27.weight": #if name == "features.27.weight": # torch.save(compressed_val, 'compressed_val' + str(local_rank())) # torch.save(compressed_idx, 'compressed_idx' + str(local_rank())) #if hvd.rank() == 0 and name == "features.27.weight": # self._it = it # self._mid = param_state['mid_store'] # self._sparsity = sparsity torch.cuda.synchronize() end_select_time = time.time() self.select_time += end_select_time - begin_select_time #tmp_t = torch.tensor([local_len], dtype=torch.long) # tmp_t = torch.tensor([local_len]) # print("len list, ", global_len_list) #local_len = torch.min(global_len_list) ##print("local_len, ", local_len) #compressed_val = compressed_val[0:local_len] #compressed_idx = compressed_idx[0:local_len] masks_size = self._masks[name].size() self._masks[name].zero_() self._masks[name] = self._masks[name].view(-1) self._masks[name][compressed_idx] = 1.0 self._masks[name] = 1.0 - self._masks[name] self._masks[name] = self._masks[name].view(masks_size) if self._debug: self._v_ref[name] = param_state['residue_buffer'] * ( 1.0 - self._masks[name]) allreduce_(self._v_ref[name], average=False) #self._V[name] = self._V[name] * (1 - self._masks[name]) #self._U[name] = self._U[name] * (1 - self._masks[name]) param_state['residue_buffer'].mul_(self._masks[name]) param_state['momentum_buffer'].mul_(self._masks[name]) #self._compressed_msg_size[name] = len(compressed_idx) torch.cuda.synchronize() begin_pack_time = time.time() #if self._use_gpu: # #compressed_msg = torch.cat([\ # # torch.tensor([len(compressed_idx)]).type('torch.cuda.FloatTensor'),\ # # compressed_idx.type('torch.cuda.FloatTensor'), \ # # compressed_val]) # compressed_msg = torch.cat([\ # torch.tensor([len(compressed_idx)]).type('torch.cuda.LongTensor'), \ # compressed_idx]) handle = _allgather_async(compressed_idx, self._compressed_idx[name], name=name + 'idx') #compressed_msg = torch.randn(100).cuda() self._handles[p] = handle handle = _allgather_async(compressed_val, self._compressed_val[name], name=name + 'val') #compressed_msg = torch.randn(100).cuda() self._handles_val[p] = handle handle = _allgather_async(torch.tensor([len(compressed_idx)]), self._compressed_len[name], name=name + 'len') #handle = _allgather_async(len(compressed_idx), self._compressed_len[name], name=name + 'len') #compressed_msg = torch.randn(100).cuda() self._handles_len[p] = handle torch.cuda.synchronize() self.pack_time += time.time() - begin_pack_time else: p.grad.data.add_(torch.mul(p.data, self._weight_decay)) if self._use_nesterov: self._U[name] = torch.mul( torch.add(self._U[name], p.grad.data), self._momentum) self._V[name] = self._V[name] + self._U[name] + p.grad.data else: self._U[ name] = self._momentum * self._U[name] + p.grad.data self._V[name] = self._V[name] + self._U[name] p.grad.data = self._V[name] #compressed_msg = torch.randn(100).cuda() #handle = _allgather_async(compressed_msg, self._compressed_msg[name], name=name) handle = allreduce_async_(p.grad.data, average=True, name=name) self._handles[p] = handle torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time
def hook(*ignore): assert p not in self._handles assert not p.grad.requires_grad name = self._parameter_names.get(p) p_size = np.prod(p.size()) torch.cuda.synchronize() begin_time = time.time() if self._use_allgather and p_size > 1024: # fjr compress grad p.grad.data.add_(torch.mul(p.data, self._weight_decay)) p.grad.data.div_(hvd.size()) if self._use_nesterov: self._U[name] = torch.mul( torch.add(self._U[name], p.grad.data), self._momentum) self._V[name] = self._V[name] + self._U[name] + p.grad.data else: self._U[ name] = self._momentum * self._U[name] + p.grad.data self._V[name] = self._V[name] + self._U[name] compressed_val = [] compressed_idx = [] #if p_size < 1000: # self._masks[name], compressed_val, compressed_idx = select_top_k_thd(self._V[name], 0.001, self._masks[name]) #else: #self._masks[name], compressed_val, compressed_idx = select_top_k_thd(self._V[name], 0.001, self._masks[name]) #self._masks[name], compressed_val, compressed_idx = select_top_k_thd(self._V[name], 0.001, self._masks[name]) torch.cuda.synchronize() begin_select_time = time.time() local_mean, compressed_idx = select_top_k_thd_mean( self._V[name], 0.001) torch.cuda.synchronize() end_select_time = time.time() self.select_time += end_select_time - begin_select_time #tmp_t = torch.tensor([local_len], dtype=torch.long) # tmp_t = torch.tensor([local_len]) # print("len list, ", global_len_list) #local_len = torch.min(global_len_list) ##print("local_len, ", local_len) #compressed_val = compressed_val[0:local_len] #compressed_idx = compressed_idx[0:local_len] masks_size = self._masks[name].size() self._masks[name].zero_() self._masks[name] = self._masks[name].view(-1) self._masks[name][compressed_idx] = 1.0 self._masks[name] = 1.0 - self._masks[name] self._masks[name] = self._masks[name].view(masks_size) if self._debug: self._v_ref[name] = self._V[name] * (1.0 - self._masks[name]) allreduce_(self._v_ref[name], average=False) #self._V[name] = self._V[name] * (1 - self._masks[name]) #self._U[name] = self._U[name] * (1 - self._masks[name]) self._V[name].mul_(self._masks[name]) self._U[name].mul_(self._masks[name]) #self._compressed_msg_size[name] = len(compressed_idx) if self._use_gpu: compressed_msg = torch.cat(\ [torch.tensor([len(compressed_idx)]).type('torch.cuda.FloatTensor'), \ torch.tensor([local_mean]).type('torch.cuda.FloatTensor'), \ compressed_idx.type('torch.cuda.FloatTensor')]) else: pass handle = _allgather_async(compressed_msg, self._compressed_msg[name], name=name) #compressed_msg = torch.randn(100).cuda() self._handles[p] = handle else: p.grad.data.add_(torch.mul(p.data, self._weight_decay)) if self._use_nesterov: self._U[name] = torch.mul( torch.add(self._U[name], p.grad.data), self._momentum) self._V[name] = self._V[name] + self._U[name] + p.grad.data else: self._U[ name] = self._momentum * self._U[name] + p.grad.data self._V[name] = self._V[name] + self._U[name] p.grad.data = self._V[name] #compressed_msg = torch.randn(100).cuda() #handle = _allgather_async(compressed_msg, self._compressed_msg[name], name=name) handle = allreduce_async_(p.grad.data, average=True, name=name) self._handles[p] = handle torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time
def step(self, closure=None): # local clipping # DGC for group in self.param_groups: for p in group['params']: p.grad.data.add_(torch.mul(p.data, self._weight_decay)) p.grad.data.div_(hvd.size()) torch.nn.utils.clip_grad_norm_(group['params'], 0.25 * hvd.size() ** -0.5) #torch.nn.utils.clip_grad_norm(group['params'], 0.25) #weight_decay = group['weight_decay'] #momentum = group['momentum'] torch.cuda.synchronize() begin_time = time.time() dampening = 0.0 #gcoup['dampening'] for p in group['params']: assert p not in self._handles assert not p.grad.requires_grad name = self._parameter_names.get(p) p_size = np.prod(p.size()) if self._use_allgather and p_size > 1024: param_state = self.state[p] # fjr compress grad if self._use_nesterov: self._U[name] = torch.mul(torch.add(self._U[name], p.grad.data), self._momentum) self._V[name] = self._V[name] + self._U[name] + p.grad.data else: self._U[name] = self._momentum * self._U[name] + p.grad.data self._V[name] = self._V[name] + self._U[name] compressed_val = [] compressed_idx = [] torch.cuda.synchronize() begin_select_time = time.time() #if 'interval' not in param_state: # param_state['interval'] = 1 #if param_state['interval'] == 0: # compressed_val, compressed_idx, _, _, _ = \ # select_bs_top(self._V[name], 0.001) # param_state['interval'] = 1 #else: # compressed_val, compressed_idx, _, _, _ = \ # select_bs_bottom(self._V[name], 0.001) # param_state['interval'] = 0 compressed_val_top, compressed_idx_top, _, _, _ = \ select_bs_top(self._V[name], 0.001) compressed_val_low, compressed_idx_low, _, _, _ = \ select_bs_bottom(self._V[name], 0.001) compressed_mean = 0.0 if torch.mean(compressed_val_top) > -torch.mean(compressed_val_low): compressed_val = compressed_val_top compressed_idx = compressed_idx_top else: compressed_val = compressed_val_low compressed_idx = compressed_idx_low masks_size = self._masks[name].size() self._masks[name].zero_() self._masks[name] = self._masks[name].view(-1) self._masks[name][compressed_idx] = 1.0 self._masks[name] = 1.0 - self._masks[name] self._masks[name] = self._masks[name].view(masks_size) torch.cuda.synchronize() end_select_time = time.time() self.select_time += end_select_time - begin_select_time if self._debug: self._v_ref[name] = self._V[name] * (1.0 - self._masks[name]) allreduce_(self._v_ref[name], average = False) #self._V[name] = self._V[name] * (1 - self._masks[name]) #self._U[name] = self._U[name] * (1 - self._masks[name]) self._V[name].mul_(self._masks[name]) self._U[name].mul_(self._masks[name]) torch.cuda.synchronize() begin_comm_time = time.time() self._compressed_msg_size[name] = len(compressed_idx) if self._use_gpu: compressed_msg = torch.cat([\ torch.tensor([len(compressed_idx)]).type('torch.cuda.LongTensor'),\ compressed_idx]) handle = _allgather_async(compressed_msg, self._compressed_idx[name], name=name+"idx") self._handles[p] = handle handle = _allgather_async(torch.mean(compressed_val), \ self._compressed_val[name], name=name+"val") self._handles_val[p] = handle torch.cuda.synchronize() self.comm_time += time.time() - begin_comm_time else: #p.grad.data.add_(torch.mul(p.data, self._weight_decay)) if self._use_nesterov: self._U[name] = torch.mul(torch.add(self._U[name], p.grad.data), self._momentum) self._V[name] = self._V[name] + self._U[name] + p.grad.data else: self._U[name] = self._momentum * self._U[name] + p.grad.data self._V[name] = self._V[name] + self._U[name] p.grad.data = self._V[name] #compressed_msg = torch.randn(100).cuda() handle = allreduce_async_(p.grad.data, average=True, name=name) self._handles[p] = handle torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time self.synchronize() return super(self.__class__, self).step(closure)
def hook(*ignore): assert p not in self._handles assert not p.grad.requires_grad name = self._parameter_names.get(p) p_size = np.prod(p.size()) torch.cuda.synchronize() begin_time = time.time() if self._use_allgather and p_size > 1024: weight_decay = self._weight_decay #group['weight_decay'] momentum = self._momentum #group['momentum'] dampening = 0.0 #group['dampening'] nesterov = False #group['nesterov'] d_p = p.grad.data d_p.div_(hvd.size()) if weight_decay != 0: d_p.add_(weight_decay, p.data) if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) buf.mul_(momentum).add_(d_p) else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(d_p) #buf.mul_(momentum).add_(1 - dampening, d_p) #TODO # if nesterov: # d_p = d_p.add(momentum, buf) # else: # d_p = buf if 'residue_buffer' not in param_state: rsd = param_state['residue_buffer'] = torch.zeros_like(p.data) rsd.add_(param_state['momentum_buffer']) if self._use_nesterov: rsd = rsd.add(momentum, d_p) else: rsd = param_state['residue_buffer'] rsd.add_(param_state['momentum_buffer']) if self._use_nesterov: rsd = rsd.add(momentum, d_p) compressed_val = [] compressed_idx = [] torch.cuda.synchronize() begin_select_time = time.time() if 'interval' not in param_state: param_state['interval'] = 1 it = 0 sparsity = 0.0 if param_state['interval'] == 1: compressed_val, compressed_idx, it, _, sparsity = \ select_bs_top(param_state['residue_buffer'], 0.001) param_state['interval'] = 0 else: compressed_val, compressed_idx, it, _, sparsity = \ select_bs_bottom(param_state['residue_buffer'], 0.001) param_state['interval'] = 1 assert(len(compressed_idx) > 0) torch.cuda.synchronize() end_select_time = time.time() self.select_time += end_select_time - begin_select_time masks_size = self._masks[name].size() self._masks[name].zero_() self._masks[name] = self._masks[name].view(-1) self._masks[name][compressed_idx] = 1.0 self._masks[name] = 1.0 - self._masks[name] self._masks[name] = self._masks[name].view(masks_size) if self._debug: self._v_ref[name] = torch.mean(compressed_val) \ * (1.0 - self._masks[name]) allreduce_(self._v_ref[name], average = False) if hvd.size() == 1: p.grad.data = torch.mean(compressed_val) \ * (1.0 - self._masks[name]) param_state['residue_buffer'].mul_(self._masks[name]) param_state['momentum_buffer'].mul_(self._masks[name]) torch.cuda.synchronize() begin_pack_time = time.time() compressed_msg = [] if hvd.size() > 1: if self._use_gpu: compressed_msg= torch.cat((\ torch.tensor([len(compressed_idx)]).type(torch.cuda.LongTensor),\ compressed_idx)) handle = _allgather_async(compressed_msg, self._compressed_idx[name], name=name + "idx") self._handles[p] = handle handle = _allgather_async(torch.mean(compressed_val), self._compressed_val[name], name=name + "val") self._handles_val[p] = handle torch.cuda.synchronize() self.pack_time += time.time() - begin_pack_time else: weight_decay = self._weight_decay #group['weight_decay'] momentum = self._momentum #group['momentum'] dampening = 0.0 #group['dampening'] d_p = p.grad.data if weight_decay != 0: d_p.add_(weight_decay, p.data) if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) buf.mul_(momentum).add_(d_p) else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(1 - dampening, d_p) if self._use_nesterov: d_p = d_p.add(momentum, buf) else: d_p = buf #compressed_msg = torch.randn(100).cuda() #handle = _allgather_async(compressed_msg, self._compressed_msg[name], name=name) if hvd.size() > 1: handle = allreduce_async_(p.grad.data, average=True, name=name) self._handles[p] = handle torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time
def step(self, closure=None): # local clipping # DGC for group in self.param_groups: for p in group['params']: p.grad.data.add_(torch.mul(p.data, self._weight_decay)) p.grad.data.div_(hvd.size()) torch.nn.utils.clip_grad_norm_(group['params'], 0.25 * hvd.size() ** -0.5) #torch.nn.utils.clip_grad_norm(group['params'], 0.25) #weight_decay = group['weight_decay'] #momentum = group['momentum'] torch.cuda.synchronize() begin_time = time.time() dampening = 0.0 #gcoup['dampening'] for p in group['params']: assert p not in self._handles assert not p.grad.requires_grad name = self._parameter_names.get(p) p_size = np.prod(p.size()) if self._use_allgather and p_size > 1024: param_state = self.state[p] self._V[name].add_(p.grad.data) compressed_val = [] compressed_idx = [] #if p_size < 1000: #self._masks[name], compressed_val, compressed_idx = select_top_k_appr(self._V[name], 0.001, self._masks[name]) torch.cuda.synchronize() begin_select_time = time.time() if 'mid_store' not in param_state: param_state['mid_store'] = 0.0 if 'interval' not in param_state: param_state['interval'] = self._interval compressed_val = [] compressed_idx = [] if param_state['interval'] == self._interval: compressed_val, compressed_idx, it, param_state['mid_store'], sparsity = \ select_top_k_thdv3(self._V[name], 0.001) param_state['interval'] = 0 else: compressed_val, compressed_idx, sparsity = \ select_top_k_fixthd(self._V[name], param_state['mid_store']) param_state['interval'] += 1 #masks_size = self._masks[name].size() #self._masks[name].zero_() #self._masks[name] = self._masks[name].view(-1) #self._masks[name][compressed_idx] = 1.0 #self._masks[name] = 1.0 - self._masks[name] #self._masks[name] = self._masks[name].view(masks_size) torch.cuda.synchronize() self.select_time += time.time() - begin_select_time if self._debug: self._v_ref[name] = self._V[name] * self._masks[name] allreduce_(self._v_ref[name], average = False) #self._V[name] = self._V[name] * (1 - self._masks[name]) #self._U[name] = self._U[name] * (1 - self._masks[name]) torch.cuda.synchronize() begin_mask_time = time.time() V_size = self._masks[name].size() self._V[name] = self._V[name].view(-1) self._V[name][compressed_idx] = 0.0 self._V[name] = self._V[name].view(V_size) torch.cuda.synchronize() self.mask_time += time.time() - begin_mask_time begin_pack_time = time.time() self._compressed_msg_size[name] = len(compressed_idx) if self._use_gpu: compressed_msg = torch.cat([\ torch.tensor([len(compressed_idx)]).type('torch.cuda.FloatTensor'),\ compressed_idx.type('torch.cuda.FloatTensor'), \ compressed_val]) handle = _allgather_async(compressed_msg, self._compressed_msg[name], name=name) self._handles[p] = handle torch.cuda.synchronize() self.pack_time += time.time() - begin_pack_time else: handle = allreduce_async_(p.grad.data, average=True, name=name) self._handles[p] = handle torch.cuda.synchronize() self.pruning_time += time.time() - begin_time self.synchronize() return super(self.__class__, self).step(closure)