Ejemplo n.º 1
0
def log_after_round_evaluation(
        experiment_logger,
        tag: str,
        loss_train_test: Tensor,
        acc_train_test: Tensor,
        loss_test_test: Tensor,
        acc_test_test: Tensor,
        step: int
    ):
    log_loss_and_acc(
        f'{tag}train-test',
        loss_train_test,
        acc_train_test,
        experiment_logger,
        step
    )
    log_goal_test_acc(f'{tag}train-test', acc_train_test, experiment_logger, step)
    if loss_test_test is not None and acc_test_test is not None:
        log_loss_and_acc(
            f'{tag}test-test',
            loss_test_test,
            acc_test_test,
            experiment_logger,
            step
        )
        log_goal_test_acc(f'{tag}test-test', acc_test_test, experiment_logger, step)
def log_after_round_evaluation(experiment_logger, tags, loss: Tensor,
                               acc: Tensor, step: int):
    if type(tags) is not list:
        tags = [tags]
    try:
        global_confusion_matrix = GlobalConfusionMatrix()
        if global_confusion_matrix.has_data:
            matrix = global_confusion_matrix.compute()
            for tag in tags:
                image = generate_confusion_matrix_heatmap(matrix, title=tag)
                experiment_logger.experiment.add_image(tag, image.numpy(),
                                                       step)
    except Exception as e:
        logger.error('failed to log confusion matrix', e)

    for tag in tags:
        log_loss_and_acc(tag, loss, acc, experiment_logger, step)
        log_goal_test_acc(tag, acc, experiment_logger, step)
def log_after_round_evaluation(experiment_logger, tag: str, loss: Tensor,
                               acc: Tensor, step: int):
    log_loss_and_acc(tag, loss, acc, experiment_logger, step)
    log_goal_test_acc(tag, acc, experiment_logger, step)
def run_hierarchical_clustering_reptile(
        seed,
        name,
        dataset,
        num_clients,
        batch_size,
        num_label_limit,
        use_colored_images,
        sample_threshold,
        hc_lr,
        hc_cluster_initialization_rounds,
        hc_client_fraction,
        hc_local_epochs,
        hc_train_args,
        hc_partitioner_class,
        hc_linkage_mech,
        hc_criterion,
        hc_dis_metric,
        hc_max_value_criterion,  # distance threshold
        hc_reallocate_clients,  #
        hc_threshold_min_client_cluster,  # only with hc_reallocate_clients = True,
        # results in clusters having at least this number of clients
    hc_train_cluster_args,
        rp_sgd,  # True -> Use SGD as inner optimizer; False -> Use Adam
        rp_adam_betas,  # Used only if sgd = False
        rp_meta_batch_size,
        rp_num_meta_steps,
        rp_meta_learning_rate_initial,
        rp_meta_learning_rate_final,
        rp_eval_interval,
        rp_inner_learning_rate,
        rp_num_inner_steps,
        rp_num_inner_steps_eval):
    fix_random_seeds(seed)
    global_tag = 'global_performance'

    if dataset == 'femnist':
        if use_colored_images:
            fed_dataset = load_femnist_colored_dataset(
                data_dir=str((REPO_ROOT / 'data').absolute()),
                num_clients=num_clients,
                batch_size=batch_size,
                sample_threshold=sample_threshold)
        else:
            fed_dataset = load_femnist_dataset(
                data_dir=str((REPO_ROOT / 'data').absolute()),
                num_clients=num_clients,
                batch_size=batch_size,
                sample_threshold=sample_threshold)
        if num_label_limit != -1:
            fed_dataset = scratch_labels(fed_dataset, num_label_limit)
    else:
        raise ValueError(f'dataset "{dataset}" unknown')

    if not hasattr(hc_max_value_criterion, '__iter__'):
        hc_max_value_criterion = [hc_max_value_criterion]
    if not hasattr(hc_lr, '__iter__'):
        hc_lr = [hc_lr]
    input_channels = 3 if use_colored_images else 1
    data_distribution_logged = False
    for cf in hc_client_fraction:
        for lr_i in hc_lr:
            # Initialize experiment context parameters
            fedavg_optimizer_args = OptimizerArgs(optim.SGD, lr=lr_i)
            fedavg_model_args = ModelArgs(CNNLightning,
                                          optimizer_args=fedavg_optimizer_args,
                                          input_channels=input_channels,
                                          only_digits=False)
            fedavg_context = FedAvgExperimentContext(
                name=name,
                client_fraction=cf,
                local_epochs=hc_local_epochs,
                lr=lr_i,
                batch_size=batch_size,
                optimizer_args=fedavg_optimizer_args,
                model_args=fedavg_model_args,
                train_args=hc_train_args,
                dataset_name=dataset)
            reptile_context = ReptileExperimentContext(
                name=name,
                dataset_name=dataset,
                swap_labels=False,
                num_classes_per_client=0,
                num_shots_per_class=0,
                seed=seed,
                model_class=CNNLightning,
                sgd=rp_sgd,
                adam_betas=rp_adam_betas,
                num_clients_train=num_clients,
                num_clients_test=0,
                meta_batch_size=rp_meta_batch_size,
                num_meta_steps=rp_num_meta_steps,
                meta_learning_rate_initial=rp_meta_learning_rate_initial,
                meta_learning_rate_final=rp_meta_learning_rate_final,
                eval_interval=rp_eval_interval,
                num_eval_clients_training=-1,
                do_final_evaluation=True,
                num_eval_clients_final=-1,
                inner_batch_size=batch_size,
                inner_learning_rate=rp_inner_learning_rate,
                num_inner_steps=rp_num_inner_steps,
                num_inner_steps_eval=rp_num_inner_steps_eval)
            experiment_specification = f'{fedavg_context}'
            experiment_logger = create_tensorboard_logger(
                name, experiment_specification)
            if not data_distribution_logged:
                log_dataset_distribution(experiment_logger, 'full dataset',
                                         fed_dataset)
                data_distribution_logged = True

            log_after_round_evaluation_fns = [
                partial(log_after_round_evaluation, experiment_logger,
                        'fedavg'),
                partial(log_after_round_evaluation, experiment_logger,
                        global_tag)
            ]
            server, clients = run_fedavg(
                context=fedavg_context,
                num_rounds=max(hc_cluster_initialization_rounds),
                dataset=fed_dataset,
                save_states=True,
                restore_state=True,
                after_round_evaluation=log_after_round_evaluation_fns)

            for init_rounds, max_value in generate_configuration(
                    hc_cluster_initialization_rounds, hc_max_value_criterion):
                # load the model state
                round_model_state = load_fedavg_state(fedavg_context,
                                                      init_rounds)
                overwrite_participants_models(round_model_state, clients)
                # initialize the cluster configuration
                round_configuration = {
                    'num_rounds_init': init_rounds,
                    'num_rounds_cluster': 0
                }
                cluster_args = ClusterArgs(
                    hc_partitioner_class,
                    linkage_mech=hc_linkage_mech,
                    criterion=hc_criterion,
                    dis_metric=hc_dis_metric,
                    max_value_criterion=max_value,
                    plot_dendrogram=False,
                    reallocate_clients=hc_reallocate_clients,
                    threshold_min_client_cluster=
                    hc_threshold_min_client_cluster,
                    **round_configuration)
                # create new logger for cluster experiment
                experiment_specification = f'{fedavg_context}_{cluster_args}_{reptile_context}'
                experiment_logger = create_tensorboard_logger(
                    name, experiment_specification)
                fedavg_context.experiment_logger = experiment_logger

                initial_train_fn = partial(run_fedavg_train_round,
                                           round_model_state,
                                           training_args=hc_train_cluster_args)
                create_aggregator_fn = partial(FedAvgServer,
                                               model_args=fedavg_model_args,
                                               context=fedavg_context)

                # HIERARCHICAL CLUSTERING
                logger.debug('starting local training before clustering.')
                trained_participants = initial_train_fn(clients)
                if len(trained_participants) != len(clients):
                    raise ValueError(
                        'not all clients successfully participated in the clustering round'
                    )

                # Clustering of participants by model updates
                partitioner = cluster_args()
                cluster_clients_dic = partitioner.cluster(clients, server)
                _cluster_clients_dic = dict()
                for cluster_id, participants in cluster_clients_dic.items():
                    _cluster_clients_dic[cluster_id] = [
                        c._name for c in participants
                    ]
                log_cluster_distribution(experiment_logger,
                                         cluster_clients_dic, 62)

                # Initialize cluster models
                cluster_server_dic = {}
                for cluster_id, participants in cluster_clients_dic.items():
                    intermediate_cluster_server = create_aggregator_fn(
                        'cluster_server' + cluster_id)
                    intermediate_cluster_server.aggregate(participants)
                    cluster_server = ReptileServer(
                        participant_name=f'cluster_server{cluster_id}',
                        model_args=reptile_context.meta_model_args,
                        context=reptile_context,
                        initial_model_state=intermediate_cluster_server.model.
                        state_dict())
                    #create_aggregator_fn('cluster_server' + cluster_id)
                    #cluster_server.aggregate(participants)
                    cluster_server_dic[cluster_id] = cluster_server

                # REPTILE TRAINING INSIDE CLUSTERS
                after_round_evaluation = [log_after_round_evaluation]
                RANDOM = random.Random(seed)

                # Perform training
                for i in range(reptile_context.num_meta_steps):
                    for cluster_id, participants in cluster_clients_dic.items(
                    ):

                        if reptile_context.meta_batch_size == -1:
                            meta_batch = participants
                        else:
                            meta_batch = [
                                participants[k] for k in cyclerange(
                                    start=i * reptile_context.meta_batch_size %
                                    len(participants),
                                    interval=reptile_context.meta_batch_size,
                                    total_len=len(participants))
                            ]
                        # Meta training step
                        reptile_train_step(
                            aggregator=cluster_server_dic[cluster_id],
                            participants=meta_batch,
                            inner_training_args=reptile_context.
                            get_inner_training_args(),
                            meta_training_args=reptile_context.
                            get_meta_training_args(
                                frac_done=i / reptile_context.num_meta_steps))

                    # Evaluation on train and test clients
                    if i % reptile_context.eval_interval == 0:
                        global_step = init_rounds + i
                        global_loss, global_acc = [], []

                        for cluster_id, participants in cluster_clients_dic.items(
                        ):
                            # Test on all clients inside clusters
                            reptile_train_step(
                                aggregator=cluster_server_dic[cluster_id],
                                participants=participants,
                                inner_training_args=reptile_context.
                                get_inner_training_args(eval=True),
                                evaluation_mode=True)
                            result = evaluate_local_models(
                                participants=participants)
                            loss = result.get('test/loss')
                            acc = result.get('test/acc')

                            # Log
                            if after_round_evaluation is not None:
                                for c in after_round_evaluation:
                                    c(experiment_logger,
                                      f'cluster_{cluster_id}', loss, acc,
                                      global_step)
                            loss_list = loss.tolist()
                            acc_list = acc.tolist()
                            global_loss.extend(loss_list if isinstance(
                                loss_list, list) else [loss_list])
                            global_acc.extend(acc_list if isinstance(
                                acc_list, list) else [acc_list])

                        if after_round_evaluation is not None:
                            for c in after_round_evaluation:
                                c(experiment_logger, 'mean_over_all_clients',
                                  Tensor(global_loss), Tensor(global_acc),
                                  global_step)

                    logger.info(f'Finished Reptile training round {i}')

                # Final evaluation at end of training
                if reptile_context.do_final_evaluation:
                    global_loss, global_acc = [], []

                    for cluster_id, participants in cluster_clients_dic.items(
                    ):
                        # Final evaluation on train and test clients
                        # Test on all clients inside clusters
                        reptile_train_step(
                            aggregator=cluster_server_dic[cluster_id],
                            participants=participants,
                            inner_training_args=reptile_context.
                            get_inner_training_args(eval=True),
                            evaluation_mode=True)
                        result = evaluate_local_models(
                            participants=participants)
                        loss = result.get('test/loss')
                        acc = result.get('test/acc')
                        print(
                            f'Cluster {cluster_id} ({len(participants)} part.): loss = {loss}, acc = {acc}'
                        )

                        loss_list = loss.tolist()
                        acc_list = acc.tolist()
                        global_loss.extend(loss_list if isinstance(
                            loss_list, list) else [loss_list])
                        global_acc.extend(acc_list if isinstance(
                            acc_list, list) else [acc_list])

                        # Log
                        if after_round_evaluation is not None:
                            for c in after_round_evaluation:
                                c(experiment_logger, f'cluster_{cluster_id}',
                                  loss, acc, reptile_context.num_meta_steps)

                    log_loss_and_acc('overall_mean', Tensor(global_loss),
                                     Tensor(global_acc), experiment_logger, 0)