Esempio n. 1
0
    def handle_message_receive_model_from_client(self, msg_params):
        sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER)
        model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
        # local参数的数量 公式中的ni
        local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES)

        # aggregator 会记录客户端发过来的参数和客户端id
        self.aggregator.add_local_trained_result(sender_id - 1, model_params, local_sample_number)
        # 检查是否所有人都发送完毕了
        b_all_received = self.aggregator.check_whether_all_receive()
        logging.info("b_all_received = " + str(b_all_received))
        if b_all_received:
            # 启动aggregation
            global_model_params = self.aggregator.aggregate()
            self.aggregator.test_on_all_clients(self.round_idx)

            # start the next round
            self.round_idx += 1
            if self.round_idx == self.round_num:
                self.finish()
                return

            # sampling clients 下一轮采样 (用户量比较大的情况) uniform sampling
            client_indexes = self.aggregator.client_sampling(self.round_idx, self.args.client_num_in_total,
                                                             self.args.client_num_per_round)
            print("size = %d" % self.size)
            if self.args.is_mobile == 1:
                print("transform_tensor_to_list")
                global_model_params = transform_tensor_to_list(global_model_params)

            for receiver_id in range(1, self.size):
                self.send_message_sync_model_to_client(receiver_id, global_model_params, client_indexes[receiver_id-1])
Esempio n. 2
0
    def train(self):
        self.model.to(self.device)
        # change to train mode
        self.model.train()

        epoch_loss = []
        for epoch in range(self.args.epochs):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.train_local):
                # logging.info(images.shape)
                images, labels = images.to(self.device), labels.to(self.device)
                self.optimizer.zero_grad()
                log_probs = self.model(images)
                loss = self.criterion(log_probs, labels)
                loss.backward()
                self.optimizer.step()
                batch_loss.append(loss.item())
            if len(batch_loss) > 0:
                epoch_loss.append(sum(batch_loss) / len(batch_loss))
                logging.info(
                    '(client {}. Local Training Epoch: {} \tLoss: {:.6f}'.
                    format(self.client_index, epoch,
                           sum(epoch_loss) / len(epoch_loss)))

        weights = self.model.cpu().state_dict()

        # transform Tensor to list
        if self.args.is_mobile == 1:
            weights = transform_tensor_to_list(weights)
        return weights, self.local_sample_number
    def handle_message_receive_model_from_client(self, msg_params):
        sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER)
        model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
        local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES)

        self.aggregator.add_local_trained_result(sender_id - 1, model_params,
                                                 local_sample_number)
        b_all_received = self.aggregator.check_whether_all_receive()
        logging.info("b_all_received = " + str(b_all_received))
        if b_all_received:
            global_model_params = self.aggregator.aggregate()
            self.aggregator.test_on_all_clients(self.round_idx)

            # measure the target task accuracy
            self.aggregator.test_target_accuracy(self.round_idx)

            # start the next round
            self.round_idx += 1
            if self.round_idx == self.round_num:
                self.finish()
                return

            # sampling clients
            client_indexes = self.aggregator.client_sampling(
                self.round_idx, self.args.client_num_in_total,
                self.args.client_num_per_round)
            print("size = %d" % self.size)
            if self.args.is_mobile == 1:
                print("transform_tensor_to_list")
                global_model_params = transform_tensor_to_list(
                    global_model_params)

            for receiver_id in range(1, self.size):
                self.send_message_sync_model_to_client(
                    receiver_id, global_model_params,
                    client_indexes[receiver_id - 1])
Esempio n. 4
0
    def train(self):
        try:
            if self.args.client_num_in_total == self.args.client_num_per_round:
                #apply residue
                weights = self.model.cpu().state_dict()
                for k in self.quant_residue.keys():
                    #if 'bias' in k:
                    #    logging.info(str(k))
                    if 'weight' in k and 'bn' not in k:
                        weights[k] = copy.deepcopy(weights[k] +
                                                   self.quant_residue[k])
                    elif 'bias' in k and 'bn' not in k:
                        weights[k] = copy.deepcopy(weights[k] +
                                                   self.quant_residue[k])
                self.update_model(weights)
            else:
                logging.info('residue compensation skipped')
            #logging.info('Before training starts..............')
            self.model.to(self.device)
            # change to train mode
            self.model.train()
            #logging.info('model init done !!!!!!!!!!!!!')
            epoch_loss = []
            for epoch in range(self.args.epochs):
                batch_loss = []
                try:
                    for batch_idx, (x, labels) in enumerate(self.train_local):

                        #logging.info('Beginning of training on one batch !!!!!!!!!!!!!!!!!!!!!!!!!!')

                        _iters = self.glb_epoch * len(
                            self.train_local) + batch_idx
                        #cyclic_period = int((self.args.comm_round * self.args.epochs * len(self.train_local)) // self.cyclic_period)
                        cyclic_period = int(
                            (self.args.comm_round * self.args.epochs *
                             len(self.train_local)) // 16)
                        #logging.info("cyclic period: "+ str(cyclic_period))
                        #cyclic_period = int((self.args.comm_round * self.args.epochs * len(self.train_local)) // self.cyclic_period)*2
                        #if self.glb_epoch % 2 == 0:
                        #    self.args.cyclic_num_bits_schedule=[8,32]
                        #else:
                        #    self.args.cyclic_num_bits_schedule=[4,32]

                        if (self.args.cyclic_num_bits_schedule[0] == 0
                                or self.args.cyclic_num_bits_schedule[1] == 0):
                            num_bits = 0
                        #elif self.glb_epoch>=self.args.comm_round-3:
                        #    num_bits = 8
                        # elif self.glb_epoch>=self.args.comm_round-10:
                        #     #self.args.inference_bits=32
                        #     self.args.cyclic_num_bits_schedule=[4,32]
                        #     cyclic_period = int((self.args.comm_round * len(self.train_local)) // 64)
                        #     offset=self.offset_finder(self.args.cyclic_num_bits_schedule[1],cyclic_period,len(self.train_local),self.lr_steps)
                        #     offseted_iters=min(max(0,_iters-offset),self.lr_steps)
                        #     num_bits = self.cyclic_adjust_precision(offseted_iters, cyclic_period)
                        else:
                            offset = self.offset_finder(
                                self.args.cyclic_num_bits_schedule[1],
                                cyclic_period, len(self.train_local),
                                self.lr_steps)
                            offseted_iters = min(max(0, _iters - offset),
                                                 self.lr_steps)
                            num_bits = self.cyclic_adjust_precision(
                                offseted_iters, cyclic_period)
                            # num_bits = self.cyclic_adjust_precision(_iters, cyclic_period)
                            #if num_bits == 32:
                            #    num_bits =0
                        #logging.info('Right before data moving starts!!!!!')
                        # logging.info(images.shape)
                        x, labels = x.to(self.device), labels.to(self.device)
                        self.optimizer.zero_grad()
                        # if epoch < 10 and  self.first_run:
                        #     log_probs = self.model(x, num_bits=0)
                        # else:
                        #logging.info('Right before training started !!!!!!!!!!!!!!!!!')
                        log_probs = self.model(x, num_bits=num_bits)
                        loss = self.criterion(log_probs, labels)
                        loss.backward()
                        if self.args.client_num_in_total == self.args.client_num_per_round:
                            g_norm = nn.utils.clip_grad_norm_(
                                self.model.parameters(), 0.9, 'inf')
                        else:
                            g_norm = nn.utils.clip_grad_norm_(
                                self.model.parameters(), 0.9, 'inf')
                        #logging.info(str(g_norm))
                        self.optimizer.step()
                        batch_loss.append(loss.item())

                        self.scheduler.step()

                        #logging.info('End of training on one batch !!!!!!!!!!!!!!!!!!!!!!!!!!')

                    self.glb_epoch += 1
                    if len(batch_loss) > 0:
                        epoch_loss.append(sum(batch_loss) / len(batch_loss))
                        logging.info(
                            '(client {}. Local Training Epoch: {} \tLoss: {:.6f}'
                            .format(self.client_index, epoch,
                                    sum(epoch_loss) / len(epoch_loss)))
                except Exception as e:
                    logging.info(str(e))
            # if self.comm_round < (self.args.lr_decay_step_size+1):
            #     self.scheduler.step()
            self.comm_round += 1
            for g in self.optimizer.param_groups:
                logging.info("===current learning rate===: " + str(g['lr']))
                break
            logging.info("========= number of batches =======: " +
                         str(batch_idx + 1))
            #logging.info("========= Transmitted bits ========: "+str(num_bits))
            self.first_run = False

            weights = self.model.cpu().state_dict()

            # transform Tensor to list
            if self.args.is_mobile == 1:
                weights = transform_tensor_to_list(weights)
            latent_weight = copy.deepcopy(weights)
            logging.info('Quantizing model')
            if num_bits != 0:
                for k in weights.keys():
                    #if 'bias' in k:
                    #    logging.info(str(k))
                    try:
                        if 'weight' in k and 'bn' not in k and 'downsample.1' not in k:
                            print(k)
                            weight_qparams = calculate_qparams(
                                copy.deepcopy(weights[k]),
                                num_bits=num_bits,
                                flatten_dims=(1, -1),
                                reduce_dim=None)
                            weights[k] = quantize(copy.deepcopy(weights[k]),
                                                  qparams=weight_qparams)
                            self.quant_residue[k] = copy.deepcopy(
                                latent_weight[k] - weights[k])
                        elif 'bias' in k and 'bn' not in k and 'downsample.1' not in k:
                            weights[k] = quantize(copy.deepcopy(weights[k]),
                                                  num_bits=num_bits,
                                                  flatten_dims=(0, -1))
                            self.quant_residue[k] = copy.deepcopy(
                                latent_weight[k] - weights[k])
                    except Exception as e:
                        logging.info(str(k))
                        logging.info(str(e))
                        exc_type, exc_obj, exc_tb = sys.exc_info()
                        fname = os.path.split(
                            exc_tb.tb_frame.f_code.co_filename)[1]
                        logging.info(
                            str(exc_type) + " " + str(fname) + " " +
                            str(exc_tb.tb_lineno))

            logging.info('Sending model')
            return weights, self.local_sample_number, num_bits, latent_weight
        except Exception as e:
            logging.info(str(e))
            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
            logging.info(
                str(exc_type) + " " + str(fname) + " " + str(exc_tb.tb_lineno))