def setUp(self) -> None: self.weight_summarizer = fed_learn.FedAvg() nb_clients = 3 nb_weight_arrays = 6 self.all_clients_weights = [] for i in range(nb_clients): client_weight_arrays = [] for k in range(nb_weight_arrays): rnd_weight_array = np.ones((8, 12)) rnd_weight_array += i client_weight_arrays.append(rnd_weight_array) self.all_clients_weights.append(client_weight_arrays) self.avg_weights = self.weight_summarizer.process(self.all_clients_weights)
experiment_folder_path = Path(__file__).resolve().parent / "experiments" / args.name experiment = fed_learn.Experiment(experiment_folder_path, args.overwrite_experiment) experiment.serialize_args(args) tf_scalar_logger = experiment.create_scalar_logger() client_train_params = {"epochs": args.client_epochs, "batch_size": args.batch_size} def model_fn(): return fed_learn.create_model_cnn((28, 28,1), 10, init_with_imagenet=False, learning_rate=args.learning_rate) weight_summarizer = fed_learn.FedAvg() server = fed_learn.Server(model_fn, weight_summarizer, args.clients, args.fraction) weight_path = args.weights_file if weight_path is not None: server.load_model_weights(weight_path) server.update_client_train_params(client_train_params) server.create_clients_plus() (x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data() x_train = x_train.astype("float32") / 255 x_test = x_test.astype("float32") / 255