class ECDQuantizationCompressor(object): def __init__(self, aggregator, comm_op, comm_device, compress_ratio, quantize_level, is_biased, backend, use_ipc, **kargs): # assign the common hyper-parameters self.aggregator_fn = aggregator self.comm_op = comm_op self.comm_device = comm_device self.compress_ratio = compress_ratio self.quantize_level = quantize_level self.is_biased = is_biased self.backend = backend self.use_ipc = use_ipc self.kargs = kargs self.compressor_fn = QuantizationCompressor() def compress(self, sync_buffer): # get the sign/magnitude for the tensor (to be transmitted). quantized_values = [] for flatten_updated_param in sync_buffer["flatten_updated_params"]: _quantized_values = self.compressor_fn.compress( flatten_updated_param, self.comm_op, self.quantize_level, self.is_biased) quantized_values.append(_quantized_values) # flatten selected values/indices. flatten_updates = TensorBuffer(quantized_values) # get n_bits to transmit. n_bits = get_n_bits(flatten_updates.buffer) * self.quantize_level / 32 # update shared dict. sync_buffer["flatten_updates"] = flatten_updates sync_buffer["n_bits"] = n_bits def sync(self, sync_buffer): # prepare the sync. to_sync_message = sync_buffer["flatten_updates"].buffer if self.comm_device == "cpu": to_sync_message = to_sync_message.cpu().pin_memory() # sync. synced_message = self.aggregator_fn._agg(to_sync_message, op="get_raw_sync_data", force_wait=True) # update sync_buffer. sync_buffer["synced_message"] = synced_message def uncompress(self, sync_buffer, neighbor_hat_params, local_index): # uncompress and update. for rank, hat_params in neighbor_hat_params.items(): # map the tensors to the correct location. _message = comm.recover_device(sync_buffer["synced_message"][rank], device=hat_params.buffer.device) # update the flatten hat params. hat_params.buffer.mul_(1 - 2 / local_index).add_( 2 / local_index, _message)
class DeepSqueezeQuantizationCompressor(object): def __init__( self, aggregator, rank, comm_op, comm_device, compress_ratio, quantize_level, is_biased, backend, use_ipc, consensus_stepsize, **kargs ): # assign the common hyper-parameters self.aggregator_fn = aggregator self.rank = rank self.comm_op = comm_op self.comm_device = comm_device self.compress_ratio = compress_ratio self.quantize_level = quantize_level self.is_biased = is_biased self.backend = backend self.use_ipc = use_ipc self.consensus_stepsize = consensus_stepsize self.kargs = kargs self.compressor_fn = QuantizationCompressor() def compress(self, sync_buffer): # get the sign/magnitude for the tensor (to be transmitted). quantized_values = [] # compress and get compressed model. local_compressed_params_tb = deepcopy(sync_buffer["params_tb"]) local_compressed_params_tb.buffer = torch.zeros_like( local_compressed_params_tb.buffer ) for param, local_compressed_param in zip( sync_buffer["params_tb"], local_compressed_params_tb ): # quantize. _quantized_values = self.compressor_fn.compress( param, self.comm_op, self.quantize_level, self.is_biased ) quantized_values.append(_quantized_values) # update the local compressed params. local_compressed_param.data.copy_(_quantized_values) # flatten selected values/indices. flatten_updates = TensorBuffer(quantized_values) # get n_bits to transmit. n_bits = get_n_bits(flatten_updates.buffer) * self.quantize_level / 32 # update shared dict. sync_buffer["flatten_updates"] = flatten_updates sync_buffer["n_bits"] = n_bits return local_compressed_params_tb def sync(self, sync_buffer): # prepare the sync. to_sync_message = sync_buffer["flatten_updates"].buffer if self.comm_device == "cpu": to_sync_message = to_sync_message.cpu().pin_memory() # sync. synced_message = self.aggregator_fn._agg( to_sync_message, op="get_raw_sync_data", force_wait=True ) # update sync_buffer. sync_buffer["synced_message"] = synced_message def uncompress(self, sync_buffer, neighbors_info): aggregated_info_tb = deepcopy(sync_buffer["params_tb"]) aggregated_info_tb.buffer = torch.zeros_like(aggregated_info_tb.buffer) # uncompress and update. for rank in neighbors_info.keys(): # map the tensors to the correct location. _message = comm.recover_device( sync_buffer["synced_message"][rank], device=sync_buffer["params_tb"].buffer.device, ) # update the flatten hat params. aggregated_info_tb.buffer.add_( self.consensus_stepsize * (neighbors_info[rank] - (1 if rank == self.rank else 0)) * _message ) return aggregated_info_tb
class CHOCOQuantizationCompressor(object): def __init__( self, aggregator, comm_op, comm_device, compress_ratio, quantize_level, is_biased, backend, use_ipc, **kargs, ): # assign the common hyper-parameters self.aggregator_fn = aggregator self.comm_op = comm_op self.comm_device = comm_device self.compress_ratio = compress_ratio self.quantize_level = quantize_level self.is_biased = is_biased self.backend = backend self.use_ipc = use_ipc self.kargs = kargs self.compressor_fn = QuantizationCompressor() # define gossip_stream if self.comm_device == "cpu": self.gossip_stream = torch.cuda.current_stream() else: self.gossip_stream = torch.cuda.current_stream() def pipeline(self, sync_buffer, neighbor_hat_params, neighbors_info): with torch.cuda.stream(self.gossip_stream): try: self.compress(sync_buffer) self.sync(sync_buffer) self.uncompress(sync_buffer, neighbor_hat_params, neighbors_info) except RuntimeError as e: print("Error: {}".format(e)) def compress(self, sync_buffer): quantized_values = [] for half_param, hat_param in zip(sync_buffer["flatten_params"], sync_buffer["flatten_hat_params"]): _quantized_values = self.compressor_fn.compress( half_param - hat_param, self.comm_op, self.quantize_level, self.is_biased, ) quantized_values.append(_quantized_values) # flatten selected values/indices. flatten_updates = TensorBuffer(quantized_values) # get n_bits to transmit. n_bits = get_n_bits(flatten_updates.buffer) * self.quantize_level / 32 # update shared dict. sync_buffer["flatten_updates"] = flatten_updates sync_buffer["n_bits"] = n_bits def sync(self, sync_buffer): # prepare the sync. to_sync_message = sync_buffer["flatten_updates"].buffer if self.comm_device == "cpu": to_sync_message = to_sync_message.cpu().pin_memory() # sync. sync_message_reqs, synced_message = self.aggregator_fn._agg( to_sync_message, op="get_raw_sync_data", force_wait=False) # update sync_buffer. sync_buffer["sync_reqs"] = sync_message_reqs sync_buffer["synced_message"] = synced_message def uncompress(self, sync_buffer, neighbor_hat_params, neighbors_info): # wait the sync. self.aggregator_fn.complete_wait(sync_buffer["sync_reqs"]) for rank, weight in neighbors_info.items(): hat_params = neighbor_hat_params[rank if rank in neighbor_hat_params else "memory"] hat_params_memory = neighbor_hat_params["memory"] # recover correct values/indices. q_values = comm.recover_device(sync_buffer["synced_message"][rank], device=hat_params.buffer.device) # update neighbor_hat_params if rank in neighbor_hat_params: hat_params.buffer += q_values hat_params_memory.buffer += weight * q_values