def __init_opt(self):
     if self.config.SERVER_OPT is not None:
         self.server_opt = TorchOptRepo.name2cls(self.config.SERVER_OPT)(
             self.model.parameters(),
             lr=self.config.SERVER_LEARNING_RATE,
             **self.config.SERVER_OPT_ARGS,
         )
     else:
         self.server_opt = None
 def switch_to_sgd(self, lr):
     # TODO __del__
     self.__delete_objects_tmp_files()
     # opt_cls = sgd
     self.opt_cls = TorchOptRepo.name2cls("SGD")
     # delete self.opt
     self.opt = None
     # delete __opt_state_to_be_loaded
     self.__opt_state_to_be_loaded = None
     # lr
     self.opt_cls_param = {
         "lr": lr,
         "weight_decay": self.opt_cls_param["weight_decay"],
     }
    def __train_one_round(self, curr_round: int):
        self.model.train()

        client_sample = self.__select_clients()

        comm_avg_model_state = None
        comm_avg_opt_state = None
        if self.config.SCAFFOLD:
            comm_c = None
            logging.info("SCAFFOLD: comm_c initialized")

        with ElapsedTime("Training clients"):
            for i, client in enumerate(client_sample):
                client.set_model(self.model.state_dict())
                if (self.config.CLIENT_OPT_STRATEGY == "avg") and (
                    self.avg_opt_state is not None
                ):
                    client.set_opt_state(self.avg_opt_state)

                if self.config.SCAFFOLD:
                    client.set_server_c(self.c)

                if self.config.SCAFFOLD:
                    model_state, opt_state, c = client.train_round(
                        self.config.N_EPOCH_PER_CLIENT, curr_round
                    )
                    comm_c = commulative_avg_params(comm_c, c, i)
                    logging.info("SCAFFOLD: c received from client")
                else:
                    model_state, opt_state = client.train_round(
                        self.config.N_EPOCH_PER_CLIENT, curr_round
                    )

                comm_avg_model_state = commulative_avg_model_state_dicts(
                    comm_avg_model_state, model_state, i
                )

                if self.config.CLIENT_OPT_STRATEGY == "avg":
                    if comm_avg_opt_state is not None:
                        comm_avg_opt_state = [
                            commulative_avg_model_state_dicts(opt_s[0], opt_s[1], i)
                            for opt_s in zip(comm_avg_opt_state, opt_state)
                        ]
                    else:
                        comm_avg_opt_state = opt_state

        with ElapsedTime("Setting gradients"):
            if self.server_opt is not None:
                self.server_opt.zero_grad()
                server_opt_state = self.server_opt.state_dict()
                self.__set_model_grads(comm_avg_model_state)
                self.server_opt = TorchOptRepo.name2cls(self.config.SERVER_OPT)(
                    self.model.parameters(),
                    lr=self.config.SERVER_LEARNING_RATE,
                    **self.config.SERVER_OPT_ARGS,
                )
                self.server_opt.load_state_dict(server_opt_state)
                self.server_opt.step()
            else:
                self.__log("setting avg model state")
                self.model.load_state_dict(comm_avg_model_state)
            if self.config.SCAFFOLD:
                self.c = lambda2_params(
                    self.c,
                    comm_c,
                    lambda a, b: a + (b * len(client_sample) / self.config.N_CLIENTS),
                )
                logging.info("SCAFFOLD: server c updated")

        self.avg_opt_state = comm_avg_opt_state
        comm_avg_opt_state = None
    def __init__(
        self,
        experiment: Experiment,
        config: TorchFederatedLearnerConfig,
        config_technical: TorchFederatedLearnerTechnicalConfig,
    ) -> None:
        """Initialises the training.

        Arguments:
            experiment {Experiment} -- Comet.ml experiment object for online logging.
            config {TorchFederatedLearnerConfig} -- Training configuration description.
        """
        super().__init__()
        if config.SEED is not None:
            random.seed(config.SEED)
            np.random.seed(config.SEED)
            th.manual_seed(config.SEED)
            th.backends.cudnn.deterministic = True
            th.backends.cudnn.benchmark = False

        self.device = th.device("cuda" if th.cuda.is_available() else "cpu")
        self.experiment = experiment
        self.config = config
        self.config.set_defaults()
        self.experiment.log_parameters(self.config.flatten())
        self.config_technical = config_technical
        self.config_technical.check()
        self.experiment.log_parameters(self.config_technical.flatten())
        self.PATH = self.tmp_dir / f"{self.experiment.id}_checkpoints"
        os.makedirs(self.PATH, exist_ok=True)

        model_cls = self.get_model_cls()
        self.model = model_cls().to(self.device)
        self.__init_opt()
        self.avg_opt_state = None
        if self.config.SCAFFOLD:
            self.c = lambda_params(self.model.parameters(), th.zeros_like)
            logging.info("SCAFFOLD: server c initialized")

        self.train_loader_list, self.test_loader, self.random_acc = self.load_data()
        self.n_train_batches = len(self.train_loader_list[0])
        logging.info(f"Number of training batches: {self.n_train_batches}")

        TorchClient.reset_ID_counter()
        self.clients = [
            TorchClient(
                self,
                model_cls=model_cls,
                is_keep_model_on_gpu=not self.config_technical.STORE_MODEL_IN_RAM,
                is_store_opt_on_disk=self.config_technical.STORE_OPT_ON_DISK,
                loss=self.get_loss(),
                dataloader=loader,
                device=self.device,
                opt_cls=TorchOptRepo.name2cls(self.config.CLIENT_OPT),
                opt_cls_param={
                    "lr": self.config.CLIENT_LEARNING_RATE,
                    "weight_decay": self.config.CLIENT_OPT_L2,
                },
                is_maintaine_opt_state=config.CLIENT_OPT_STRATEGY == "nothing",
                exp_id=experiment.id,
                is_scaffold=self.config.SCAFFOLD,
            )
            for loader in self.train_loader_list
        ]