Example #1
0
    def initialize(self, exp_config, resume):
        list_extra_params = self.get_init_extra_params()

        self.socket.wait_for_connections()

        if resume:
            print("Resuming server...")
            self.list_loss = load(os.path.join(self.save_path, "loss.pt"))
            self.list_acc = load(os.path.join(self.save_path, "accuracy.pt"))
            self.list_time_stamp = load(os.path.join(self.save_path, "time.pt"))
            self.list_model_size = load(os.path.join(self.save_path, "model_size.pt"))

            self.model = load(os.path.join(self.save_path, "model.pt"))

            num_loss_acc = len(self.list_loss)
            assert len(self.list_acc) == num_loss_acc

            num_evals = len(self.list_time_stamp)
            assert len(self.list_model_size) == num_evals

            if num_evals - num_loss_acc == 1:
                loss, acc = self.model.evaluate(self.test_loader)
                self.list_loss.append(loss)
                self.list_acc.append(acc)
            elif num_evals != num_loss_acc:
                raise RuntimeError("Cannot resume")

            self.round = (num_evals - 1) * self.config.EVAL_DISP_INTERVAL
            assert self.round >= 0
            self.start_time = timer() - self.list_time_stamp[-1]

            self.check_client_to_sparse()
            resume_param = (True, self.round + 1, self.client_is_sparse)
            list_params = [(idx, exp_config, self.model, list_extra_params[idx], resume_param) for idx in
                           range(self.config.NUM_CLIENTS)]
            resume_msgs_to_client = [ServerToClientInitMessage(init_params) for init_params in list_params]
            self.socket.init_connections(resume_msgs_to_client)

            self.round += 1

            print("Server resumed")
            print(self)

        else:
            self.list_loss = []
            self.list_acc = []
            self.list_time_stamp = []
            self.list_model_size = []
            self.start_time = timer() + self.init_time_offset
            self.round = 0
            mkdir_save(self.model, os.path.join(self.save_path, "init_model.pt"))
            self.model.eval()

            list_init_params = [(idx, exp_config, self.model, list_extra_params[idx], (False, 0, False)) for idx in
                                range(self.config.NUM_CLIENTS)]
            init_msgs_to_client = [ServerToClientInitMessage(init_params) for init_params in list_init_params]
            self.socket.init_connections(init_msgs_to_client)

            print("Server initialized")
            print(self)
Example #2
0
def main():
    client = ClientSocket(SERVER_ADDR, SERVER_PORT)
    init_msg = client.init_connections()

    pruning_type, n_pruning_levels, seed = init_msg.server_params
    model = init_msg.model
    model.train()
    data_indices = init_msg.slice_indices
    n_local_updates = init_msg.n_local_updates
    batch_size = init_msg.batch_size
    pruning_type_name = PRUNING_TYPE_NAMES[pruning_type]

    if pruning_type == 0:
        path = os.path.join("results", EXP_NAME, CLIENT_NAME, pruning_type_name, "seed_" + str(seed))
    else:
        path = os.path.join("results", EXP_NAME, CLIENT_NAME, pruning_type_name, "level_" + str(n_pruning_levels),
                            "seed_" + str(seed))

    train_loader = get_train_loader(EXP_NAME, train_batch_size=MNIST.N_TRAIN, shuffle=False, flatten=True,
                                    train_set_indices=data_indices, one_hot=True, n_workers=16, pin_memory=True)

    train_iter = DataIterator(data_loader=train_loader, batch_size=batch_size)
    list_t_computation, list_t_communication = [], []
    print("CLIENT. PRUNING TYPE = {}, N_PRUNING_LEVELS = {}, SEED = {}.".format(pruning_type_name, n_pruning_levels,
                                                                                seed))

    # Write system status into files
    # status_writer_process = Process(target=_status_writer)
    # status_writer_process.start()

    while True:
        t_start = timer()
        model.zero_grad()
        for _ in range(n_local_updates):
            inputs, labels = train_iter.get_next_batch()
            loss = model.loss(inputs, labels)
            loss.backward()
            model.apply_grad()

        t_comp = timer()

        client.send_msg(ClientToServerUpdateMessage([model.state_dict(), batch_size]))
        update_msg = client.recv_update_msg()
        state_dict = update_msg.state_dict

        model.load_state_dict(state_dict)

        t_end = timer()
        list_t_computation.append(t_comp - t_start)
        list_t_communication.append(t_end - t_comp)

        mkdir_save(list_t_computation, os.path.join(path, "computation_time"))
        mkdir_save(list_t_communication, os.path.join(path, "communication_time"))

        terminate = update_msg.terminate
        if terminate:
            print("Task completed")
            break

    client.close()
Example #3
0
 def save_exp_config(self):
     exp_config = {"exp_name": EXP_NAME, "seed": args.seed, "batch_size": CLIENT_BATCH_SIZE,
                   "num_local_updates": NUM_LOCAL_UPDATES, "mdd": MAX_DEC_DIFF, "init_lr": INIT_LR,
                   "ahl": ADJ_HALF_LIFE, "use_adaptive": self.use_adaptive,
                   "client_selection": args.client_selection}
     if self.client_selection:
         exp_config["num_users"] = num_users
     mkdir_save(exp_config, os.path.join(self.save_path, "exp_config.pt"))
Example #4
0
    def __init__(self, args, config, model, save_interval=50):
        self.config = config
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.experiment_name = args.experiment_name
        self.save_path = os.path.join("results", config.EXP_NAME,
                                      args.experiment_name)
        self.save_interval = save_interval
        self.mode = args.mode
        assert self.mode in ["r", "rr"]

        self.model = model.to(self.device)
        self.adaptive_folder = "adaptive{}{}".format(
            "_target" if args.targeted else "",
            "_cs" if args.client_selection else "")
        init_model_path = os.path.join("results", config.EXP_NAME,
                                       self.adaptive_folder, "init_model.pt")
        final_model_path = os.path.join("results", config.EXP_NAME,
                                        self.adaptive_folder, "model.pt")
        final_model = load(final_model_path)

        # reinit
        if self.mode == "r":
            self.model = load(init_model_path).to(self.device)
            self.model.reinit_from_model(final_model)

        # random reinit, using different seed for initialization but same mask
        elif self.mode == "rr":
            for layer, final_layer in zip(self.model.prunable_layers,
                                          final_model.prunable_layers):
                layer.mask = final_layer.mask.clone().to(layer.mask.device)
        else:
            raise ValueError("Mode {} not supported".format(self.mode))

        with torch.no_grad():
            for layer in self.model.prunable_layers:
                layer.weight.mul_(layer.mask)

        disp_num_params(self.model)

        self.model.train()
        mkdir_save(self.model, os.path.join(self.save_path, "init_model.pt"))

        self.test_loader = None

        self.init_test_loader()
        self.init_clients()
Example #5
0
    def __init__(self, args, config, model, save_interval=50):
        self.config = config
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.experiment_name = args.experiment_name
        self.save_path = os.path.join("results", config.EXP_NAME, args.experiment_name)
        self.save_interval = save_interval

        self.model = model.to(self.device)
        self.model.train()
        mkdir_save(self.model, os.path.join(self.save_path, "init_model.pt"))

        self.num_all_params = self.model.nelement()
        self.k = self.num_all_params
        self.k_aux = int(np.ceil(self.k * 0.9))
        self.control = OcoGradEstimation(np.round(self.num_all_params * 0.002), self.num_all_params)

        self.test_loader = None
        self.prev_model = None

        self.init_test_loader()
        self.init_clients()
Example #6
0
    def __init__(self, args, config, model, save_interval=50):
        self.config = config
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.experiment_name = args.experiment_name
        self.save_path = os.path.join("results", config.EXP_NAME,
                                      args.experiment_name)
        self.save_interval = save_interval
        adaptive_folder = "adaptive_cs" if args.client_selection else "adaptive"
        self.prune_rate = 1 - load("results/{}/{}/model.pt".format(
            config.EXP_NAME,
            adaptive_folder)).density()**(1 / config.NUM_ITERATIVE_PRUNING)

        self.model = model.to(self.device)
        self.model.train()
        mkdir_save(self.model, os.path.join(self.save_path, "init_model.pt"))

        self.test_loader = None

        self.init_test_loader()
        self.init_clients()
Example #7
0
    def __init__(self, args, config, model, save_interval=50):
        self.config = config
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.experiment_name = args.experiment_name
        self.save_path = os.path.join("results", config.EXP_NAME,
                                      args.experiment_name)
        self.save_interval = save_interval
        self.use_adaptive = args.use_adaptive
        self.client_selection = args.client_selection

        if self.use_adaptive:
            print("Init max dec = {}. "
                  "Adjustment dec half-life = {}. "
                  "Adjustment interval = {}.".format(self.config.MAX_DEC_DIFF,
                                                     self.config.ADJ_HALF_LIFE,
                                                     self.config.ADJ_INTERVAL))

        self.model = model.to(self.device)
        self.model.train()
        mkdir_save(self.model, os.path.join(self.save_path, "init_model.pt"))

        self.indices = None

        self.ip_train_loader = None
        self.ip_test_loader = None
        self.ip_optimizer_wrapper = None
        self.ip_control = None

        self.test_loader = None
        self.control = None
        self.init_test_loader()
        self.init_clients()
        self.init_control()
        self.init_ip_config()
        self.save_exp_config()
Example #8
0
    def main(self, idx, list_sd, list_num_proc, list_data_proc, lr, start, list_loss, list_acc, list_est_time,
             list_model_size):
        total_num_proc = sum(list_num_proc)

        with torch.no_grad():
            for key, param in self.model.state_dict().items():
                avg_inc_val = None
                for num_proc, state_dict in zip(list_num_proc, list_sd):
                    if key in state_dict.keys():
                        mask = self.model.get_mask_by_name(key)
                        if mask is None:
                            inc_val = state_dict[key] - param
                        else:
                            inc_val = state_dict[key] - param * self.model.get_mask_by_name(key)

                        if avg_inc_val is None:
                            avg_inc_val = num_proc / total_num_proc * inc_val
                        else:
                            avg_inc_val += num_proc / total_num_proc * inc_val

                if avg_inc_val is None or key.endswith("num_batches_tracked"):
                    continue
                else:
                    param.add_(avg_inc_val)

        if idx % self.config.EVAL_DISP_INTERVAL == 0:
            loss, acc = self.model.evaluate(self.test_loader)
            list_loss.append(loss)
            list_acc.append(acc)

            print("Round #{} (Experiment = {}).".format(idx, self.experiment_name))
            print("Loss/acc (at round {}) = {}/{}".format((len(list_loss) - 1) * self.config.EVAL_DISP_INTERVAL, loss,
                                                          acc))
            print("Estimated time = {}".format(sum(list_est_time)))
            print("Elapsed time = {}".format(timer() - start))
            print("Current lr = {}".format(lr))
            print("Current density = {}".format(self.model.density()))

        # control
        if len(list_acc) >= 2:
            inputs = None
            labels = None
            for data in list_data_proc:
                inp, lab = data
                if inputs is None:
                    inputs = inp
                else:
                    inputs = torch.cat([inputs, inp], dim=0)

                if labels is None:
                    labels = lab
                else:
                    labels = torch.cat([labels, lab], dim=0)
            prev_loss = self.prev_model.evaluate([(inputs, labels)])[0]
            cur_loss = self.model.evaluate([(inputs, labels)])[0]

            retain_by_num(self.model, self.k_aux)
            cur_aux_loss = self.model.evaluate([(inputs, labels)])[0]

            cost = calc_cost(self.k, self.config, self.model)
            if prev_loss > cur_loss and prev_loss > cur_aux_loss:
                cost_aux = calc_cost(self.k_aux, self.config, self.model)
                cost_aux *= (prev_loss - cur_loss) / (prev_loss - cur_aux_loss)
            else:
                cost_aux = None
            self.k, self.k_aux = self.control.tuning_k_grad_sign(self.k, self.k_aux, cost, cost_aux, idx + 1)
            self.k, self.k_aux = int(self.k), int(self.k_aux)
            retain_by_num(self.model, self.k)

        est_time = self.config.TIME_CONSTANT
        for layer, comp_coeff in zip(self.model.prunable_layers, self.config.COMP_COEFFICIENTS):
            est_time += layer.num_weight * (comp_coeff + self.config.COMM_COEFFICIENT)

        model_size = self.model.calc_num_all_active_params(True)
        list_est_time.append(est_time)
        list_model_size.append(model_size)

        if idx % self.save_interval == 0:
            mkdir_save(list_loss, os.path.join(self.save_path, "loss.pt"))
            mkdir_save(list_acc, os.path.join(self.save_path, "accuracy.pt"))
            mkdir_save(list_est_time, os.path.join(self.save_path, "est_time.pt"))
            mkdir_save(list_model_size, os.path.join(self.save_path, "model_size.pt"))
            mkdir_save(self.model, os.path.join(self.save_path, "model.pt"))

        self.prev_model = deepcopy(self.model)

        return [layer.mask for layer in self.model.prunable_layers], [self.model.state_dict() for _ in
                                                                      range(self.config.NUM_CLIENTS)]
Example #9
0
    def main(self, idx, list_sd, list_num_proc, lr, start, list_loss, list_acc,
             list_est_time, list_model_size):
        total_num_proc = sum(list_num_proc)

        with torch.no_grad():
            for key, param in self.model.state_dict().items():
                avg_inc_val = None
                for num_proc, state_dict in zip(list_num_proc, list_sd):
                    if key in state_dict.keys():
                        mask = self.model.get_mask_by_name(key)
                        if mask is None:
                            inc_val = state_dict[key] - param
                        else:
                            inc_val = state_dict[
                                key] - param * self.model.get_mask_by_name(key)

                        if avg_inc_val is None:
                            avg_inc_val = num_proc / total_num_proc * inc_val
                        else:
                            avg_inc_val += num_proc / total_num_proc * inc_val

                if avg_inc_val is None or key.endswith("num_batches_tracked"):
                    continue
                else:
                    param.add_(avg_inc_val)

        if idx % self.config.EVAL_DISP_INTERVAL == 0:
            loss, acc = self.model.evaluate(self.test_loader)
            list_loss.append(loss)
            list_acc.append(acc)

            print("Round #{} (Experiment = {}).".format(
                idx, self.experiment_name))
            print("Loss/acc (at round {}) = {}/{}".format(
                (len(list_loss) - 1) * self.config.EVAL_DISP_INTERVAL, loss,
                acc))
            print("Estimated time = {}".format(sum(list_est_time)))
            print("Elapsed time = {}".format(timer() - start))
            print("Current lr = {}".format(lr))

        est_time = self.config.TIME_CONSTANT
        for layer, comp_coeff in zip(self.model.prunable_layers,
                                     self.config.COMP_COEFFICIENTS):
            est_time += layer.num_weight * (comp_coeff +
                                            self.config.COMM_COEFFICIENT)

        model_size = self.model.calc_num_all_active_params(True)
        list_est_time.append(est_time)
        list_model_size.append(model_size)

        if idx % self.save_interval == 0:
            mkdir_save(list_loss, os.path.join(self.save_path, "loss.pt"))
            mkdir_save(list_acc, os.path.join(self.save_path, "accuracy.pt"))
            mkdir_save(list_est_time,
                       os.path.join(self.save_path, "est_time.pt"))
            mkdir_save(list_model_size,
                       os.path.join(self.save_path, "model_size.pt"))
            mkdir_save(self.model, os.path.join(self.save_path, "model.pt"))

        return [layer.mask for layer in self.model.prunable_layers], [
            self.model.state_dict() for _ in range(self.config.NUM_CLIENTS)
        ]
Example #10
0
def main():
    # Wait for connections
    server = ServerSocket(SERVER_ADDR, SERVER_PORT, N_CLIENTS)
    server.wait_for_connections()
    torch.manual_seed(SEED)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Current device is {}.".format(device))

    list_loss, list_acc, list_time = [], [], []
    model = load_model()
    random_indices = random_split(MNIST.N_TRAIN, N_CLIENTS)
    test_loader = get_test_loader(EXP_NAME,
                                  test_batch_size=MNIST.N_TEST,
                                  shuffle=False,
                                  flatten=True,
                                  one_hot=True,
                                  n_workers=16,
                                  pin_memory=True)
    test_iter = DataIterator(data_loader=test_loader,
                             batch_size=200,
                             device=device)

    init_msgs = [
        ServerToClientInitMessage([
            model,
            torch.tensor(random_indices[idx]), N_LOCAL_UPDATES,
            CLIENT_BATCH_SIZE, (PRUNING_TYPE, N_PRUNING_LEVELS, SEED)
        ]) for idx in range(N_CLIENTS)
    ]

    server.init_connections(init_msgs)
    model.eval()

    print(
        "SERVER. PRUNING TYPE = {}, N_PRUNING_LEVELS = {}, PRUNING_PCT = {}, SEED = {}, N_ITERATIONS = {}."
        .format(PRUNING_TYPE_NAME, N_PRUNING_LEVELS, PRUNING_PCT, SEED,
                N_ITERATIONS))

    prev_thread = None
    t_start = timer()
    for idx in range(N_ITERATIONS):
        t_fed_start = timer()
        msgs = server.recv_update_msg_from_all()

        list_state_dict = [msg.state_dict for msg in msgs]
        avg_state_dict = model.state_dict().copy()
        for key in avg_state_dict.keys():
            new_val = None
            for state_dict in list_state_dict:
                if new_val is None:
                    new_val = state_dict[key]
                else:
                    new_val += state_dict[key]
            new_val /= N_CLIENTS
            avg_state_dict[key] = new_val

        model.load_state_dict(avg_state_dict)

        if idx % EVAL_INTERVAL == 0:
            # Asynchronously evaluate model
            if prev_thread is not None:
                prev_thread.join()
            t = Thread(target=eval_async,
                       args=(model.evaluate, test_iter, list_loss, list_acc))
            t.start()
            prev_thread = t

        if idx % DISP_SAVE_INTERVAL == 0:
            print("Federation #{}".format(idx))
            if len(list_loss) != 0 and len(list_acc) != 0:
                loss, acc = list_loss[-1], list_acc[-1]
                print("Loss/acc at iteration {} = {}/{}".format(
                    (len(list_loss) - 1) * EVAL_INTERVAL, loss, acc))
                t = timer()
                print("Elapsed time = {}".format(t - t_start))

        terminate = True if idx == N_ITERATIONS - 1 or timer(
        ) - t_start >= MAX_TIME else False

        server.send_msg_to_all(
            ServerToClientUpdateMessage([model.state_dict(), terminate]))
        t_fed_end = timer()
        list_time.append(t_fed_end - t_fed_start)
        if terminate:
            prev_thread.join()

        # Saving loss/acc
        mkdir_save(list_loss, os.path.join(DATA_DIR_PATH, "loss"))
        mkdir_save(list_acc, os.path.join(DATA_DIR_PATH, "accuracy"))
        mkdir_save(list_time, os.path.join(DATA_DIR_PATH, "time"))
        mkdir_save(model, os.path.join(DATA_DIR_PATH, "model"))

        if terminate:
            break

    print("Task completed")
    server.close()
Example #11
0
def eval_async(func, iterator, loss_list: list, acc_list: list):
    l, a = func(iterator)
    loss_list.append(l)
    acc_list.append(a)
    mkdir_save(loss_list, os.path.join(DATA_DIR_PATH, "loss"))
    mkdir_save(acc_list, os.path.join(DATA_DIR_PATH, "accuracy"))
Example #12
0
    def main(self, idx, list_sd, list_num_proc, lr, start, list_loss, list_acc, list_est_time, list_model_size):
        total_num_proc = sum(list_num_proc)

        grad_dict = dict()
        weight_dict = dict()
        with torch.no_grad():
            for key, param in self.model.state_dict().items():
                avg_inc_val = None
                for num_proc, state_dict in zip(list_num_proc, list_sd):
                    if key in state_dict.keys():
                        inc_val = state_dict[key] - param

                        if avg_inc_val is None:
                            avg_inc_val = num_proc / total_num_proc * inc_val
                        else:
                            avg_inc_val += num_proc / total_num_proc * inc_val

                if avg_inc_val is None or key.endswith("num_batches_tracked"):
                    continue
                else:
                    if idx == 0 and key in dict(self.model.named_parameters()).keys() and key.endswith(
                            "weight") and key[:-7] in self.model.prunable_layer_prefixes:
                        grad_dict[key] = avg_inc_val / lr
                        weight_dict[key] = dict(self.model.named_parameters())[key].clone()
                    param.add_(avg_inc_val)

        if idx == 0:
            abs_all_wg = None
            for (name_w, w), (name_g, g) in zip(weight_dict.items(), grad_dict.items()):
                assert name_w == name_g
                if abs_all_wg is None:
                    abs_all_wg = (w * g).view(-1).abs()
                else:
                    abs_all_wg = torch.cat([abs_all_wg, (w * g).view(-1).abs()], dim=0)

            threshold = abs_all_wg.sort(descending=True)[0][int(self.density * abs_all_wg.nelement())]
            for layer, layer_prefix in zip(self.model.prunable_layers, self.model.prunable_layer_prefixes):
                abs_layer_wg = (weight_dict[layer_prefix + ".weight"] * grad_dict[layer_prefix + ".weight"]).abs()
                layer.mask = abs_layer_wg >= threshold

            with torch.no_grad():
                for layer in self.model.prunable_layers:
                    layer.weight *= layer.mask

            print("Snip pruning completed. Remaining params:")
            disp_num_params(self.model)

        if idx % self.config.EVAL_DISP_INTERVAL == 0:
            loss, acc = self.model.evaluate(self.test_loader)
            list_loss.append(loss)
            list_acc.append(acc)

            print("Round #{} (Experiment = {}).".format(idx, self.experiment_name))
            print("Loss/acc (at round {}) = {}/{}".format((len(list_loss) - 1) * self.config.EVAL_DISP_INTERVAL, loss,
                                                          acc))
            print("Estimated time = {}".format(sum(list_est_time)))
            print("Elapsed time = {}".format(timer() - start))

        est_time = self.config.TIME_CONSTANT
        for layer, comp_coeff in zip(self.model.prunable_layers, self.config.COMP_COEFFICIENTS):
            est_time += layer.num_weight * (comp_coeff + self.config.COMM_COEFFICIENT)

        model_size = self.model.calc_num_all_active_params(True)
        list_est_time.append(est_time)
        list_model_size.append(model_size)

        if idx % self.save_interval == 0:
            mkdir_save(list_loss, os.path.join(self.save_path, "loss.pt"))
            mkdir_save(list_acc, os.path.join(self.save_path, "accuracy.pt"))
            mkdir_save(list_est_time, os.path.join(self.save_path, "est_time.pt"))
            mkdir_save(list_model_size, os.path.join(self.save_path, "model_size.pt"))
            mkdir_save(self.model, os.path.join(self.save_path, "model.pt"))

        return [layer.mask for layer in self.model.prunable_layers], [self.model.state_dict() for _ in
                                                                      range(self.config.NUM_CLIENTS)]
Example #13
0
    def main(self,
             idx,
             list_sd,
             list_num_proc,
             lr,
             list_accumulated_sgrad,
             start,
             list_loss,
             list_acc,
             list_est_time,
             list_model_size,
             is_adj_round,
             density_limit=None):
        total_num_proc = sum(list_num_proc)

        with torch.no_grad():
            for key, param in self.model.state_dict().items():
                avg_inc_val = None
                for num_proc, state_dict in zip(list_num_proc, list_sd):
                    if key in state_dict.keys():
                        mask = self.model.get_mask_by_name(key)
                        if mask is None:
                            inc_val = state_dict[key] - param
                        else:
                            inc_val = state_dict[
                                key] - param * self.model.get_mask_by_name(key)

                        if avg_inc_val is None:
                            avg_inc_val = num_proc / total_num_proc * inc_val
                        else:
                            avg_inc_val += num_proc / total_num_proc * inc_val

                if avg_inc_val is None or key.endswith("num_batches_tracked"):
                    continue
                else:
                    param.add_(avg_inc_val)

        if idx % self.config.EVAL_DISP_INTERVAL == 0:
            loss, acc = self.model.evaluate(self.test_loader)
            list_loss.append(loss)
            list_acc.append(acc)

            print("Round #{} (Experiment = {}).".format(
                idx, self.experiment_name))
            print("Loss/acc (at round #{}) = {}/{}".format(
                (len(list_loss) - 1) * self.config.EVAL_DISP_INTERVAL, loss,
                acc))
            print("Estimated time = {}".format(sum(list_est_time)))
            print("Elapsed time = {}".format(timer() - start))
            print("Current lr = {}".format(lr))

        if self.use_adaptive and is_adj_round:
            alg_start = timer()

            for d in list_accumulated_sgrad:
                for k, sg in d.items():
                    self.control.accumulate(k, sg)

            print("Running adaptive pruning algorithm")
            max_dec_diff = self.config.MAX_DEC_DIFF * (0.5**(
                idx / self.config.ADJ_HALF_LIFE))
            self.control.adjust(max_dec_diff, max_density=density_limit)
            print("Total alg time = {}. Max density = {}.".format(
                timer() - alg_start, density_limit))
            print("Num params:")
            disp_num_params(self.model)

        est_time = self.config.TIME_CONSTANT
        for layer, comp_coeff in zip(self.model.prunable_layers,
                                     self.config.COMP_COEFFICIENTS):
            est_time += layer.num_weight * (comp_coeff +
                                            self.config.COMM_COEFFICIENT)

        model_size = self.model.calc_num_all_active_params(True)
        list_est_time.append(est_time)
        list_model_size.append(model_size)

        if idx % self.save_interval == 0:
            mkdir_save(list_loss, os.path.join(self.save_path, "loss.pt"))
            mkdir_save(list_acc, os.path.join(self.save_path, "accuracy.pt"))
            mkdir_save(list_est_time,
                       os.path.join(self.save_path, "est_time.pt"))
            mkdir_save(list_model_size,
                       os.path.join(self.save_path, "model_size.pt"))
            mkdir_save(self.model, os.path.join(self.save_path, "model.pt"))

        return [layer.mask for layer in self.model.prunable_layers], [
            self.model.state_dict() for _ in range(self.config.NUM_CLIENTS)
        ]
Example #14
0
 def save_exp(self):
     mkdir_save(self.list_loss, os.path.join(self.save_path, "loss.pt"))
     mkdir_save(self.list_acc, os.path.join(self.save_path, "accuracy.pt"))
     mkdir_save(self.list_time_stamp, os.path.join(self.save_path, "time.pt"))
     mkdir_save(self.list_model_size, os.path.join(self.save_path, "model_size.pt"))
     mkdir_save(self.model, os.path.join(self.save_path, "model.pt"))