コード例 #1
0
def run_reptile_experiment(
    name,
    dataset,
    swap_labels,
    classes,
    shots,
    seed,
    model_class,
    sgd,
    adam_betas,
    num_clients_train,
    num_clients_test,
    meta_batch_size,
    num_meta_steps,
    meta_learning_rate_initial,
    meta_learning_rate_final,
    eval_interval,
    num_eval_clients_training,
    do_final_evaluation,
    num_eval_clients_final,
    inner_batch_size,
    inner_learning_rate,
    num_inner_steps,
    num_inner_steps_eval,
    mean=None,
    std=None
):
    fix_random_seeds(seed)
    fed_dataset_test = None
    if dataset == 'femnist':
        fed_dataset_train = load_femnist_dataset(
            data_dir=str((REPO_ROOT / 'data').absolute()),
            num_clients=num_clients_train,
            batch_size=inner_batch_size,
            random_seed=seed
        )
    elif dataset == 'omniglot':
        fed_dataset_train, fed_dataset_test = load_omniglot_datasets(
            data_dir=str((REPO_ROOT / 'data' / 'omniglot').absolute()),
            num_clients_train=num_clients_train,
            num_clients_test=num_clients_test,
            num_classes_per_client=classes,
            num_shots_per_class=shots,
            inner_batch_size=inner_batch_size,
            random_seed=seed
        )
    elif dataset == 'ham10k':
        fed_dataset_train = load_ham10k_federated(partitions=num_clients_train, batch_size=inner_batch_size, mean=mean, std=std)
    else:
        raise ValueError(f'dataset "{dataset}" unknown')

    #data_distribution_logged = False
    for lr in inner_learning_rate:
        for _is in num_inner_steps:
            reptile_context = ReptileExperimentContext(
                name=name,
                dataset_name=dataset,
                swap_labels=swap_labels,
                num_classes_per_client=classes,
                num_shots_per_class=shots,
                seed=seed,
                model_class=model_class,
                sgd=sgd,
                adam_betas=adam_betas,
                num_clients_train=num_clients_train,
                num_clients_test=num_clients_test,
                meta_batch_size=meta_batch_size,
                num_meta_steps=num_meta_steps,
                meta_learning_rate_initial=meta_learning_rate_initial,
                meta_learning_rate_final=meta_learning_rate_final,
                eval_interval=eval_interval,
                num_eval_clients_training=num_eval_clients_training,
                do_final_evaluation=do_final_evaluation,
                num_eval_clients_final=num_eval_clients_final,
                inner_batch_size=inner_batch_size,
                inner_learning_rate=lr,
                num_inner_steps=_is,
                num_inner_steps_eval=_is
            )

            experiment_specification = f'{reptile_context}'
            experiment_logger = create_tensorboard_logger(
                reptile_context.name, experiment_specification
            )
            reptile_context.experiment_logger = experiment_logger

            log_after_round_evaluation_fns = [
                partial(log_after_round_evaluation, experiment_logger)
            ]
            run_reptile(
                context=reptile_context,
                dataset_train=fed_dataset_train,
                dataset_test=fed_dataset_test,
                initial_model_state=None,
                after_round_evaluation=log_after_round_evaluation_fns
            )
コード例 #2
0
def clustering_test(mean, std, seed, lr, local_epochs, client_fraction,
                    optimizer_args, total_fedavg_rounds, batch_size,
                    num_clients, model_args, train_args, train_cluster_args,
                    initialization_rounds, partitioner_class, linkage_mech,
                    criterion, dis_metric, max_value_criterion):
    fix_random_seeds(seed)

    fed_dataset = load_ham10k_federated(partitions=num_clients,
                                        batch_size=batch_size,
                                        mean=mean,
                                        std=std)
    initialize_clients_fn = initialize_ham10k_clients

    fedavg_context = FedAvgExperimentContext(name='ham10k_clustering',
                                             client_fraction=client_fraction,
                                             local_epochs=local_epochs,
                                             lr=lr,
                                             batch_size=batch_size,
                                             optimizer_args=optimizer_args,
                                             model_args=model_args,
                                             train_args=train_args,
                                             dataset_name='ham10k')
    experiment_specification = f'{fedavg_context}'
    experiment_logger = create_tensorboard_logger(fedavg_context.name,
                                                  experiment_specification)

    log_dataset_distribution(experiment_logger, 'full dataset', fed_dataset)

    server, clients = run_fedavg(context=fedavg_context,
                                 num_rounds=total_fedavg_rounds,
                                 dataset=fed_dataset,
                                 save_states=True,
                                 restore_state=True,
                                 evaluate_rounds=False,
                                 initialize_clients_fn=initialize_clients_fn)

    for init_rounds in initialization_rounds:
        # load the model state
        round_model_state = load_fedavg_state(fedavg_context, init_rounds)
        overwrite_participants_models(round_model_state, clients)
        run_fedavg_train_round(round_model_state,
                               training_args=train_cluster_args,
                               participants=clients)
        for max_value in max_value_criterion:
            # initialize the cluster configuration
            round_configuration = {
                'num_rounds_init': init_rounds,
                'num_rounds_cluster': total_fedavg_rounds - init_rounds
            }
            cluster_args = ClusterArgs(partitioner_class,
                                       linkage_mech=linkage_mech,
                                       criterion=criterion,
                                       dis_metric=dis_metric,
                                       max_value_criterion=max_value,
                                       plot_dendrogram=False,
                                       reallocate_clients=False,
                                       threshold_min_client_cluster=-1,
                                       **round_configuration)
            experiment_logger = create_tensorboard_logger(
                fedavg_context.name,
                f'{experiment_specification}{cluster_args}')
            partitioner = cluster_args()
            cluster_clients_dic = partitioner.cluster(clients, server)
            log_cluster_distribution(experiment_logger, cluster_clients_dic, 7)
def run_hierarchical_clustering(local_evaluation_steps,
                                seed,
                                lr,
                                name,
                                total_fedavg_rounds,
                                cluster_initialization_rounds,
                                client_fraction,
                                local_epochs,
                                batch_size,
                                num_clients,
                                sample_threshold,
                                num_label_limit,
                                train_args,
                                dataset,
                                partitioner_class,
                                linkage_mech,
                                criterion,
                                dis_metric,
                                max_value_criterion,
                                reallocate_clients,
                                threshold_min_client_cluster,
                                use_colored_images,
                                use_pattern,
                                train_cluster_args=None,
                                mean=None,
                                std=None):
    fix_random_seeds(seed)
    global_tag = 'global_performance'
    global_tag_local = 'global_performance_personalized'
    initialize_clients_fn = DEFAULT_CLIENT_INIT_FN
    if dataset == 'ham10k':
        fed_dataset = load_ham10k_federated(partitions=num_clients,
                                            batch_size=batch_size,
                                            mean=mean,
                                            std=std)
        initialize_clients_fn = initialize_ham10k_clients
    else:
        raise ValueError(f'dataset "{dataset}" unknown')

    if not hasattr(max_value_criterion, '__iter__'):
        max_value_criterion = [max_value_criterion]
    if not hasattr(lr, '__iter__'):
        lr = [lr]

    for cf in client_fraction:
        for lr_i in lr:
            optimizer_args = OptimizerArgs(optim.SGD, lr=lr_i)
            model_args = ModelArgs(MobileNetV2Lightning,
                                   optimizer_args=optimizer_args,
                                   num_classes=7)
            fedavg_context = FedAvgExperimentContext(
                name=name,
                client_fraction=cf,
                local_epochs=local_epochs,
                lr=lr_i,
                batch_size=batch_size,
                optimizer_args=optimizer_args,
                model_args=model_args,
                train_args=train_args,
                dataset_name=dataset)
            experiment_specification = f'{fedavg_context}'
            experiment_logger = create_tensorboard_logger(
                fedavg_context.name, experiment_specification)
            fedavg_context.experiment_logger = experiment_logger
            for init_rounds, max_value in generate_configuration(
                    cluster_initialization_rounds, max_value_criterion):
                # load the model state
                round_model_state = load_fedavg_state(fedavg_context,
                                                      init_rounds)

                server = FedAvgServer('initial_server',
                                      fedavg_context.model_args,
                                      fedavg_context)
                server.overwrite_model_state(round_model_state)
                logger.info('initializing clients ...')
                clients = initialize_clients_fn(fedavg_context, fed_dataset,
                                                server.model.state_dict())

                overwrite_participants_models(round_model_state, clients)
                # initialize the cluster configuration
                round_configuration = {
                    'num_rounds_init': init_rounds,
                    'num_rounds_cluster': total_fedavg_rounds - init_rounds
                }
                if partitioner_class == DatadependentPartitioner:
                    clustering_dataset = load_femnist_colored_dataset(
                        str((REPO_ROOT / 'data').absolute()),
                        num_clients=num_clients,
                        batch_size=batch_size,
                        sample_threshold=sample_threshold)
                    dataloader = load_n_of_each_class(
                        clustering_dataset,
                        n=5,
                        tabu=list(fed_dataset.train_data_local_dict.keys()))
                    cluster_args = ClusterArgs(
                        partitioner_class,
                        linkage_mech=linkage_mech,
                        criterion=criterion,
                        dis_metric=dis_metric,
                        max_value_criterion=max_value,
                        plot_dendrogram=False,
                        reallocate_clients=reallocate_clients,
                        threshold_min_client_cluster=
                        threshold_min_client_cluster,
                        dataloader=dataloader,
                        **round_configuration)
                else:
                    cluster_args = ClusterArgs(
                        partitioner_class,
                        linkage_mech=linkage_mech,
                        criterion=criterion,
                        dis_metric=dis_metric,
                        max_value_criterion=max_value,
                        plot_dendrogram=False,
                        reallocate_clients=reallocate_clients,
                        threshold_min_client_cluster=
                        threshold_min_client_cluster,
                        **round_configuration)
                # create new logger for cluster experiment
                experiment_specification = f'{fedavg_context}_{cluster_args}'
                experiment_logger = create_tensorboard_logger(
                    fedavg_context.name, experiment_specification)
                fedavg_context.experiment_logger = experiment_logger

                initial_train_fn = partial(run_fedavg_train_round,
                                           round_model_state,
                                           training_args=train_cluster_args)
                create_aggregator_fn = partial(FedAvgServer,
                                               model_args=model_args,
                                               context=fedavg_context)
                federated_round_fn = partial(run_fedavg_round,
                                             training_args=train_args,
                                             client_fraction=cf)

                after_post_clustering_evaluation = [
                    partial(log_after_round_evaluation, experiment_logger,
                            'post_clustering')
                ]
                after_clustering_round_evaluation = [
                    partial(log_after_round_evaluation, experiment_logger)
                ]
                after_federated_round_evaluation = [
                    partial(log_after_round_evaluation, experiment_logger,
                            ['final hierarchical', global_tag])
                ]
                after_clustering_fn = [
                    partial(log_cluster_distribution,
                            experiment_logger,
                            num_classes=fed_dataset.class_num),
                    partial(log_sample_images_from_each_client,
                            experiment_logger)
                ]
                after_federated_round_fn = [
                    partial(
                        log_personalized_global_cluster_performance,
                        experiment_logger,
                        ['final hierarchical personalized', global_tag_local],
                        local_evaluation_steps)
                ]
                run_fedavg_hierarchical(
                    server,
                    clients,
                    cluster_args,
                    initial_train_fn,
                    federated_round_fn,
                    create_aggregator_fn,
                    after_post_clustering_evaluation,
                    after_clustering_round_evaluation,
                    after_federated_round_evaluation,
                    after_clustering_fn,
                    after_federated_round=after_federated_round_fn)
def run_hierarchical_clustering_reptile(
        seed,
        name,
        dataset,
        num_clients,
        batch_size,
        num_label_limit,
        use_colored_images,
        sample_threshold,
        hc_lr,
        hc_cluster_initialization_rounds,
        hc_client_fraction,
        hc_local_epochs,
        hc_train_args,
        hc_partitioner_class,
        hc_linkage_mech,
        hc_criterion,
        hc_dis_metric,
        hc_max_value_criterion,  # distance threshold
        hc_reallocate_clients,  #
        hc_threshold_min_client_cluster,  # only with hc_reallocate_clients = True,
        # results in clusters having at least this number of clients
    hc_train_cluster_args,
        rp_sgd,  # True -> Use SGD as inner optimizer; False -> Use Adam
        rp_adam_betas,  # Used only if sgd = False
        rp_meta_batch_size,
        rp_num_meta_steps,
        rp_meta_learning_rate_initial,
        rp_meta_learning_rate_final,
        rp_eval_interval,
        rp_inner_learning_rate,
        rp_num_inner_steps,
        rp_num_inner_steps_eval,
        mean=None,
        std=None):
    fix_random_seeds(seed)
    global_tag = 'global_performance'
    initialize_clients_fn = DEFAULT_CLIENT_INIT_FN

    if dataset == 'femnist':
        if use_colored_images:
            fed_dataset = load_femnist_colored_dataset(
                data_dir=str((REPO_ROOT / 'data').absolute()),
                num_clients=num_clients,
                batch_size=batch_size,
                sample_threshold=sample_threshold)
        else:
            fed_dataset = load_femnist_dataset(
                data_dir=str((REPO_ROOT / 'data').absolute()),
                num_clients=num_clients,
                batch_size=batch_size,
                sample_threshold=sample_threshold)
        if num_label_limit != -1:
            fed_dataset = scratch_labels(fed_dataset, num_label_limit)
    elif dataset == 'ham10k':
        fed_dataset = load_ham10k_federated(partitions=num_clients,
                                            batch_size=batch_size,
                                            mean=mean,
                                            std=std)
        initialize_clients_fn = initialize_ham10k_clients
    else:
        raise ValueError(f'dataset "{dataset}" unknown')

    if not hasattr(hc_max_value_criterion, '__iter__'):
        hc_max_value_criterion = [hc_max_value_criterion]
    if not hasattr(hc_lr, '__iter__'):
        hc_lr = [hc_lr]
    input_channels = 3 if use_colored_images else 1
    data_distribution_logged = False
    for cf in hc_client_fraction:
        for lr_i in hc_lr:
            # Initialize experiment context parameters
            fedavg_optimizer_args = OptimizerArgs(optim.SGD, lr=lr_i)
            fedavg_model_args = ModelArgs(CNNLightning,
                                          optimizer_args=fedavg_optimizer_args,
                                          input_channels=input_channels,
                                          only_digits=False)
            fedavg_context = FedAvgExperimentContext(
                name=name,
                client_fraction=cf,
                local_epochs=hc_local_epochs,
                lr=lr_i,
                batch_size=batch_size,
                optimizer_args=fedavg_optimizer_args,
                model_args=fedavg_model_args,
                train_args=hc_train_args,
                dataset_name=dataset)
            reptile_context = ReptileExperimentContext(
                name=name,
                dataset_name=dataset,
                swap_labels=False,
                num_classes_per_client=0,
                num_shots_per_class=0,
                seed=seed,
                model_class=CNNLightning,
                sgd=rp_sgd,
                adam_betas=rp_adam_betas,
                num_clients_train=num_clients,
                num_clients_test=0,
                meta_batch_size=rp_meta_batch_size,
                num_meta_steps=rp_num_meta_steps,
                meta_learning_rate_initial=rp_meta_learning_rate_initial,
                meta_learning_rate_final=rp_meta_learning_rate_final,
                eval_interval=rp_eval_interval,
                num_eval_clients_training=-1,
                do_final_evaluation=True,
                num_eval_clients_final=-1,
                inner_batch_size=batch_size,
                inner_learning_rate=rp_inner_learning_rate,
                num_inner_steps=rp_num_inner_steps,
                num_inner_steps_eval=rp_num_inner_steps_eval)
            experiment_specification = f'{fedavg_context}'
            experiment_logger = create_tensorboard_logger(
                name, experiment_specification)
            if not data_distribution_logged:
                log_dataset_distribution(experiment_logger, 'full dataset',
                                         fed_dataset)
                data_distribution_logged = True

            log_after_round_evaluation_fns = [
                partial(log_after_round_evaluation, experiment_logger,
                        'fedavg'),
                partial(log_after_round_evaluation, experiment_logger,
                        global_tag)
            ]
            server, clients = run_fedavg(
                context=fedavg_context,
                num_rounds=max(hc_cluster_initialization_rounds),
                dataset=fed_dataset,
                save_states=True,
                restore_state=True,
                after_round_evaluation=log_after_round_evaluation_fns,
                initialize_clients_fn=initialize_clients_fn)

            for init_rounds, max_value in generate_configuration(
                    hc_cluster_initialization_rounds, hc_max_value_criterion):
                # load the model state
                round_model_state = load_fedavg_state(fedavg_context,
                                                      init_rounds)
                overwrite_participants_models(round_model_state, clients)
                # initialize the cluster configuration
                round_configuration = {
                    'num_rounds_init': init_rounds,
                    'num_rounds_cluster': 0
                }
                cluster_args = ClusterArgs(
                    hc_partitioner_class,
                    linkage_mech=hc_linkage_mech,
                    criterion=hc_criterion,
                    dis_metric=hc_dis_metric,
                    max_value_criterion=max_value,
                    plot_dendrogram=False,
                    reallocate_clients=hc_reallocate_clients,
                    threshold_min_client_cluster=
                    hc_threshold_min_client_cluster,
                    **round_configuration)
                # create new logger for cluster experiment
                experiment_specification = f'{fedavg_context}_{cluster_args}_{reptile_context}'
                experiment_logger = create_tensorboard_logger(
                    name, experiment_specification)
                fedavg_context.experiment_logger = experiment_logger

                initial_train_fn = partial(run_fedavg_train_round,
                                           round_model_state,
                                           training_args=hc_train_cluster_args)
                create_aggregator_fn = partial(FedAvgServer,
                                               model_args=fedavg_model_args,
                                               context=fedavg_context)

                # HIERARCHICAL CLUSTERING
                logger.debug('starting local training before clustering.')
                trained_participants = initial_train_fn(clients)
                if len(trained_participants) != len(clients):
                    raise ValueError(
                        'not all clients successfully participated in the clustering round'
                    )

                # Clustering of participants by model updates
                partitioner = cluster_args()
                cluster_clients_dic = partitioner.cluster(clients, server)
                _cluster_clients_dic = dict()
                for cluster_id, participants in cluster_clients_dic.items():
                    _cluster_clients_dic[cluster_id] = [
                        c._name for c in participants
                    ]
                log_cluster_distribution(experiment_logger,
                                         cluster_clients_dic, 62)

                # Initialize cluster models
                cluster_server_dic = {}
                for cluster_id, participants in cluster_clients_dic.items():
                    intermediate_cluster_server = create_aggregator_fn(
                        'cluster_server' + cluster_id)
                    intermediate_cluster_server.aggregate(participants)
                    cluster_server = ReptileServer(
                        participant_name=f'cluster_server{cluster_id}',
                        model_args=reptile_context.meta_model_args,
                        context=reptile_context,
                        initial_model_state=intermediate_cluster_server.model.
                        state_dict())
                    #create_aggregator_fn('cluster_server' + cluster_id)
                    #cluster_server.aggregate(participants)
                    cluster_server_dic[cluster_id] = cluster_server

                # REPTILE TRAINING INSIDE CLUSTERS
                after_round_evaluation = [log_after_round_evaluation]
                RANDOM = random.Random(seed)

                # Perform training
                for i in range(reptile_context.num_meta_steps):
                    for cluster_id, participants in cluster_clients_dic.items(
                    ):

                        if reptile_context.meta_batch_size == -1:
                            meta_batch = participants
                        else:
                            meta_batch = [
                                participants[k] for k in cyclerange(
                                    start=i * reptile_context.meta_batch_size %
                                    len(participants),
                                    interval=reptile_context.meta_batch_size,
                                    total_len=len(participants))
                            ]
                        # Meta training step
                        reptile_train_step(
                            aggregator=cluster_server_dic[cluster_id],
                            participants=meta_batch,
                            inner_training_args=reptile_context.
                            get_inner_training_args(),
                            meta_training_args=reptile_context.
                            get_meta_training_args(
                                frac_done=i / reptile_context.num_meta_steps))

                    # Evaluation on train and test clients
                    if i % reptile_context.eval_interval == 0:
                        global_step = init_rounds + i
                        global_loss, global_acc = [], []

                        for cluster_id, participants in cluster_clients_dic.items(
                        ):
                            # Test on all clients inside clusters
                            reptile_train_step(
                                aggregator=cluster_server_dic[cluster_id],
                                participants=participants,
                                inner_training_args=reptile_context.
                                get_inner_training_args(eval=True),
                                evaluation_mode=True)
                            result = evaluate_local_models(
                                participants=participants)
                            loss = result.get('test/loss')
                            acc = result.get('test/acc')

                            # Log
                            if after_round_evaluation is not None:
                                for c in after_round_evaluation:
                                    c(experiment_logger,
                                      f'cluster_{cluster_id}', loss, acc,
                                      global_step)
                            loss_list = loss.tolist()
                            acc_list = acc.tolist()
                            global_loss.extend(loss_list if isinstance(
                                loss_list, list) else [loss_list])
                            global_acc.extend(acc_list if isinstance(
                                acc_list, list) else [acc_list])

                        if after_round_evaluation is not None:
                            for c in after_round_evaluation:
                                c(experiment_logger, 'mean_over_all_clients',
                                  Tensor(global_loss), Tensor(global_acc),
                                  global_step)

                    logger.info(f'Finished Reptile training round {i}')

                # Final evaluation at end of training
                if reptile_context.do_final_evaluation:
                    global_loss, global_acc = [], []

                    for cluster_id, participants in cluster_clients_dic.items(
                    ):
                        # Final evaluation on train and test clients
                        # Test on all clients inside clusters
                        reptile_train_step(
                            aggregator=cluster_server_dic[cluster_id],
                            participants=participants,
                            inner_training_args=reptile_context.
                            get_inner_training_args(eval=True),
                            evaluation_mode=True)
                        result = evaluate_local_models(
                            participants=participants)
                        loss = result.get('test/loss')
                        acc = result.get('test/acc')
                        print(
                            f'Cluster {cluster_id} ({len(participants)} part.): loss = {loss}, acc = {acc}'
                        )

                        loss_list = loss.tolist()
                        acc_list = acc.tolist()
                        global_loss.extend(loss_list if isinstance(
                            loss_list, list) else [loss_list])
                        global_acc.extend(acc_list if isinstance(
                            acc_list, list) else [acc_list])

                        # Log
                        if after_round_evaluation is not None:
                            for c in after_round_evaluation:
                                c(experiment_logger, f'cluster_{cluster_id}',
                                  loss, acc, reptile_context.num_meta_steps)

                    log_loss_and_acc('overall_mean', Tensor(global_loss),
                                     Tensor(global_acc), experiment_logger, 0)