Ejemplo n.º 1
0
 def init_ip_config(self):
     self.ip_train_loader = get_data_loader(EXP_NAME, data_type="train", batch_size=CLIENT_BATCH_SIZE, shuffle=True,
                                            num_workers=8, user_list=[0], pin_memory=True)
     self.ip_test_loader = get_data_loader(EXP_NAME, data_type="test", num_workers=8, pin_memory=True)
     ip_optimizer = SGD(self.model.parameters(), lr=INIT_LR)
     self.ip_optimizer_wrapper = OptimizerWrapper(self.model, ip_optimizer)
     self.ip_control = ControlModule(model=self.model, config=config)
Ejemplo n.º 2
0
    def convert_to_sparse(self):
        self.model = self.model.to_sparse()
        old_lr = self.optimizer.state_dict()["param_groups"][0]["lr"]
        self.optimizer = self.exp_config.optimizer_class(params=self.model.parameters(),
                                                         **self.exp_config.optimizer_params)
        if self.exp_config.lr_scheduler_class is not None:
            lr_scheduler_state_dict = deepcopy(self.lr_scheduler.state_dict())
            self.lr_scheduler = self.exp_config.lr_scheduler_class(optimizer=self.optimizer,
                                                                   **self.exp_config.lr_scheduler_params)
            self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)
        self.optimizer.param_groups[0]["lr"] = old_lr
        self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer, self.lr_scheduler)

        self.is_sparse = True

        print("Model converted to sparse")
Ejemplo n.º 3
0
 def init_optimizer(self):
     self.optimizer = SGD(self.model.parameters(),
                          lr=INIT_LR,
                          momentum=MOMENTUM,
                          weight_decay=WEIGHT_DECAY)
     self.optimizer_scheduler = lr_scheduler.StepLR(
         self.optimizer,
         step_size=STEP_SIZE,
         gamma=0.5**(STEP_SIZE / LR_HALF_LIFE))
     self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer,
                                               self.optimizer_scheduler)
Ejemplo n.º 4
0
    def __init__(self, network_config, max_try=100):
        self.network_config = network_config
        self.socket = ClientSocket(network_config.SERVER_ADDR, network_config.SERVER_PORT)
        self.train_loader = None

        init_msg = self.socket.init_connections(max_try)
        self.client_id = init_msg.client_id

        self.exp_config = init_msg.exp_config

        torch.manual_seed(self.exp_config.seed)

        # self.save_path = os.path.join("results", "exp_{}".format(self.exp_config.exp_name),
        #                               self.exp_config.save_dir_name, "client_{}".format(self.client_id))

        self.model = init_msg.model
        self.model.train()

        self.optimizer = self.exp_config.optimizer_class(params=self.model.parameters(),
                                                         **self.exp_config.optimizer_params)
        self.lr_scheduler = None
        if self.exp_config.lr_scheduler_class is not None:
            self.lr_scheduler = self.exp_config.lr_scheduler_class(optimizer=self.optimizer,
                                                                   **self.exp_config.lr_scheduler_params)
        self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer, self.lr_scheduler)

        if self.exp_config.use_adaptive:
            self.dict_extra_sgrad = dict()
            self.accum_dense_grad = dict()

        self.is_adj_round = False
        self.is_sparse = False
        self.terminate = False
        self.parse_init_extra_params(init_msg.extra_params)

        resume, cur_round, resume_to_sparse = init_msg.resume_params
        self.initialize(resume, cur_round, resume_to_sparse)
Ejemplo n.º 5
0
 def init_optimizer(self):
     self.optimizer = SGD(self.model.parameters(), lr=INIT_LR)
     self.optimizer_scheduler = lr_scheduler.StepLR(
         self.optimizer, step_size=1, gamma=0.5**(1 / LR_HALF_LIFE))
     self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer,
                                               self.optimizer_scheduler)
Ejemplo n.º 6
0
 def init_optimizer(self):
     self.optimizer = SGD(self.model.parameters(), lr=INIT_LR)
     self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer)
Ejemplo n.º 7
0
        server_adjust_interval))
    server_loader = get_data_loader(config.EXP_NAME,
                                    data_type="train",
                                    batch_size=config.CLIENT_BATCH_SIZE,
                                    shuffle=True,
                                    num_workers=8,
                                    user_list=[0],
                                    pin_memory=True)
    server_inputs, server_outputs = [], []
    for _ in range(num_pre_batch):
        inp, out = server_loader.get_next_batch()
        server_inputs.append(inp)
        server_outputs.append(out)

    server_optimizer = SGD(model.parameters(), lr=config.INIT_LR)
    server_optimizer_wrapper = OptimizerWrapper(model, server_optimizer)
    server_control = ControlModule(model=model, config=config)

    prev_density, prev_num, prev_ind = None, 5, []
    for server_i in range(1, server_pruning_rounds + 1):
        for server_inp, server_out in zip(server_inputs, server_outputs):
            list_grad = server_optimizer_wrapper.step(server_inp, server_out)
            for (key, param), g in zip(model.named_parameters(), list_grad):
                assert param.size() == g.size()
                server_control.accumulate(key, g**2)

        if server_i % server_adjust_interval == 0:
            server_control.adjust(config.MAX_DEC_DIFF)
            cur_density = disp_num_params(model)

            if prev_density is not None:
Ejemplo n.º 8
0
class Client(ABC):
    def __init__(self, network_config, max_try=100):
        self.network_config = network_config
        self.socket = ClientSocket(network_config.SERVER_ADDR, network_config.SERVER_PORT)
        self.train_loader = None

        init_msg = self.socket.init_connections(max_try)
        self.client_id = init_msg.client_id

        self.exp_config = init_msg.exp_config

        torch.manual_seed(self.exp_config.seed)

        # self.save_path = os.path.join("results", "exp_{}".format(self.exp_config.exp_name),
        #                               self.exp_config.save_dir_name, "client_{}".format(self.client_id))

        self.model = init_msg.model
        self.model.train()

        self.optimizer = self.exp_config.optimizer_class(params=self.model.parameters(),
                                                         **self.exp_config.optimizer_params)
        self.lr_scheduler = None
        if self.exp_config.lr_scheduler_class is not None:
            self.lr_scheduler = self.exp_config.lr_scheduler_class(optimizer=self.optimizer,
                                                                   **self.exp_config.lr_scheduler_params)
        self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer, self.lr_scheduler)

        if self.exp_config.use_adaptive:
            self.dict_extra_sgrad = dict()
            self.accum_dense_grad = dict()

        self.is_adj_round = False
        self.is_sparse = False
        self.terminate = False
        self.parse_init_extra_params(init_msg.extra_params)

        resume, cur_round, resume_to_sparse = init_msg.resume_params
        self.initialize(resume, cur_round, resume_to_sparse)

    @torch.no_grad()
    def load_state_dict(self, state_dict):
        param_dict = dict(self.model.named_parameters())
        buffer_dict = dict(self.model.named_buffers())
        for key, param in {**param_dict, **buffer_dict}.items():
            if key in state_dict.keys():
                if state_dict[key].size() != param.size():
                    # sparse param with value only
                    param._values().copy_(state_dict[key])
                elif state_dict[key].is_sparse:
                    # sparse param at adjustment round
                    # print(param, param.size(), state_dict[key].is_sparse, state_dict[key])
                    # param.zero_()
                    param.copy_(state_dict[key])
                    # param._indices().copy_(state_dict[key]._indices())
                    # param._values().copy_(state_dict[key]._values())
                    # need to reload mask in this case
                    param.mask.copy_(state_dict[key].mask)
                else:
                    param.copy_(state_dict[key])

    def initialize(self, resume, cur_round, resume_to_sparse):
        if resume:
            print("Resuming client...")
            # move optimizer to the right position
            for _ in range(cur_round * self.exp_config.num_local_updates):
                self.optimizer_wrapper.lr_scheduler_step()

            # move train loader to the right position
            remaining_batches = cur_round * self.exp_config.num_local_updates
            num_batches_epoch = len(self.train_loader)
            while remaining_batches >= num_batches_epoch:
                self.train_loader.skip_epoch()
                remaining_batches -= num_batches_epoch
            for _ in range(remaining_batches):
                self.train_loader.get_next_batch()

            if resume_to_sparse:
                self.convert_to_sparse()

            print("Client resumed")
        else:
            print("Client initialized")

    @abstractmethod
    def parse_init_extra_params(self, extra_params):
        # Initialize train_loader, etc.
        pass

    def cleanup_state_dict_to_server(self) -> dict:
        """
        Clean up state dict before process, e.g. remove entries, transpose.
        To be overridden by subclasses.
        """
        clean_state_dict = copy_dict(self.model.state_dict())  # not deepcopy
        if self.is_sparse:
            for layer, prefix in zip(self.model.param_layers, self.model.param_layer_prefixes):
                key = prefix + ".bias"
                if isinstance(layer, SparseLinear) and key in clean_state_dict.keys():
                    clean_state_dict[key] = clean_state_dict[key].view(-1)

            del_list = []
            del_suffix = "placeholder"
            for key in clean_state_dict.keys():
                if key.endswith(del_suffix):
                    del_list.append(key)

            for del_key in del_list:
                del clean_state_dict[del_key]

        return clean_state_dict

    @torch.no_grad()
    def process_state_dict_to_server(self) -> dict:
        """
        Process state dict before sending to server, e.g. keep values only, extra param in adjustment round.
        if not self.is_sparse: send dense
        elif self.adjustment_round: send sparse values + extra grad values
        else: send sparse values only
        To be overridden by subclasses.
        """
        clean_state_dict = self.cleanup_state_dict_to_server()

        if self.is_sparse:
            for key, param in clean_state_dict.items():
                if param.is_sparse:
                    clean_state_dict[key] = param._values()

        if self.is_adj_round:
            clean_state_dict.update(self.dict_extra_sgrad)
            self.dict_extra_sgrad = dict()

        return clean_state_dict

    def convert_to_sparse(self):
        self.model = self.model.to_sparse()
        old_lr = self.optimizer.state_dict()["param_groups"][0]["lr"]
        self.optimizer = self.exp_config.optimizer_class(params=self.model.parameters(),
                                                         **self.exp_config.optimizer_params)
        if self.exp_config.lr_scheduler_class is not None:
            lr_scheduler_state_dict = deepcopy(self.lr_scheduler.state_dict())
            self.lr_scheduler = self.exp_config.lr_scheduler_class(optimizer=self.optimizer,
                                                                   **self.exp_config.lr_scheduler_params)
            self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)
        self.optimizer.param_groups[0]["lr"] = old_lr
        self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer, self.lr_scheduler)

        self.is_sparse = True

        print("Model converted to sparse")

    def accumulate_dense_grad_round(self):
        for key, param in self.model.named_parameters():
            if hasattr(param, "is_sparse_param"):
                if key in self.accum_dense_grad.keys():
                    self.accum_dense_grad[key] += param.dense.grad
                else:
                    self.accum_dense_grad[key] = param.dense.grad

    def accumulate_sgrad(self, num_proc_data):
        prefix = "extra."
        for key, param in self.accum_dense_grad.items():
            pkey = prefix + key
            if pkey in self.dict_extra_sgrad.keys():
                self.dict_extra_sgrad[pkey] += (param ** 2) * num_proc_data
            else:
                self.dict_extra_sgrad[pkey] = (param ** 2) * num_proc_data

            if self.is_adj_round:
                param_mask = dict(self.model.named_parameters())[key].mask == 0.
                self.dict_extra_sgrad[pkey] = self.dict_extra_sgrad[pkey].masked_select(param_mask)

    def main(self):
        num_proc_data = 0
        for _ in range(self.exp_config.num_local_updates):
            inputs, labels = self.train_loader.get_next_batch()
            self.optimizer_wrapper.step(inputs, labels)

            if self.exp_config.use_adaptive:
                self.accumulate_dense_grad_round()

            num_proc_data += len(inputs)

        if self.exp_config.use_adaptive:
            self.accumulate_sgrad(num_proc_data)
            self.accum_dense_grad = dict()

        lr = self.optimizer_wrapper.get_last_lr()

        state_dict_to_server = self.process_state_dict_to_server()
        msg_to_server = ClientToServerUpdateMessage((state_dict_to_server, num_proc_data, lr))
        self.socket.send_msg(msg_to_server)

        update_msg = self.socket.recv_update_msg()
        self.is_adj_round = update_msg.adjustment
        if not self.is_sparse and update_msg.to_sparse:
            self.convert_to_sparse()

        state_dict_received = update_msg.state_dict
        self.load_state_dict(state_dict_received)

        self.optimizer_wrapper.lr_scheduler_step()

        terminate = update_msg.terminate
        if terminate:
            self.socket.send_ack_msg()
            self.socket.close()
            self.terminate = True
            print("Task completed")

        return terminate