def initialize(self, exp_config, resume): list_extra_params = self.get_init_extra_params() self.socket.wait_for_connections() if resume: print("Resuming server...") self.list_loss = load(os.path.join(self.save_path, "loss.pt")) self.list_acc = load(os.path.join(self.save_path, "accuracy.pt")) self.list_time_stamp = load(os.path.join(self.save_path, "time.pt")) self.list_model_size = load(os.path.join(self.save_path, "model_size.pt")) self.model = load(os.path.join(self.save_path, "model.pt")) num_loss_acc = len(self.list_loss) assert len(self.list_acc) == num_loss_acc num_evals = len(self.list_time_stamp) assert len(self.list_model_size) == num_evals if num_evals - num_loss_acc == 1: loss, acc = self.model.evaluate(self.test_loader) self.list_loss.append(loss) self.list_acc.append(acc) elif num_evals != num_loss_acc: raise RuntimeError("Cannot resume") self.round = (num_evals - 1) * self.config.EVAL_DISP_INTERVAL assert self.round >= 0 self.start_time = timer() - self.list_time_stamp[-1] self.check_client_to_sparse() resume_param = (True, self.round + 1, self.client_is_sparse) list_params = [(idx, exp_config, self.model, list_extra_params[idx], resume_param) for idx in range(self.config.NUM_CLIENTS)] resume_msgs_to_client = [ServerToClientInitMessage(init_params) for init_params in list_params] self.socket.init_connections(resume_msgs_to_client) self.round += 1 print("Server resumed") print(self) else: self.list_loss = [] self.list_acc = [] self.list_time_stamp = [] self.list_model_size = [] self.start_time = timer() + self.init_time_offset self.round = 0 mkdir_save(self.model, os.path.join(self.save_path, "init_model.pt")) self.model.eval() list_init_params = [(idx, exp_config, self.model, list_extra_params[idx], (False, 0, False)) for idx in range(self.config.NUM_CLIENTS)] init_msgs_to_client = [ServerToClientInitMessage(init_params) for init_params in list_init_params] self.socket.init_connections(init_msgs_to_client) print("Server initialized") print(self)
def main(): client = ClientSocket(SERVER_ADDR, SERVER_PORT) init_msg = client.init_connections() pruning_type, n_pruning_levels, seed = init_msg.server_params model = init_msg.model model.train() data_indices = init_msg.slice_indices n_local_updates = init_msg.n_local_updates batch_size = init_msg.batch_size pruning_type_name = PRUNING_TYPE_NAMES[pruning_type] if pruning_type == 0: path = os.path.join("results", EXP_NAME, CLIENT_NAME, pruning_type_name, "seed_" + str(seed)) else: path = os.path.join("results", EXP_NAME, CLIENT_NAME, pruning_type_name, "level_" + str(n_pruning_levels), "seed_" + str(seed)) train_loader = get_train_loader(EXP_NAME, train_batch_size=MNIST.N_TRAIN, shuffle=False, flatten=True, train_set_indices=data_indices, one_hot=True, n_workers=16, pin_memory=True) train_iter = DataIterator(data_loader=train_loader, batch_size=batch_size) list_t_computation, list_t_communication = [], [] print("CLIENT. PRUNING TYPE = {}, N_PRUNING_LEVELS = {}, SEED = {}.".format(pruning_type_name, n_pruning_levels, seed)) # Write system status into files # status_writer_process = Process(target=_status_writer) # status_writer_process.start() while True: t_start = timer() model.zero_grad() for _ in range(n_local_updates): inputs, labels = train_iter.get_next_batch() loss = model.loss(inputs, labels) loss.backward() model.apply_grad() t_comp = timer() client.send_msg(ClientToServerUpdateMessage([model.state_dict(), batch_size])) update_msg = client.recv_update_msg() state_dict = update_msg.state_dict model.load_state_dict(state_dict) t_end = timer() list_t_computation.append(t_comp - t_start) list_t_communication.append(t_end - t_comp) mkdir_save(list_t_computation, os.path.join(path, "computation_time")) mkdir_save(list_t_communication, os.path.join(path, "communication_time")) terminate = update_msg.terminate if terminate: print("Task completed") break client.close()
def save_exp_config(self): exp_config = {"exp_name": EXP_NAME, "seed": args.seed, "batch_size": CLIENT_BATCH_SIZE, "num_local_updates": NUM_LOCAL_UPDATES, "mdd": MAX_DEC_DIFF, "init_lr": INIT_LR, "ahl": ADJ_HALF_LIFE, "use_adaptive": self.use_adaptive, "client_selection": args.client_selection} if self.client_selection: exp_config["num_users"] = num_users mkdir_save(exp_config, os.path.join(self.save_path, "exp_config.pt"))
def __init__(self, args, config, model, save_interval=50): self.config = config self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") self.experiment_name = args.experiment_name self.save_path = os.path.join("results", config.EXP_NAME, args.experiment_name) self.save_interval = save_interval self.mode = args.mode assert self.mode in ["r", "rr"] self.model = model.to(self.device) self.adaptive_folder = "adaptive{}{}".format( "_target" if args.targeted else "", "_cs" if args.client_selection else "") init_model_path = os.path.join("results", config.EXP_NAME, self.adaptive_folder, "init_model.pt") final_model_path = os.path.join("results", config.EXP_NAME, self.adaptive_folder, "model.pt") final_model = load(final_model_path) # reinit if self.mode == "r": self.model = load(init_model_path).to(self.device) self.model.reinit_from_model(final_model) # random reinit, using different seed for initialization but same mask elif self.mode == "rr": for layer, final_layer in zip(self.model.prunable_layers, final_model.prunable_layers): layer.mask = final_layer.mask.clone().to(layer.mask.device) else: raise ValueError("Mode {} not supported".format(self.mode)) with torch.no_grad(): for layer in self.model.prunable_layers: layer.weight.mul_(layer.mask) disp_num_params(self.model) self.model.train() mkdir_save(self.model, os.path.join(self.save_path, "init_model.pt")) self.test_loader = None self.init_test_loader() self.init_clients()
def __init__(self, args, config, model, save_interval=50): self.config = config self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.experiment_name = args.experiment_name self.save_path = os.path.join("results", config.EXP_NAME, args.experiment_name) self.save_interval = save_interval self.model = model.to(self.device) self.model.train() mkdir_save(self.model, os.path.join(self.save_path, "init_model.pt")) self.num_all_params = self.model.nelement() self.k = self.num_all_params self.k_aux = int(np.ceil(self.k * 0.9)) self.control = OcoGradEstimation(np.round(self.num_all_params * 0.002), self.num_all_params) self.test_loader = None self.prev_model = None self.init_test_loader() self.init_clients()
def __init__(self, args, config, model, save_interval=50): self.config = config self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") self.experiment_name = args.experiment_name self.save_path = os.path.join("results", config.EXP_NAME, args.experiment_name) self.save_interval = save_interval adaptive_folder = "adaptive_cs" if args.client_selection else "adaptive" self.prune_rate = 1 - load("results/{}/{}/model.pt".format( config.EXP_NAME, adaptive_folder)).density()**(1 / config.NUM_ITERATIVE_PRUNING) self.model = model.to(self.device) self.model.train() mkdir_save(self.model, os.path.join(self.save_path, "init_model.pt")) self.test_loader = None self.init_test_loader() self.init_clients()
def __init__(self, args, config, model, save_interval=50): self.config = config self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") self.experiment_name = args.experiment_name self.save_path = os.path.join("results", config.EXP_NAME, args.experiment_name) self.save_interval = save_interval self.use_adaptive = args.use_adaptive self.client_selection = args.client_selection if self.use_adaptive: print("Init max dec = {}. " "Adjustment dec half-life = {}. " "Adjustment interval = {}.".format(self.config.MAX_DEC_DIFF, self.config.ADJ_HALF_LIFE, self.config.ADJ_INTERVAL)) self.model = model.to(self.device) self.model.train() mkdir_save(self.model, os.path.join(self.save_path, "init_model.pt")) self.indices = None self.ip_train_loader = None self.ip_test_loader = None self.ip_optimizer_wrapper = None self.ip_control = None self.test_loader = None self.control = None self.init_test_loader() self.init_clients() self.init_control() self.init_ip_config() self.save_exp_config()
def main(self, idx, list_sd, list_num_proc, list_data_proc, lr, start, list_loss, list_acc, list_est_time, list_model_size): total_num_proc = sum(list_num_proc) with torch.no_grad(): for key, param in self.model.state_dict().items(): avg_inc_val = None for num_proc, state_dict in zip(list_num_proc, list_sd): if key in state_dict.keys(): mask = self.model.get_mask_by_name(key) if mask is None: inc_val = state_dict[key] - param else: inc_val = state_dict[key] - param * self.model.get_mask_by_name(key) if avg_inc_val is None: avg_inc_val = num_proc / total_num_proc * inc_val else: avg_inc_val += num_proc / total_num_proc * inc_val if avg_inc_val is None or key.endswith("num_batches_tracked"): continue else: param.add_(avg_inc_val) if idx % self.config.EVAL_DISP_INTERVAL == 0: loss, acc = self.model.evaluate(self.test_loader) list_loss.append(loss) list_acc.append(acc) print("Round #{} (Experiment = {}).".format(idx, self.experiment_name)) print("Loss/acc (at round {}) = {}/{}".format((len(list_loss) - 1) * self.config.EVAL_DISP_INTERVAL, loss, acc)) print("Estimated time = {}".format(sum(list_est_time))) print("Elapsed time = {}".format(timer() - start)) print("Current lr = {}".format(lr)) print("Current density = {}".format(self.model.density())) # control if len(list_acc) >= 2: inputs = None labels = None for data in list_data_proc: inp, lab = data if inputs is None: inputs = inp else: inputs = torch.cat([inputs, inp], dim=0) if labels is None: labels = lab else: labels = torch.cat([labels, lab], dim=0) prev_loss = self.prev_model.evaluate([(inputs, labels)])[0] cur_loss = self.model.evaluate([(inputs, labels)])[0] retain_by_num(self.model, self.k_aux) cur_aux_loss = self.model.evaluate([(inputs, labels)])[0] cost = calc_cost(self.k, self.config, self.model) if prev_loss > cur_loss and prev_loss > cur_aux_loss: cost_aux = calc_cost(self.k_aux, self.config, self.model) cost_aux *= (prev_loss - cur_loss) / (prev_loss - cur_aux_loss) else: cost_aux = None self.k, self.k_aux = self.control.tuning_k_grad_sign(self.k, self.k_aux, cost, cost_aux, idx + 1) self.k, self.k_aux = int(self.k), int(self.k_aux) retain_by_num(self.model, self.k) est_time = self.config.TIME_CONSTANT for layer, comp_coeff in zip(self.model.prunable_layers, self.config.COMP_COEFFICIENTS): est_time += layer.num_weight * (comp_coeff + self.config.COMM_COEFFICIENT) model_size = self.model.calc_num_all_active_params(True) list_est_time.append(est_time) list_model_size.append(model_size) if idx % self.save_interval == 0: mkdir_save(list_loss, os.path.join(self.save_path, "loss.pt")) mkdir_save(list_acc, os.path.join(self.save_path, "accuracy.pt")) mkdir_save(list_est_time, os.path.join(self.save_path, "est_time.pt")) mkdir_save(list_model_size, os.path.join(self.save_path, "model_size.pt")) mkdir_save(self.model, os.path.join(self.save_path, "model.pt")) self.prev_model = deepcopy(self.model) return [layer.mask for layer in self.model.prunable_layers], [self.model.state_dict() for _ in range(self.config.NUM_CLIENTS)]
def main(self, idx, list_sd, list_num_proc, lr, start, list_loss, list_acc, list_est_time, list_model_size): total_num_proc = sum(list_num_proc) with torch.no_grad(): for key, param in self.model.state_dict().items(): avg_inc_val = None for num_proc, state_dict in zip(list_num_proc, list_sd): if key in state_dict.keys(): mask = self.model.get_mask_by_name(key) if mask is None: inc_val = state_dict[key] - param else: inc_val = state_dict[ key] - param * self.model.get_mask_by_name(key) if avg_inc_val is None: avg_inc_val = num_proc / total_num_proc * inc_val else: avg_inc_val += num_proc / total_num_proc * inc_val if avg_inc_val is None or key.endswith("num_batches_tracked"): continue else: param.add_(avg_inc_val) if idx % self.config.EVAL_DISP_INTERVAL == 0: loss, acc = self.model.evaluate(self.test_loader) list_loss.append(loss) list_acc.append(acc) print("Round #{} (Experiment = {}).".format( idx, self.experiment_name)) print("Loss/acc (at round {}) = {}/{}".format( (len(list_loss) - 1) * self.config.EVAL_DISP_INTERVAL, loss, acc)) print("Estimated time = {}".format(sum(list_est_time))) print("Elapsed time = {}".format(timer() - start)) print("Current lr = {}".format(lr)) est_time = self.config.TIME_CONSTANT for layer, comp_coeff in zip(self.model.prunable_layers, self.config.COMP_COEFFICIENTS): est_time += layer.num_weight * (comp_coeff + self.config.COMM_COEFFICIENT) model_size = self.model.calc_num_all_active_params(True) list_est_time.append(est_time) list_model_size.append(model_size) if idx % self.save_interval == 0: mkdir_save(list_loss, os.path.join(self.save_path, "loss.pt")) mkdir_save(list_acc, os.path.join(self.save_path, "accuracy.pt")) mkdir_save(list_est_time, os.path.join(self.save_path, "est_time.pt")) mkdir_save(list_model_size, os.path.join(self.save_path, "model_size.pt")) mkdir_save(self.model, os.path.join(self.save_path, "model.pt")) return [layer.mask for layer in self.model.prunable_layers], [ self.model.state_dict() for _ in range(self.config.NUM_CLIENTS) ]
def main(): # Wait for connections server = ServerSocket(SERVER_ADDR, SERVER_PORT, N_CLIENTS) server.wait_for_connections() torch.manual_seed(SEED) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("Current device is {}.".format(device)) list_loss, list_acc, list_time = [], [], [] model = load_model() random_indices = random_split(MNIST.N_TRAIN, N_CLIENTS) test_loader = get_test_loader(EXP_NAME, test_batch_size=MNIST.N_TEST, shuffle=False, flatten=True, one_hot=True, n_workers=16, pin_memory=True) test_iter = DataIterator(data_loader=test_loader, batch_size=200, device=device) init_msgs = [ ServerToClientInitMessage([ model, torch.tensor(random_indices[idx]), N_LOCAL_UPDATES, CLIENT_BATCH_SIZE, (PRUNING_TYPE, N_PRUNING_LEVELS, SEED) ]) for idx in range(N_CLIENTS) ] server.init_connections(init_msgs) model.eval() print( "SERVER. PRUNING TYPE = {}, N_PRUNING_LEVELS = {}, PRUNING_PCT = {}, SEED = {}, N_ITERATIONS = {}." .format(PRUNING_TYPE_NAME, N_PRUNING_LEVELS, PRUNING_PCT, SEED, N_ITERATIONS)) prev_thread = None t_start = timer() for idx in range(N_ITERATIONS): t_fed_start = timer() msgs = server.recv_update_msg_from_all() list_state_dict = [msg.state_dict for msg in msgs] avg_state_dict = model.state_dict().copy() for key in avg_state_dict.keys(): new_val = None for state_dict in list_state_dict: if new_val is None: new_val = state_dict[key] else: new_val += state_dict[key] new_val /= N_CLIENTS avg_state_dict[key] = new_val model.load_state_dict(avg_state_dict) if idx % EVAL_INTERVAL == 0: # Asynchronously evaluate model if prev_thread is not None: prev_thread.join() t = Thread(target=eval_async, args=(model.evaluate, test_iter, list_loss, list_acc)) t.start() prev_thread = t if idx % DISP_SAVE_INTERVAL == 0: print("Federation #{}".format(idx)) if len(list_loss) != 0 and len(list_acc) != 0: loss, acc = list_loss[-1], list_acc[-1] print("Loss/acc at iteration {} = {}/{}".format( (len(list_loss) - 1) * EVAL_INTERVAL, loss, acc)) t = timer() print("Elapsed time = {}".format(t - t_start)) terminate = True if idx == N_ITERATIONS - 1 or timer( ) - t_start >= MAX_TIME else False server.send_msg_to_all( ServerToClientUpdateMessage([model.state_dict(), terminate])) t_fed_end = timer() list_time.append(t_fed_end - t_fed_start) if terminate: prev_thread.join() # Saving loss/acc mkdir_save(list_loss, os.path.join(DATA_DIR_PATH, "loss")) mkdir_save(list_acc, os.path.join(DATA_DIR_PATH, "accuracy")) mkdir_save(list_time, os.path.join(DATA_DIR_PATH, "time")) mkdir_save(model, os.path.join(DATA_DIR_PATH, "model")) if terminate: break print("Task completed") server.close()
def eval_async(func, iterator, loss_list: list, acc_list: list): l, a = func(iterator) loss_list.append(l) acc_list.append(a) mkdir_save(loss_list, os.path.join(DATA_DIR_PATH, "loss")) mkdir_save(acc_list, os.path.join(DATA_DIR_PATH, "accuracy"))
def main(self, idx, list_sd, list_num_proc, lr, start, list_loss, list_acc, list_est_time, list_model_size): total_num_proc = sum(list_num_proc) grad_dict = dict() weight_dict = dict() with torch.no_grad(): for key, param in self.model.state_dict().items(): avg_inc_val = None for num_proc, state_dict in zip(list_num_proc, list_sd): if key in state_dict.keys(): inc_val = state_dict[key] - param if avg_inc_val is None: avg_inc_val = num_proc / total_num_proc * inc_val else: avg_inc_val += num_proc / total_num_proc * inc_val if avg_inc_val is None or key.endswith("num_batches_tracked"): continue else: if idx == 0 and key in dict(self.model.named_parameters()).keys() and key.endswith( "weight") and key[:-7] in self.model.prunable_layer_prefixes: grad_dict[key] = avg_inc_val / lr weight_dict[key] = dict(self.model.named_parameters())[key].clone() param.add_(avg_inc_val) if idx == 0: abs_all_wg = None for (name_w, w), (name_g, g) in zip(weight_dict.items(), grad_dict.items()): assert name_w == name_g if abs_all_wg is None: abs_all_wg = (w * g).view(-1).abs() else: abs_all_wg = torch.cat([abs_all_wg, (w * g).view(-1).abs()], dim=0) threshold = abs_all_wg.sort(descending=True)[0][int(self.density * abs_all_wg.nelement())] for layer, layer_prefix in zip(self.model.prunable_layers, self.model.prunable_layer_prefixes): abs_layer_wg = (weight_dict[layer_prefix + ".weight"] * grad_dict[layer_prefix + ".weight"]).abs() layer.mask = abs_layer_wg >= threshold with torch.no_grad(): for layer in self.model.prunable_layers: layer.weight *= layer.mask print("Snip pruning completed. Remaining params:") disp_num_params(self.model) if idx % self.config.EVAL_DISP_INTERVAL == 0: loss, acc = self.model.evaluate(self.test_loader) list_loss.append(loss) list_acc.append(acc) print("Round #{} (Experiment = {}).".format(idx, self.experiment_name)) print("Loss/acc (at round {}) = {}/{}".format((len(list_loss) - 1) * self.config.EVAL_DISP_INTERVAL, loss, acc)) print("Estimated time = {}".format(sum(list_est_time))) print("Elapsed time = {}".format(timer() - start)) est_time = self.config.TIME_CONSTANT for layer, comp_coeff in zip(self.model.prunable_layers, self.config.COMP_COEFFICIENTS): est_time += layer.num_weight * (comp_coeff + self.config.COMM_COEFFICIENT) model_size = self.model.calc_num_all_active_params(True) list_est_time.append(est_time) list_model_size.append(model_size) if idx % self.save_interval == 0: mkdir_save(list_loss, os.path.join(self.save_path, "loss.pt")) mkdir_save(list_acc, os.path.join(self.save_path, "accuracy.pt")) mkdir_save(list_est_time, os.path.join(self.save_path, "est_time.pt")) mkdir_save(list_model_size, os.path.join(self.save_path, "model_size.pt")) mkdir_save(self.model, os.path.join(self.save_path, "model.pt")) return [layer.mask for layer in self.model.prunable_layers], [self.model.state_dict() for _ in range(self.config.NUM_CLIENTS)]
def main(self, idx, list_sd, list_num_proc, lr, list_accumulated_sgrad, start, list_loss, list_acc, list_est_time, list_model_size, is_adj_round, density_limit=None): total_num_proc = sum(list_num_proc) with torch.no_grad(): for key, param in self.model.state_dict().items(): avg_inc_val = None for num_proc, state_dict in zip(list_num_proc, list_sd): if key in state_dict.keys(): mask = self.model.get_mask_by_name(key) if mask is None: inc_val = state_dict[key] - param else: inc_val = state_dict[ key] - param * self.model.get_mask_by_name(key) if avg_inc_val is None: avg_inc_val = num_proc / total_num_proc * inc_val else: avg_inc_val += num_proc / total_num_proc * inc_val if avg_inc_val is None or key.endswith("num_batches_tracked"): continue else: param.add_(avg_inc_val) if idx % self.config.EVAL_DISP_INTERVAL == 0: loss, acc = self.model.evaluate(self.test_loader) list_loss.append(loss) list_acc.append(acc) print("Round #{} (Experiment = {}).".format( idx, self.experiment_name)) print("Loss/acc (at round #{}) = {}/{}".format( (len(list_loss) - 1) * self.config.EVAL_DISP_INTERVAL, loss, acc)) print("Estimated time = {}".format(sum(list_est_time))) print("Elapsed time = {}".format(timer() - start)) print("Current lr = {}".format(lr)) if self.use_adaptive and is_adj_round: alg_start = timer() for d in list_accumulated_sgrad: for k, sg in d.items(): self.control.accumulate(k, sg) print("Running adaptive pruning algorithm") max_dec_diff = self.config.MAX_DEC_DIFF * (0.5**( idx / self.config.ADJ_HALF_LIFE)) self.control.adjust(max_dec_diff, max_density=density_limit) print("Total alg time = {}. Max density = {}.".format( timer() - alg_start, density_limit)) print("Num params:") disp_num_params(self.model) est_time = self.config.TIME_CONSTANT for layer, comp_coeff in zip(self.model.prunable_layers, self.config.COMP_COEFFICIENTS): est_time += layer.num_weight * (comp_coeff + self.config.COMM_COEFFICIENT) model_size = self.model.calc_num_all_active_params(True) list_est_time.append(est_time) list_model_size.append(model_size) if idx % self.save_interval == 0: mkdir_save(list_loss, os.path.join(self.save_path, "loss.pt")) mkdir_save(list_acc, os.path.join(self.save_path, "accuracy.pt")) mkdir_save(list_est_time, os.path.join(self.save_path, "est_time.pt")) mkdir_save(list_model_size, os.path.join(self.save_path, "model_size.pt")) mkdir_save(self.model, os.path.join(self.save_path, "model.pt")) return [layer.mask for layer in self.model.prunable_layers], [ self.model.state_dict() for _ in range(self.config.NUM_CLIENTS) ]
def save_exp(self): mkdir_save(self.list_loss, os.path.join(self.save_path, "loss.pt")) mkdir_save(self.list_acc, os.path.join(self.save_path, "accuracy.pt")) mkdir_save(self.list_time_stamp, os.path.join(self.save_path, "time.pt")) mkdir_save(self.list_model_size, os.path.join(self.save_path, "model_size.pt")) mkdir_save(self.model, os.path.join(self.save_path, "model.pt"))