def uncompress(self, sync_buffer, neighbor_hat_params, neighbors_info): # wait the sync. self.aggregator_fn.complete_wait(sync_buffer["sync_reqs_1"]) self.aggregator_fn.complete_wait(sync_buffer["sync_reqs_2"]) # uncompress and update. for rank, weight in neighbors_info.items(): # get hat_params of the current rank. hat_params = neighbor_hat_params[rank if rank in neighbor_hat_params else "memory"] # recover the message and the corresponding device. sync_buffer["flatten_norms"].buffer = comm.recover_device( sync_buffer["synced_flatten_norms"][rank], device=hat_params.buffer.device) sync_buffer[ "flatten_directions"].buffer = self.compressor_fn.uncompress( comm.recover_device(sync_buffer["synced_signs"][rank], device=hat_params.buffer.device), sync_buffer["sign_size"], ) # update neighbor_hat_params for hat_param, hat_param_memory, norm, sign in zip( hat_params, neighbor_hat_params["memory"], sync_buffer["flatten_norms"], sync_buffer["flatten_directions"], ): _update = norm / sign.nelement() * sign if rank in neighbor_hat_params: hat_param.add_(_update) hat_param_memory.add_(weight, _update)
def decompress(self, sync_buffer): # decompress and update. for rank in range(self.world_size): if rank == self.rank: continue # get grad_norm and build its tensorbuffer. _grad_norms = comm.recover_device( sync_buffer["synced_grad_norms"][rank], device=sync_buffer["synced_grads_tb"].buffer.device, ) grad_norms_tb = TensorBuffer(_grad_norms) # get signs and build its tensorbuffer. signs = comm.recover_device( sync_buffer["synced_signs"][rank], device=sync_buffer["synced_grads_tb"].buffer.device, ) _signs = self.compressor_fn.uncompress(signs, sync_buffer["sign_size"]) signs_tb = copy.deepcopy(sync_buffer["synced_grads_tb"]) signs_tb.buffer = _signs # update grads. for grad_norm, sign, synced_grad in zip( grad_norms_tb, signs_tb, sync_buffer["synced_grads_tb"]): _update = grad_norm * sign / synced_grad.nelement() synced_grad.add_(_update) # average grad. sync_buffer["synced_grads_tb"].buffer /= self.world_size * 1.0 return sync_buffer["synced_grads_tb"]
def uncompress(self, sync_buffer, neighbors_info): aggregated_info_tb = deepcopy(sync_buffer["params_tb"]) aggregated_info_tb.buffer = torch.zeros_like(aggregated_info_tb.buffer) # uncompress and update. for rank in neighbors_info.keys(): param_norms = sync_buffer["synced_param_norms"][rank] signs = sync_buffer["synced_signs"][rank] # recover the message and the corresponding device. param_norms = comm.recover_device( param_norms, device=sync_buffer["params_tb"].buffer.device ) signs = self.compressor_fn.uncompress( comm.recover_device( signs, device=sync_buffer["params_tb"].buffer.device ), sync_buffer["sign_size"], ) # build the corresponding tensorbuffer. param_norms_tb = TensorBuffer(param_norms) signs_tb = deepcopy(sync_buffer["params_tb"]) signs_tb.buffer = signs # accumulate information for the neighborhood.. for _info, _param_norm, _sign in zip( aggregated_info_tb, param_norms_tb, signs_tb ): _info.add_( self.consensus_stepsize * (neighbors_info[rank] - (1 if rank == self.rank else 0)) * (_param_norm / _sign.nelement() * _sign) ) return aggregated_info_tb
def decompress(self, sync_buffer): # wait the sync. self.aggregator_fn.complete_wait(sync_buffer["sync_req"]) # init placeholder. synced_updates_tb = deepcopy(sync_buffer["grads_tb"]) synced_updates_tb.buffer = torch.zeros_like(synced_updates_tb.buffer) # decompress and update. for rank in range(self.world_size): # get signs and build its tensorbuffer. synced_updates_tb.buffer += self.compressor_fn.uncompress( comm.recover_device( sync_buffer["synced_signs"][rank], device=sync_buffer["grads_tb"].buffer.device, ), sync_buffer["sign_size"], ) # average grad. if self.majority_vote: synced_updates_tb.buffer = torch.sign(synced_updates_tb.buffer) else: synced_updates_tb.buffer /= self.world_size * 1.0 return synced_updates_tb
def _recover_info(self, flatten_params, synced_message, message_size, selected_shapes, shapes): # use the pointers to recover the info and get synced grad. _message_size = int(message_size / 2) if self.is_compress_op: empty_grads = torch.zeros_like(flatten_params) for message in synced_message: q_values, q_indices = self.compressor_fn.uncompress( message[:_message_size], message[_message_size:], selected_shapes, shapes, ) empty_grads[q_indices] += q_values # get update tensor. _update = empty_grads / self.n_nodes else: # get update tensor. _update = synced_message / self.n_nodes # update flatten_params (assume the used lr is the same over params) updated_flatten_params = flatten_params.add( -self.param_groups[0]["lr"], recover_device(_update, device=flatten_params.device), ) return updated_flatten_params
def uncompress(self, sync_buffer, neighbors_info): aggregated_info_tb = deepcopy(sync_buffer["params_tb"]) aggregated_info_tb.buffer = torch.zeros_like(aggregated_info_tb.buffer) # uncompress and update. sycned_message_size = int(sync_buffer["sycned_message_size"] / 2) for rank in neighbors_info.keys(): _message = comm.recover_device( sync_buffer["synced_message"][rank], device=sync_buffer["params_tb"].buffer.device, ) values = _message[:sycned_message_size] indices = _message[sycned_message_size:] # deal with unbalanced values/indieces q_values, q_indices = self.compressor_fn.uncompress( values, indices, sync_buffer["selected_shapes"], sync_buffer["original_shapes"], ) # update the flatten hat params. aggregated_info_tb.buffer[q_indices] += ( self.consensus_stepsize * (neighbors_info[rank] - (1 if rank == self.rank else 0)) * q_values ) return aggregated_info_tb
def uncompress(self, sync_buffer, neighbor_hat_params, local_index): sycned_message_size = int(sync_buffer["sycned_message_size"] / 2) # uncompress and update. for rank, hat_params in neighbor_hat_params.items(): _message = comm.recover_device( sync_buffer["synced_message"][rank], device=hat_params.buffer.device ) values = _message[:sycned_message_size] indices = _message[sycned_message_size:] # deal with unbalanced values/indieces q_values, q_indices = self.compressor_fn.uncompress( values, indices, sync_buffer["selected_shapes"], sync_buffer["original_shapes"], ) # update the flatten hat params. hat_params.buffer[q_indices] = ( hat_params.buffer[q_indices] .mul(1 - 2 / local_index) .add(2 / local_index, q_values) )
def uncompress(self, sync_buffer, neighbor_hat_params): # uncompress and update. for rank, hat_params in neighbor_hat_params.items(): # map the tensors to the correct location. _message = comm.recover_device(sync_buffer["synced_message"][rank], device=hat_params.buffer.device) # update the flatten hat params. hat_params.buffer.add_(_message)
def uncompress(self, sync_buffer, neighbor_hat_params): # uncompress and update. for rank, hat_params in neighbor_hat_params.items(): # recover the message and the corresponding device. sync_buffer["flatten_norms"].buffer = comm.recover_device( sync_buffer["synced_flatten_norms"][rank], device=hat_params.buffer.device, ) sync_buffer[ "flatten_updates"].buffer = self.compressor_fn.uncompress( comm.recover_device(sync_buffer["synced_signs"][rank], device=hat_params.buffer.device), sync_buffer["sign_size"], ) # update hat_params. for hat_param, norm, sign in zip(hat_params, sync_buffer["flatten_norms"], sync_buffer["flatten_updates"]): # update the flatten hat params. hat_param.add_(norm / sign.nelement(), sign)
def uncompress(self, sync_buffer, neighbor_hat_params, neighbors_info): # wait the sync. self.aggregator_fn.complete_wait(sync_buffer["sync_reqs"]) for rank, weight in neighbors_info.items(): hat_params = neighbor_hat_params[rank if rank in neighbor_hat_params else "memory"] hat_params_memory = neighbor_hat_params["memory"] # recover correct values/indices. q_values = comm.recover_device(sync_buffer["synced_message"][rank], device=hat_params.buffer.device) # update neighbor_hat_params if rank in neighbor_hat_params: hat_params.buffer += q_values hat_params_memory.buffer += weight * q_values
def uncompress(self, sync_buffer, neighbors_info): aggregated_info_tb = deepcopy(sync_buffer["params_tb"]) aggregated_info_tb.buffer = torch.zeros_like(aggregated_info_tb.buffer) # uncompress and update. for rank in neighbors_info.keys(): # map the tensors to the correct location. _message = comm.recover_device( sync_buffer["synced_message"][rank], device=sync_buffer["params_tb"].buffer.device, ) # update the flatten hat params. aggregated_info_tb.buffer.add_( self.consensus_stepsize * (neighbors_info[rank] - (1 if rank == self.rank else 0)) * _message ) return aggregated_info_tb
def _uncompress_helper( self, _hat_params, _rank, synced_message, sycned_message_size, selected_shapes, original_shapes, ): # recover the message and the corresponding device. _message = comm.recover_device(synced_message[_rank], device=_hat_params.buffer.device) values = _message[:sycned_message_size] indices = _message[sycned_message_size:] # deal with unbalanced values/indieces q_values, q_indices = self.compressor_fn.uncompress( values, indices, selected_shapes, original_shapes) return q_values, q_indices