Beispiel #1
0
def create_reptile_omniglot_experiment_context() -> ReptileExperimentContext:
    num_client_classes = 5
    num_clients_train = 5000
    num_clients_test = 50
    num_classes_per_client = 5
    num_shots_per_class = 5
    eval_iters = 100
    inner_learning_rate = 0.001
    inner_training_steps = 5
    inner_batch_size = 10

    meta_batch_size = 5
    meta_learning_rate_initial = 1
    meta_learning_rate_final = 0
    meta_num_steps = 100000

    inner_training_args = TrainArgs(min_steps=inner_training_steps,
                                    max_steps=inner_training_steps)
    meta_training_args = TrainArgs(
        meta_learning_rate=meta_learning_rate_initial)
    inner_optimizer_args = OptimizerArgs(optimizer_class=optim.Adam,
                                         lr=inner_learning_rate)
    meta_optimizer_args = OptimizerArgs(optimizer_class=optim.SGD,
                                        lr=meta_learning_rate_initial)
    inner_model_args = ModelArgs(model_class=OmniglotLightning,
                                 num_classes=num_client_classes,
                                 optimizer_args=inner_optimizer_args)
    meta_model_args = ModelArgs(model_class=OmniglotLightning,
                                num_classes=num_client_classes,
                                optimizer_args=meta_optimizer_args)
    context = ReptileExperimentContext(
        name='reptile',
        dataset_name='omniglot',
        eval_iters=eval_iters,
        inner_training_steps=inner_training_steps,
        inner_batch_size=inner_batch_size,
        inner_optimizer_args=inner_optimizer_args,
        inner_learning_rate=inner_learning_rate,
        inner_model_args=inner_model_args,
        inner_train_args=inner_training_args,
        num_clients_train=num_clients_train,
        num_clients_test=num_clients_test,
        num_classes_per_client=num_classes_per_client,
        num_shots_per_class=num_shots_per_class,
        meta_model_args=meta_model_args,
        meta_batch_size=meta_batch_size,
        meta_learning_rate_initial=meta_learning_rate_initial,
        meta_learning_rate_final=meta_learning_rate_final,
        meta_num_steps=meta_num_steps,
        meta_optimizer_args=meta_optimizer_args,
        meta_training_args=meta_training_args)
    experiment_logger = create_tensorboard_logger(context.name, str(context))
    context.experiment_logger = experiment_logger
    return context
Beispiel #2
0
def create_mnist_experiment_context(
        name: str,
        local_epochs: int,
        batch_size: int,
        lr: float,
        client_fraction: float,
        dataset_name: str,
        num_classes: int,
        fixed_logger_version=None,
        no_progress_bar=False,
        cluster_args: Optional[ClusterArgs] = None):
    logger.debug('creating experiment context ...')
    optimizer_args = OptimizerArgs(optim.SGD, lr=lr)
    model_args = ModelArgs(CNNMnistLightning,
                           num_classes=num_classes,
                           optimizer_args=optimizer_args)
    train_args_dict = {'max_epochs': local_epochs, 'min_epochs': local_epochs}
    if no_progress_bar:
        train_args_dict['progress_bar_refresh_rate'] = 0
    training_args = TrainArgs(**train_args_dict)
    context = FedAvgExperimentContext(name=name,
                                      client_fraction=client_fraction,
                                      local_epochs=local_epochs,
                                      lr=lr,
                                      batch_size=batch_size,
                                      optimizer_args=optimizer_args,
                                      model_args=model_args,
                                      train_args=training_args,
                                      dataset_name=dataset_name)
    experiment_specification = f'{context}'
    experiment_specification += f'_{optimizer_args}'
    if cluster_args is not None:
        context.cluster_args = cluster_args
        experiment_specification += f'_{cluster_args}'
    experiment_logger = create_tensorboard_logger(context.name,
                                                  experiment_specification,
                                                  fixed_logger_version)
    context.experiment_logger = experiment_logger
    return context
Beispiel #3
0
def run_reptile_experiment(
    name,
    dataset,
    swap_labels,
    classes,
    shots,
    seed,
    model_class,
    sgd,
    adam_betas,
    num_clients_train,
    num_clients_test,
    meta_batch_size,
    num_meta_steps,
    meta_learning_rate_initial,
    meta_learning_rate_final,
    eval_interval,
    num_eval_clients_training,
    do_final_evaluation,
    num_eval_clients_final,
    inner_batch_size,
    inner_learning_rate,
    num_inner_steps,
    num_inner_steps_eval,
    mean=None,
    std=None
):
    fix_random_seeds(seed)
    fed_dataset_test = None
    if dataset == 'femnist':
        fed_dataset_train = load_femnist_dataset(
            data_dir=str((REPO_ROOT / 'data').absolute()),
            num_clients=num_clients_train,
            batch_size=inner_batch_size,
            random_seed=seed
        )
    elif dataset == 'omniglot':
        fed_dataset_train, fed_dataset_test = load_omniglot_datasets(
            data_dir=str((REPO_ROOT / 'data' / 'omniglot').absolute()),
            num_clients_train=num_clients_train,
            num_clients_test=num_clients_test,
            num_classes_per_client=classes,
            num_shots_per_class=shots,
            inner_batch_size=inner_batch_size,
            random_seed=seed
        )
    elif dataset == 'ham10k':
        fed_dataset_train = load_ham10k_federated(partitions=num_clients_train, batch_size=inner_batch_size, mean=mean, std=std)
    else:
        raise ValueError(f'dataset "{dataset}" unknown')

    #data_distribution_logged = False
    for lr in inner_learning_rate:
        for _is in num_inner_steps:
            reptile_context = ReptileExperimentContext(
                name=name,
                dataset_name=dataset,
                swap_labels=swap_labels,
                num_classes_per_client=classes,
                num_shots_per_class=shots,
                seed=seed,
                model_class=model_class,
                sgd=sgd,
                adam_betas=adam_betas,
                num_clients_train=num_clients_train,
                num_clients_test=num_clients_test,
                meta_batch_size=meta_batch_size,
                num_meta_steps=num_meta_steps,
                meta_learning_rate_initial=meta_learning_rate_initial,
                meta_learning_rate_final=meta_learning_rate_final,
                eval_interval=eval_interval,
                num_eval_clients_training=num_eval_clients_training,
                do_final_evaluation=do_final_evaluation,
                num_eval_clients_final=num_eval_clients_final,
                inner_batch_size=inner_batch_size,
                inner_learning_rate=lr,
                num_inner_steps=_is,
                num_inner_steps_eval=_is
            )

            experiment_specification = f'{reptile_context}'
            experiment_logger = create_tensorboard_logger(
                reptile_context.name, experiment_specification
            )
            reptile_context.experiment_logger = experiment_logger

            log_after_round_evaluation_fns = [
                partial(log_after_round_evaluation, experiment_logger)
            ]
            run_reptile(
                context=reptile_context,
                dataset_train=fed_dataset_train,
                dataset_test=fed_dataset_test,
                initial_model_state=None,
                after_round_evaluation=log_after_round_evaluation_fns
            )
def run_hierarchical_clustering(local_evaluation_steps,
                                seed,
                                lr,
                                name,
                                total_fedavg_rounds,
                                cluster_initialization_rounds,
                                client_fraction,
                                local_epochs,
                                batch_size,
                                num_clients,
                                sample_threshold,
                                num_label_limit,
                                train_args,
                                dataset,
                                partitioner_class,
                                linkage_mech,
                                criterion,
                                dis_metric,
                                max_value_criterion,
                                reallocate_clients,
                                threshold_min_client_cluster,
                                use_colored_images,
                                use_pattern,
                                train_cluster_args=None,
                                mean=None,
                                std=None):
    fix_random_seeds(seed)
    global_tag = 'global_performance'
    global_tag_local = 'global_performance_personalized'
    initialize_clients_fn = DEFAULT_CLIENT_INIT_FN
    if dataset == 'ham10k':
        fed_dataset = load_ham10k_federated(partitions=num_clients,
                                            batch_size=batch_size,
                                            mean=mean,
                                            std=std)
        initialize_clients_fn = initialize_ham10k_clients
    else:
        raise ValueError(f'dataset "{dataset}" unknown')

    if not hasattr(max_value_criterion, '__iter__'):
        max_value_criterion = [max_value_criterion]
    if not hasattr(lr, '__iter__'):
        lr = [lr]

    for cf in client_fraction:
        for lr_i in lr:
            optimizer_args = OptimizerArgs(optim.SGD, lr=lr_i)
            model_args = ModelArgs(MobileNetV2Lightning,
                                   optimizer_args=optimizer_args,
                                   num_classes=7)
            fedavg_context = FedAvgExperimentContext(
                name=name,
                client_fraction=cf,
                local_epochs=local_epochs,
                lr=lr_i,
                batch_size=batch_size,
                optimizer_args=optimizer_args,
                model_args=model_args,
                train_args=train_args,
                dataset_name=dataset)
            experiment_specification = f'{fedavg_context}'
            experiment_logger = create_tensorboard_logger(
                fedavg_context.name, experiment_specification)
            fedavg_context.experiment_logger = experiment_logger
            for init_rounds, max_value in generate_configuration(
                    cluster_initialization_rounds, max_value_criterion):
                # load the model state
                round_model_state = load_fedavg_state(fedavg_context,
                                                      init_rounds)

                server = FedAvgServer('initial_server',
                                      fedavg_context.model_args,
                                      fedavg_context)
                server.overwrite_model_state(round_model_state)
                logger.info('initializing clients ...')
                clients = initialize_clients_fn(fedavg_context, fed_dataset,
                                                server.model.state_dict())

                overwrite_participants_models(round_model_state, clients)
                # initialize the cluster configuration
                round_configuration = {
                    'num_rounds_init': init_rounds,
                    'num_rounds_cluster': total_fedavg_rounds - init_rounds
                }
                if partitioner_class == DatadependentPartitioner:
                    clustering_dataset = load_femnist_colored_dataset(
                        str((REPO_ROOT / 'data').absolute()),
                        num_clients=num_clients,
                        batch_size=batch_size,
                        sample_threshold=sample_threshold)
                    dataloader = load_n_of_each_class(
                        clustering_dataset,
                        n=5,
                        tabu=list(fed_dataset.train_data_local_dict.keys()))
                    cluster_args = ClusterArgs(
                        partitioner_class,
                        linkage_mech=linkage_mech,
                        criterion=criterion,
                        dis_metric=dis_metric,
                        max_value_criterion=max_value,
                        plot_dendrogram=False,
                        reallocate_clients=reallocate_clients,
                        threshold_min_client_cluster=
                        threshold_min_client_cluster,
                        dataloader=dataloader,
                        **round_configuration)
                else:
                    cluster_args = ClusterArgs(
                        partitioner_class,
                        linkage_mech=linkage_mech,
                        criterion=criterion,
                        dis_metric=dis_metric,
                        max_value_criterion=max_value,
                        plot_dendrogram=False,
                        reallocate_clients=reallocate_clients,
                        threshold_min_client_cluster=
                        threshold_min_client_cluster,
                        **round_configuration)
                # create new logger for cluster experiment
                experiment_specification = f'{fedavg_context}_{cluster_args}'
                experiment_logger = create_tensorboard_logger(
                    fedavg_context.name, experiment_specification)
                fedavg_context.experiment_logger = experiment_logger

                initial_train_fn = partial(run_fedavg_train_round,
                                           round_model_state,
                                           training_args=train_cluster_args)
                create_aggregator_fn = partial(FedAvgServer,
                                               model_args=model_args,
                                               context=fedavg_context)
                federated_round_fn = partial(run_fedavg_round,
                                             training_args=train_args,
                                             client_fraction=cf)

                after_post_clustering_evaluation = [
                    partial(log_after_round_evaluation, experiment_logger,
                            'post_clustering')
                ]
                after_clustering_round_evaluation = [
                    partial(log_after_round_evaluation, experiment_logger)
                ]
                after_federated_round_evaluation = [
                    partial(log_after_round_evaluation, experiment_logger,
                            ['final hierarchical', global_tag])
                ]
                after_clustering_fn = [
                    partial(log_cluster_distribution,
                            experiment_logger,
                            num_classes=fed_dataset.class_num),
                    partial(log_sample_images_from_each_client,
                            experiment_logger)
                ]
                after_federated_round_fn = [
                    partial(
                        log_personalized_global_cluster_performance,
                        experiment_logger,
                        ['final hierarchical personalized', global_tag_local],
                        local_evaluation_steps)
                ]
                run_fedavg_hierarchical(
                    server,
                    clients,
                    cluster_args,
                    initial_train_fn,
                    federated_round_fn,
                    create_aggregator_fn,
                    after_post_clustering_evaluation,
                    after_clustering_round_evaluation,
                    after_federated_round_evaluation,
                    after_clustering_fn,
                    after_federated_round=after_federated_round_fn)
def clustering_test(mean, std, seed, lr, local_epochs, client_fraction,
                    optimizer_args, total_fedavg_rounds, batch_size,
                    num_clients, model_args, train_args, train_cluster_args,
                    initialization_rounds, partitioner_class, linkage_mech,
                    criterion, dis_metric, max_value_criterion):
    fix_random_seeds(seed)

    fed_dataset = load_ham10k_federated(partitions=num_clients,
                                        batch_size=batch_size,
                                        mean=mean,
                                        std=std)
    initialize_clients_fn = initialize_ham10k_clients

    fedavg_context = FedAvgExperimentContext(name='ham10k_clustering',
                                             client_fraction=client_fraction,
                                             local_epochs=local_epochs,
                                             lr=lr,
                                             batch_size=batch_size,
                                             optimizer_args=optimizer_args,
                                             model_args=model_args,
                                             train_args=train_args,
                                             dataset_name='ham10k')
    experiment_specification = f'{fedavg_context}'
    experiment_logger = create_tensorboard_logger(fedavg_context.name,
                                                  experiment_specification)

    log_dataset_distribution(experiment_logger, 'full dataset', fed_dataset)

    server, clients = run_fedavg(context=fedavg_context,
                                 num_rounds=total_fedavg_rounds,
                                 dataset=fed_dataset,
                                 save_states=True,
                                 restore_state=True,
                                 evaluate_rounds=False,
                                 initialize_clients_fn=initialize_clients_fn)

    for init_rounds in initialization_rounds:
        # load the model state
        round_model_state = load_fedavg_state(fedavg_context, init_rounds)
        overwrite_participants_models(round_model_state, clients)
        run_fedavg_train_round(round_model_state,
                               training_args=train_cluster_args,
                               participants=clients)
        for max_value in max_value_criterion:
            # initialize the cluster configuration
            round_configuration = {
                'num_rounds_init': init_rounds,
                'num_rounds_cluster': total_fedavg_rounds - init_rounds
            }
            cluster_args = ClusterArgs(partitioner_class,
                                       linkage_mech=linkage_mech,
                                       criterion=criterion,
                                       dis_metric=dis_metric,
                                       max_value_criterion=max_value,
                                       plot_dendrogram=False,
                                       reallocate_clients=False,
                                       threshold_min_client_cluster=-1,
                                       **round_configuration)
            experiment_logger = create_tensorboard_logger(
                fedavg_context.name,
                f'{experiment_specification}{cluster_args}')
            partitioner = cluster_args()
            cluster_clients_dic = partitioner.cluster(clients, server)
            log_cluster_distribution(experiment_logger, cluster_clients_dic, 7)
def run_hierarchical_clustering_reptile(
        seed,
        name,
        dataset,
        num_clients,
        batch_size,
        num_label_limit,
        use_colored_images,
        sample_threshold,
        hc_lr,
        hc_cluster_initialization_rounds,
        hc_client_fraction,
        hc_local_epochs,
        hc_train_args,
        hc_partitioner_class,
        hc_linkage_mech,
        hc_criterion,
        hc_dis_metric,
        hc_max_value_criterion,  # distance threshold
        hc_reallocate_clients,  #
        hc_threshold_min_client_cluster,  # only with hc_reallocate_clients = True,
        # results in clusters having at least this number of clients
    hc_train_cluster_args,
        rp_sgd,  # True -> Use SGD as inner optimizer; False -> Use Adam
        rp_adam_betas,  # Used only if sgd = False
        rp_meta_batch_size,
        rp_num_meta_steps,
        rp_meta_learning_rate_initial,
        rp_meta_learning_rate_final,
        rp_eval_interval,
        rp_inner_learning_rate,
        rp_num_inner_steps,
        rp_num_inner_steps_eval):
    fix_random_seeds(seed)
    global_tag = 'global_performance'

    if dataset == 'femnist':
        if use_colored_images:
            fed_dataset = load_femnist_colored_dataset(
                data_dir=str((REPO_ROOT / 'data').absolute()),
                num_clients=num_clients,
                batch_size=batch_size,
                sample_threshold=sample_threshold)
        else:
            fed_dataset = load_femnist_dataset(
                data_dir=str((REPO_ROOT / 'data').absolute()),
                num_clients=num_clients,
                batch_size=batch_size,
                sample_threshold=sample_threshold)
        if num_label_limit != -1:
            fed_dataset = scratch_labels(fed_dataset, num_label_limit)
    else:
        raise ValueError(f'dataset "{dataset}" unknown')

    if not hasattr(hc_max_value_criterion, '__iter__'):
        hc_max_value_criterion = [hc_max_value_criterion]
    if not hasattr(hc_lr, '__iter__'):
        hc_lr = [hc_lr]
    input_channels = 3 if use_colored_images else 1
    data_distribution_logged = False
    for cf in hc_client_fraction:
        for lr_i in hc_lr:
            # Initialize experiment context parameters
            fedavg_optimizer_args = OptimizerArgs(optim.SGD, lr=lr_i)
            fedavg_model_args = ModelArgs(CNNLightning,
                                          optimizer_args=fedavg_optimizer_args,
                                          input_channels=input_channels,
                                          only_digits=False)
            fedavg_context = FedAvgExperimentContext(
                name=name,
                client_fraction=cf,
                local_epochs=hc_local_epochs,
                lr=lr_i,
                batch_size=batch_size,
                optimizer_args=fedavg_optimizer_args,
                model_args=fedavg_model_args,
                train_args=hc_train_args,
                dataset_name=dataset)
            reptile_context = ReptileExperimentContext(
                name=name,
                dataset_name=dataset,
                swap_labels=False,
                num_classes_per_client=0,
                num_shots_per_class=0,
                seed=seed,
                model_class=CNNLightning,
                sgd=rp_sgd,
                adam_betas=rp_adam_betas,
                num_clients_train=num_clients,
                num_clients_test=0,
                meta_batch_size=rp_meta_batch_size,
                num_meta_steps=rp_num_meta_steps,
                meta_learning_rate_initial=rp_meta_learning_rate_initial,
                meta_learning_rate_final=rp_meta_learning_rate_final,
                eval_interval=rp_eval_interval,
                num_eval_clients_training=-1,
                do_final_evaluation=True,
                num_eval_clients_final=-1,
                inner_batch_size=batch_size,
                inner_learning_rate=rp_inner_learning_rate,
                num_inner_steps=rp_num_inner_steps,
                num_inner_steps_eval=rp_num_inner_steps_eval)
            experiment_specification = f'{fedavg_context}'
            experiment_logger = create_tensorboard_logger(
                name, experiment_specification)
            if not data_distribution_logged:
                log_dataset_distribution(experiment_logger, 'full dataset',
                                         fed_dataset)
                data_distribution_logged = True

            log_after_round_evaluation_fns = [
                partial(log_after_round_evaluation, experiment_logger,
                        'fedavg'),
                partial(log_after_round_evaluation, experiment_logger,
                        global_tag)
            ]
            server, clients = run_fedavg(
                context=fedavg_context,
                num_rounds=max(hc_cluster_initialization_rounds),
                dataset=fed_dataset,
                save_states=True,
                restore_state=True,
                after_round_evaluation=log_after_round_evaluation_fns)

            for init_rounds, max_value in generate_configuration(
                    hc_cluster_initialization_rounds, hc_max_value_criterion):
                # load the model state
                round_model_state = load_fedavg_state(fedavg_context,
                                                      init_rounds)
                overwrite_participants_models(round_model_state, clients)
                # initialize the cluster configuration
                round_configuration = {
                    'num_rounds_init': init_rounds,
                    'num_rounds_cluster': 0
                }
                cluster_args = ClusterArgs(
                    hc_partitioner_class,
                    linkage_mech=hc_linkage_mech,
                    criterion=hc_criterion,
                    dis_metric=hc_dis_metric,
                    max_value_criterion=max_value,
                    plot_dendrogram=False,
                    reallocate_clients=hc_reallocate_clients,
                    threshold_min_client_cluster=
                    hc_threshold_min_client_cluster,
                    **round_configuration)
                # create new logger for cluster experiment
                experiment_specification = f'{fedavg_context}_{cluster_args}_{reptile_context}'
                experiment_logger = create_tensorboard_logger(
                    name, experiment_specification)
                fedavg_context.experiment_logger = experiment_logger

                initial_train_fn = partial(run_fedavg_train_round,
                                           round_model_state,
                                           training_args=hc_train_cluster_args)
                create_aggregator_fn = partial(FedAvgServer,
                                               model_args=fedavg_model_args,
                                               context=fedavg_context)

                # HIERARCHICAL CLUSTERING
                logger.debug('starting local training before clustering.')
                trained_participants = initial_train_fn(clients)
                if len(trained_participants) != len(clients):
                    raise ValueError(
                        'not all clients successfully participated in the clustering round'
                    )

                # Clustering of participants by model updates
                partitioner = cluster_args()
                cluster_clients_dic = partitioner.cluster(clients, server)
                _cluster_clients_dic = dict()
                for cluster_id, participants in cluster_clients_dic.items():
                    _cluster_clients_dic[cluster_id] = [
                        c._name for c in participants
                    ]
                log_cluster_distribution(experiment_logger,
                                         cluster_clients_dic, 62)

                # Initialize cluster models
                cluster_server_dic = {}
                for cluster_id, participants in cluster_clients_dic.items():
                    intermediate_cluster_server = create_aggregator_fn(
                        'cluster_server' + cluster_id)
                    intermediate_cluster_server.aggregate(participants)
                    cluster_server = ReptileServer(
                        participant_name=f'cluster_server{cluster_id}',
                        model_args=reptile_context.meta_model_args,
                        context=reptile_context,
                        initial_model_state=intermediate_cluster_server.model.
                        state_dict())
                    #create_aggregator_fn('cluster_server' + cluster_id)
                    #cluster_server.aggregate(participants)
                    cluster_server_dic[cluster_id] = cluster_server

                # REPTILE TRAINING INSIDE CLUSTERS
                after_round_evaluation = [log_after_round_evaluation]
                RANDOM = random.Random(seed)

                # Perform training
                for i in range(reptile_context.num_meta_steps):
                    for cluster_id, participants in cluster_clients_dic.items(
                    ):

                        if reptile_context.meta_batch_size == -1:
                            meta_batch = participants
                        else:
                            meta_batch = [
                                participants[k] for k in cyclerange(
                                    start=i * reptile_context.meta_batch_size %
                                    len(participants),
                                    interval=reptile_context.meta_batch_size,
                                    total_len=len(participants))
                            ]
                        # Meta training step
                        reptile_train_step(
                            aggregator=cluster_server_dic[cluster_id],
                            participants=meta_batch,
                            inner_training_args=reptile_context.
                            get_inner_training_args(),
                            meta_training_args=reptile_context.
                            get_meta_training_args(
                                frac_done=i / reptile_context.num_meta_steps))

                    # Evaluation on train and test clients
                    if i % reptile_context.eval_interval == 0:
                        global_step = init_rounds + i
                        global_loss, global_acc = [], []

                        for cluster_id, participants in cluster_clients_dic.items(
                        ):
                            # Test on all clients inside clusters
                            reptile_train_step(
                                aggregator=cluster_server_dic[cluster_id],
                                participants=participants,
                                inner_training_args=reptile_context.
                                get_inner_training_args(eval=True),
                                evaluation_mode=True)
                            result = evaluate_local_models(
                                participants=participants)
                            loss = result.get('test/loss')
                            acc = result.get('test/acc')

                            # Log
                            if after_round_evaluation is not None:
                                for c in after_round_evaluation:
                                    c(experiment_logger,
                                      f'cluster_{cluster_id}', loss, acc,
                                      global_step)
                            loss_list = loss.tolist()
                            acc_list = acc.tolist()
                            global_loss.extend(loss_list if isinstance(
                                loss_list, list) else [loss_list])
                            global_acc.extend(acc_list if isinstance(
                                acc_list, list) else [acc_list])

                        if after_round_evaluation is not None:
                            for c in after_round_evaluation:
                                c(experiment_logger, 'mean_over_all_clients',
                                  Tensor(global_loss), Tensor(global_acc),
                                  global_step)

                    logger.info(f'Finished Reptile training round {i}')

                # Final evaluation at end of training
                if reptile_context.do_final_evaluation:
                    global_loss, global_acc = [], []

                    for cluster_id, participants in cluster_clients_dic.items(
                    ):
                        # Final evaluation on train and test clients
                        # Test on all clients inside clusters
                        reptile_train_step(
                            aggregator=cluster_server_dic[cluster_id],
                            participants=participants,
                            inner_training_args=reptile_context.
                            get_inner_training_args(eval=True),
                            evaluation_mode=True)
                        result = evaluate_local_models(
                            participants=participants)
                        loss = result.get('test/loss')
                        acc = result.get('test/acc')
                        print(
                            f'Cluster {cluster_id} ({len(participants)} part.): loss = {loss}, acc = {acc}'
                        )

                        loss_list = loss.tolist()
                        acc_list = acc.tolist()
                        global_loss.extend(loss_list if isinstance(
                            loss_list, list) else [loss_list])
                        global_acc.extend(acc_list if isinstance(
                            acc_list, list) else [acc_list])

                        # Log
                        if after_round_evaluation is not None:
                            for c in after_round_evaluation:
                                c(experiment_logger, f'cluster_{cluster_id}',
                                  loss, acc, reptile_context.num_meta_steps)

                    log_loss_and_acc('overall_mean', Tensor(global_loss),
                                     Tensor(global_acc), experiment_logger, 0)
def run_reptile(context: ExperimentContext, initial_model_state=None):

    num_clients_train = 10000
    num_clients_test = 1000
    num_classes_per_client = 5
    num_shots_per_class = 5

    eval_iters = 10

    reptile_args = ReptileTrainingArgs(model=OmniglotLightning,
                                       inner_optimizer=optim.Adam,
                                       inner_learning_rate=0.001,
                                       num_inner_steps=5,
                                       num_inner_steps_eval=50,
                                       log_every_n_steps=3,
                                       inner_batch_size=10,
                                       meta_batch_size=5,
                                       meta_learning_rate_initial=1,
                                       meta_learning_rate_final=0,
                                       num_meta_steps=3000)
    experiment_logger = create_tensorboard_logger(
        context.name, "dataloading_ours;models_ours")

    # Load and prepare Omniglot data
    data_dir = REPO_ROOT / 'data' / 'omniglot'

    #######
    tf.disable_eager_execution()
    #######

    omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets(
        str(data_dir.absolute()),
        num_clients_train=num_clients_train,
        num_clients_test=num_clients_test,
        num_classes_per_client=num_classes_per_client,
        num_shots_per_class=num_shots_per_class,
        inner_batch_size=reptile_args.inner_batch_size,
        random_seed=RANDOM_SEED)

    # Prepare ModelArgs for task training
    inner_optimizer_args = OptimizerArgs(
        optimizer_class=reptile_args.inner_optimizer,
        lr=reptile_args.inner_learning_rate,
        betas=(0, 0.999))
    inner_model_args = ModelArgs(model_class=reptile_args.model,
                                 optimizer_args=inner_optimizer_args,
                                 num_classes=num_classes_per_client)
    dummy_optimizer_args = OptimizerArgs(optimizer_class=optim.SGD)
    meta_model_args = ModelArgs(reptile_args.model,
                                dummy_optimizer_args,
                                num_classes=num_classes_per_client)
    """
    # Set up clients
    # Since we are doing meta-learning, we need separate sets of training and
    # test clients
    train_clients = []
    for c in omniglot_train_clients.train_data_local_dict.keys():
        client = ReptileClient(
            client_id=str(c),
            model_args=inner_model_args,
            context=context,
            train_dataloader=omniglot_train_clients.train_data_local_dict[c],
            num_train_samples=omniglot_train_clients.data_local_train_num_dict[c],
            test_dataloader=omniglot_train_clients.test_data_local_dict[c],
            num_test_samples=omniglot_train_clients.data_local_test_num_dict[c],
            lightning_logger=experiment_logger
        )
        checkpoint_callback = ModelCheckpoint(
            filepath=str(client.get_checkpoint_path(suffix='cb').absolute()))
        client.set_trainer_callbacks([checkpoint_callback])
        train_clients.append(client)
    test_clients = []
    for c in omniglot_test_clients.train_data_local_dict.keys():
        client = ReptileClient(
            client_id=str(c),
            model_args=inner_model_args,
            context=context,
            train_dataloader=omniglot_test_clients.train_data_local_dict[c],
            num_train_samples=omniglot_test_clients.data_local_train_num_dict[c],
            test_dataloader=omniglot_test_clients.test_data_local_dict[c],
            num_test_samples=omniglot_test_clients.data_local_test_num_dict[c],
            lightning_logger=experiment_logger
        )
        test_clients.append(client)

    # Set up server
    server = ReptileServer(
        participant_name='initial_server',
        model_args=meta_model_args,
        context=context,
        initial_model_state=initial_model_state
    )"""

    #torch_model = OmniglotModel(num_classes=num_classes_per_client)
    torch_model = OmniglotLightning(participant_name='global_model',
                                    **inner_model_args.kwargs)
    #torch_optimizer = inner_optimizer_args.optimizer_class(
    #    torch_model.parameters(),
    #    **inner_optimizer_args.optimizer_kwargs
    #)

    reptile = Reptile(global_model=torch_model,
                      model_kwargs=inner_model_args.kwargs,
                      inner_iterations=reptile_args.num_inner_steps,
                      inner_iterations_eval=reptile_args.num_inner_steps_eval)

    for i in range(reptile_args.num_meta_steps):
        frac_done = i / reptile_args.num_meta_steps
        cur_meta_step_size = frac_done * reptile_args.meta_learning_rate_final + (
            1 - frac_done) * reptile_args.meta_learning_rate_initial

        meta_batch = {
            k: omniglot_train_clients.train_data_local_dict[k]
            for k in cyclerange(
                i * reptile_args.meta_batch_size %
                len(omniglot_train_clients.train_data_local_dict), (i + 1) *
                reptile_args.meta_batch_size %
                len(omniglot_train_clients.train_data_local_dict),
                len(omniglot_train_clients.train_data_local_dict))
        }

        reptile.train_step(meta_batch=meta_batch,
                           meta_step_size=cur_meta_step_size)

        if i % eval_iters == 0:
            accuracies = []
            k = RANDOM.randrange(
                len(omniglot_train_clients.train_data_local_dict))
            train_train = omniglot_train_clients.train_data_local_dict[k]
            train_test = omniglot_train_clients.test_data_local_dict[k]
            k = RANDOM.randrange(
                len(omniglot_test_clients.train_data_local_dict))
            test_train = omniglot_test_clients.train_data_local_dict[k]
            test_test = omniglot_test_clients.test_data_local_dict[k]

            for train_dl, test_dl in [(train_train, train_test),
                                      (test_train, test_test)]:
                correct = reptile.evaluate(train_dl, test_dl)
                accuracies.append(correct / num_classes_per_client)
            print('batch %d: train=%f test=%f' %
                  (i, accuracies[0], accuracies[1]))

            # Write to TensorBoard
            experiment_logger.experiment.add_scalar(
                'train-test/acc/{}/mean'.format('global_model'),
                accuracies[0],
                global_step=i)
            experiment_logger.experiment.add_scalar(
                'test-test/acc/{}/mean'.format('global_model'),
                accuracies[1],
                global_step=i)
    """
Beispiel #8
0
def create_client_batches(clients: List[ReptileClient],
                          batch_size: int) -> List[List[ReptileClient]]:
    if batch_size == -1:
        client_batches = [clients]
    else:
        client_batches = [
            clients[i:i + batch_size]
            for i in range(0, len(clients), batch_size)
        ]
    return client_batches

    num_clients_train = 10000
    num_clients_test = 1000
    num_classes_per_client = 5
    num_shots_per_class = 5

    eval_iters = 10

    reptile_args = ReptileTrainingArgs(model=OmniglotModel,
                                       inner_optimizer=optim.Adam,
                                       inner_learning_rate=0.001,
                                       num_inner_steps=5,
                                       num_inner_steps_eval=50,
                                       log_every_n_steps=3,
                                       inner_batch_size=10,
                                       meta_batch_size=5,
                                       meta_learning_rate_initial=1,
                                       meta_learning_rate_final=0,
                                       num_meta_steps=3000)
    experiment_logger = create_tensorboard_logger(
        context.name, "dataloading_ours;models_ours")

    # Load and prepare Omniglot data
    data_dir = REPO_ROOT / 'data' / 'omniglot'

    #######
    tf.disable_eager_execution()
    #######

    omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets(
        str(data_dir.absolute()),
        num_clients_train=num_clients_train,
        num_clients_test=num_clients_test,
        num_classes_per_client=num_classes_per_client,
        num_shots_per_class=num_shots_per_class,
        inner_batch_size=reptile_args.inner_batch_size,
        random_seed=RANDOM_SEED)

    # Prepare ModelArgs for task training
    inner_optimizer_args = OptimizerArgs(
        optimizer_class=reptile_args.inner_optimizer,
        lr=reptile_args.inner_learning_rate,
        betas=(0, 0.999))
    inner_model_args = ModelArgs(reptile_args.model,
                                 inner_optimizer_args,
                                 num_classes=num_classes_per_client)
    dummy_optimizer_args = OptimizerArgs(optimizer_class=optim.SGD)
    meta_model_args = ModelArgs(reptile_args.model,
                                dummy_optimizer_args,
                                num_classes=num_classes_per_client)
    """
    # Set up clients
    # Since we are doing meta-learning, we need separate sets of training and
    # test clients
    train_clients = initialize_reptile_clients(context, train_datasets)
    test_clients = initialize_reptile_clients(context, test_datasets)

    # Set up server
    server = ReptileServer(
        participant_name='initial_server',
        model_args=context.meta_model_args,
        context=context,
        initial_model_state=initial_model_state
    )"""

    torch_model = OmniglotModel(num_classes=num_classes_per_client)
    torch_optimizer = inner_optimizer_args.optimizer_class(
        torch_model.parameters(), **inner_optimizer_args.optimizer_kwargs)

    reptile = Reptile(model=torch_model,
                      optimizer=torch_optimizer,
                      inner_iterations=reptile_args.num_inner_steps,
                      inner_iterations_eval=reptile_args.num_inner_steps_eval)

    for i in range(reptile_args.num_meta_steps):
        frac_done = i / reptile_args.num_meta_steps
        cur_meta_step_size = frac_done * reptile_args.meta_learning_rate_final + (
            1 - frac_done) * reptile_args.meta_learning_rate_initial

        meta_batch = {
            k: omniglot_train_clients.train_data_local_dict[k]
            for k in cyclerange(
                i * reptile_args.meta_batch_size %
                len(omniglot_train_clients.train_data_local_dict), (i + 1) *
                reptile_args.meta_batch_size %
                len(omniglot_train_clients.train_data_local_dict),
                len(omniglot_train_clients.train_data_local_dict))
        }

        reptile.train_step(meta_batch=meta_batch,
                           meta_step_size=cur_meta_step_size)

        if i % eval_iters == 0:
            accuracies = []
            k = RANDOM.randrange(
                len(omniglot_train_clients.train_data_local_dict))
            train_train = omniglot_train_clients.train_data_local_dict[k]
            train_test = omniglot_train_clients.test_data_local_dict[k]
            k = RANDOM.randrange(
                len(omniglot_test_clients.train_data_local_dict))
            test_train = omniglot_test_clients.train_data_local_dict[k]
            test_test = omniglot_test_clients.test_data_local_dict[k]

            for train_dl, test_dl in [(train_train, train_test),
                                      (test_train, test_test)]:
                correct = reptile.evaluate(train_dl, test_dl)
                accuracies.append(correct / num_classes_per_client)
            print('batch %d: train=%f test=%f' %
                  (i, accuracies[0], accuracies[1]))

            # Write to TensorBoard
            experiment_logger.experiment.add_scalar(
                'train-test/acc/{}/mean'.format('global_model'),
                accuracies[0],
                global_step=i)
            experiment_logger.experiment.add_scalar(
                'test-test/acc/{}/mean'.format('global_model'),
                accuracies[1],
                global_step=i)
    """
    for r in range(8, 23, 10):
        for c in range(8, 23, 10):
            rnd = np.random.random_sample((8, ))
            dotmap[r - 2, c] = rnd[0] * 0.3
            dotmap[r - 1, c - 1] = rnd[1] * 0.3
            dotmap[r - 1, c + 1] = rnd[2] * 0.3
            dotmap[r + 2, c] = rnd[3] * 0.3
            dotmap[r + 1, c - 1] = rnd[4] * 0.3
            dotmap[r + 1, c + 1] = rnd[5] * 0.3
            dotmap[r, c + 2] = rnd[6] * 0.3
            dotmap[r, c - 2] = rnd[7] * 0.3
    return original_pixels * dotmap


if __name__ == '__main__':
    from mlmi.settings import REPO_ROOT
    from mlmi.utils import create_tensorboard_logger
    experiment_logger = create_tensorboard_logger('colortest', 'femnist')
    dataset = load_femnist_colored_dataset(str(
        (REPO_ROOT / 'data').absolute()))
    dataloaders = list(dataset.train_data_local_dict.values())[0:5]
    images = []
    for dl in dataloaders:
        for i, (x, y) in enumerate(dl):
            for s in x:
                images.append(s)
    images_array = np.stack(images)
    experiment_logger.experiment.add_image('test',
                                           images_array,
                                           dataformats='NCHW')
Beispiel #10
0
    def run():
        parser = argparse.ArgumentParser()
        add_args(parser)
        args = parser.parse_args()

        # fix_random_seeds(args.seed)

        logger.debug('loading experiment data ...')
        data_dir = REPO_ROOT / 'data'
        fed_dataset = None
        context = None

        if args.cifar10:
            pass
        elif args.cifar100:
            pass
        elif args.mnist:
            fed_dataset = load_mnist_dataset(str(data_dir.absolute()),
                                             num_clients=100,
                                             batch_size=10)
        else:
            # default to femnist dataset
            fed_dataset = load_femnist_dataset(str(data_dir.absolute()),
                                               num_clients=367,
                                               batch_size=10,
                                               only_digits=False,
                                               sample_threshold=250)

        if args.non_iid_scratch:
            non_iid_scratch(fed_dataset, num_mnist_label_zero=5)

        if args.scratch_data:
            client_fraction_to_scratch = 0.75
            data_fraction_to_scratch = 0.9
            scratch_data(fed_dataset,
                         client_fraction_to_scratch=client_fraction_to_scratch,
                         fraction_to_scratch=data_fraction_to_scratch)
            fed_dataset.name += f'_scratched{client_fraction_to_scratch:.2f}by{data_fraction_to_scratch:.2f}'

        if args.log_data_distribution:
            logger.info(
                '... found log distribution flag, only logging data distribution information'
            )
            experiment_logger = create_tensorboard_logger('datadistribution',
                                                          fed_dataset.name,
                                                          version=0)
            log_data_distribution_by_dataset('fedavg', fed_dataset,
                                             experiment_logger)
            return

        if args.plot_client_labels:
            augment_for_clustering(fed_dataset,
                                   0.1,
                                   4,
                                   label_core_num=12,
                                   label_deviation=3)
            image = generate_data_label_heatmap(
                'initial distribution',
                fed_dataset.train_data_local_dict.values(), 62)
            experiment_logger = create_tensorboard_logger(
                'datadistribution', fed_dataset.name)
            experiment_logger.experiment.add_image('label distribution/test',
                                                   image.numpy())
            return
        """
        default: run fed avg with fixed parameters
        """
        try:
            context = create_mnist_experiment_context(
                name='fedavg',
                client_fraction=0.1,
                local_epochs=5,
                num_classes=10,
                lr=0.1,
                batch_size=fed_dataset.batch_size,
                dataset_name='mnist_momentum0.5',
                no_progress_bar=args.no_progress_bar)
            logger.info(
                f'running FedAvg with the following configuration: {context}')
            run_fedavg(context, 50, save_states=False, dataset=fed_dataset)
        except Exception as e:
            logger.exception(f'Failed to execute configuration {context}', e)
def run_reptile(context: str, initial_model_state=None):

    args = argument_parser().parse_args()
    RANDOM = random.Random(args.seed)

    # TODO: Possibly implement logic using ReptileExperimentContext
    reptile_args = ReptileTrainingArgs(
        model_class=OmniglotLightning,
        sgd=args.sgd,
        inner_learning_rate=args.learning_rate,
        num_inner_steps=args.inner_iters,
        num_inner_steps_eval=args.eval_iters,
        log_every_n_steps=3,
        meta_learning_rate_initial=args.meta_step,
        meta_learning_rate_final=args.meta_step_final,
        num_classes_per_client=args.classes
    )
    experiment_logger = create_tensorboard_logger(
        'reptile',
        (
            f"{context};seed{args.seed};"
            f"train-clients{args.train_clients};"
            f"{args.classes}-way{args.shots}-shot;"
            f"ib{args.inner_batch}ii{args.inner_iters}"
            f"ilr{str(args.learning_rate).replace('.', '')}"
            f"ms{str(args.meta_step).replace('.', '')}"
            f"mb{args.meta_batch}ei{args.eval_iters}"
            f"{'sgd' if args.sgd else 'adam'}"
        )
    )

    # Load and prepare Omniglot data
    data_dir = REPO_ROOT / 'data' / 'omniglot'
    omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets(
        str(data_dir.absolute()),
        num_clients_train=args.train_clients,
        num_clients_test=args.test_clients,
        num_classes_per_client=args.classes,
        num_shots_per_class=args.shots,
        inner_batch_size=args.inner_batch,
        random_seed=args.seed
    )

    # Set up clients
    # Since we are doing meta-learning, we need separate sets of training and
    # test clients
    train_clients = initialize_clients(omniglot_train_clients, reptile_args.get_inner_model_args(), context,
                                       experiment_logger)
    test_clients = initialize_clients(omniglot_test_clients, reptile_args.get_inner_model_args(), context,
                                      experiment_logger)

    # Set up server
    server = ReptileServer(
        participant_name='initial_server',
        model_args=reptile_args.get_meta_model_args(),
        context=context,  # TODO: Change to ReptileExperimentContext
        initial_model_state=initial_model_state
    )

    # Perform training
    for i in range(args.meta_iters):
        if args.meta_batch == -1:
            meta_batch = train_clients
        else:
            meta_batch = [
                train_clients[k] for k in cyclerange(
                    i*args.meta_batch % len(train_clients),
                    (i+1)*args.meta_batch % len(train_clients),
                    len(train_clients)
                )
            ]
        # Meta training step
        reptile_train_step(
            aggregator=server,
            participants=meta_batch,
            inner_training_args=reptile_args.get_inner_training_args(),
            meta_training_args=reptile_args.get_meta_training_args(
                frac_done=i / args.meta_iters
            )
        )

        # Evaluation on train and test clients
        if i % args.eval_interval == 0:
            # train-test set
            # Pick one train client at random and test on it
            k = RANDOM.randrange(len(train_clients))
            client = [train_clients[k]]
            reptile_train_step(
                aggregator=server,
                participants=client,
                inner_training_args=reptile_args.get_inner_training_args(eval=True),
                evaluation_mode=True
            )
            result = evaluate_local_models(participants=client)
            experiment_logger.experiment.add_scalar(
                'train-test/acc/{}/mean'.format('global_model'),
                torch.mean(result.get('test/acc')),
                global_step=i + 1
            )
            # test-test set
            # Pick one test client at random and test on it
            k = RANDOM.randrange(len(test_clients))
            client = [test_clients[k]]
            reptile_train_step(
                aggregator=server,
                participants=client,
                inner_training_args=reptile_args.get_inner_training_args(eval=True),
                evaluation_mode=True
            )
            result = evaluate_local_models(participants=client)
            experiment_logger.experiment.add_scalar(
                'test-test/acc/{}/mean'.format('global_model'),
                torch.mean(result.get('test/acc')),
                global_step=i + 1
            )
        logger.info('finished training round')

    # Final evaluation on a sample of train/test clients
    for label, client_set in zip(['Train', 'Test'], [train_clients, test_clients]):
        eval_sample = RANDOM.sample(client_set, args.eval_samples)
        reptile_train_step(
            aggregator=server,
            participants=eval_sample,
            inner_training_args=reptile_args.get_inner_training_args(eval=True),
            evaluation_mode=True
        )
        result = evaluate_local_models(participants=eval_sample)
        log_loss_and_acc('global_model', result.get('test/loss'), result.get('test/acc'),
                         experiment_logger, global_step=args.meta_iters)
        experiment_logger.experiment.add_scalar(
            f'final_{label}_acc',
            torch.mean(result.get('test/acc')),
            global_step=0
        )
        print(f"{label} accuracy: {torch.mean(result.get('test/acc'))}")
Beispiel #12
0
def run_reptile(context: ExperimentContext, initial_model_state=None):

    num_clients_train = 10000
    num_clients_test = 1000
    num_classes_per_client = 5
    num_shots_per_class = 5

    eval_iters = 10

    reptile_args = ReptileTrainingArgs(model=OmniglotModel,
                                       inner_optimizer=optim.Adam,
                                       inner_learning_rate=0.001,
                                       num_inner_steps=5,
                                       num_inner_steps_eval=50,
                                       log_every_n_steps=3,
                                       inner_batch_size=10,
                                       meta_batch_size=5,
                                       meta_learning_rate_initial=1,
                                       meta_learning_rate_final=0,
                                       num_meta_steps=3000)
    experiment_logger = create_tensorboard_logger(context.name,
                                                  "framework_ours_test_1")

    # Load and prepare Omniglot data
    data_dir = REPO_ROOT / 'data' / 'omniglot'

    omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets(
        str(data_dir.absolute()),
        num_clients_train=num_clients_train,
        num_clients_test=num_clients_test,
        num_classes_per_client=num_classes_per_client,
        num_shots_per_class=num_shots_per_class,
        inner_batch_size=reptile_args.inner_batch_size,
        random_seed=RANDOM_SEED)

    # Prepare ModelArgs for task training
    inner_optimizer_args = OptimizerArgs(
        optimizer_class=reptile_args.inner_optimizer,
        lr=reptile_args.inner_learning_rate,
        betas=(0, 0.999))
    inner_model_args = ModelArgs(reptile_args.model,
                                 inner_optimizer_args,
                                 num_classes=num_classes_per_client)
    dummy_optimizer_args = OptimizerArgs(optimizer_class=optim.SGD)
    meta_model_args = ModelArgs(reptile_args.model,
                                dummy_optimizer_args,
                                num_classes=num_classes_per_client)

    ####################
    torch_model = OmniglotModel(num_classes=num_classes_per_client)
    torch_optimizer = inner_optimizer_args.optimizer_class(
        torch_model.parameters(), **inner_optimizer_args.optimizer_kwargs)
    torch_criterion = torch.nn.CrossEntropyLoss()
    ####################

    # Set up clients
    # Since we are doing meta-learning, we need separate sets of training and
    # test clients
    train_clients = []
    for c in omniglot_train_clients.train_data_local_dict.keys():
        client = ReptileClient(
            client_id=str(c),
            model_args=inner_model_args,
            context=context,
            train_dataloader=omniglot_train_clients.train_data_local_dict[c],
            num_train_samples=omniglot_train_clients.
            data_local_train_num_dict[c],
            test_dataloader=omniglot_train_clients.test_data_local_dict[c],
            num_test_samples=omniglot_train_clients.
            data_local_test_num_dict[c],
            lightning_logger=experiment_logger)
        #checkpoint_callback = ModelCheckpoint(
        #    filepath=str(client.get_checkpoint_path(suffix='cb').absolute()))
        #client.set_trainer_callbacks([checkpoint_callback])
        train_clients.append(client)
    test_clients = []
    for c in omniglot_test_clients.train_data_local_dict.keys():
        client = ReptileClient(
            client_id=str(c),
            model_args=inner_model_args,
            context=context,
            train_dataloader=omniglot_test_clients.train_data_local_dict[c],
            num_train_samples=omniglot_test_clients.
            data_local_train_num_dict[c],
            test_dataloader=omniglot_test_clients.test_data_local_dict[c],
            num_test_samples=omniglot_test_clients.data_local_test_num_dict[c],
            lightning_logger=experiment_logger)
        test_clients.append(client)

    # Set up server
    server = ReptileServer(model_state=copy.deepcopy(torch_model.state_dict()),
                           participant_name='initial_server',
                           model_args=meta_model_args,
                           context=context,
                           initial_model_state=initial_model_state)

    for i in range(reptile_args.num_meta_steps):

        meta_batch = [
            train_clients[k] for k in cyclerange(
                i * reptile_args.meta_batch_size %
                len(train_clients), (i + 1) * reptile_args.meta_batch_size %
                len(train_clients), len(train_clients))
        ]
        """reptile.train_step(
            meta_batch=meta_batch,
            meta_step_size=cur_meta_step_size
        )"""
        reptile_train_step(
            model=torch_model,
            optimizer=torch_optimizer,
            criterion=torch_criterion,
            aggregator=server,
            participants=meta_batch,
            inner_training_args=reptile_args.get_inner_training_args(),
            meta_training_args=reptile_args.get_meta_training_args(
                frac_done=i / reptile_args.num_meta_steps))

        if i % eval_iters == 0:
            accuracies = []
            # train-test set
            # Pick one train client at random and test on it
            client = [RANDOM.choice(train_clients)]
            reptile_train_step(
                model=torch_model,
                optimizer=torch_optimizer,
                criterion=torch_criterion,
                aggregator=server,
                participants=client,
                inner_training_args=reptile_args.get_inner_training_args(
                    eval=True),
                evaluation_mode=True)
            inputs, labels = list(client[0]._test_dataloader)[0]
            torch_model.load_state_dict(client[0].model_state)
            test_preds = torch_model(inputs).argmax(dim=1)
            num_correct = int(
                sum([pred == label
                     for pred, label in zip(test_preds, labels)]))
            ##############result = evaluate_local_models(participants=client)
            accuracies.append(num_correct / num_classes_per_client)
            # log_loss_and_acc('global_model', result.get('test/loss'), result.get('test/acc'),
            #                 experiment_logger, global_step=i+1)
            #experiment_logger.experiment.add_scalar('train-test/acc/{}/mean'.format('global_model'),
            #                                        torch.mean(result.get('test/acc')),
            #                                        global_step=i + 1)
            # test-test set
            # Pick one test client at random and test on it
            client = [RANDOM.choice(test_clients)]
            reptile_train_step(
                model=torch_model,
                optimizer=torch_optimizer,
                criterion=torch_criterion,
                aggregator=server,
                participants=client,
                inner_training_args=reptile_args.get_inner_training_args(
                    eval=True),
                evaluation_mode=True)
            inputs, labels = list(client[0]._test_dataloader)[0]
            torch_model.load_state_dict(client[0].model_state)
            test_preds = torch_model(inputs).argmax(dim=1)
            num_correct = int(
                sum([pred == label
                     for pred, label in zip(test_preds, labels)]))
            ##############result = evaluate_local_models(participants=client)
            accuracies.append(num_correct / num_classes_per_client)
            # log_loss_and_acc('global_model', result.get('test/loss'), result.get('test/acc'),
            #                 experiment_logger, global_step=i+1)

            print('batch %d: train=%f test=%f' %
                  (i, accuracies[0], accuracies[1]))
            experiment_logger.experiment.add_scalar(
                'train-test/acc/{}/mean'.format('global_model'),
                accuracies[0],
                global_step=i)
            experiment_logger.experiment.add_scalar(
                'test-test/acc/{}/mean'.format('global_model'),
                accuracies[1],
                global_step=i)
    """
Beispiel #13
0
def run_reptile(context: ExperimentContext, initial_model_state=None):

    num_clients_train = 10000
    num_clients_test = 1000
    num_classes_per_client = 5
    num_shots_per_class = 5

    # Every *eval_iters* meta steps, evaluation is performed on one random
    # client in the training and test set, respectively
    eval_iters = 10

    reptile_args = ReptileTrainingArgs(model=OmniglotLightning,
                                       inner_optimizer=optim.Adam,
                                       inner_learning_rate=0.001,
                                       num_inner_steps=5,
                                       num_inner_steps_eval=50,
                                       log_every_n_steps=3,
                                       inner_batch_size=10,
                                       meta_batch_size=5,
                                       meta_learning_rate_initial=1,
                                       meta_learning_rate_final=0,
                                       num_meta_steps=3000)
    experiment_logger = create_tensorboard_logger(context.name, (
        f"nichol_reptile_our_dataloading;{num_clients_train}clients;{num_classes_per_client}"
        f"-way{num_shots_per_class}-shot;"
        f"mlr{str(reptile_args.meta_learning_rate_initial).replace('.', '')}"
        f"ilr{str(reptile_args.inner_learning_rate).replace('.', '')}"
        f"is{reptile_args.num_inner_steps}"))

    # Load and prepare Omniglot data
    data_dir = REPO_ROOT / 'data' / 'omniglot'
    tf.disable_eager_execution()
    omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets(
        str(data_dir.absolute()),
        num_clients_train=num_clients_train,
        num_clients_test=num_clients_test,
        num_classes_per_client=num_classes_per_client,
        num_shots_per_class=num_shots_per_class,
        inner_batch_size=reptile_args.inner_batch_size,
        tensorflow=True,
        random_seed=RANDOM_SEED)

    with tf.Session() as sess:
        model = OmniglotModel(num_classes_per_client,
                              learning_rate=reptile_args.inner_learning_rate)
        reptile = Reptile(sess, transductive=True, pre_step_op=weight_decay(1))
        accuracy_ph = tf.placeholder(tf.float32, shape=())
        tf.summary.scalar('accuracy', accuracy_ph)
        merged = tf.summary.merge_all()
        tf.global_variables_initializer().run()
        sess.run(tf.global_variables_initializer())
        for i in range(reptile_args.num_meta_steps):
            frac_done = i / reptile_args.num_meta_steps
            cur_meta_step_size = frac_done * reptile_args.meta_learning_rate_final + (
                1 - frac_done) * reptile_args.meta_learning_rate_initial

            meta_batch = {
                k: omniglot_train_clients.train_data_local_dict[k]
                for k in cyclerange(
                    i * reptile_args.meta_batch_size %
                    len(omniglot_train_clients.train_data_local_dict), (i +
                                                                        1) *
                    reptile_args.meta_batch_size %
                    len(omniglot_train_clients.train_data_local_dict),
                    len(omniglot_train_clients.train_data_local_dict))
            }

            reptile.train_step(meta_batch=meta_batch,
                               input_ph=model.input_ph,
                               label_ph=model.label_ph,
                               minimize_op=model.minimize_op,
                               inner_iters=reptile_args.num_inner_steps,
                               meta_step_size=cur_meta_step_size)
            if i % eval_iters == 0:
                accuracies = []
                k = RANDOM.randrange(
                    len(omniglot_train_clients.train_data_local_dict))
                train_train = omniglot_train_clients.train_data_local_dict[k]
                train_test = omniglot_train_clients.test_data_local_dict[k]
                k = RANDOM.randrange(
                    len(omniglot_test_clients.train_data_local_dict))
                test_train = omniglot_test_clients.train_data_local_dict[k]
                test_test = omniglot_test_clients.test_data_local_dict[k]

                for train_dl, test_dl in [(train_train, train_test),
                                          (test_train, test_test)]:
                    correct = reptile.evaluate(
                        train_data_loader=train_dl,
                        test_data_loader=test_dl,
                        input_ph=model.input_ph,
                        label_ph=model.label_ph,
                        minimize_op=model.minimize_op,
                        predictions=model.predictions,
                        inner_iters=reptile_args.num_inner_steps_eval)
                    #summary = sess.run(merged, feed_dict={accuracy_ph: correct/num_classes_per_client})
                    accuracies.append(correct / num_classes_per_client)
                print('batch %d: train=%f test=%f' %
                      (i, accuracies[0], accuracies[1]))

                # Write to TensorBoard
                experiment_logger.experiment.add_scalar(
                    'train-test/acc/{}/mean'.format('global_model'),
                    accuracies[0],
                    global_step=i)
                experiment_logger.experiment.add_scalar(
                    'test-test/acc/{}/mean'.format('global_model'),
                    accuracies[1],
                    global_step=i)
def run_reptile(context: str, initial_model_state=None):

    args = argument_parser().parse_args()
    RANDOM = random.Random(args.seed)

    experiment_logger = create_tensorboard_logger(
        'reptile', (f"{context};seed{args.seed};"
                    f"train-clients{args.train_clients};"
                    f"{args.classes}-way{args.shots}-shot;"
                    f"ib{args.inner_batch}ii{args.inner_iters}"
                    f"ilr{str(args.learning_rate).replace('.', '')}"
                    f"ms{str(args.meta_step).replace('.', '')}"
                    f"mb{args.meta_batch}ei{args.eval_iters}"
                    f"{'sgd' if args.sgd else 'adam'}"))

    # Load and prepare Omniglot data
    data_dir = REPO_ROOT / 'data' / 'omniglot'
    tf.disable_eager_execution()
    omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets(
        str(data_dir.absolute()),
        num_clients_train=args.train_clients,
        num_clients_test=args.test_clients,
        num_classes_per_client=args.classes,
        num_shots_per_class=args.shots,
        inner_batch_size=args.inner_batch,
        tensorflow=True,
        random_seed=args.seed)

    model_kwargs = {'learning_rate': args.learning_rate}
    if args.sgd:
        model_kwargs['optimizer'] = tf.train.GradientDescentOptimizer
    model = OmniglotModel(num_classes=args.classes, **model_kwargs)

    with tf.Session() as sess:
        reptile = ReptileForFederatedData(session=sess,
                                          transductive=True,
                                          pre_step_op=weight_decay(1))
        accuracy_ph = tf.placeholder(tf.float32, shape=())
        tf.summary.scalar('accuracy', accuracy_ph)
        merged = tf.summary.merge_all()
        tf.global_variables_initializer().run()
        sess.run(tf.global_variables_initializer())

        for i in range(args.meta_iters):
            frac_done = i / args.meta_iters
            cur_meta_step_size = frac_done * args.meta_step_final + (
                1 - frac_done) * args.meta_step

            crange = cyclerange(
                start=i * args.meta_batch %
                len(omniglot_train_clients.train_data_local_dict),
                stop=(i + 1) * args.meta_batch %
                len(omniglot_train_clients.train_data_local_dict),
                len=len(omniglot_train_clients.train_data_local_dict))
            #print(f"Meta-step {i}: train clients {crange[0]}-{crange[-1]}")
            meta_batch = {
                k: omniglot_train_clients.train_data_local_dict[k]
                for k in crange
            }
            reptile.train_step(meta_batch=meta_batch,
                               input_ph=model.input_ph,
                               label_ph=model.label_ph,
                               minimize_op=model.minimize_op,
                               inner_iters=args.inner_iters,
                               meta_step_size=cur_meta_step_size)
            if i % args.eval_interval == 0:
                accuracies = []
                k = RANDOM.randrange(
                    len(omniglot_train_clients.train_data_local_dict))
                train_train = omniglot_train_clients.train_data_local_dict[k]
                train_test = omniglot_train_clients.test_data_local_dict[k]
                k = RANDOM.randrange(
                    len(omniglot_test_clients.train_data_local_dict))
                test_train = omniglot_test_clients.train_data_local_dict[k]
                test_test = omniglot_test_clients.test_data_local_dict[k]

                for train_dl, test_dl in [(train_train, train_test),
                                          (test_train, test_test)]:
                    correct = reptile.evaluate(train_data_loader=train_dl,
                                               test_data_loader=test_dl,
                                               input_ph=model.input_ph,
                                               label_ph=model.label_ph,
                                               minimize_op=model.minimize_op,
                                               predictions=model.predictions,
                                               inner_iters=args.eval_iters)
                    #summary = sess.run(merged, feed_dict={accuracy_ph: correct/num_classes_per_client})
                    accuracies.append(correct / args.classes)
                print('batch %d: train=%f test=%f' %
                      (i, accuracies[0], accuracies[1]))

                # Write to TensorBoard
                experiment_logger.experiment.add_scalar(
                    'train-test/acc/{}/mean'.format('global_model'),
                    accuracies[0],
                    global_step=i)
                experiment_logger.experiment.add_scalar(
                    'test-test/acc/{}/mean'.format('global_model'),
                    accuracies[1],
                    global_step=i)

        # Final evaluation on a sample of training/test clients
        for label, dataset in zip(
            ['Train', 'Test'],
            [omniglot_train_clients, omniglot_test_clients]):
            keys = RANDOM.sample(dataset.train_data_local_dict.keys(),
                                 args.eval_samples)
            train_eval_sample = {
                k: dataset.train_data_local_dict[k]
                for k in keys
            }
            test_eval_sample = {
                k: dataset.test_data_local_dict[k]
                for k in keys
            }
            accuracy = evaluate(sess=sess,
                                model=model,
                                train_dataloaders=train_eval_sample,
                                test_dataloaders=test_eval_sample,
                                num_classes=args.classes,
                                eval_inner_iters=args.eval_iters,
                                transductive=True)
            experiment_logger.experiment.add_scalar(f'final_{label}_acc',
                                                    accuracy,
                                                    global_step=0)
            print(f"{label} accuracy: {accuracy}")