def init_ip_config(self): self.ip_train_loader = get_data_loader(EXP_NAME, data_type="train", batch_size=CLIENT_BATCH_SIZE, shuffle=True, num_workers=8, user_list=[0], pin_memory=True) self.ip_test_loader = get_data_loader(EXP_NAME, data_type="test", num_workers=8, pin_memory=True) ip_optimizer = SGD(self.model.parameters(), lr=INIT_LR) self.ip_optimizer_wrapper = OptimizerWrapper(self.model, ip_optimizer) self.ip_control = ControlModule(model=self.model, config=config)
def convert_to_sparse(self): self.model = self.model.to_sparse() old_lr = self.optimizer.state_dict()["param_groups"][0]["lr"] self.optimizer = self.exp_config.optimizer_class(params=self.model.parameters(), **self.exp_config.optimizer_params) if self.exp_config.lr_scheduler_class is not None: lr_scheduler_state_dict = deepcopy(self.lr_scheduler.state_dict()) self.lr_scheduler = self.exp_config.lr_scheduler_class(optimizer=self.optimizer, **self.exp_config.lr_scheduler_params) self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) self.optimizer.param_groups[0]["lr"] = old_lr self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer, self.lr_scheduler) self.is_sparse = True print("Model converted to sparse")
def init_optimizer(self): self.optimizer = SGD(self.model.parameters(), lr=INIT_LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY) self.optimizer_scheduler = lr_scheduler.StepLR( self.optimizer, step_size=STEP_SIZE, gamma=0.5**(STEP_SIZE / LR_HALF_LIFE)) self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer, self.optimizer_scheduler)
def __init__(self, network_config, max_try=100): self.network_config = network_config self.socket = ClientSocket(network_config.SERVER_ADDR, network_config.SERVER_PORT) self.train_loader = None init_msg = self.socket.init_connections(max_try) self.client_id = init_msg.client_id self.exp_config = init_msg.exp_config torch.manual_seed(self.exp_config.seed) # self.save_path = os.path.join("results", "exp_{}".format(self.exp_config.exp_name), # self.exp_config.save_dir_name, "client_{}".format(self.client_id)) self.model = init_msg.model self.model.train() self.optimizer = self.exp_config.optimizer_class(params=self.model.parameters(), **self.exp_config.optimizer_params) self.lr_scheduler = None if self.exp_config.lr_scheduler_class is not None: self.lr_scheduler = self.exp_config.lr_scheduler_class(optimizer=self.optimizer, **self.exp_config.lr_scheduler_params) self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer, self.lr_scheduler) if self.exp_config.use_adaptive: self.dict_extra_sgrad = dict() self.accum_dense_grad = dict() self.is_adj_round = False self.is_sparse = False self.terminate = False self.parse_init_extra_params(init_msg.extra_params) resume, cur_round, resume_to_sparse = init_msg.resume_params self.initialize(resume, cur_round, resume_to_sparse)
def init_optimizer(self): self.optimizer = SGD(self.model.parameters(), lr=INIT_LR) self.optimizer_scheduler = lr_scheduler.StepLR( self.optimizer, step_size=1, gamma=0.5**(1 / LR_HALF_LIFE)) self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer, self.optimizer_scheduler)
def init_optimizer(self): self.optimizer = SGD(self.model.parameters(), lr=INIT_LR) self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer)
server_adjust_interval)) server_loader = get_data_loader(config.EXP_NAME, data_type="train", batch_size=config.CLIENT_BATCH_SIZE, shuffle=True, num_workers=8, user_list=[0], pin_memory=True) server_inputs, server_outputs = [], [] for _ in range(num_pre_batch): inp, out = server_loader.get_next_batch() server_inputs.append(inp) server_outputs.append(out) server_optimizer = SGD(model.parameters(), lr=config.INIT_LR) server_optimizer_wrapper = OptimizerWrapper(model, server_optimizer) server_control = ControlModule(model=model, config=config) prev_density, prev_num, prev_ind = None, 5, [] for server_i in range(1, server_pruning_rounds + 1): for server_inp, server_out in zip(server_inputs, server_outputs): list_grad = server_optimizer_wrapper.step(server_inp, server_out) for (key, param), g in zip(model.named_parameters(), list_grad): assert param.size() == g.size() server_control.accumulate(key, g**2) if server_i % server_adjust_interval == 0: server_control.adjust(config.MAX_DEC_DIFF) cur_density = disp_num_params(model) if prev_density is not None:
class Client(ABC): def __init__(self, network_config, max_try=100): self.network_config = network_config self.socket = ClientSocket(network_config.SERVER_ADDR, network_config.SERVER_PORT) self.train_loader = None init_msg = self.socket.init_connections(max_try) self.client_id = init_msg.client_id self.exp_config = init_msg.exp_config torch.manual_seed(self.exp_config.seed) # self.save_path = os.path.join("results", "exp_{}".format(self.exp_config.exp_name), # self.exp_config.save_dir_name, "client_{}".format(self.client_id)) self.model = init_msg.model self.model.train() self.optimizer = self.exp_config.optimizer_class(params=self.model.parameters(), **self.exp_config.optimizer_params) self.lr_scheduler = None if self.exp_config.lr_scheduler_class is not None: self.lr_scheduler = self.exp_config.lr_scheduler_class(optimizer=self.optimizer, **self.exp_config.lr_scheduler_params) self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer, self.lr_scheduler) if self.exp_config.use_adaptive: self.dict_extra_sgrad = dict() self.accum_dense_grad = dict() self.is_adj_round = False self.is_sparse = False self.terminate = False self.parse_init_extra_params(init_msg.extra_params) resume, cur_round, resume_to_sparse = init_msg.resume_params self.initialize(resume, cur_round, resume_to_sparse) @torch.no_grad() def load_state_dict(self, state_dict): param_dict = dict(self.model.named_parameters()) buffer_dict = dict(self.model.named_buffers()) for key, param in {**param_dict, **buffer_dict}.items(): if key in state_dict.keys(): if state_dict[key].size() != param.size(): # sparse param with value only param._values().copy_(state_dict[key]) elif state_dict[key].is_sparse: # sparse param at adjustment round # print(param, param.size(), state_dict[key].is_sparse, state_dict[key]) # param.zero_() param.copy_(state_dict[key]) # param._indices().copy_(state_dict[key]._indices()) # param._values().copy_(state_dict[key]._values()) # need to reload mask in this case param.mask.copy_(state_dict[key].mask) else: param.copy_(state_dict[key]) def initialize(self, resume, cur_round, resume_to_sparse): if resume: print("Resuming client...") # move optimizer to the right position for _ in range(cur_round * self.exp_config.num_local_updates): self.optimizer_wrapper.lr_scheduler_step() # move train loader to the right position remaining_batches = cur_round * self.exp_config.num_local_updates num_batches_epoch = len(self.train_loader) while remaining_batches >= num_batches_epoch: self.train_loader.skip_epoch() remaining_batches -= num_batches_epoch for _ in range(remaining_batches): self.train_loader.get_next_batch() if resume_to_sparse: self.convert_to_sparse() print("Client resumed") else: print("Client initialized") @abstractmethod def parse_init_extra_params(self, extra_params): # Initialize train_loader, etc. pass def cleanup_state_dict_to_server(self) -> dict: """ Clean up state dict before process, e.g. remove entries, transpose. To be overridden by subclasses. """ clean_state_dict = copy_dict(self.model.state_dict()) # not deepcopy if self.is_sparse: for layer, prefix in zip(self.model.param_layers, self.model.param_layer_prefixes): key = prefix + ".bias" if isinstance(layer, SparseLinear) and key in clean_state_dict.keys(): clean_state_dict[key] = clean_state_dict[key].view(-1) del_list = [] del_suffix = "placeholder" for key in clean_state_dict.keys(): if key.endswith(del_suffix): del_list.append(key) for del_key in del_list: del clean_state_dict[del_key] return clean_state_dict @torch.no_grad() def process_state_dict_to_server(self) -> dict: """ Process state dict before sending to server, e.g. keep values only, extra param in adjustment round. if not self.is_sparse: send dense elif self.adjustment_round: send sparse values + extra grad values else: send sparse values only To be overridden by subclasses. """ clean_state_dict = self.cleanup_state_dict_to_server() if self.is_sparse: for key, param in clean_state_dict.items(): if param.is_sparse: clean_state_dict[key] = param._values() if self.is_adj_round: clean_state_dict.update(self.dict_extra_sgrad) self.dict_extra_sgrad = dict() return clean_state_dict def convert_to_sparse(self): self.model = self.model.to_sparse() old_lr = self.optimizer.state_dict()["param_groups"][0]["lr"] self.optimizer = self.exp_config.optimizer_class(params=self.model.parameters(), **self.exp_config.optimizer_params) if self.exp_config.lr_scheduler_class is not None: lr_scheduler_state_dict = deepcopy(self.lr_scheduler.state_dict()) self.lr_scheduler = self.exp_config.lr_scheduler_class(optimizer=self.optimizer, **self.exp_config.lr_scheduler_params) self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) self.optimizer.param_groups[0]["lr"] = old_lr self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer, self.lr_scheduler) self.is_sparse = True print("Model converted to sparse") def accumulate_dense_grad_round(self): for key, param in self.model.named_parameters(): if hasattr(param, "is_sparse_param"): if key in self.accum_dense_grad.keys(): self.accum_dense_grad[key] += param.dense.grad else: self.accum_dense_grad[key] = param.dense.grad def accumulate_sgrad(self, num_proc_data): prefix = "extra." for key, param in self.accum_dense_grad.items(): pkey = prefix + key if pkey in self.dict_extra_sgrad.keys(): self.dict_extra_sgrad[pkey] += (param ** 2) * num_proc_data else: self.dict_extra_sgrad[pkey] = (param ** 2) * num_proc_data if self.is_adj_round: param_mask = dict(self.model.named_parameters())[key].mask == 0. self.dict_extra_sgrad[pkey] = self.dict_extra_sgrad[pkey].masked_select(param_mask) def main(self): num_proc_data = 0 for _ in range(self.exp_config.num_local_updates): inputs, labels = self.train_loader.get_next_batch() self.optimizer_wrapper.step(inputs, labels) if self.exp_config.use_adaptive: self.accumulate_dense_grad_round() num_proc_data += len(inputs) if self.exp_config.use_adaptive: self.accumulate_sgrad(num_proc_data) self.accum_dense_grad = dict() lr = self.optimizer_wrapper.get_last_lr() state_dict_to_server = self.process_state_dict_to_server() msg_to_server = ClientToServerUpdateMessage((state_dict_to_server, num_proc_data, lr)) self.socket.send_msg(msg_to_server) update_msg = self.socket.recv_update_msg() self.is_adj_round = update_msg.adjustment if not self.is_sparse and update_msg.to_sparse: self.convert_to_sparse() state_dict_received = update_msg.state_dict self.load_state_dict(state_dict_received) self.optimizer_wrapper.lr_scheduler_step() terminate = update_msg.terminate if terminate: self.socket.send_ack_msg() self.socket.close() self.terminate = True print("Task completed") return terminate