def step(self, closure=None, **kargs): if self.conf.is_centralized: with kargs["timer"]("sync/get_data", epoch=self.conf.epoch_): # Get data. grads, _ = comm.get_data(self.param_groups, self.param_names, is_get_grad=True) flatten_grads = TensorBuffer(grads) with kargs["timer"]("sync/sync", epoch=self.conf.epoch_): # Aggregate the gradients. flatten_grads.buffer = self.world_aggregator._agg( flatten_grads.buffer, op="avg", distributed=self.conf.distributed) with kargs["timer"]("sync/unflatten_grad", epoch=self.conf.epoch_): # unflatten grads. flatten_grads.unpack(grads) with kargs["timer"]("sync/apply_grad", epoch=self.conf.epoch_): utils.apply_gradient(self.param_groups, self.state, apply_grad_to_model=True) # Get n_bits to transmit. n_bits = get_n_bits(flatten_grads.buffer) else: with kargs["timer"]("sync/apply_grad", epoch=self.conf.epoch_): utils.apply_gradient(self.param_groups, self.state, apply_grad_to_model=True) with kargs["timer"]("sync/get_data", epoch=self.conf.epoch_): # first get and flatten all params. params, _ = comm.get_data(self.param_groups, self.param_names, is_get_grad=False) flatten_params = TensorBuffer(params) with kargs["timer"]("sync/sync", epoch=self.conf.epoch_): # prepare the sync. if self.conf.comm_device == "cpu": flatten_params.buffer.cpu().detach_() # then sync. flatten_params.buffer = self.decentralized_aggregator._agg( flatten_params.buffer, op="weighted") with kargs["timer"]("sync/update_model", epoch=self.conf.epoch_): # finally unflatten. flatten_params.unpack(params) # Get n_bits to transmit. n_bits = get_n_bits(flatten_params.buffer) return n_bits
def _receive_models_from_selected_clients(self, selected_client_ids): self.conf.logger.log(f"Master waits to receive the local models.") dist.barrier() # init the placeholders to recv the local models from workers. flatten_local_models = dict() for selected_client_id in selected_client_ids: arch = self.clientid2arch[selected_client_id] client_tb = TensorBuffer( list(self.client_models[arch].state_dict().values())) client_tb.buffer = torch.zeros_like(client_tb.buffer) flatten_local_models[selected_client_id] = client_tb # async to receive model from clients. reqs = [] for client_id, world_id in zip(selected_client_ids, self.world_ids): req = dist.irecv(tensor=flatten_local_models[client_id].buffer, src=world_id) reqs.append(req) for req in reqs: req.wait() dist.barrier() self.conf.logger.log(f"Master received all local models.") return flatten_local_models
def step(self, closure=None, **kargs): with kargs['timer']('sync', epoch=self.conf.epoch_): # do the local update steps. with kargs["timer"]("local_update", epoch=self.conf.epoch_): utils.apply_gradient(self.param_groups, self.state, apply_grad_to_model=True) # enter the global sync if it satisfies the condition. if (self.conf.epoch_ < self.turn_on_local_step_from_epoch or self.conf.local_index % self.local_step == 0): with kargs["timer"]("get_params", epoch=self.conf.epoch_): # get parmas. params, _ = comm.get_data(self.param_groups, self.param_names, is_get_grad=False) params_tb = TensorBuffer(params) with kargs['timer']('memory_and_compress', epoch=self.conf.epoch_): # get the params difference w.r.t. previous synced model. local_scale, local_sign = [], [] for consensus_param, param, memory in zip( self.consensus_params_tb, params_tb, self.memory_tb): memory.data.copy_(consensus_param - param + memory) # compress. with kargs["timer"]("directions", epoch=self.conf.epoch_): direction = exchange(self.memory_tb.buffer) #signum with kargs['timer']('memory_and_compress', epoch=self.conf.epoch_): for consensus_param, param, memory in zip( self.consensus_params_tb, params_tb, self.memory_tb): _local_scale, _local_sign = scaled_sign(memory) local_scale.append(_local_scale) local_sign.append(_local_sign) memory.data.copy_(memory - _local_scale * _local_sign) with kargs["timer"]("directions", epoch=self.conf.epoch_): global_direction = TB(self.memory_tb, direction) with kargs["timer"]("magnitudes", epoch=self.conf.epoch_): magnitudes_tb = TensorBuffer(local_scale) magnitudes_tb.buffer = self.world_aggregator._agg( magnitudes_tb.buffer, "avg", distributed=self.conf.distributed) # unpack the synced info and update the consensus params. with kargs["timer"]("update_consensus", epoch=self.conf.epoch_): for update_magnitude, update_direction, consensus_param in zip( magnitudes_tb, global_direction, self.consensus_params_tb): consensus_param.add_( -1.0, update_direction.mul(update_magnitude)) # consistent the local models by assigning the consensus params. self.consensus_params_tb.unpack(params) n_bits = get_n_bits(magnitudes_tb.buffer) else: n_bits = 0 return n_bits
def init_neighbor_hat_params(self): params, self.shapes = comm.get_data(self.param_groups, self.param_names, is_get_grad=False) flatten_params = TensorBuffer(params) flatten_params.buffer = torch.zeros_like(flatten_params.buffer) # init the neighbor_params. self.neighbor_hat_params = { self.rank: deepcopy(flatten_params), "memory": deepcopy(flatten_params), }
def _init_neighbor_hat_params(conf, param_groups, param_names): params, params_shapes = comm.get_data(param_groups, param_names, is_get_grad=False) flatten_params = TensorBuffer(params) flatten_params.buffer = torch.zeros_like(flatten_params.buffer) # init the neighbor_params. return ( { conf.graph.rank: deepcopy(flatten_params), "memory": deepcopy(flatten_params), }, params_shapes, )
def step(self, closure=None, **kargs): # Apply the gradients with the weight decay and momentum. with kargs["timer"]("grad.apply_grad", epoch=self.conf.epoch_): utils.apply_gradient(self.param_groups, self.state, apply_grad_to_model=False) with kargs["timer"]("grad.get_grads", epoch=self.conf.epoch_): params, _ = comm.get_data(self.param_groups, self.param_names, is_get_grad=False) flatten_params = TensorBuffer(params) grads, _ = comm.get_data(self.param_groups, self.param_names, is_get_grad=True) flatten_grads = TensorBuffer(grads) # Get weighted hat params and apply the local gradient. with kargs["timer"]("grad.apply_local_gradient", epoch=self.conf.epoch_): flatten_half_params = deepcopy(flatten_params) flatten_half_params.buffer = (sum([ _hat_params.buffer * self.neighbors_info[_rank] for _rank, _hat_params in self.neighbor_hat_params.items() ]) - self.param_groups[0]["lr"] * flatten_grads.buffer) # compress the model difference and sync. with kargs["timer"]("grad.compress", epoch=self.conf.epoch_): sync_buffer = { "original_shapes": self.shapes, "flatten_half_params": flatten_half_params, "flatten_params": flatten_params, } self.compressor.compress(sync_buffer) with kargs["timer"]("grad.sync", epoch=self.conf.epoch_): self.compressor.sync(sync_buffer) # finally unflatten and update local model. with kargs["timer"]("grad.unflatten_to_update", epoch=self.conf.epoch_): self.compressor.uncompress(sync_buffer, self.neighbor_hat_params) flatten_params.buffer = self.neighbor_hat_params[ self.rank].buffer.clone() flatten_params.unpack(params) return sync_buffer["n_bits"]
def step(self, closure=None, **kargs): # do the local update steps. with kargs["timer"]("sync.local_update", epoch=self.conf.epoch_): for group in self.param_groups: weight_decay = group["weight_decay"] momentum = group["momentum"] dampening = group["dampening"] nesterov = group["nesterov"] for p in group["params"]: # get param_state param_state = self.state[p] # get the gradient if p.grad is None: continue d_p = p.grad.data # add the weight decay and apply the momentum. if weight_decay != 0: d_p.add_(weight_decay, p.data) # apply the momentum. if momentum != 0: if "momentum_buffer" not in param_state: buf = param_state["momentum_buffer"] = torch.zeros_like( p.data ) buf.mul_(momentum).add_(d_p) else: buf = param_state["momentum_buffer"] buf.mul_(momentum).add_(1 - dampening, d_p) if nesterov: d_p = d_p.add(momentum, buf) else: d_p = buf # get the local sign and apply to the local model. p.data.add_(-group["lr"], torch.sign(d_p)) # enter the global sync if it satisfies the condition. if ( self.conf.epoch_ < self.turn_on_local_step_from_epoch or self.conf.local_index % self.local_step == 0 ): with kargs["timer"]("sync.get_params", epoch=self.conf.epoch_): # get parmas. params, _ = comm.get_data( self.param_groups, self.param_names, is_get_grad=False ) params_tb = TensorBuffer(params) # get the params difference w.r.t. previous synced model. local_scale, local_sign = [], [] for consensus_param, param in zip(self.consensus_params_tb, params_tb): _local_scale, _local_sign = scaled_sign(consensus_param - param) local_scale.append(_local_scale) local_sign.append(_local_sign) # concat the update magnitude and directions. magnitudes_tb = TensorBuffer(local_scale) directions_tb = TensorBuffer(local_sign) # sync and decompress. with kargs["timer"]("sync.sync_and_decompress", epoch=self.conf.epoch_): # sync the directions. directions_tb.buffer = self.world_aggregator._agg( directions_tb.buffer, "avg", distributed=self.conf.distributed ) magnitudes_tb.buffer = self.world_aggregator._agg( magnitudes_tb.buffer, "avg", distributed=self.conf.distributed ) # unpack the synced info and update the consensus params. with kargs["timer"]("sync.update_consensus", epoch=self.conf.epoch_): for update_magnitude, update_direction, consensus_param in zip( magnitudes_tb, directions_tb, self.consensus_params_tb ): consensus_param.add_(-1.0, update_direction.mul(update_magnitude)) # consistent the local models by assigning the consensus params. self.consensus_params_tb.unpack(params) n_bits = get_n_bits(directions_tb.buffer) + get_n_bits(magnitudes_tb.buffer) else: n_bits = 0 return n_bits