Beispiel #1
0
    def client_update(self, global_model, global_init_model, round_index):
        self.elapsed_comm_rounds += 1
        print(f'***** Client #{self.client_id} *****', flush=True)
        self.model = copy_model(global_model,
                                self.args.dataset, self.args.arch,
                                dict(self.model.named_buffers()))

        num_pruned, num_params = get_prune_summary(self.model)
        cur_prune_rate = num_pruned / num_params
        #prune_step = math.floor(num_params * self.args.prune_step)

        eval_score = evaluate(self.model,
                              self.test_loader,
                              verbose=self.args.test_verbosity)

        if eval_score['Accuracy'][
                0] > self.args.acc_thresh and cur_prune_rate < self.args.prune_percent:
            # I'm adding 0.001 just to ensure we go clear the target prune_percent. This may not be needed
            prune_fraction = min(
                self.args.prune_step,
                0.001 + self.args.prune_percent - cur_prune_rate)
            prune_fixed_amount(self.model,
                               prune_fraction,
                               verbose=self.args.prune_verbosity,
                               glob=True)
            self.model = copy_model(global_init_model, self.args.dataset,
                                    self.args.arch,
                                    dict(self.model.named_buffers()))
        losses = []
        accuracies = []
        for i in range(self.args.client_epoch):
            train_score = train(round_index,
                                self.client_id,
                                i,
                                self.model,
                                self.train_loader,
                                lr=self.args.lr,
                                verbose=self.args.train_verbosity)

            losses.append(train_score['Loss'][-1].data.item())
            accuracies.append(train_score['Accuracy'][-1])

        mask_log_path = f'{self.args.log_folder}/round{round_index}/c{self.client_id}.mask'
        client_mask = dict(self.model.named_buffers())
        log_obj(mask_log_path, client_mask)

        num_pruned, num_params = get_prune_summary(self.model)
        cur_prune_rate = num_pruned / num_params
        prune_step = math.floor(num_params * self.args.prune_step)
        print(
            f"num_pruned {num_pruned}, num_params {num_params}, cur_prune_rate {cur_prune_rate}, prune_step: {prune_step}"
        )

        self.losses[round_index:] = np.array(losses)
        self.accuracies[round_index:] = np.array(accuracies)
        self.prune_rates[round_index:] = cur_prune_rate

        return copy_model(self.model, self.args.dataset, self.args.arch)
    def client_update(self, global_model, global_init_model, comm_round):
        self.elapsed_comm_rounds += 1
        print(f'***** Client #{self.client_id} *****', flush=True)
        self.model = copy_model(global_model, self.args.dataset,
                                self.args.arch)

        losses = []
        accuracies = []
        for epoch in range(self.args.client_epoch):
            train_score = train(comm_round,
                                self.client_id,
                                epoch,
                                self.model,
                                self.train_loader,
                                lr=self.args.lr,
                                verbose=self.args.train_verbosity)
            losses.append(train_score['Loss'][-1].data.item())
            accuracies.append(train_score['Accuracy'][-1])

        num_pruned, num_params = get_prune_summary(self.model)
        cur_prune_rate = num_pruned / num_params
        print(
            f"num_pruned {num_pruned}, num_params {num_params}, cur_prune_rate {cur_prune_rate}"
        )
        self.losses[comm_round:] = np.array(losses)
        self.accuracies[comm_round:] = np.array(accuracies)
        self.prune_rates[comm_round:] = cur_prune_rate
    def server_update(self):
        self.elapsed_comm_rounds += 1
        self.global_models.train()

        for comm_round in range(self.comm_rounds):
            selected_clients = np.random.choice(
                self.num_clients,
                max(int(self.frac * self.num_clients), 1),
                replace=False)
            print('-------------------------------------', flush=True)
            print(
                f'Communication Round #{comm_round} Clients={selected_clients}',
                flush=True)
            print('-------------------------------------', flush=True)
            for c in [self.clients[i] for i in selected_clients]:
                c.client_update(self.global_models, self.global_init_model,
                                comm_round)

            new_model = fed_avg([c.model for c in self.clients],
                                self.args.dataset, self.args.arch,
                                self.client_data_num)
            # fed_avg clobbers the mask, so we need to copy it back into the global model
            global_buffers = dict(self.global_models.named_buffers())
            for name, buffer in new_model.named_buffers():
                buffer.data.copy_(global_buffers[name])
            self.global_models = new_model

            # server accuracies are not useful for Genesis
            self.accuracies[comm_round] = 0
            # gather client accuracies
            for k, m in enumerate(self.clients):
                if k in selected_clients:
                    self.client_accuracies[k][comm_round] = m.evaluate()
                elif comm_round > 0:
                    self.client_accuracies[k][
                        comm_round] = self.client_accuracies[k][comm_round - 1]

            print(
                f"End of round accuracy: all={self.client_accuracies[:, comm_round].mean()}, "
                f"participating={self.client_accuracies[selected_clients, comm_round].mean()}"
            )

            # prune global model if appropriate
            num_pruned, num_params = get_prune_summary(self.global_models)
            cur_prune_rate = num_pruned / num_params
            if self.client_accuracies[:, comm_round].mean() > self.args.acc_thresh \
                    and cur_prune_rate < self.args.prune_percent:
                prune_fixed_amount(self.global_models,
                                   self.args.prune_step,
                                   verbose=self.args.prune_verbosity)
                self.global_models = copy_model(
                    self.global_init_model, self.args.dataset, self.args.arch,
                    dict(self.global_models.named_buffers()))
Beispiel #4
0
    def aggr(self, models, clients, *args, **kwargs):
        print("----------Averaging Models--------")
        weights_per_client = np.array([client.num_data for client in clients],
                                      dtype=np.float32)
        weights_per_client /= np.sum(weights_per_client)
        aggr_model = fed_avg(models=models,
                             weights=weights_per_client,
                             device=self.args.device)
        pruned_summary, _, _ = get_prune_summary(aggr_model, name='weight')
        print(tabulate(pruned_summary, headers='keys', tablefmt='github'))

        prune_params = get_prune_params(aggr_model)
        for param, name in prune_params:
            zeroed_weights = torch.eq(getattr(param, name).data,
                                      0.00).sum().float()
            prune.l1_unstructured(param, name, int(zeroed_weights))

        return aggr_model
Beispiel #5
0
def client_update_method1(client_self, global_model, global_init_model):
    print(f'***** Client #{client_self.client_id} *****', flush=True)
    # Checking if the client object has been properly initialized
    assert isinstance(client_self.model,
                      nn.Module), "A model must be a PyTorch module"
    assert 0 <= client_self.args.prune_percent <= 1, "The prune percentage must be between 0 and 1"
    assert client_self.args.client_epoch, '"args" must contain a "client_epoch" field'
    assert client_self.test_loader, "test_loader field does not exist. Check if the client is initialized correctly"
    assert client_self.train_loader, "train_loader field does not exist. Check if the client is initialized correctly"
    assert isinstance(
        client_self.train_loader,
        torch.utils.data.DataLoader), "train_loader must be a DataLoader type"
    assert isinstance(
        client_self.test_loader,
        torch.utils.data.DataLoader), "test_loader must be a DataLoader type"

    client_self.model = copy_model(global_model, client_self.args.dataset,
                                   client_self.args.arch)

    num_pruned, num_params = get_prune_summary(client_self.model)
    cur_prune_rate = num_pruned / num_params
    prune_step = math.floor(num_params * client_self.args.prune_step)

    for i in range(client_self.args.client_epoch):
        print(f'Epoch {i + 1}')
        train(client_self.model,
              client_self.train_loader,
              lr=client_self.args.lr,
              verbose=client_self.args.train_verbosity)

    score = evaluate(client_self.model,
                     client_self.test_loader,
                     verbose=client_self.args.test_verbosity)

    if score['Accuracy'][
            0] > client_self.args.acc_thresh and cur_prune_rate < client_self.args.prune_percent:
        prune_fixed_amount(client_self.model,
                           prune_step,
                           verbose=client_self.args.prune_verbosity)

    return copy_model(client_self.model, client_self.args.dataset,
                      client_self.args.arch)
Beispiel #6
0
    def update(self, *args, **kwargs):
        """
            Interface to server and clients
        """
        self.elapsed_comm_rounds += 1
        self.prev_model = copy_model(self.model, self.args.device)
        print('-----------------------------', flush=True)
        print(f'| Communication Round: {self.elapsed_comm_rounds}  | ',
              flush=True)
        print('-----------------------------', flush=True)
        _, num_pruned, num_total = get_prune_summary(self.model)

        prune_percent = num_pruned / num_total
        # global_model pruned at fixed freq
        # with a fixed pruning step
        if (self.args.server_prune == True and
            (self.elapsed_comm_rounds % self.args.server_prune_freq) == 0) and \
                (prune_percent < self.args.server_prune_threshold):

            # prune the model using super_mask
            self.prune()
            # reinitialize model with std.dev of init_model
            self.reinit()

        client_idxs = np.random.choice(
            self.num_clients,
            int(self.args.frac_clients_per_round * self.num_clients),
            replace=False,
        )
        clients = [self.clients[i] for i in client_idxs]

        # upload model to selected clients
        self.upload(clients)

        # call training loop on all clients
        for client in clients:
            client.update()

        # download models from selected clients
        models, accs = self.download(clients)

        avg_accuracy = np.mean(accs, axis=0, dtype=np.float32)
        print('-----------------------------', flush=True)
        print(f'| Average Accuracy: {avg_accuracy}  | ', flush=True)
        print('-----------------------------', flush=True)

        # compute average-model and (prune it by 0.00 )
        aggr_model = self.aggr(models, clients)

        # copy aggregated-model's params to self.model (keep buffer same)
        self.model = aggr_model

        _, num_pruned, num_total = get_prune_summary(self.model)
        prune_percent = num_pruned / num_total

        wandb.log({
            "client_avg_acc": avg_accuracy,
            "comm_round": self.elapsed_comm_rounds,
            "global_prune_percent": prune_percent
        })

        print('Saving global model')
        torch.save(
            self.model.state_dict(),
            f"./checkpoints/server_model_{self.elapsed_comm_rounds}.pt")
Beispiel #7
0
    def update(self) -> None:
        """
            Interface to Server
        """
        print(f"\n----------Client:{self.idx} Update---------------------")
        print(f'----------User Class ids: {self.class_idxs}------------')
        print(f"Evaluating Global model ")
        metrics = self.eval(self.global_model)
        accuracy = metrics['Accuracy'][0]
        print(f'Global model accuracy: {accuracy}')

        prune_summmary, num_zeros, num_global = get_prune_summary(model=self.global_model,
                                                                  name='weight')
        prune_rate = num_zeros / num_global
        print('Global model prune percentage: {}'.format(prune_rate))

        if self.cur_prune_rate < self.args.prune_threshold:
            if accuracy > self.eita:
                self.cur_prune_rate = min(self.cur_prune_rate + self.args.prune_step,
                                          self.args.prune_threshold)
                if self.cur_prune_rate > prune_rate:
                    l1_prune(model=self.global_model,
                             amount=self.cur_prune_rate,
                             name='weight',
                             verbose=self.args.prune_verbose)
                    # reinitialize model with init_params
                    source_params = dict(
                        self.global_init_model.named_parameters())
                    for name, param in self.global_model.named_parameters():
                        param.data.copy_(source_params[name].data)
                    self.prune_rates.append(self.cur_prune_rate)
                else:
                    # reprune by the downloaded global-model(important)
                    # REVIEW: Rather than pruning each layer by orig_global_pruned_%,
                    # pruned each layer by its' orig_pruned_%
                    params_to_prune = get_prune_params(self.global_model)
                    for param, name in params_to_prune:
                        amount = torch.eq(getattr(param, name),
                                          0.00).sum().float()
                        prune.l1_unstructured(param, name, amount=int(amount))
                    self.prune_rates.append(prune_rate)

                self.model = self.global_model
                self.eita = self.eita_hat

            else:
                # reprune by the downloaded global-model(important)
                # REVIEW: Rather than pruning each layer by orig_global_pruned_%,
                # pruned each layer by its' orig_pruned_%
                params_to_prune = get_prune_params(self.global_model)
                for param, name in params_to_prune:
                    amount = torch.eq(getattr(param, name), 0.00).sum().float()
                    prune.l1_unstructured(param, name, amount=int(amount))
                self.eita *= self.alpha
                self.model = self.global_model
                self.prune_rates.append(prune_rate)
        else:
            if self.cur_prune_rate > prune_rate:
                l1_prune(model=self.global_model,
                         amount=self.cur_prune_rate,
                         name='weight',
                         verbose=self.args.prune_verbose)
                source_params = dict(self.global_init_model.named_parameters())
                for name, param in self.global_model.named_parameters():
                    param.data.copy_(source_params[name].data)
                self.prune_rates.append(self.cur_prune_rate)
            else:
                # reprune by the downloaded global-model(not important)
                params_to_prune = get_prune_params(self.global_model)
                for param, name in params_to_prune:
                    amount = torch.eq(getattr(param, name), 0.00).sum().float()
                    prune.l1_unstructured(param, name, amount=int(amount))
                self.prune_rates.append(prune_rate)

            self.model = self.global_model

        print(f"\nTraining local model")
        self.train(self.elapsed_comm_rounds)

        print(f"\nEvaluating Trained Model")
        metrics = self.eval(self.model)
        print(f'Trained model accuracy: {metrics["Accuracy"][0]}')

        wandb.log({f"{self.idx}_cur_prune_rate": self.cur_prune_rate})
        wandb.log({f"{self.idx}_eita": self.eita})
        wandb.log(
            {f"{self.idx}_percent_pruned": self.prune_rates[-1]})

        for key, thing in metrics.items():
            if(isinstance(thing, list)):
                wandb.log({f"{self.idx}_{key}": thing[0]})
            else:
                wandb.log({f"{self.idx}_{key}": thing})

        self.save(self.model)
        self.elapsed_comm_rounds += 1