예제 #1
0
def DefaultConfig():
    seed = 123123123
    lr = 0.01
    name = 'clustering_test'
    total_fedavg_rounds = 50
    client_fraction = 0.1
    local_epochs = 1
    batch_size = 16
    num_clients = 27
    num_classes = 7
    train_args = TrainArgs(max_epochs=local_epochs,
                           min_epochs=local_epochs,
                           progress_bar_refresh_rate=0)
    train_cluster_args = TrainArgs(max_epochs=3,
                                   min_epochs=3,
                                   progress_bar_refresh_rate=0)
    dataset = 'ham10k'
    partitioner_class = AlternativePartitioner
    optimizer_args = OptimizerArgs(optim.SGD, lr=lr)
    model_args = ModelArgs(MobileNetV2Lightning,
                           optimizer_args=optimizer_args,
                           num_classes=num_classes)
    initialization_rounds = [25, 50]
    linkage_mech = 'ward'
    criterion = 'distance'
    dis_metric = 'euclidean'
    max_value_criterion = [100, 200, 300]
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
예제 #2
0
def create_mnist_experiment_context(
        name: str,
        local_epochs: int,
        batch_size: int,
        lr: float,
        client_fraction: float,
        dataset_name: str,
        num_classes: int,
        fixed_logger_version=None,
        no_progress_bar=False,
        cluster_args: Optional[ClusterArgs] = None):
    logger.debug('creating experiment context ...')
    optimizer_args = OptimizerArgs(optim.SGD, lr=lr)
    model_args = ModelArgs(CNNMnistLightning,
                           num_classes=num_classes,
                           optimizer_args=optimizer_args)
    train_args_dict = {'max_epochs': local_epochs, 'min_epochs': local_epochs}
    if no_progress_bar:
        train_args_dict['progress_bar_refresh_rate'] = 0
    training_args = TrainArgs(**train_args_dict)
    context = FedAvgExperimentContext(name=name,
                                      client_fraction=client_fraction,
                                      local_epochs=local_epochs,
                                      lr=lr,
                                      batch_size=batch_size,
                                      optimizer_args=optimizer_args,
                                      model_args=model_args,
                                      train_args=training_args,
                                      dataset_name=dataset_name)
    experiment_specification = f'{context}'
    experiment_specification += f'_{optimizer_args}'
    if cluster_args is not None:
        context.cluster_args = cluster_args
        experiment_specification += f'_{cluster_args}'
    experiment_logger = create_tensorboard_logger(context.name,
                                                  experiment_specification,
                                                  fixed_logger_version)
    context.experiment_logger = experiment_logger
    return context
def run_hierarchical_clustering(local_evaluation_steps,
                                seed,
                                lr,
                                name,
                                total_fedavg_rounds,
                                cluster_initialization_rounds,
                                client_fraction,
                                local_epochs,
                                batch_size,
                                num_clients,
                                sample_threshold,
                                num_label_limit,
                                train_args,
                                dataset,
                                partitioner_class,
                                linkage_mech,
                                criterion,
                                dis_metric,
                                max_value_criterion,
                                reallocate_clients,
                                threshold_min_client_cluster,
                                use_colored_images,
                                use_pattern,
                                train_cluster_args=None,
                                mean=None,
                                std=None):
    fix_random_seeds(seed)
    global_tag = 'global_performance'
    global_tag_local = 'global_performance_personalized'
    initialize_clients_fn = DEFAULT_CLIENT_INIT_FN
    if dataset == 'ham10k':
        fed_dataset = load_ham10k_federated(partitions=num_clients,
                                            batch_size=batch_size,
                                            mean=mean,
                                            std=std)
        initialize_clients_fn = initialize_ham10k_clients
    else:
        raise ValueError(f'dataset "{dataset}" unknown')

    if not hasattr(max_value_criterion, '__iter__'):
        max_value_criterion = [max_value_criterion]
    if not hasattr(lr, '__iter__'):
        lr = [lr]

    for cf in client_fraction:
        for lr_i in lr:
            optimizer_args = OptimizerArgs(optim.SGD, lr=lr_i)
            model_args = ModelArgs(MobileNetV2Lightning,
                                   optimizer_args=optimizer_args,
                                   num_classes=7)
            fedavg_context = FedAvgExperimentContext(
                name=name,
                client_fraction=cf,
                local_epochs=local_epochs,
                lr=lr_i,
                batch_size=batch_size,
                optimizer_args=optimizer_args,
                model_args=model_args,
                train_args=train_args,
                dataset_name=dataset)
            experiment_specification = f'{fedavg_context}'
            experiment_logger = create_tensorboard_logger(
                fedavg_context.name, experiment_specification)
            fedavg_context.experiment_logger = experiment_logger
            for init_rounds, max_value in generate_configuration(
                    cluster_initialization_rounds, max_value_criterion):
                # load the model state
                round_model_state = load_fedavg_state(fedavg_context,
                                                      init_rounds)

                server = FedAvgServer('initial_server',
                                      fedavg_context.model_args,
                                      fedavg_context)
                server.overwrite_model_state(round_model_state)
                logger.info('initializing clients ...')
                clients = initialize_clients_fn(fedavg_context, fed_dataset,
                                                server.model.state_dict())

                overwrite_participants_models(round_model_state, clients)
                # initialize the cluster configuration
                round_configuration = {
                    'num_rounds_init': init_rounds,
                    'num_rounds_cluster': total_fedavg_rounds - init_rounds
                }
                if partitioner_class == DatadependentPartitioner:
                    clustering_dataset = load_femnist_colored_dataset(
                        str((REPO_ROOT / 'data').absolute()),
                        num_clients=num_clients,
                        batch_size=batch_size,
                        sample_threshold=sample_threshold)
                    dataloader = load_n_of_each_class(
                        clustering_dataset,
                        n=5,
                        tabu=list(fed_dataset.train_data_local_dict.keys()))
                    cluster_args = ClusterArgs(
                        partitioner_class,
                        linkage_mech=linkage_mech,
                        criterion=criterion,
                        dis_metric=dis_metric,
                        max_value_criterion=max_value,
                        plot_dendrogram=False,
                        reallocate_clients=reallocate_clients,
                        threshold_min_client_cluster=
                        threshold_min_client_cluster,
                        dataloader=dataloader,
                        **round_configuration)
                else:
                    cluster_args = ClusterArgs(
                        partitioner_class,
                        linkage_mech=linkage_mech,
                        criterion=criterion,
                        dis_metric=dis_metric,
                        max_value_criterion=max_value,
                        plot_dendrogram=False,
                        reallocate_clients=reallocate_clients,
                        threshold_min_client_cluster=
                        threshold_min_client_cluster,
                        **round_configuration)
                # create new logger for cluster experiment
                experiment_specification = f'{fedavg_context}_{cluster_args}'
                experiment_logger = create_tensorboard_logger(
                    fedavg_context.name, experiment_specification)
                fedavg_context.experiment_logger = experiment_logger

                initial_train_fn = partial(run_fedavg_train_round,
                                           round_model_state,
                                           training_args=train_cluster_args)
                create_aggregator_fn = partial(FedAvgServer,
                                               model_args=model_args,
                                               context=fedavg_context)
                federated_round_fn = partial(run_fedavg_round,
                                             training_args=train_args,
                                             client_fraction=cf)

                after_post_clustering_evaluation = [
                    partial(log_after_round_evaluation, experiment_logger,
                            'post_clustering')
                ]
                after_clustering_round_evaluation = [
                    partial(log_after_round_evaluation, experiment_logger)
                ]
                after_federated_round_evaluation = [
                    partial(log_after_round_evaluation, experiment_logger,
                            ['final hierarchical', global_tag])
                ]
                after_clustering_fn = [
                    partial(log_cluster_distribution,
                            experiment_logger,
                            num_classes=fed_dataset.class_num),
                    partial(log_sample_images_from_each_client,
                            experiment_logger)
                ]
                after_federated_round_fn = [
                    partial(
                        log_personalized_global_cluster_performance,
                        experiment_logger,
                        ['final hierarchical personalized', global_tag_local],
                        local_evaluation_steps)
                ]
                run_fedavg_hierarchical(
                    server,
                    clients,
                    cluster_args,
                    initial_train_fn,
                    federated_round_fn,
                    create_aggregator_fn,
                    after_post_clustering_evaluation,
                    after_clustering_round_evaluation,
                    after_federated_round_evaluation,
                    after_clustering_fn,
                    after_federated_round=after_federated_round_fn)
def run_hierarchical_clustering_reptile(
        seed,
        name,
        dataset,
        num_clients,
        batch_size,
        num_label_limit,
        use_colored_images,
        sample_threshold,
        hc_lr,
        hc_cluster_initialization_rounds,
        hc_client_fraction,
        hc_local_epochs,
        hc_train_args,
        hc_partitioner_class,
        hc_linkage_mech,
        hc_criterion,
        hc_dis_metric,
        hc_max_value_criterion,  # distance threshold
        hc_reallocate_clients,  #
        hc_threshold_min_client_cluster,  # only with hc_reallocate_clients = True,
        # results in clusters having at least this number of clients
    hc_train_cluster_args,
        rp_sgd,  # True -> Use SGD as inner optimizer; False -> Use Adam
        rp_adam_betas,  # Used only if sgd = False
        rp_meta_batch_size,
        rp_num_meta_steps,
        rp_meta_learning_rate_initial,
        rp_meta_learning_rate_final,
        rp_eval_interval,
        rp_inner_learning_rate,
        rp_num_inner_steps,
        rp_num_inner_steps_eval):
    fix_random_seeds(seed)
    global_tag = 'global_performance'

    if dataset == 'femnist':
        if use_colored_images:
            fed_dataset = load_femnist_colored_dataset(
                data_dir=str((REPO_ROOT / 'data').absolute()),
                num_clients=num_clients,
                batch_size=batch_size,
                sample_threshold=sample_threshold)
        else:
            fed_dataset = load_femnist_dataset(
                data_dir=str((REPO_ROOT / 'data').absolute()),
                num_clients=num_clients,
                batch_size=batch_size,
                sample_threshold=sample_threshold)
        if num_label_limit != -1:
            fed_dataset = scratch_labels(fed_dataset, num_label_limit)
    else:
        raise ValueError(f'dataset "{dataset}" unknown')

    if not hasattr(hc_max_value_criterion, '__iter__'):
        hc_max_value_criterion = [hc_max_value_criterion]
    if not hasattr(hc_lr, '__iter__'):
        hc_lr = [hc_lr]
    input_channels = 3 if use_colored_images else 1
    data_distribution_logged = False
    for cf in hc_client_fraction:
        for lr_i in hc_lr:
            # Initialize experiment context parameters
            fedavg_optimizer_args = OptimizerArgs(optim.SGD, lr=lr_i)
            fedavg_model_args = ModelArgs(CNNLightning,
                                          optimizer_args=fedavg_optimizer_args,
                                          input_channels=input_channels,
                                          only_digits=False)
            fedavg_context = FedAvgExperimentContext(
                name=name,
                client_fraction=cf,
                local_epochs=hc_local_epochs,
                lr=lr_i,
                batch_size=batch_size,
                optimizer_args=fedavg_optimizer_args,
                model_args=fedavg_model_args,
                train_args=hc_train_args,
                dataset_name=dataset)
            reptile_context = ReptileExperimentContext(
                name=name,
                dataset_name=dataset,
                swap_labels=False,
                num_classes_per_client=0,
                num_shots_per_class=0,
                seed=seed,
                model_class=CNNLightning,
                sgd=rp_sgd,
                adam_betas=rp_adam_betas,
                num_clients_train=num_clients,
                num_clients_test=0,
                meta_batch_size=rp_meta_batch_size,
                num_meta_steps=rp_num_meta_steps,
                meta_learning_rate_initial=rp_meta_learning_rate_initial,
                meta_learning_rate_final=rp_meta_learning_rate_final,
                eval_interval=rp_eval_interval,
                num_eval_clients_training=-1,
                do_final_evaluation=True,
                num_eval_clients_final=-1,
                inner_batch_size=batch_size,
                inner_learning_rate=rp_inner_learning_rate,
                num_inner_steps=rp_num_inner_steps,
                num_inner_steps_eval=rp_num_inner_steps_eval)
            experiment_specification = f'{fedavg_context}'
            experiment_logger = create_tensorboard_logger(
                name, experiment_specification)
            if not data_distribution_logged:
                log_dataset_distribution(experiment_logger, 'full dataset',
                                         fed_dataset)
                data_distribution_logged = True

            log_after_round_evaluation_fns = [
                partial(log_after_round_evaluation, experiment_logger,
                        'fedavg'),
                partial(log_after_round_evaluation, experiment_logger,
                        global_tag)
            ]
            server, clients = run_fedavg(
                context=fedavg_context,
                num_rounds=max(hc_cluster_initialization_rounds),
                dataset=fed_dataset,
                save_states=True,
                restore_state=True,
                after_round_evaluation=log_after_round_evaluation_fns)

            for init_rounds, max_value in generate_configuration(
                    hc_cluster_initialization_rounds, hc_max_value_criterion):
                # load the model state
                round_model_state = load_fedavg_state(fedavg_context,
                                                      init_rounds)
                overwrite_participants_models(round_model_state, clients)
                # initialize the cluster configuration
                round_configuration = {
                    'num_rounds_init': init_rounds,
                    'num_rounds_cluster': 0
                }
                cluster_args = ClusterArgs(
                    hc_partitioner_class,
                    linkage_mech=hc_linkage_mech,
                    criterion=hc_criterion,
                    dis_metric=hc_dis_metric,
                    max_value_criterion=max_value,
                    plot_dendrogram=False,
                    reallocate_clients=hc_reallocate_clients,
                    threshold_min_client_cluster=
                    hc_threshold_min_client_cluster,
                    **round_configuration)
                # create new logger for cluster experiment
                experiment_specification = f'{fedavg_context}_{cluster_args}_{reptile_context}'
                experiment_logger = create_tensorboard_logger(
                    name, experiment_specification)
                fedavg_context.experiment_logger = experiment_logger

                initial_train_fn = partial(run_fedavg_train_round,
                                           round_model_state,
                                           training_args=hc_train_cluster_args)
                create_aggregator_fn = partial(FedAvgServer,
                                               model_args=fedavg_model_args,
                                               context=fedavg_context)

                # HIERARCHICAL CLUSTERING
                logger.debug('starting local training before clustering.')
                trained_participants = initial_train_fn(clients)
                if len(trained_participants) != len(clients):
                    raise ValueError(
                        'not all clients successfully participated in the clustering round'
                    )

                # Clustering of participants by model updates
                partitioner = cluster_args()
                cluster_clients_dic = partitioner.cluster(clients, server)
                _cluster_clients_dic = dict()
                for cluster_id, participants in cluster_clients_dic.items():
                    _cluster_clients_dic[cluster_id] = [
                        c._name for c in participants
                    ]
                log_cluster_distribution(experiment_logger,
                                         cluster_clients_dic, 62)

                # Initialize cluster models
                cluster_server_dic = {}
                for cluster_id, participants in cluster_clients_dic.items():
                    intermediate_cluster_server = create_aggregator_fn(
                        'cluster_server' + cluster_id)
                    intermediate_cluster_server.aggregate(participants)
                    cluster_server = ReptileServer(
                        participant_name=f'cluster_server{cluster_id}',
                        model_args=reptile_context.meta_model_args,
                        context=reptile_context,
                        initial_model_state=intermediate_cluster_server.model.
                        state_dict())
                    #create_aggregator_fn('cluster_server' + cluster_id)
                    #cluster_server.aggregate(participants)
                    cluster_server_dic[cluster_id] = cluster_server

                # REPTILE TRAINING INSIDE CLUSTERS
                after_round_evaluation = [log_after_round_evaluation]
                RANDOM = random.Random(seed)

                # Perform training
                for i in range(reptile_context.num_meta_steps):
                    for cluster_id, participants in cluster_clients_dic.items(
                    ):

                        if reptile_context.meta_batch_size == -1:
                            meta_batch = participants
                        else:
                            meta_batch = [
                                participants[k] for k in cyclerange(
                                    start=i * reptile_context.meta_batch_size %
                                    len(participants),
                                    interval=reptile_context.meta_batch_size,
                                    total_len=len(participants))
                            ]
                        # Meta training step
                        reptile_train_step(
                            aggregator=cluster_server_dic[cluster_id],
                            participants=meta_batch,
                            inner_training_args=reptile_context.
                            get_inner_training_args(),
                            meta_training_args=reptile_context.
                            get_meta_training_args(
                                frac_done=i / reptile_context.num_meta_steps))

                    # Evaluation on train and test clients
                    if i % reptile_context.eval_interval == 0:
                        global_step = init_rounds + i
                        global_loss, global_acc = [], []

                        for cluster_id, participants in cluster_clients_dic.items(
                        ):
                            # Test on all clients inside clusters
                            reptile_train_step(
                                aggregator=cluster_server_dic[cluster_id],
                                participants=participants,
                                inner_training_args=reptile_context.
                                get_inner_training_args(eval=True),
                                evaluation_mode=True)
                            result = evaluate_local_models(
                                participants=participants)
                            loss = result.get('test/loss')
                            acc = result.get('test/acc')

                            # Log
                            if after_round_evaluation is not None:
                                for c in after_round_evaluation:
                                    c(experiment_logger,
                                      f'cluster_{cluster_id}', loss, acc,
                                      global_step)
                            loss_list = loss.tolist()
                            acc_list = acc.tolist()
                            global_loss.extend(loss_list if isinstance(
                                loss_list, list) else [loss_list])
                            global_acc.extend(acc_list if isinstance(
                                acc_list, list) else [acc_list])

                        if after_round_evaluation is not None:
                            for c in after_round_evaluation:
                                c(experiment_logger, 'mean_over_all_clients',
                                  Tensor(global_loss), Tensor(global_acc),
                                  global_step)

                    logger.info(f'Finished Reptile training round {i}')

                # Final evaluation at end of training
                if reptile_context.do_final_evaluation:
                    global_loss, global_acc = [], []

                    for cluster_id, participants in cluster_clients_dic.items(
                    ):
                        # Final evaluation on train and test clients
                        # Test on all clients inside clusters
                        reptile_train_step(
                            aggregator=cluster_server_dic[cluster_id],
                            participants=participants,
                            inner_training_args=reptile_context.
                            get_inner_training_args(eval=True),
                            evaluation_mode=True)
                        result = evaluate_local_models(
                            participants=participants)
                        loss = result.get('test/loss')
                        acc = result.get('test/acc')
                        print(
                            f'Cluster {cluster_id} ({len(participants)} part.): loss = {loss}, acc = {acc}'
                        )

                        loss_list = loss.tolist()
                        acc_list = acc.tolist()
                        global_loss.extend(loss_list if isinstance(
                            loss_list, list) else [loss_list])
                        global_acc.extend(acc_list if isinstance(
                            acc_list, list) else [acc_list])

                        # Log
                        if after_round_evaluation is not None:
                            for c in after_round_evaluation:
                                c(experiment_logger, f'cluster_{cluster_id}',
                                  loss, acc, reptile_context.num_meta_steps)

                    log_loss_and_acc('overall_mean', Tensor(global_loss),
                                     Tensor(global_acc), experiment_logger, 0)
예제 #5
0
    def __init__(
            self, name: str, dataset_name: str, swap_labels: bool,
            num_classes_per_client: int, num_shots_per_class: int, seed: int,
            model_class, sgd: bool, adam_betas: tuple, num_clients_train: int,
            num_clients_test: int, meta_batch_size: int, num_meta_steps: int,
            meta_learning_rate_initial: float, meta_learning_rate_final: float,
            eval_interval: int, num_eval_clients_training: int,
            do_final_evaluation: bool, num_eval_clients_final: int,
            inner_batch_size: int, inner_learning_rate: float,
            num_inner_steps: int, num_inner_steps_eval: int):
        self.name = name
        self.seed = seed

        # Arguments pertaining to data set
        self.dataset_name = dataset_name
        self.swap_labels = swap_labels
        self.num_classes_per_client = num_classes_per_client
        self.num_shots_per_class = num_shots_per_class
        self.num_clients_train = num_clients_train
        self.num_clients_test = num_clients_test

        # Model arguments
        self.model_class = model_class
        self.sgd = sgd
        self.adam_betas = adam_betas

        # Arguments for inner training
        self.num_inner_steps = num_inner_steps
        self.inner_learning_rate = inner_learning_rate
        self.inner_batch_size = inner_batch_size
        if self.sgd:
            inner_optimizer_args = OptimizerArgs(
                optimizer_class=torch.optim.SGD, lr=self.inner_learning_rate)
        else:
            inner_optimizer_args = OptimizerArgs(
                optimizer_class=torch.optim.Adam,
                lr=self.inner_learning_rate,
                betas=self.adam_betas)
        if self.dataset_name == 'omniglot':
            self.inner_model_args = ModelArgs(
                model_class=self.model_class,
                optimizer_args=inner_optimizer_args,
                num_classes=self.num_classes_per_client)
        elif self.dataset_name == 'ham10k':
            self.inner_model_args = ModelArgs(
                model_class=self.model_class,
                optimizer_args=inner_optimizer_args,
                num_classes=7)
        else:
            self.inner_model_args = ModelArgs(
                model_class=self.model_class,
                optimizer_args=inner_optimizer_args,
            )

        # Arguments for meta training
        self.num_meta_steps = num_meta_steps
        self.meta_learning_rate_initial = meta_learning_rate_initial
        self.meta_learning_rate_final = meta_learning_rate_final
        self.meta_batch_size = meta_batch_size  # number of clients per round
        if self.dataset_name == 'omniglot':
            self.meta_model_args = ModelArgs(
                model_class=self.model_class,
                optimizer_args=OptimizerArgs(  # Dummy optimizer args
                    optimizer_class=torch.optim.SGD),
                num_classes=self.num_classes_per_client)
        elif self.dataset_name == 'ham10k':
            self.meta_model_args = ModelArgs(
                model_class=self.model_class,
                optimizer_args=OptimizerArgs(  # Dummy optimizer args
                    optimizer_class=torch.optim.SGD),
                num_classes=7)
        else:
            self.meta_model_args = ModelArgs(
                model_class=self.model_class,
                optimizer_args=OptimizerArgs(  # Dummy optimizer args
                    optimizer_class=torch.optim.SGD))

        # Arguments for evaluation
        self.num_inner_steps_eval = num_inner_steps_eval
        self.eval_interval = eval_interval
        self.num_eval_clients_training = num_eval_clients_training
        if num_eval_clients_training > num_clients_train or \
                (num_clients_test and num_eval_clients_training > num_clients_test):
            raise ValueError(
                "num_eval_clients_training must be lower or equal to "
                "num_clients_train and num_clients_test")
        self.do_final_evaluation = do_final_evaluation
        if do_final_evaluation and (
                num_eval_clients_final > num_clients_train or
            (num_clients_test and num_eval_clients_final > num_clients_test)):
            raise ValueError(
                "num_eval_clients_final must be lower or equal to "
                "num_clients_train and num_clients_test")
        self.num_eval_clients_final = num_eval_clients_final

        self.weighted_aggregation: bool = True
        self._cluster_args = None
        self._experiment_logger = None