def run_reptile(context: str, initial_model_state=None): args = argument_parser().parse_args() RANDOM = random.Random(args.seed) # TODO: Possibly implement logic using ReptileExperimentContext reptile_args = ReptileTrainingArgs( model_class=OmniglotLightning, sgd=args.sgd, inner_learning_rate=args.learning_rate, num_inner_steps=args.inner_iters, num_inner_steps_eval=args.eval_iters, log_every_n_steps=3, meta_learning_rate_initial=args.meta_step, meta_learning_rate_final=args.meta_step_final, num_classes_per_client=args.classes ) experiment_logger = create_tensorboard_logger( 'reptile', ( f"{context};seed{args.seed};" f"train-clients{args.train_clients};" f"{args.classes}-way{args.shots}-shot;" f"ib{args.inner_batch}ii{args.inner_iters}" f"ilr{str(args.learning_rate).replace('.', '')}" f"ms{str(args.meta_step).replace('.', '')}" f"mb{args.meta_batch}ei{args.eval_iters}" f"{'sgd' if args.sgd else 'adam'}" ) ) # Load and prepare Omniglot data data_dir = REPO_ROOT / 'data' / 'omniglot' omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets( str(data_dir.absolute()), num_clients_train=args.train_clients, num_clients_test=args.test_clients, num_classes_per_client=args.classes, num_shots_per_class=args.shots, inner_batch_size=args.inner_batch, random_seed=args.seed ) # Set up clients # Since we are doing meta-learning, we need separate sets of training and # test clients train_clients = initialize_clients(omniglot_train_clients, reptile_args.get_inner_model_args(), context, experiment_logger) test_clients = initialize_clients(omniglot_test_clients, reptile_args.get_inner_model_args(), context, experiment_logger) # Set up server server = ReptileServer( participant_name='initial_server', model_args=reptile_args.get_meta_model_args(), context=context, # TODO: Change to ReptileExperimentContext initial_model_state=initial_model_state ) # Perform training for i in range(args.meta_iters): if args.meta_batch == -1: meta_batch = train_clients else: meta_batch = [ train_clients[k] for k in cyclerange( i*args.meta_batch % len(train_clients), (i+1)*args.meta_batch % len(train_clients), len(train_clients) ) ] # Meta training step reptile_train_step( aggregator=server, participants=meta_batch, inner_training_args=reptile_args.get_inner_training_args(), meta_training_args=reptile_args.get_meta_training_args( frac_done=i / args.meta_iters ) ) # Evaluation on train and test clients if i % args.eval_interval == 0: # train-test set # Pick one train client at random and test on it k = RANDOM.randrange(len(train_clients)) client = [train_clients[k]] reptile_train_step( aggregator=server, participants=client, inner_training_args=reptile_args.get_inner_training_args(eval=True), evaluation_mode=True ) result = evaluate_local_models(participants=client) experiment_logger.experiment.add_scalar( 'train-test/acc/{}/mean'.format('global_model'), torch.mean(result.get('test/acc')), global_step=i + 1 ) # test-test set # Pick one test client at random and test on it k = RANDOM.randrange(len(test_clients)) client = [test_clients[k]] reptile_train_step( aggregator=server, participants=client, inner_training_args=reptile_args.get_inner_training_args(eval=True), evaluation_mode=True ) result = evaluate_local_models(participants=client) experiment_logger.experiment.add_scalar( 'test-test/acc/{}/mean'.format('global_model'), torch.mean(result.get('test/acc')), global_step=i + 1 ) logger.info('finished training round') # Final evaluation on a sample of train/test clients for label, client_set in zip(['Train', 'Test'], [train_clients, test_clients]): eval_sample = RANDOM.sample(client_set, args.eval_samples) reptile_train_step( aggregator=server, participants=eval_sample, inner_training_args=reptile_args.get_inner_training_args(eval=True), evaluation_mode=True ) result = evaluate_local_models(participants=eval_sample) log_loss_and_acc('global_model', result.get('test/loss'), result.get('test/acc'), experiment_logger, global_step=args.meta_iters) experiment_logger.experiment.add_scalar( f'final_{label}_acc', torch.mean(result.get('test/acc')), global_step=0 ) print(f"{label} accuracy: {torch.mean(result.get('test/acc'))}")
def run_hierarchical_clustering_reptile( seed, name, dataset, num_clients, batch_size, num_label_limit, use_colored_images, sample_threshold, hc_lr, hc_cluster_initialization_rounds, hc_client_fraction, hc_local_epochs, hc_train_args, hc_partitioner_class, hc_linkage_mech, hc_criterion, hc_dis_metric, hc_max_value_criterion, # distance threshold hc_reallocate_clients, # hc_threshold_min_client_cluster, # only with hc_reallocate_clients = True, # results in clusters having at least this number of clients hc_train_cluster_args, rp_sgd, # True -> Use SGD as inner optimizer; False -> Use Adam rp_adam_betas, # Used only if sgd = False rp_meta_batch_size, rp_num_meta_steps, rp_meta_learning_rate_initial, rp_meta_learning_rate_final, rp_eval_interval, rp_inner_learning_rate, rp_num_inner_steps, rp_num_inner_steps_eval): fix_random_seeds(seed) global_tag = 'global_performance' if dataset == 'femnist': if use_colored_images: fed_dataset = load_femnist_colored_dataset( data_dir=str((REPO_ROOT / 'data').absolute()), num_clients=num_clients, batch_size=batch_size, sample_threshold=sample_threshold) else: fed_dataset = load_femnist_dataset( data_dir=str((REPO_ROOT / 'data').absolute()), num_clients=num_clients, batch_size=batch_size, sample_threshold=sample_threshold) if num_label_limit != -1: fed_dataset = scratch_labels(fed_dataset, num_label_limit) else: raise ValueError(f'dataset "{dataset}" unknown') if not hasattr(hc_max_value_criterion, '__iter__'): hc_max_value_criterion = [hc_max_value_criterion] if not hasattr(hc_lr, '__iter__'): hc_lr = [hc_lr] input_channels = 3 if use_colored_images else 1 data_distribution_logged = False for cf in hc_client_fraction: for lr_i in hc_lr: # Initialize experiment context parameters fedavg_optimizer_args = OptimizerArgs(optim.SGD, lr=lr_i) fedavg_model_args = ModelArgs(CNNLightning, optimizer_args=fedavg_optimizer_args, input_channels=input_channels, only_digits=False) fedavg_context = FedAvgExperimentContext( name=name, client_fraction=cf, local_epochs=hc_local_epochs, lr=lr_i, batch_size=batch_size, optimizer_args=fedavg_optimizer_args, model_args=fedavg_model_args, train_args=hc_train_args, dataset_name=dataset) reptile_context = ReptileExperimentContext( name=name, dataset_name=dataset, swap_labels=False, num_classes_per_client=0, num_shots_per_class=0, seed=seed, model_class=CNNLightning, sgd=rp_sgd, adam_betas=rp_adam_betas, num_clients_train=num_clients, num_clients_test=0, meta_batch_size=rp_meta_batch_size, num_meta_steps=rp_num_meta_steps, meta_learning_rate_initial=rp_meta_learning_rate_initial, meta_learning_rate_final=rp_meta_learning_rate_final, eval_interval=rp_eval_interval, num_eval_clients_training=-1, do_final_evaluation=True, num_eval_clients_final=-1, inner_batch_size=batch_size, inner_learning_rate=rp_inner_learning_rate, num_inner_steps=rp_num_inner_steps, num_inner_steps_eval=rp_num_inner_steps_eval) experiment_specification = f'{fedavg_context}' experiment_logger = create_tensorboard_logger( name, experiment_specification) if not data_distribution_logged: log_dataset_distribution(experiment_logger, 'full dataset', fed_dataset) data_distribution_logged = True log_after_round_evaluation_fns = [ partial(log_after_round_evaluation, experiment_logger, 'fedavg'), partial(log_after_round_evaluation, experiment_logger, global_tag) ] server, clients = run_fedavg( context=fedavg_context, num_rounds=max(hc_cluster_initialization_rounds), dataset=fed_dataset, save_states=True, restore_state=True, after_round_evaluation=log_after_round_evaluation_fns) for init_rounds, max_value in generate_configuration( hc_cluster_initialization_rounds, hc_max_value_criterion): # load the model state round_model_state = load_fedavg_state(fedavg_context, init_rounds) overwrite_participants_models(round_model_state, clients) # initialize the cluster configuration round_configuration = { 'num_rounds_init': init_rounds, 'num_rounds_cluster': 0 } cluster_args = ClusterArgs( hc_partitioner_class, linkage_mech=hc_linkage_mech, criterion=hc_criterion, dis_metric=hc_dis_metric, max_value_criterion=max_value, plot_dendrogram=False, reallocate_clients=hc_reallocate_clients, threshold_min_client_cluster= hc_threshold_min_client_cluster, **round_configuration) # create new logger for cluster experiment experiment_specification = f'{fedavg_context}_{cluster_args}_{reptile_context}' experiment_logger = create_tensorboard_logger( name, experiment_specification) fedavg_context.experiment_logger = experiment_logger initial_train_fn = partial(run_fedavg_train_round, round_model_state, training_args=hc_train_cluster_args) create_aggregator_fn = partial(FedAvgServer, model_args=fedavg_model_args, context=fedavg_context) # HIERARCHICAL CLUSTERING logger.debug('starting local training before clustering.') trained_participants = initial_train_fn(clients) if len(trained_participants) != len(clients): raise ValueError( 'not all clients successfully participated in the clustering round' ) # Clustering of participants by model updates partitioner = cluster_args() cluster_clients_dic = partitioner.cluster(clients, server) _cluster_clients_dic = dict() for cluster_id, participants in cluster_clients_dic.items(): _cluster_clients_dic[cluster_id] = [ c._name for c in participants ] log_cluster_distribution(experiment_logger, cluster_clients_dic, 62) # Initialize cluster models cluster_server_dic = {} for cluster_id, participants in cluster_clients_dic.items(): intermediate_cluster_server = create_aggregator_fn( 'cluster_server' + cluster_id) intermediate_cluster_server.aggregate(participants) cluster_server = ReptileServer( participant_name=f'cluster_server{cluster_id}', model_args=reptile_context.meta_model_args, context=reptile_context, initial_model_state=intermediate_cluster_server.model. state_dict()) #create_aggregator_fn('cluster_server' + cluster_id) #cluster_server.aggregate(participants) cluster_server_dic[cluster_id] = cluster_server # REPTILE TRAINING INSIDE CLUSTERS after_round_evaluation = [log_after_round_evaluation] RANDOM = random.Random(seed) # Perform training for i in range(reptile_context.num_meta_steps): for cluster_id, participants in cluster_clients_dic.items( ): if reptile_context.meta_batch_size == -1: meta_batch = participants else: meta_batch = [ participants[k] for k in cyclerange( start=i * reptile_context.meta_batch_size % len(participants), interval=reptile_context.meta_batch_size, total_len=len(participants)) ] # Meta training step reptile_train_step( aggregator=cluster_server_dic[cluster_id], participants=meta_batch, inner_training_args=reptile_context. get_inner_training_args(), meta_training_args=reptile_context. get_meta_training_args( frac_done=i / reptile_context.num_meta_steps)) # Evaluation on train and test clients if i % reptile_context.eval_interval == 0: global_step = init_rounds + i global_loss, global_acc = [], [] for cluster_id, participants in cluster_clients_dic.items( ): # Test on all clients inside clusters reptile_train_step( aggregator=cluster_server_dic[cluster_id], participants=participants, inner_training_args=reptile_context. get_inner_training_args(eval=True), evaluation_mode=True) result = evaluate_local_models( participants=participants) loss = result.get('test/loss') acc = result.get('test/acc') # Log if after_round_evaluation is not None: for c in after_round_evaluation: c(experiment_logger, f'cluster_{cluster_id}', loss, acc, global_step) loss_list = loss.tolist() acc_list = acc.tolist() global_loss.extend(loss_list if isinstance( loss_list, list) else [loss_list]) global_acc.extend(acc_list if isinstance( acc_list, list) else [acc_list]) if after_round_evaluation is not None: for c in after_round_evaluation: c(experiment_logger, 'mean_over_all_clients', Tensor(global_loss), Tensor(global_acc), global_step) logger.info(f'Finished Reptile training round {i}') # Final evaluation at end of training if reptile_context.do_final_evaluation: global_loss, global_acc = [], [] for cluster_id, participants in cluster_clients_dic.items( ): # Final evaluation on train and test clients # Test on all clients inside clusters reptile_train_step( aggregator=cluster_server_dic[cluster_id], participants=participants, inner_training_args=reptile_context. get_inner_training_args(eval=True), evaluation_mode=True) result = evaluate_local_models( participants=participants) loss = result.get('test/loss') acc = result.get('test/acc') print( f'Cluster {cluster_id} ({len(participants)} part.): loss = {loss}, acc = {acc}' ) loss_list = loss.tolist() acc_list = acc.tolist() global_loss.extend(loss_list if isinstance( loss_list, list) else [loss_list]) global_acc.extend(acc_list if isinstance( acc_list, list) else [acc_list]) # Log if after_round_evaluation is not None: for c in after_round_evaluation: c(experiment_logger, f'cluster_{cluster_id}', loss, acc, reptile_context.num_meta_steps) log_loss_and_acc('overall_mean', Tensor(global_loss), Tensor(global_acc), experiment_logger, 0)
def run_reptile(context: ReptileExperimentContext, dataset_train: FederatedDatasetData, dataset_test: FederatedDatasetData, initial_model_state, after_round_evaluation): RANDOM = random.Random(context.seed) # Randomly swap labels if context.swap_labels: dataset_train = swap_labels( fed_dataset=dataset_train, max_classes_per_client=62, random_seed=context.seed ) if dataset_test is not None: dataset_test = swap_labels( fed_dataset=dataset_test, max_classes_per_client=62, random_seed=context.seed ) # Set up clients train_clients = initialize_clients( dataset=dataset_train, model_args=context.inner_model_args, context=context.name, experiment_logger=context.experiment_logger ) if dataset_test is not None: test_clients = initialize_clients( dataset=dataset_test, model_args=context.inner_model_args, context=context.name, experiment_logger=context.experiment_logger ) else: test_clients = None # Set up server server = ReptileServer( participant_name='initial_server', model_args=context.meta_model_args, context=context.name, initial_model_state=initial_model_state ) # Perform training for i in range(context.num_meta_steps): if context.meta_batch_size == -1: meta_batch = train_clients else: meta_batch = [ train_clients[k] for k in cyclerange( start=i*context.meta_batch_size % len(train_clients), interval=context.meta_batch_size, total_len=len(train_clients) ) ] # Meta training step reptile_train_step( aggregator=server, participants=meta_batch, inner_training_args=context.get_inner_training_args(), meta_training_args=context.get_meta_training_args( frac_done=i / context.num_meta_steps ) ) # Evaluation on train and test clients if i % context.eval_interval == 0: # Pick train / test clients at random and test on them losses, accs = [], [] for client_set in [train_clients, test_clients]: if client_set: if context.num_eval_clients_training == -1: clients = client_set else: clients = RANDOM.sample( client_set, context.num_eval_clients_training ) reptile_train_step( aggregator=server, participants=clients, inner_training_args=context.get_inner_training_args(eval=True), evaluation_mode=True ) result = evaluate_local_models(participants=clients) losses.append(result.get('test/loss')) accs.append(result.get('test/acc')) else: losses.append(None) accs.append(None) # Log if after_round_evaluation is not None: for c in after_round_evaluation: c('', losses[0], accs[0], losses[1], accs[1], i) logger.info('finished training round') if context.do_final_evaluation: # Final evaluation on subsample of train / test clients losses, accs = [], [] for client_set in [train_clients, test_clients]: if client_set: if context.num_eval_clients_final == -1: eval_clients = client_set else: eval_clients = RANDOM.sample( client_set, context.num_eval_clients_final ) reptile_train_step( aggregator=server, participants=eval_clients, inner_training_args=context.get_inner_training_args(eval=True), evaluation_mode=True ) result = evaluate_local_models(participants=eval_clients) losses.append(result.get('test/loss')) accs.append(result.get('test/acc')) else: losses.append(None) accs.append(None) # Log if after_round_evaluation is not None: for c in after_round_evaluation: c('final_', losses[0], accs[0], losses[1], accs[1], 0)