def DefaultConfig(): seed = 123123123 lr = 0.01 name = 'clustering_test' total_fedavg_rounds = 50 client_fraction = 0.1 local_epochs = 1 batch_size = 16 num_clients = 27 num_classes = 7 train_args = TrainArgs(max_epochs=local_epochs, min_epochs=local_epochs, progress_bar_refresh_rate=0) train_cluster_args = TrainArgs(max_epochs=3, min_epochs=3, progress_bar_refresh_rate=0) dataset = 'ham10k' partitioner_class = AlternativePartitioner optimizer_args = OptimizerArgs(optim.SGD, lr=lr) model_args = ModelArgs(MobileNetV2Lightning, optimizer_args=optimizer_args, num_classes=num_classes) initialization_rounds = [25, 50] linkage_mech = 'ward' criterion = 'distance' dis_metric = 'euclidean' max_value_criterion = [100, 200, 300] mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225)
def create_mnist_experiment_context( name: str, local_epochs: int, batch_size: int, lr: float, client_fraction: float, dataset_name: str, num_classes: int, fixed_logger_version=None, no_progress_bar=False, cluster_args: Optional[ClusterArgs] = None): logger.debug('creating experiment context ...') optimizer_args = OptimizerArgs(optim.SGD, lr=lr) model_args = ModelArgs(CNNMnistLightning, num_classes=num_classes, optimizer_args=optimizer_args) train_args_dict = {'max_epochs': local_epochs, 'min_epochs': local_epochs} if no_progress_bar: train_args_dict['progress_bar_refresh_rate'] = 0 training_args = TrainArgs(**train_args_dict) context = FedAvgExperimentContext(name=name, client_fraction=client_fraction, local_epochs=local_epochs, lr=lr, batch_size=batch_size, optimizer_args=optimizer_args, model_args=model_args, train_args=training_args, dataset_name=dataset_name) experiment_specification = f'{context}' experiment_specification += f'_{optimizer_args}' if cluster_args is not None: context.cluster_args = cluster_args experiment_specification += f'_{cluster_args}' experiment_logger = create_tensorboard_logger(context.name, experiment_specification, fixed_logger_version) context.experiment_logger = experiment_logger return context
def run_hierarchical_clustering(local_evaluation_steps, seed, lr, name, total_fedavg_rounds, cluster_initialization_rounds, client_fraction, local_epochs, batch_size, num_clients, sample_threshold, num_label_limit, train_args, dataset, partitioner_class, linkage_mech, criterion, dis_metric, max_value_criterion, reallocate_clients, threshold_min_client_cluster, use_colored_images, use_pattern, train_cluster_args=None, mean=None, std=None): fix_random_seeds(seed) global_tag = 'global_performance' global_tag_local = 'global_performance_personalized' initialize_clients_fn = DEFAULT_CLIENT_INIT_FN if dataset == 'ham10k': fed_dataset = load_ham10k_federated(partitions=num_clients, batch_size=batch_size, mean=mean, std=std) initialize_clients_fn = initialize_ham10k_clients else: raise ValueError(f'dataset "{dataset}" unknown') if not hasattr(max_value_criterion, '__iter__'): max_value_criterion = [max_value_criterion] if not hasattr(lr, '__iter__'): lr = [lr] for cf in client_fraction: for lr_i in lr: optimizer_args = OptimizerArgs(optim.SGD, lr=lr_i) model_args = ModelArgs(MobileNetV2Lightning, optimizer_args=optimizer_args, num_classes=7) fedavg_context = FedAvgExperimentContext( name=name, client_fraction=cf, local_epochs=local_epochs, lr=lr_i, batch_size=batch_size, optimizer_args=optimizer_args, model_args=model_args, train_args=train_args, dataset_name=dataset) experiment_specification = f'{fedavg_context}' experiment_logger = create_tensorboard_logger( fedavg_context.name, experiment_specification) fedavg_context.experiment_logger = experiment_logger for init_rounds, max_value in generate_configuration( cluster_initialization_rounds, max_value_criterion): # load the model state round_model_state = load_fedavg_state(fedavg_context, init_rounds) server = FedAvgServer('initial_server', fedavg_context.model_args, fedavg_context) server.overwrite_model_state(round_model_state) logger.info('initializing clients ...') clients = initialize_clients_fn(fedavg_context, fed_dataset, server.model.state_dict()) overwrite_participants_models(round_model_state, clients) # initialize the cluster configuration round_configuration = { 'num_rounds_init': init_rounds, 'num_rounds_cluster': total_fedavg_rounds - init_rounds } if partitioner_class == DatadependentPartitioner: clustering_dataset = load_femnist_colored_dataset( str((REPO_ROOT / 'data').absolute()), num_clients=num_clients, batch_size=batch_size, sample_threshold=sample_threshold) dataloader = load_n_of_each_class( clustering_dataset, n=5, tabu=list(fed_dataset.train_data_local_dict.keys())) cluster_args = ClusterArgs( partitioner_class, linkage_mech=linkage_mech, criterion=criterion, dis_metric=dis_metric, max_value_criterion=max_value, plot_dendrogram=False, reallocate_clients=reallocate_clients, threshold_min_client_cluster= threshold_min_client_cluster, dataloader=dataloader, **round_configuration) else: cluster_args = ClusterArgs( partitioner_class, linkage_mech=linkage_mech, criterion=criterion, dis_metric=dis_metric, max_value_criterion=max_value, plot_dendrogram=False, reallocate_clients=reallocate_clients, threshold_min_client_cluster= threshold_min_client_cluster, **round_configuration) # create new logger for cluster experiment experiment_specification = f'{fedavg_context}_{cluster_args}' experiment_logger = create_tensorboard_logger( fedavg_context.name, experiment_specification) fedavg_context.experiment_logger = experiment_logger initial_train_fn = partial(run_fedavg_train_round, round_model_state, training_args=train_cluster_args) create_aggregator_fn = partial(FedAvgServer, model_args=model_args, context=fedavg_context) federated_round_fn = partial(run_fedavg_round, training_args=train_args, client_fraction=cf) after_post_clustering_evaluation = [ partial(log_after_round_evaluation, experiment_logger, 'post_clustering') ] after_clustering_round_evaluation = [ partial(log_after_round_evaluation, experiment_logger) ] after_federated_round_evaluation = [ partial(log_after_round_evaluation, experiment_logger, ['final hierarchical', global_tag]) ] after_clustering_fn = [ partial(log_cluster_distribution, experiment_logger, num_classes=fed_dataset.class_num), partial(log_sample_images_from_each_client, experiment_logger) ] after_federated_round_fn = [ partial( log_personalized_global_cluster_performance, experiment_logger, ['final hierarchical personalized', global_tag_local], local_evaluation_steps) ] run_fedavg_hierarchical( server, clients, cluster_args, initial_train_fn, federated_round_fn, create_aggregator_fn, after_post_clustering_evaluation, after_clustering_round_evaluation, after_federated_round_evaluation, after_clustering_fn, after_federated_round=after_federated_round_fn)
def run_hierarchical_clustering_reptile( seed, name, dataset, num_clients, batch_size, num_label_limit, use_colored_images, sample_threshold, hc_lr, hc_cluster_initialization_rounds, hc_client_fraction, hc_local_epochs, hc_train_args, hc_partitioner_class, hc_linkage_mech, hc_criterion, hc_dis_metric, hc_max_value_criterion, # distance threshold hc_reallocate_clients, # hc_threshold_min_client_cluster, # only with hc_reallocate_clients = True, # results in clusters having at least this number of clients hc_train_cluster_args, rp_sgd, # True -> Use SGD as inner optimizer; False -> Use Adam rp_adam_betas, # Used only if sgd = False rp_meta_batch_size, rp_num_meta_steps, rp_meta_learning_rate_initial, rp_meta_learning_rate_final, rp_eval_interval, rp_inner_learning_rate, rp_num_inner_steps, rp_num_inner_steps_eval): fix_random_seeds(seed) global_tag = 'global_performance' if dataset == 'femnist': if use_colored_images: fed_dataset = load_femnist_colored_dataset( data_dir=str((REPO_ROOT / 'data').absolute()), num_clients=num_clients, batch_size=batch_size, sample_threshold=sample_threshold) else: fed_dataset = load_femnist_dataset( data_dir=str((REPO_ROOT / 'data').absolute()), num_clients=num_clients, batch_size=batch_size, sample_threshold=sample_threshold) if num_label_limit != -1: fed_dataset = scratch_labels(fed_dataset, num_label_limit) else: raise ValueError(f'dataset "{dataset}" unknown') if not hasattr(hc_max_value_criterion, '__iter__'): hc_max_value_criterion = [hc_max_value_criterion] if not hasattr(hc_lr, '__iter__'): hc_lr = [hc_lr] input_channels = 3 if use_colored_images else 1 data_distribution_logged = False for cf in hc_client_fraction: for lr_i in hc_lr: # Initialize experiment context parameters fedavg_optimizer_args = OptimizerArgs(optim.SGD, lr=lr_i) fedavg_model_args = ModelArgs(CNNLightning, optimizer_args=fedavg_optimizer_args, input_channels=input_channels, only_digits=False) fedavg_context = FedAvgExperimentContext( name=name, client_fraction=cf, local_epochs=hc_local_epochs, lr=lr_i, batch_size=batch_size, optimizer_args=fedavg_optimizer_args, model_args=fedavg_model_args, train_args=hc_train_args, dataset_name=dataset) reptile_context = ReptileExperimentContext( name=name, dataset_name=dataset, swap_labels=False, num_classes_per_client=0, num_shots_per_class=0, seed=seed, model_class=CNNLightning, sgd=rp_sgd, adam_betas=rp_adam_betas, num_clients_train=num_clients, num_clients_test=0, meta_batch_size=rp_meta_batch_size, num_meta_steps=rp_num_meta_steps, meta_learning_rate_initial=rp_meta_learning_rate_initial, meta_learning_rate_final=rp_meta_learning_rate_final, eval_interval=rp_eval_interval, num_eval_clients_training=-1, do_final_evaluation=True, num_eval_clients_final=-1, inner_batch_size=batch_size, inner_learning_rate=rp_inner_learning_rate, num_inner_steps=rp_num_inner_steps, num_inner_steps_eval=rp_num_inner_steps_eval) experiment_specification = f'{fedavg_context}' experiment_logger = create_tensorboard_logger( name, experiment_specification) if not data_distribution_logged: log_dataset_distribution(experiment_logger, 'full dataset', fed_dataset) data_distribution_logged = True log_after_round_evaluation_fns = [ partial(log_after_round_evaluation, experiment_logger, 'fedavg'), partial(log_after_round_evaluation, experiment_logger, global_tag) ] server, clients = run_fedavg( context=fedavg_context, num_rounds=max(hc_cluster_initialization_rounds), dataset=fed_dataset, save_states=True, restore_state=True, after_round_evaluation=log_after_round_evaluation_fns) for init_rounds, max_value in generate_configuration( hc_cluster_initialization_rounds, hc_max_value_criterion): # load the model state round_model_state = load_fedavg_state(fedavg_context, init_rounds) overwrite_participants_models(round_model_state, clients) # initialize the cluster configuration round_configuration = { 'num_rounds_init': init_rounds, 'num_rounds_cluster': 0 } cluster_args = ClusterArgs( hc_partitioner_class, linkage_mech=hc_linkage_mech, criterion=hc_criterion, dis_metric=hc_dis_metric, max_value_criterion=max_value, plot_dendrogram=False, reallocate_clients=hc_reallocate_clients, threshold_min_client_cluster= hc_threshold_min_client_cluster, **round_configuration) # create new logger for cluster experiment experiment_specification = f'{fedavg_context}_{cluster_args}_{reptile_context}' experiment_logger = create_tensorboard_logger( name, experiment_specification) fedavg_context.experiment_logger = experiment_logger initial_train_fn = partial(run_fedavg_train_round, round_model_state, training_args=hc_train_cluster_args) create_aggregator_fn = partial(FedAvgServer, model_args=fedavg_model_args, context=fedavg_context) # HIERARCHICAL CLUSTERING logger.debug('starting local training before clustering.') trained_participants = initial_train_fn(clients) if len(trained_participants) != len(clients): raise ValueError( 'not all clients successfully participated in the clustering round' ) # Clustering of participants by model updates partitioner = cluster_args() cluster_clients_dic = partitioner.cluster(clients, server) _cluster_clients_dic = dict() for cluster_id, participants in cluster_clients_dic.items(): _cluster_clients_dic[cluster_id] = [ c._name for c in participants ] log_cluster_distribution(experiment_logger, cluster_clients_dic, 62) # Initialize cluster models cluster_server_dic = {} for cluster_id, participants in cluster_clients_dic.items(): intermediate_cluster_server = create_aggregator_fn( 'cluster_server' + cluster_id) intermediate_cluster_server.aggregate(participants) cluster_server = ReptileServer( participant_name=f'cluster_server{cluster_id}', model_args=reptile_context.meta_model_args, context=reptile_context, initial_model_state=intermediate_cluster_server.model. state_dict()) #create_aggregator_fn('cluster_server' + cluster_id) #cluster_server.aggregate(participants) cluster_server_dic[cluster_id] = cluster_server # REPTILE TRAINING INSIDE CLUSTERS after_round_evaluation = [log_after_round_evaluation] RANDOM = random.Random(seed) # Perform training for i in range(reptile_context.num_meta_steps): for cluster_id, participants in cluster_clients_dic.items( ): if reptile_context.meta_batch_size == -1: meta_batch = participants else: meta_batch = [ participants[k] for k in cyclerange( start=i * reptile_context.meta_batch_size % len(participants), interval=reptile_context.meta_batch_size, total_len=len(participants)) ] # Meta training step reptile_train_step( aggregator=cluster_server_dic[cluster_id], participants=meta_batch, inner_training_args=reptile_context. get_inner_training_args(), meta_training_args=reptile_context. get_meta_training_args( frac_done=i / reptile_context.num_meta_steps)) # Evaluation on train and test clients if i % reptile_context.eval_interval == 0: global_step = init_rounds + i global_loss, global_acc = [], [] for cluster_id, participants in cluster_clients_dic.items( ): # Test on all clients inside clusters reptile_train_step( aggregator=cluster_server_dic[cluster_id], participants=participants, inner_training_args=reptile_context. get_inner_training_args(eval=True), evaluation_mode=True) result = evaluate_local_models( participants=participants) loss = result.get('test/loss') acc = result.get('test/acc') # Log if after_round_evaluation is not None: for c in after_round_evaluation: c(experiment_logger, f'cluster_{cluster_id}', loss, acc, global_step) loss_list = loss.tolist() acc_list = acc.tolist() global_loss.extend(loss_list if isinstance( loss_list, list) else [loss_list]) global_acc.extend(acc_list if isinstance( acc_list, list) else [acc_list]) if after_round_evaluation is not None: for c in after_round_evaluation: c(experiment_logger, 'mean_over_all_clients', Tensor(global_loss), Tensor(global_acc), global_step) logger.info(f'Finished Reptile training round {i}') # Final evaluation at end of training if reptile_context.do_final_evaluation: global_loss, global_acc = [], [] for cluster_id, participants in cluster_clients_dic.items( ): # Final evaluation on train and test clients # Test on all clients inside clusters reptile_train_step( aggregator=cluster_server_dic[cluster_id], participants=participants, inner_training_args=reptile_context. get_inner_training_args(eval=True), evaluation_mode=True) result = evaluate_local_models( participants=participants) loss = result.get('test/loss') acc = result.get('test/acc') print( f'Cluster {cluster_id} ({len(participants)} part.): loss = {loss}, acc = {acc}' ) loss_list = loss.tolist() acc_list = acc.tolist() global_loss.extend(loss_list if isinstance( loss_list, list) else [loss_list]) global_acc.extend(acc_list if isinstance( acc_list, list) else [acc_list]) # Log if after_round_evaluation is not None: for c in after_round_evaluation: c(experiment_logger, f'cluster_{cluster_id}', loss, acc, reptile_context.num_meta_steps) log_loss_and_acc('overall_mean', Tensor(global_loss), Tensor(global_acc), experiment_logger, 0)
def __init__( self, name: str, dataset_name: str, swap_labels: bool, num_classes_per_client: int, num_shots_per_class: int, seed: int, model_class, sgd: bool, adam_betas: tuple, num_clients_train: int, num_clients_test: int, meta_batch_size: int, num_meta_steps: int, meta_learning_rate_initial: float, meta_learning_rate_final: float, eval_interval: int, num_eval_clients_training: int, do_final_evaluation: bool, num_eval_clients_final: int, inner_batch_size: int, inner_learning_rate: float, num_inner_steps: int, num_inner_steps_eval: int): self.name = name self.seed = seed # Arguments pertaining to data set self.dataset_name = dataset_name self.swap_labels = swap_labels self.num_classes_per_client = num_classes_per_client self.num_shots_per_class = num_shots_per_class self.num_clients_train = num_clients_train self.num_clients_test = num_clients_test # Model arguments self.model_class = model_class self.sgd = sgd self.adam_betas = adam_betas # Arguments for inner training self.num_inner_steps = num_inner_steps self.inner_learning_rate = inner_learning_rate self.inner_batch_size = inner_batch_size if self.sgd: inner_optimizer_args = OptimizerArgs( optimizer_class=torch.optim.SGD, lr=self.inner_learning_rate) else: inner_optimizer_args = OptimizerArgs( optimizer_class=torch.optim.Adam, lr=self.inner_learning_rate, betas=self.adam_betas) if self.dataset_name == 'omniglot': self.inner_model_args = ModelArgs( model_class=self.model_class, optimizer_args=inner_optimizer_args, num_classes=self.num_classes_per_client) elif self.dataset_name == 'ham10k': self.inner_model_args = ModelArgs( model_class=self.model_class, optimizer_args=inner_optimizer_args, num_classes=7) else: self.inner_model_args = ModelArgs( model_class=self.model_class, optimizer_args=inner_optimizer_args, ) # Arguments for meta training self.num_meta_steps = num_meta_steps self.meta_learning_rate_initial = meta_learning_rate_initial self.meta_learning_rate_final = meta_learning_rate_final self.meta_batch_size = meta_batch_size # number of clients per round if self.dataset_name == 'omniglot': self.meta_model_args = ModelArgs( model_class=self.model_class, optimizer_args=OptimizerArgs( # Dummy optimizer args optimizer_class=torch.optim.SGD), num_classes=self.num_classes_per_client) elif self.dataset_name == 'ham10k': self.meta_model_args = ModelArgs( model_class=self.model_class, optimizer_args=OptimizerArgs( # Dummy optimizer args optimizer_class=torch.optim.SGD), num_classes=7) else: self.meta_model_args = ModelArgs( model_class=self.model_class, optimizer_args=OptimizerArgs( # Dummy optimizer args optimizer_class=torch.optim.SGD)) # Arguments for evaluation self.num_inner_steps_eval = num_inner_steps_eval self.eval_interval = eval_interval self.num_eval_clients_training = num_eval_clients_training if num_eval_clients_training > num_clients_train or \ (num_clients_test and num_eval_clients_training > num_clients_test): raise ValueError( "num_eval_clients_training must be lower or equal to " "num_clients_train and num_clients_test") self.do_final_evaluation = do_final_evaluation if do_final_evaluation and ( num_eval_clients_final > num_clients_train or (num_clients_test and num_eval_clients_final > num_clients_test)): raise ValueError( "num_eval_clients_final must be lower or equal to " "num_clients_train and num_clients_test") self.num_eval_clients_final = num_eval_clients_final self.weighted_aggregation: bool = True self._cluster_args = None self._experiment_logger = None