def run_experiment(xp, xp_count, n_experiments): print(xp) hp = xp.hyperparameters model_fn, optimizer, optimizer_hp = models.get_model(hp["net"]) optimizer_fn = lambda x: optimizer( x, **{k: hp[k] if k in hp else v for k, v in optimizer_hp.items()}) train_data, test_data = data.get_data(hp["dataset"], args.DATA_PATH) distill_data = data.get_data(hp["distill_dataset"], args.DATA_PATH) distill_data = torch.utils.data.Subset( distill_data, np.random.permutation(len(distill_data))[:hp["n_distill"]]) client_loaders, test_loader = data.get_loaders( train_data, test_data, n_clients=hp["n_clients"], classes_per_client=hp["classes_per_client"], batch_size=hp["batch_size"], n_data=None) distill_loader = torch.utils.data.DataLoader(distill_data, batch_size=128, shuffle=False) clients = [ Client(model_fn, optimizer_fn, loader) for loader in client_loaders ] server = Server(model_fn, lambda x: torch.optim.Adam(x, lr=0.001), test_loader, distill_loader) server.load_model(path=args.CHECKPOINT_PATH, name=hp["pretrained"]) # print model models.print_model(server.model) # Start Distributed Training Process print("Start Distributed Training..\n") t1 = time.time() for c_round in range(1, hp["communication_rounds"] + 1): participating_clients = server.select_clients(clients, hp["participation_rate"]) for client in tqdm(participating_clients): client.synchronize_with_server(server) train_stats = client.compute_weight_update(hp["local_epochs"]) if hp["aggregate"]: server.aggregate_weight_updates(participating_clients) if hp["use_distillation"]: server.distill(participating_clients, hp["distill_epochs"], compress=hp["compress"]) # Logging if xp.is_log_round(c_round): print("Experiment: {} ({}/{})".format(args.schedule, xp_count + 1, n_experiments)) xp.log({ 'communication_round': c_round, 'epochs': c_round * hp['local_epochs'] }) xp.log({ key: clients[0].optimizer.__dict__['param_groups'][0][key] for key in optimizer_hp }) # Evaluate xp.log({ "client_train_{}".format(key): value for key, value in train_stats.items() }) xp.log({ "server_val_{}".format(key): value for key, value in server.evaluate().items() }) # Save results to Disk try: xp.save_to_disc(path=args.RESULTS_PATH, name=hp['log_path']) except: print("Saving results Failed!") # Timing e = int((time.time() - t1) / c_round * (hp['communication_rounds'] - c_round)) print( "Remaining Time (approx.):", '{:02d}:{:02d}:{:02d}'.format(e // 3600, (e % 3600 // 60), e % 60), "[{:.2f}%]\n".format(c_round / hp['communication_rounds'] * 100)) # Save model to disk server.save_model(path=args.CHECKPOINT_PATH, name=hp["save_model"]) # Delete objects to free up GPU memory del server clients.clear() torch.cuda.empty_cache()
def run_experiment(xp, xp_count, n_experiments): print(xp) hp = xp.hyperparameters model_fn, optimizer_fn = models.get_model(hp["net"]) client_data, server_data = data.get_data(hp["dataset"], n_clients=hp["n_clients"], alpha=hp["dirichlet_alpha"], path=DATA_PATH) for i, d in enumerate(client_data): d.subset_transform = data.get_x_transform(hp["x_transform"], i) d.label_transform = data.get_y_transform(hp["y_transform"], i) clients = [ Client(model_fn, optimizer_fn, subset, hp["batch_size"], layers=hp["layers"], idnum=i) for i, subset in enumerate(client_data) ] server = Server(model_fn, server_data, layers=hp["layers"]) server.load_model(path=CHECKPOINT_PATH, name=hp["pretrained"]) # print model models.print_model(server.model) xp.log({"shared_layers": list(server.W.keys())}) # Start Distributed Training Process print("Start Distributed Training..\n") t1 = time.time() for c_round in range(1, hp["communication_rounds"] + 1): participating_clients = server.select_clients(clients, hp["participation_rate"]) accs = [] for i, client in enumerate(clients): client.synchronize_with_server(server) accs += [client.evaluate()["accuracy"]] xp.log({"client_accuracies": accs}, printout=False) xp.log({"mean_accuracy": np.mean(accs)}) for client in tqdm(participating_clients): train_stats = client.compute_weight_update(hp["local_epochs"]) for i, client in enumerate(clients): accs += [client.evaluate()["accuracy"]] xp.log({"post_client_accuracies": accs}, printout=False) xp.log({"post_mean_accuracy": np.mean(accs)}) server.aggregate_weight_updates(clients) for client in participating_clients: xp.log(client.compute_server_angle(server), printout=False) # Logging if xp.is_log_round(c_round): print("Experiment: {} ({}/{})".format(args.schedule, xp_count + 1, n_experiments)) xp.log({ 'communication_round': c_round, 'epochs': c_round * hp['local_epochs'] }) # Evaluate #xp.log({"client_train_{}".format(key) : value for key, value in train_stats.items()}) #for i, client in enumerate(clients): # xp.log({"client_{}_val_{}".format(i, key) : value for key, value in client.evaluate().items()}) # Save results to Disk try: xp.save_to_disc(path=RESULTS_PATH, name=hp['log_path']) except: print("Saving results Failed!") # Timing e = int((time.time() - t1) / c_round * (hp['communication_rounds'] - c_round)) print( "Remaining Time (approx.):", '{:02d}:{:02d}:{:02d}'.format(e // 3600, (e % 3600 // 60), e % 60), "[{:.2f}%]\n".format(c_round / hp['communication_rounds'] * 100)) xp.log(server.compute_pairwise_angles_layerwise(clients), printout=False) xp.save_to_disc(path=RESULTS_PATH, name=hp['log_path']) # Save model to disk server.save_model(path=CHECKPOINT_PATH, name=hp["save_model"]) # Delete objects to free up GPU memory del server clients.clear() torch.cuda.empty_cache()
def run_experiment(xp, xp_count, n_experiments): print(xp) hp = xp.hyperparameters model_fn, optimizer, optimizer_hp = models.get_model(hp["net"]) optimizer_fn = lambda x: optimizer( x, **{k: hp[k] if k in hp else v for k, v in optimizer_hp.items()}) train_data, test_data = data.get_data(hp["dataset"], args.DATA_PATH) all_distill_data = data.get_data(hp["distill_dataset"], args.DATA_PATH) all_distill_data_indexed = data.IdxSubset(all_distill_data, np.arange(100000)) print(len(all_distill_data_indexed)) np.random.seed(hp["random_seed"]) # What fraction of the unlabeled data should be used for training the anomaly detector distill_data = data.IdxSubset( all_distill_data, np.random.permutation(len( all_distill_data))[:hp["n_distill"]]) # data used for distillation client_data, label_counts = data.get_client_data( train_data, n_clients=hp["n_clients"], classes_per_client=hp["classes_per_client"]) client_loaders = [ data.DataMerger({'base': local_data}, mixture_coefficients={'base': 1}, **hp) for local_data in client_data ] test_loader = data.DataMerger( {'base': data.IdxSubset(test_data, list(range(len(test_data))))}, mixture_coefficients={'base': 1}, batch_size=256) distill_loader = DataLoader(distill_data, batch_size=hp["batch_size"], shuffle=True, num_workers=8) distill_dummy_loader = DataLoader(distill_data, batch_size=2048, shuffle=False, num_workers=8) all_distill_loader = DataLoader(all_distill_data_indexed, batch_size=hp["batch_size"], shuffle=False) clients = [ Client(model_fn, optimizer_fn, loader, idnum=i, counts=counts, distill_loader=distill_loader) for i, (loader, counts) in enumerate(zip(client_loaders, label_counts)) ] server = Server(model_fn, lambda x: torch.optim.Adam(x, lr=1e-3), test_loader, distill_loader) models.print_model(server.model) # Start Distributed Training Process print("Start Distributed Training..\n") t1 = time.time() xp.log({ "server_val_{}".format(key): value for key, value in server.evaluate().items() }) for c_round in range(1, hp["communication_rounds"] + 1): participating_clients = server.select_clients(clients, hp["participation_rate"]) xp.log({ "participating_clients": np.array([client.id for client in participating_clients]) }) for client in participating_clients: client.synchronize_with_server(server, c_round) train_stats = client.compute_weight_update( hp["local_epochs"], train_oulier_model=hp["aggregation_mode"] in ["FAD+S", "FAD+P+S"], c_round=c_round, max_c_round=hp["communication_rounds"], **hp) print(train_stats) if hp["save_softlabels"] and hp["aggregation_mode"] in [ "FDcup", "FDsample", "FDcupdown", "FDer", "FDquant", "FDquantdown" ]: predictions = client.compute_prediction_matrix( distill_dummy_loader, argmax=True) xp.log( {"client_{}_predictions".format(client.id): predictions}) if hp["aggregation_mode"] in ["FA"]: server.aggregate_weight_updates(participating_clients) if hp["aggregation_mode"] in [ "FD", "FDcup", "FDsample", "FDcupdown", "FDer", "FDquant", "FDquantdown" ]: distill_mode = { "FD": "mean_probs", "FDcup": "pate_up", "FDsample": "sample", "FDcupdown": "pate", "FDer": "mean_logits_er", "FDquant": ["quantized", hp["quantization_bits"]], "FDquantdown": ["quantized", hp["quantization_bits"]] }[hp["aggregation_mode"]] reset_model = True if hp["init_mode"] == "random" else False distill_stats = server.distill(participating_clients, hp["distill_iter"], mode=distill_mode, reset_model=reset_model) if hp["active"]: if hp["active"] == "random": idcs = np.random.permutation( len(all_distill_data))[:hp["n_distill"]] if hp["active"] == "entropy": mat = server.compute_prediction_matrix(all_distill_loader, argmax=False) idcs = np.argsort( -np.sum(-mat * np.log(mat), axis=1))[:hp["n_distill"]] if hp["active"] == "certainty": mat = server.compute_prediction_matrix(all_distill_loader, argmax=False) idcs = np.argsort(np.max(mat, axis=1))[:hp["n_distill"]] if hp["active"] == "margin": mat = server.compute_prediction_matrix(all_distill_loader, argmax=False) idcs = np.argsort( np.diff(np.sort(mat, axis=1)[:, -2:], axis=1).flatten())[:hp["n_distill"]] distill_data = data.IdxSubset(all_distill_data, idcs) distill_loader = DataLoader(distill_data, batch_size=128, shuffle=True) server.distill_loader = distill_loader if hp["init_mode"] == "co_distill": if hp["save_softlabels"] and hp["aggregation_mode"] in [ "FDcup", "FDsample", "FDcupdown", "FDer", "FDquant", "FDquantdown" ]: predictions = server.compute_prediction_matrix( distill_dummy_loader, argmax=True) xp.log({"server_predictions": predictions}) server.co_distill( hp["co_distill_iter"], quantization_bits=hp["quantization_bits_down"] if hp["aggregation_mode"] == "FDquantdown" else None) # Logging if xp.is_log_round(c_round): print("Experiment: {} ({}/{})".format(args.schedule, xp_count + 1, n_experiments)) xp.log({ 'communication_round': c_round, 'epochs': c_round * hp['local_epochs'] }) xp.log({ key: clients[0].optimizer.__dict__['param_groups'][0][key] for key in optimizer_hp }) # Evaluate xp.log({ "server_val_{}".format(key): value for key, value in server.evaluate().items() }) # Save results to Disk try: xp.save_to_disc(path=args.RESULTS_PATH, name=hp['log_path']) except: print("Saving results Failed!") # Timing e = int((time.time() - t1) / c_round * (hp['communication_rounds'] - c_round)) print( "Remaining Time (approx.):", '{:02d}:{:02d}:{:02d}'.format(e // 3600, (e % 3600 // 60), e % 60), "[{:.2f}%]\n".format(c_round / hp['communication_rounds'] * 100)) # Save model to disk server.save_model(path=args.CHECKPOINT_PATH, name=hp["save_model"]) # Delete objects to free up GPU memory del server clients.clear() torch.cuda.empty_cache()
def run_experiment(xp, xp_count, n_experiments): t0 = time.time() print(xp) hp = xp.hyperparameters num_classes = {"cifar10" : 10, "cifar100" : 100}[hp["dataset"]] model_names = [model_name for model_name, k in hp["models"].items() for _ in range(k)] optimizer, optimizer_hp = getattr(torch.optim, hp["local_optimizer"][0]), hp["local_optimizer"][1] optimizer_fn = lambda x : optimizer(x, **{k : hp[k] if k in hp else v for k, v in optimizer_hp.items()}) distill_optimizer, distill_optimizer_hp = getattr(torch.optim, hp["distill_optimizer"][0]), hp["distill_optimizer"][1] distill_optimizer_fn = lambda x : distill_optimizer(x, **{k : hp[k] if k in hp else v for k, v in distill_optimizer_hp.items()}) train_data, test_data = data.get_data(hp["dataset"], args.DATA_PATH) all_distill_data = data.get_data(hp["distill_dataset"], args.DATA_PATH) np.random.seed(hp["random_seed"]) torch.manual_seed(hp["random_seed"]) n_distill = int(hp["n_distill_frac"] * len(all_distill_data)) distill_data = data.IdxSubset(all_distill_data, np.random.permutation(len(all_distill_data))[:n_distill], return_index=True) public_data = data.IdxSubset(all_distill_data, np.random.permutation(len(all_distill_data))[n_distill:len(all_distill_data)], return_index=False) print(len(distill_data), len(public_data)) client_loaders, test_loader = data.get_loaders(train_data, test_data, n_clients=len(model_names), alpha=hp["alpha"], batch_size=hp["batch_size"], n_data=None, num_workers=0, seed=hp["random_seed"]) distill_loader = torch.utils.data.DataLoader(distill_data, batch_size=hp["distill_batch_size"], shuffle=True, num_workers=8) public_loader = torch.utils.data.DataLoader(public_data, batch_size=128, shuffle=True, num_workers=8) clients = [Client(model_name, optimizer_fn, loader, idnum=i, num_classes=num_classes) for i, (loader, model_name) in enumerate(zip(client_loaders, model_names))] server = Server(np.unique(model_names), distill_optimizer_fn, test_loader, distill_loader, num_classes=num_classes) for client in clients: client.public_loader = public_loader client.distill_loader = distill_loader # print model models.print_model(clients[0].model) if "P" in hp["aggregation_mode"] or hp["aggregation_mode"] == "FedAUX": for model_name, model in server.model_dict.items(): pretrained = hp["pretrained"] if hp["pretrained"] else "{}_{}.pth".format(model_name, hp["distill_dataset"]) loaded_state = torch.load(args.CHECKPOINT_PATH + pretrained, map_location='cpu') loaded_layers = [key for key in loaded_state if key in model.state_dict()] model.load_state_dict(loaded_state, strict=False) for client in clients: client.synchronize_with_server(server) print("Successfully loaded layers {} from".format(loaded_layers), pretrained) if hp["aggregation_mode"] == "FedAUX": print("Computing Scores...") for client in clients: client.scores = client.extract_features_and_compute_scores(client.loader, public_loader, distill_loader, lambda_reg=hp["lambda_reg_score"], eps_delt=hp["eps_delt"]) if hp["save_scores"]: xp.log({"client_{}_scores".format(client.id) : client.scores.detach().cpu().numpy()}) if "L" in hp["aggregation_mode"]: for client in clients: for name, param in client.model.named_parameters(): if "classification_layer" not in name: param.requires_grad = False for model in server.model_dict.values(): for name, param in model.named_parameters(): if "classification_layer" not in name: param.requires_grad = False # Start Distributed Training Process print("Start Distributed Training..\n") t1 = time.time() xp.log({"prep_time" : t1-t0}) xp.log({"server_val_{}".format(key) : value for key, value in server.evaluate_ensemble().items()}) for c_round in range(1, hp["communication_rounds"]+1): participating_clients = server.select_clients(clients, hp["participation_rate"]) xp.log({"participating_clients" : np.array([c.id for c in participating_clients])}) for client in participating_clients: client.synchronize_with_server(server) train_stats = client.compute_weight_update(hp["local_epochs"], lambda_fedprox=hp["lambda_fedprox"] if "PROX" in hp["aggregation_mode"] else 0.0) print(train_stats) # Averaging server.aggregate_weight_updates(participating_clients) avg_stats = server.evaluate_ensemble() xp.log({"averaging_{}".format(key) : value for key, value in avg_stats.items()}) if hp["aggregation_mode"] in ["FedDF", "FedAUX", "FedDF+P"]: distill_mode = "weighted_logits_precomputed" if hp["aggregation_mode"]=="FedAUX" else "mean_logits" distill_stats = server.distill(participating_clients, hp["distill_epochs"], mode=distill_mode, num_classes=num_classes) xp.log({"distill_{}".format(key) : value for key, value in distill_stats.items()}) # Logging if xp.is_log_round(c_round): print("Experiment: {} ({}/{})".format(args.schedule, xp_count+1, n_experiments)) xp.log({'communication_round' : c_round, 'epochs' : c_round*hp['local_epochs']}) xp.log({key : clients[0].optimizer.__dict__['param_groups'][0][key] for key in optimizer_hp}) # Evaluate xp.log({"server_val_{}".format(key) : value for key, value in server.evaluate_ensemble().items()}) xp.log({"epoch_time" : (time.time()-t1)/c_round}) # Save results to Disk try: xp.save_to_disc(path=args.RESULTS_PATH, name=hp['log_path']) except: print("Saving results Failed!") # Timing e = int((time.time()-t1)/c_round*(hp['communication_rounds']-c_round)) print("Remaining Time (approx.):", '{:02d}:{:02d}:{:02d}'.format(e // 3600, (e % 3600 // 60), e % 60), "[{:.2f}%]\n".format(c_round/hp['communication_rounds']*100)) # Save model to disk server.save_model(path=args.CHECKPOINT_PATH, name=hp["save_model"]) # Delete objects to free up GPU memory del server; clients.clear() torch.cuda.empty_cache()
def run_experiment(xp, xp_count, n_experiments): print(xp) hp = xp.hyperparameters model_fn, optimizer_fn = models.get_model(hp["net"]) compression_fn = comp.get_compression(hp["compression"]) client_data, server_data = data.get_data(hp["dataset"], n_clients=hp["n_clients"], alpha=hp["dirichlet_alpha"]) clients = [ Client(model_fn, optimizer_fn, subset, hp["batch_size"], idnum=i) for i, subset in enumerate(client_data) ] server = Server(model_fn, server_data) server.load_model(path="checkpoints/", name=hp["pretrained"]) # print model models.print_model(server.model) # Start Distributed Training Process print("Start Distributed Training..\n") t1 = time.time() for c_round in range(1, hp["communication_rounds"] + 1): participating_clients = server.select_clients(clients, hp["participation_rate"]) for client in tqdm(participating_clients): client.synchronize_with_server(server) train_stats = client.compute_weight_update(hp["local_epochs"]) client.compress_weight_update(compression_fn, accumulate=hp["accumulate"]) server.aggregate_weight_updates(clients) # Logging if xp.is_log_round(c_round): print("Experiment: {} ({}/{})".format(args.schedule, xp_count + 1, n_experiments)) xp.log({ 'communication_round': c_round, 'epochs': c_round * hp['local_epochs'] }) # Evaluate xp.log({ "client_train_{}".format(key): value for key, value in train_stats.items() }) xp.log({ "server_val_{}".format(key): value for key, value in server.evaluate().items() }) # Save results to Disk try: xp.save_to_disc(path="results/", name=hp['log_path']) except: print("Saving results Failed!") # Timing e = int((time.time() - t1) / c_round * (hp['communication_rounds'] - c_round)) print( "Remaining Time (approx.):", '{:02d}:{:02d}:{:02d}'.format(e // 3600, (e % 3600 // 60), e % 60), "[{:.2f}%]\n".format(c_round / hp['communication_rounds'] * 100)) # Save model to disk server.save_model(path="checkpoints/", name=hp["save_model"]) # Delete objects to free up GPU memory del server clients.clear() torch.cuda.empty_cache()
def run_experiment(xp, xp_count, n_experiments, args, seed=0): logger.debug(xp) hp = xp.hyperparameters model_fn, optimizer, optimizer_hp = models.get_model(hp["net"]) if "local_optimizer" in hp: optimizer = getattr(torch.optim, hp["local_optimizer"][0]) optimizer_hp = hp["local_optimizer"][1] optimizer_fn = lambda x: optimizer( x, **{k: hp[k] if k in hp else v for k, v in optimizer_hp.items()}) train_data, test_data = data.get_data(hp["dataset"], args.DATA_PATH) all_distill_data = data.get_data(hp["distill_dataset"], args.DATA_PATH) np.random.seed(seed) # What fraction of the unlabeled data should be used for training the anomaly detector distill_data = data.IdxSubset( all_distill_data, np.random.permutation(len( all_distill_data))[:hp["n_distill"]]) # data used for distillation distill_loader = torch.utils.data.DataLoader(distill_data, batch_size=128, shuffle=True) client_data, label_counts = data.get_client_data( train_data, n_clients=hp["n_clients"], classes_per_client=hp["classes_per_client"]) if hp["aggregation_mode"] in ["FD+S", "FAD+S", "FAD+P+S"]: public_data = data.IdxSubset( all_distill_data, np.random.permutation( len(all_distill_data))[hp["n_distill"]:len(all_distill_data)] ) # data used to train the outlier detector public_loader = torch.utils.data.DataLoader(public_data, batch_size=128, shuffle=True) logger.debug( "Using {} public data points for distillation and {} public data points for local training.\n" .format(len(distill_data), len(public_data))) client_loaders = [ data.DataMerger({ 'base': local_data, 'public': public_data }, **hp) for local_data in client_data ] else: client_loaders = [ data.DataMerger({'base': local_data}, **hp) for local_data in client_data ] test_loader = data.DataMerger( {'base': data.IdxSubset(test_data, list(range(len(test_data))))}, mixture_coefficients={'base': 1}, batch_size=100) distill_loader = DataLoader(distill_data, batch_size=128, shuffle=True) clients = [ Client(model_fn, optimizer_fn, loader, idnum=i, counts=counts, distill_loader=distill_loader) for i, (loader, counts) in enumerate(zip(client_loaders, label_counts)) ] server = Server( model_fn, lambda x: getattr(torch.optim, hp["distill_optimizer"][0]) (x, **hp["distill_optimizer"][1]), test_loader, distill_loader) # Modes that use pretrained representation if hp["aggregation_mode"] in [ "FAD+P", "FAD+P+S" ] and os.path.isfile(args.CHECKPOINT_PATH + hp["pretrained"]): for device in clients + [server]: device.model.load_state_dict(torch.load(args.CHECKPOINT_PATH + hp["pretrained"], map_location='cpu'), strict=False) logger.debug(f"Successfully loaded model from {hp['pretrained']}") """ # Train shallow Outlier detectors if hp["aggregation_mode"] in ["FAD+S", "FAD+P+S"]: feature_extractor = model_fn().cuda() feature_extractor.load_state_dict(torch.load(args.CHECKPOINT_PATH+hp["pretrained"], map_location='cpu'), strict=False) feature_extractor.eval() for client in clients: client.feature_extractor = feature_extractor print("Train Outlier Detectors") for client in tqdm(clients): client.train_outlier_detector(hp["outlier_model"][0], distill_loader, **hp["outlier_model"][1]) """ averaging_stats = {"accuracy": 0.0} eval_results = None models.print_model(server.model) # Start Distributed Training Process logger.info("Start Distributed Training..\n") t1 = time.time() xp.log({ "server_val_{}".format(key): value for key, value in server.evaluate().items() }) for c_round in range(1, hp["communication_rounds"] + 1): participating_clients = server.select_clients(clients, hp["participation_rate"]) xp.log({ "participating_clients": np.array([client.id for client in participating_clients]) }) for client in tqdm(participating_clients): client.synchronize_with_server(server, c_round) train_stats = client.compute_weight_update( hp["local_epochs"], train_oulier_model=hp["aggregation_mode"] in ["FAD+S", "FAD+P+S"], c_round=c_round, max_c_round=hp["communication_rounds"], **hp) logger.debug(train_stats) if hp["aggregation_mode"] in [ "FA", "FAD", "FAD+P", "FAD+S", "FAD+P+S" ]: server.aggregate_weight_updates(participating_clients) averaging_stats = server.evaluate() xp.log({ "parameter_averaging_{}".format(key): value for key, value in averaging_stats.items() }) if hp["aggregation_mode"] in [ "FD", "FAD", "FAD+P", "FAD+S", "FAD+P+S" ]: distill_mode = hp["distill_mode"] if hp["aggregation_mode"] in [ "FD+S", "FAD+S", "FAD+P+S" ] else "mean_logits" distill_stats = server.distill(participating_clients, hp["distill_epochs"], mode=distill_mode, acc0=averaging_stats["accuracy"], fallback=hp["fallback"]) xp.log({ "distill_{}".format(key): value for key, value in distill_stats.items() }) # Logging if xp.is_log_round(c_round): logger.debug("Experiment: {} ({}/{})".format( args.schedule, xp_count + 1, n_experiments)) xp.log({ 'communication_round': c_round, 'epochs': c_round * hp['local_epochs'] }) xp.log({ key: clients[0].optimizer.__dict__['param_groups'][0][key] for key in optimizer_hp }) # Evaluate eval_results = server.evaluate() xp.log({ "server_val_{}".format(key): value for key, value in eval_results.items() }) # Save results to Disk try: xp.save_to_disc(path=args.RESULTS_PATH, name=hp['log_path']) except: logger.error("Saving results Failed!") # Timing e = int((time.time() - t1) / c_round * (hp['communication_rounds'] - c_round)) logger.debug( f"Remaining Time (approx.): {e // 3600:02d}:{e % 3600 // 60:02d}:{e % 60:02d} [{c_round/hp['communication_rounds']*100:.2f}%]\n" ) # Save model to disk server.save_model(path=args.CHECKPOINT_PATH, name=hp["save_model"]) try: return 1 - eval_results[ 'accuracy'] if eval_results else 1 - server.evaluate()['accuracy'] finally: # Delete objects to free up GPU memory del server clients.clear() del clients del xp gc.collect()
def main(): parser = argparse.ArgumentParser(description='Simulate FL for noniid clients ') parser.add_argument('--num_client', type=int, default=10, help='number of clients') parser.add_argument('--num_group', type=int, default=2, help='number of clients groups') parser.add_argument('--rounds', type=int, default=50, help='training rounds') parser.add_argument('--local_epoch', type=int, default=1, help='local training epochs') parser.add_argument('--noniid_alpha', type=float, default=1.0, help='DIRICHLET_ALPHA') parser.add_argument('--frac', type=float, default=0.5, help='fraction of participants ') parser.add_argument('--eps1', type=float, default=0.4, help='fraction of participants ') parser.add_argument('--eps2', type=float, default=1.6, help='fraction of participants ') args = parser.parse_args() print(args) # initialize data, clients, server, logger dataset = ClientDataGenerater(args.num_client, args.num_group, args.noniid_alpha) dataset.load_emnist() print('Finish loading data') noniid_train, noniid_test= dataset.rotate_noniid() clients = [Client(ConvNet, lambda x : torch.optim.SGD(x, lr=0.1, momentum=0.9), dat, idnum=i) for i, dat in enumerate(noniid_train)] server = Server(ConvNet, noniid_test) cfl_stats = ExperimentLogger() cluster_indices = [np.arange(len(clients)).astype("int")] client_clusters = [[clients[i] for i in idcs] for idcs in cluster_indices] cfl_stats.log({'args':args}) print('Finish setting up') local_epoch = args.local_epoch round_interval = 0 for c_round in range(1, args.rounds+1): if c_round == 1: for client in clients: client.synchronize_with_server(server) participating_clients = server.select_clients(clients,frac = args.frac) # clients train locally for client in participating_clients: # average loss for a client train_stats = client.compute_weight_update(epochs=local_epoch) client.reset() similarities = server.compute_pairwise_similarities(clients) cluster_indices_new = [] round_interval += 1 for idc in cluster_indices: max_norm = server.compute_max_update_norm([clients[i] for i in idc]) mean_norm = server.compute_mean_update_norm([clients[i] for i in idc]) print(max_norm, mean_norm) # update avg gradient norm smaller than eps1: start converge # max norm of clients' gradient larger than eps2: need to split cluster if mean_norm<args.eps1 and max_norm>args.eps2 and len(idc)>2 : # and round_interval > 20 # server.cache_model(idc, clients[idc[0]].W, acc_clients) c1, c2 = server.cluster_clients(similarities[idc][:,idc]) cluster_indices_new += [c1, c2] round_interval = 0 cfl_stats.log({"split" : c_round}) else: cluster_indices_new += [idc] cluster_indices = cluster_indices_new client_clusters = [[clients[i] for i in idcs] for idcs in cluster_indices] server.aggregate_clusterwise(client_clusters) acc_clients = [client.evaluate() for client in clients] log_dict = {"acc_clients" : acc_clients, "mean_norm" : mean_norm, "max_norm" : max_norm, "rounds" : c_round, "clusters" : cluster_indices} cfl_stats.log(log_dict) print(log_dict) display_train_stats(cfl_stats, args.eps1, args.eps2, args.rounds) print("Clustering result: " ,cluster_indices) for idc in cluster_indices: server.cache_model(idc, clients[idc[0]].W, acc_clients)