def run_reptile_experiment( name, dataset, swap_labels, classes, shots, seed, model_class, sgd, adam_betas, num_clients_train, num_clients_test, meta_batch_size, num_meta_steps, meta_learning_rate_initial, meta_learning_rate_final, eval_interval, num_eval_clients_training, do_final_evaluation, num_eval_clients_final, inner_batch_size, inner_learning_rate, num_inner_steps, num_inner_steps_eval, mean=None, std=None ): fix_random_seeds(seed) fed_dataset_test = None if dataset == 'femnist': fed_dataset_train = load_femnist_dataset( data_dir=str((REPO_ROOT / 'data').absolute()), num_clients=num_clients_train, batch_size=inner_batch_size, random_seed=seed ) elif dataset == 'omniglot': fed_dataset_train, fed_dataset_test = load_omniglot_datasets( data_dir=str((REPO_ROOT / 'data' / 'omniglot').absolute()), num_clients_train=num_clients_train, num_clients_test=num_clients_test, num_classes_per_client=classes, num_shots_per_class=shots, inner_batch_size=inner_batch_size, random_seed=seed ) elif dataset == 'ham10k': fed_dataset_train = load_ham10k_federated(partitions=num_clients_train, batch_size=inner_batch_size, mean=mean, std=std) else: raise ValueError(f'dataset "{dataset}" unknown') #data_distribution_logged = False for lr in inner_learning_rate: for _is in num_inner_steps: reptile_context = ReptileExperimentContext( name=name, dataset_name=dataset, swap_labels=swap_labels, num_classes_per_client=classes, num_shots_per_class=shots, seed=seed, model_class=model_class, sgd=sgd, adam_betas=adam_betas, num_clients_train=num_clients_train, num_clients_test=num_clients_test, meta_batch_size=meta_batch_size, num_meta_steps=num_meta_steps, meta_learning_rate_initial=meta_learning_rate_initial, meta_learning_rate_final=meta_learning_rate_final, eval_interval=eval_interval, num_eval_clients_training=num_eval_clients_training, do_final_evaluation=do_final_evaluation, num_eval_clients_final=num_eval_clients_final, inner_batch_size=inner_batch_size, inner_learning_rate=lr, num_inner_steps=_is, num_inner_steps_eval=_is ) experiment_specification = f'{reptile_context}' experiment_logger = create_tensorboard_logger( reptile_context.name, experiment_specification ) reptile_context.experiment_logger = experiment_logger log_after_round_evaluation_fns = [ partial(log_after_round_evaluation, experiment_logger) ] run_reptile( context=reptile_context, dataset_train=fed_dataset_train, dataset_test=fed_dataset_test, initial_model_state=None, after_round_evaluation=log_after_round_evaluation_fns )
def clustering_test(mean, std, seed, lr, local_epochs, client_fraction, optimizer_args, total_fedavg_rounds, batch_size, num_clients, model_args, train_args, train_cluster_args, initialization_rounds, partitioner_class, linkage_mech, criterion, dis_metric, max_value_criterion): fix_random_seeds(seed) fed_dataset = load_ham10k_federated(partitions=num_clients, batch_size=batch_size, mean=mean, std=std) initialize_clients_fn = initialize_ham10k_clients fedavg_context = FedAvgExperimentContext(name='ham10k_clustering', client_fraction=client_fraction, local_epochs=local_epochs, lr=lr, batch_size=batch_size, optimizer_args=optimizer_args, model_args=model_args, train_args=train_args, dataset_name='ham10k') experiment_specification = f'{fedavg_context}' experiment_logger = create_tensorboard_logger(fedavg_context.name, experiment_specification) log_dataset_distribution(experiment_logger, 'full dataset', fed_dataset) server, clients = run_fedavg(context=fedavg_context, num_rounds=total_fedavg_rounds, dataset=fed_dataset, save_states=True, restore_state=True, evaluate_rounds=False, initialize_clients_fn=initialize_clients_fn) for init_rounds in initialization_rounds: # load the model state round_model_state = load_fedavg_state(fedavg_context, init_rounds) overwrite_participants_models(round_model_state, clients) run_fedavg_train_round(round_model_state, training_args=train_cluster_args, participants=clients) for max_value in max_value_criterion: # initialize the cluster configuration round_configuration = { 'num_rounds_init': init_rounds, 'num_rounds_cluster': total_fedavg_rounds - init_rounds } cluster_args = ClusterArgs(partitioner_class, linkage_mech=linkage_mech, criterion=criterion, dis_metric=dis_metric, max_value_criterion=max_value, plot_dendrogram=False, reallocate_clients=False, threshold_min_client_cluster=-1, **round_configuration) experiment_logger = create_tensorboard_logger( fedavg_context.name, f'{experiment_specification}{cluster_args}') partitioner = cluster_args() cluster_clients_dic = partitioner.cluster(clients, server) log_cluster_distribution(experiment_logger, cluster_clients_dic, 7)
def run_hierarchical_clustering(local_evaluation_steps, seed, lr, name, total_fedavg_rounds, cluster_initialization_rounds, client_fraction, local_epochs, batch_size, num_clients, sample_threshold, num_label_limit, train_args, dataset, partitioner_class, linkage_mech, criterion, dis_metric, max_value_criterion, reallocate_clients, threshold_min_client_cluster, use_colored_images, use_pattern, train_cluster_args=None, mean=None, std=None): fix_random_seeds(seed) global_tag = 'global_performance' global_tag_local = 'global_performance_personalized' initialize_clients_fn = DEFAULT_CLIENT_INIT_FN if dataset == 'ham10k': fed_dataset = load_ham10k_federated(partitions=num_clients, batch_size=batch_size, mean=mean, std=std) initialize_clients_fn = initialize_ham10k_clients else: raise ValueError(f'dataset "{dataset}" unknown') if not hasattr(max_value_criterion, '__iter__'): max_value_criterion = [max_value_criterion] if not hasattr(lr, '__iter__'): lr = [lr] for cf in client_fraction: for lr_i in lr: optimizer_args = OptimizerArgs(optim.SGD, lr=lr_i) model_args = ModelArgs(MobileNetV2Lightning, optimizer_args=optimizer_args, num_classes=7) fedavg_context = FedAvgExperimentContext( name=name, client_fraction=cf, local_epochs=local_epochs, lr=lr_i, batch_size=batch_size, optimizer_args=optimizer_args, model_args=model_args, train_args=train_args, dataset_name=dataset) experiment_specification = f'{fedavg_context}' experiment_logger = create_tensorboard_logger( fedavg_context.name, experiment_specification) fedavg_context.experiment_logger = experiment_logger for init_rounds, max_value in generate_configuration( cluster_initialization_rounds, max_value_criterion): # load the model state round_model_state = load_fedavg_state(fedavg_context, init_rounds) server = FedAvgServer('initial_server', fedavg_context.model_args, fedavg_context) server.overwrite_model_state(round_model_state) logger.info('initializing clients ...') clients = initialize_clients_fn(fedavg_context, fed_dataset, server.model.state_dict()) overwrite_participants_models(round_model_state, clients) # initialize the cluster configuration round_configuration = { 'num_rounds_init': init_rounds, 'num_rounds_cluster': total_fedavg_rounds - init_rounds } if partitioner_class == DatadependentPartitioner: clustering_dataset = load_femnist_colored_dataset( str((REPO_ROOT / 'data').absolute()), num_clients=num_clients, batch_size=batch_size, sample_threshold=sample_threshold) dataloader = load_n_of_each_class( clustering_dataset, n=5, tabu=list(fed_dataset.train_data_local_dict.keys())) cluster_args = ClusterArgs( partitioner_class, linkage_mech=linkage_mech, criterion=criterion, dis_metric=dis_metric, max_value_criterion=max_value, plot_dendrogram=False, reallocate_clients=reallocate_clients, threshold_min_client_cluster= threshold_min_client_cluster, dataloader=dataloader, **round_configuration) else: cluster_args = ClusterArgs( partitioner_class, linkage_mech=linkage_mech, criterion=criterion, dis_metric=dis_metric, max_value_criterion=max_value, plot_dendrogram=False, reallocate_clients=reallocate_clients, threshold_min_client_cluster= threshold_min_client_cluster, **round_configuration) # create new logger for cluster experiment experiment_specification = f'{fedavg_context}_{cluster_args}' experiment_logger = create_tensorboard_logger( fedavg_context.name, experiment_specification) fedavg_context.experiment_logger = experiment_logger initial_train_fn = partial(run_fedavg_train_round, round_model_state, training_args=train_cluster_args) create_aggregator_fn = partial(FedAvgServer, model_args=model_args, context=fedavg_context) federated_round_fn = partial(run_fedavg_round, training_args=train_args, client_fraction=cf) after_post_clustering_evaluation = [ partial(log_after_round_evaluation, experiment_logger, 'post_clustering') ] after_clustering_round_evaluation = [ partial(log_after_round_evaluation, experiment_logger) ] after_federated_round_evaluation = [ partial(log_after_round_evaluation, experiment_logger, ['final hierarchical', global_tag]) ] after_clustering_fn = [ partial(log_cluster_distribution, experiment_logger, num_classes=fed_dataset.class_num), partial(log_sample_images_from_each_client, experiment_logger) ] after_federated_round_fn = [ partial( log_personalized_global_cluster_performance, experiment_logger, ['final hierarchical personalized', global_tag_local], local_evaluation_steps) ] run_fedavg_hierarchical( server, clients, cluster_args, initial_train_fn, federated_round_fn, create_aggregator_fn, after_post_clustering_evaluation, after_clustering_round_evaluation, after_federated_round_evaluation, after_clustering_fn, after_federated_round=after_federated_round_fn)
def run_hierarchical_clustering_reptile( seed, name, dataset, num_clients, batch_size, num_label_limit, use_colored_images, sample_threshold, hc_lr, hc_cluster_initialization_rounds, hc_client_fraction, hc_local_epochs, hc_train_args, hc_partitioner_class, hc_linkage_mech, hc_criterion, hc_dis_metric, hc_max_value_criterion, # distance threshold hc_reallocate_clients, # hc_threshold_min_client_cluster, # only with hc_reallocate_clients = True, # results in clusters having at least this number of clients hc_train_cluster_args, rp_sgd, # True -> Use SGD as inner optimizer; False -> Use Adam rp_adam_betas, # Used only if sgd = False rp_meta_batch_size, rp_num_meta_steps, rp_meta_learning_rate_initial, rp_meta_learning_rate_final, rp_eval_interval, rp_inner_learning_rate, rp_num_inner_steps, rp_num_inner_steps_eval, mean=None, std=None): fix_random_seeds(seed) global_tag = 'global_performance' initialize_clients_fn = DEFAULT_CLIENT_INIT_FN if dataset == 'femnist': if use_colored_images: fed_dataset = load_femnist_colored_dataset( data_dir=str((REPO_ROOT / 'data').absolute()), num_clients=num_clients, batch_size=batch_size, sample_threshold=sample_threshold) else: fed_dataset = load_femnist_dataset( data_dir=str((REPO_ROOT / 'data').absolute()), num_clients=num_clients, batch_size=batch_size, sample_threshold=sample_threshold) if num_label_limit != -1: fed_dataset = scratch_labels(fed_dataset, num_label_limit) elif dataset == 'ham10k': fed_dataset = load_ham10k_federated(partitions=num_clients, batch_size=batch_size, mean=mean, std=std) initialize_clients_fn = initialize_ham10k_clients else: raise ValueError(f'dataset "{dataset}" unknown') if not hasattr(hc_max_value_criterion, '__iter__'): hc_max_value_criterion = [hc_max_value_criterion] if not hasattr(hc_lr, '__iter__'): hc_lr = [hc_lr] input_channels = 3 if use_colored_images else 1 data_distribution_logged = False for cf in hc_client_fraction: for lr_i in hc_lr: # Initialize experiment context parameters fedavg_optimizer_args = OptimizerArgs(optim.SGD, lr=lr_i) fedavg_model_args = ModelArgs(CNNLightning, optimizer_args=fedavg_optimizer_args, input_channels=input_channels, only_digits=False) fedavg_context = FedAvgExperimentContext( name=name, client_fraction=cf, local_epochs=hc_local_epochs, lr=lr_i, batch_size=batch_size, optimizer_args=fedavg_optimizer_args, model_args=fedavg_model_args, train_args=hc_train_args, dataset_name=dataset) reptile_context = ReptileExperimentContext( name=name, dataset_name=dataset, swap_labels=False, num_classes_per_client=0, num_shots_per_class=0, seed=seed, model_class=CNNLightning, sgd=rp_sgd, adam_betas=rp_adam_betas, num_clients_train=num_clients, num_clients_test=0, meta_batch_size=rp_meta_batch_size, num_meta_steps=rp_num_meta_steps, meta_learning_rate_initial=rp_meta_learning_rate_initial, meta_learning_rate_final=rp_meta_learning_rate_final, eval_interval=rp_eval_interval, num_eval_clients_training=-1, do_final_evaluation=True, num_eval_clients_final=-1, inner_batch_size=batch_size, inner_learning_rate=rp_inner_learning_rate, num_inner_steps=rp_num_inner_steps, num_inner_steps_eval=rp_num_inner_steps_eval) experiment_specification = f'{fedavg_context}' experiment_logger = create_tensorboard_logger( name, experiment_specification) if not data_distribution_logged: log_dataset_distribution(experiment_logger, 'full dataset', fed_dataset) data_distribution_logged = True log_after_round_evaluation_fns = [ partial(log_after_round_evaluation, experiment_logger, 'fedavg'), partial(log_after_round_evaluation, experiment_logger, global_tag) ] server, clients = run_fedavg( context=fedavg_context, num_rounds=max(hc_cluster_initialization_rounds), dataset=fed_dataset, save_states=True, restore_state=True, after_round_evaluation=log_after_round_evaluation_fns, initialize_clients_fn=initialize_clients_fn) for init_rounds, max_value in generate_configuration( hc_cluster_initialization_rounds, hc_max_value_criterion): # load the model state round_model_state = load_fedavg_state(fedavg_context, init_rounds) overwrite_participants_models(round_model_state, clients) # initialize the cluster configuration round_configuration = { 'num_rounds_init': init_rounds, 'num_rounds_cluster': 0 } cluster_args = ClusterArgs( hc_partitioner_class, linkage_mech=hc_linkage_mech, criterion=hc_criterion, dis_metric=hc_dis_metric, max_value_criterion=max_value, plot_dendrogram=False, reallocate_clients=hc_reallocate_clients, threshold_min_client_cluster= hc_threshold_min_client_cluster, **round_configuration) # create new logger for cluster experiment experiment_specification = f'{fedavg_context}_{cluster_args}_{reptile_context}' experiment_logger = create_tensorboard_logger( name, experiment_specification) fedavg_context.experiment_logger = experiment_logger initial_train_fn = partial(run_fedavg_train_round, round_model_state, training_args=hc_train_cluster_args) create_aggregator_fn = partial(FedAvgServer, model_args=fedavg_model_args, context=fedavg_context) # HIERARCHICAL CLUSTERING logger.debug('starting local training before clustering.') trained_participants = initial_train_fn(clients) if len(trained_participants) != len(clients): raise ValueError( 'not all clients successfully participated in the clustering round' ) # Clustering of participants by model updates partitioner = cluster_args() cluster_clients_dic = partitioner.cluster(clients, server) _cluster_clients_dic = dict() for cluster_id, participants in cluster_clients_dic.items(): _cluster_clients_dic[cluster_id] = [ c._name for c in participants ] log_cluster_distribution(experiment_logger, cluster_clients_dic, 62) # Initialize cluster models cluster_server_dic = {} for cluster_id, participants in cluster_clients_dic.items(): intermediate_cluster_server = create_aggregator_fn( 'cluster_server' + cluster_id) intermediate_cluster_server.aggregate(participants) cluster_server = ReptileServer( participant_name=f'cluster_server{cluster_id}', model_args=reptile_context.meta_model_args, context=reptile_context, initial_model_state=intermediate_cluster_server.model. state_dict()) #create_aggregator_fn('cluster_server' + cluster_id) #cluster_server.aggregate(participants) cluster_server_dic[cluster_id] = cluster_server # REPTILE TRAINING INSIDE CLUSTERS after_round_evaluation = [log_after_round_evaluation] RANDOM = random.Random(seed) # Perform training for i in range(reptile_context.num_meta_steps): for cluster_id, participants in cluster_clients_dic.items( ): if reptile_context.meta_batch_size == -1: meta_batch = participants else: meta_batch = [ participants[k] for k in cyclerange( start=i * reptile_context.meta_batch_size % len(participants), interval=reptile_context.meta_batch_size, total_len=len(participants)) ] # Meta training step reptile_train_step( aggregator=cluster_server_dic[cluster_id], participants=meta_batch, inner_training_args=reptile_context. get_inner_training_args(), meta_training_args=reptile_context. get_meta_training_args( frac_done=i / reptile_context.num_meta_steps)) # Evaluation on train and test clients if i % reptile_context.eval_interval == 0: global_step = init_rounds + i global_loss, global_acc = [], [] for cluster_id, participants in cluster_clients_dic.items( ): # Test on all clients inside clusters reptile_train_step( aggregator=cluster_server_dic[cluster_id], participants=participants, inner_training_args=reptile_context. get_inner_training_args(eval=True), evaluation_mode=True) result = evaluate_local_models( participants=participants) loss = result.get('test/loss') acc = result.get('test/acc') # Log if after_round_evaluation is not None: for c in after_round_evaluation: c(experiment_logger, f'cluster_{cluster_id}', loss, acc, global_step) loss_list = loss.tolist() acc_list = acc.tolist() global_loss.extend(loss_list if isinstance( loss_list, list) else [loss_list]) global_acc.extend(acc_list if isinstance( acc_list, list) else [acc_list]) if after_round_evaluation is not None: for c in after_round_evaluation: c(experiment_logger, 'mean_over_all_clients', Tensor(global_loss), Tensor(global_acc), global_step) logger.info(f'Finished Reptile training round {i}') # Final evaluation at end of training if reptile_context.do_final_evaluation: global_loss, global_acc = [], [] for cluster_id, participants in cluster_clients_dic.items( ): # Final evaluation on train and test clients # Test on all clients inside clusters reptile_train_step( aggregator=cluster_server_dic[cluster_id], participants=participants, inner_training_args=reptile_context. get_inner_training_args(eval=True), evaluation_mode=True) result = evaluate_local_models( participants=participants) loss = result.get('test/loss') acc = result.get('test/acc') print( f'Cluster {cluster_id} ({len(participants)} part.): loss = {loss}, acc = {acc}' ) loss_list = loss.tolist() acc_list = acc.tolist() global_loss.extend(loss_list if isinstance( loss_list, list) else [loss_list]) global_acc.extend(acc_list if isinstance( acc_list, list) else [acc_list]) # Log if after_round_evaluation is not None: for c in after_round_evaluation: c(experiment_logger, f'cluster_{cluster_id}', loss, acc, reptile_context.num_meta_steps) log_loss_and_acc('overall_mean', Tensor(global_loss), Tensor(global_acc), experiment_logger, 0)