def run_reptile(context: str, initial_model_state=None):

    args = argument_parser().parse_args()
    RANDOM = random.Random(args.seed)

    # TODO: Possibly implement logic using ReptileExperimentContext
    reptile_args = ReptileTrainingArgs(
        model_class=OmniglotLightning,
        sgd=args.sgd,
        inner_learning_rate=args.learning_rate,
        num_inner_steps=args.inner_iters,
        num_inner_steps_eval=args.eval_iters,
        log_every_n_steps=3,
        meta_learning_rate_initial=args.meta_step,
        meta_learning_rate_final=args.meta_step_final,
        num_classes_per_client=args.classes
    )
    experiment_logger = create_tensorboard_logger(
        'reptile',
        (
            f"{context};seed{args.seed};"
            f"train-clients{args.train_clients};"
            f"{args.classes}-way{args.shots}-shot;"
            f"ib{args.inner_batch}ii{args.inner_iters}"
            f"ilr{str(args.learning_rate).replace('.', '')}"
            f"ms{str(args.meta_step).replace('.', '')}"
            f"mb{args.meta_batch}ei{args.eval_iters}"
            f"{'sgd' if args.sgd else 'adam'}"
        )
    )

    # Load and prepare Omniglot data
    data_dir = REPO_ROOT / 'data' / 'omniglot'
    omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets(
        str(data_dir.absolute()),
        num_clients_train=args.train_clients,
        num_clients_test=args.test_clients,
        num_classes_per_client=args.classes,
        num_shots_per_class=args.shots,
        inner_batch_size=args.inner_batch,
        random_seed=args.seed
    )

    # Set up clients
    # Since we are doing meta-learning, we need separate sets of training and
    # test clients
    train_clients = initialize_clients(omniglot_train_clients, reptile_args.get_inner_model_args(), context,
                                       experiment_logger)
    test_clients = initialize_clients(omniglot_test_clients, reptile_args.get_inner_model_args(), context,
                                      experiment_logger)

    # Set up server
    server = ReptileServer(
        participant_name='initial_server',
        model_args=reptile_args.get_meta_model_args(),
        context=context,  # TODO: Change to ReptileExperimentContext
        initial_model_state=initial_model_state
    )

    # Perform training
    for i in range(args.meta_iters):
        if args.meta_batch == -1:
            meta_batch = train_clients
        else:
            meta_batch = [
                train_clients[k] for k in cyclerange(
                    i*args.meta_batch % len(train_clients),
                    (i+1)*args.meta_batch % len(train_clients),
                    len(train_clients)
                )
            ]
        # Meta training step
        reptile_train_step(
            aggregator=server,
            participants=meta_batch,
            inner_training_args=reptile_args.get_inner_training_args(),
            meta_training_args=reptile_args.get_meta_training_args(
                frac_done=i / args.meta_iters
            )
        )

        # Evaluation on train and test clients
        if i % args.eval_interval == 0:
            # train-test set
            # Pick one train client at random and test on it
            k = RANDOM.randrange(len(train_clients))
            client = [train_clients[k]]
            reptile_train_step(
                aggregator=server,
                participants=client,
                inner_training_args=reptile_args.get_inner_training_args(eval=True),
                evaluation_mode=True
            )
            result = evaluate_local_models(participants=client)
            experiment_logger.experiment.add_scalar(
                'train-test/acc/{}/mean'.format('global_model'),
                torch.mean(result.get('test/acc')),
                global_step=i + 1
            )
            # test-test set
            # Pick one test client at random and test on it
            k = RANDOM.randrange(len(test_clients))
            client = [test_clients[k]]
            reptile_train_step(
                aggregator=server,
                participants=client,
                inner_training_args=reptile_args.get_inner_training_args(eval=True),
                evaluation_mode=True
            )
            result = evaluate_local_models(participants=client)
            experiment_logger.experiment.add_scalar(
                'test-test/acc/{}/mean'.format('global_model'),
                torch.mean(result.get('test/acc')),
                global_step=i + 1
            )
        logger.info('finished training round')

    # Final evaluation on a sample of train/test clients
    for label, client_set in zip(['Train', 'Test'], [train_clients, test_clients]):
        eval_sample = RANDOM.sample(client_set, args.eval_samples)
        reptile_train_step(
            aggregator=server,
            participants=eval_sample,
            inner_training_args=reptile_args.get_inner_training_args(eval=True),
            evaluation_mode=True
        )
        result = evaluate_local_models(participants=eval_sample)
        log_loss_and_acc('global_model', result.get('test/loss'), result.get('test/acc'),
                         experiment_logger, global_step=args.meta_iters)
        experiment_logger.experiment.add_scalar(
            f'final_{label}_acc',
            torch.mean(result.get('test/acc')),
            global_step=0
        )
        print(f"{label} accuracy: {torch.mean(result.get('test/acc'))}")
def run_hierarchical_clustering_reptile(
        seed,
        name,
        dataset,
        num_clients,
        batch_size,
        num_label_limit,
        use_colored_images,
        sample_threshold,
        hc_lr,
        hc_cluster_initialization_rounds,
        hc_client_fraction,
        hc_local_epochs,
        hc_train_args,
        hc_partitioner_class,
        hc_linkage_mech,
        hc_criterion,
        hc_dis_metric,
        hc_max_value_criterion,  # distance threshold
        hc_reallocate_clients,  #
        hc_threshold_min_client_cluster,  # only with hc_reallocate_clients = True,
        # results in clusters having at least this number of clients
    hc_train_cluster_args,
        rp_sgd,  # True -> Use SGD as inner optimizer; False -> Use Adam
        rp_adam_betas,  # Used only if sgd = False
        rp_meta_batch_size,
        rp_num_meta_steps,
        rp_meta_learning_rate_initial,
        rp_meta_learning_rate_final,
        rp_eval_interval,
        rp_inner_learning_rate,
        rp_num_inner_steps,
        rp_num_inner_steps_eval):
    fix_random_seeds(seed)
    global_tag = 'global_performance'

    if dataset == 'femnist':
        if use_colored_images:
            fed_dataset = load_femnist_colored_dataset(
                data_dir=str((REPO_ROOT / 'data').absolute()),
                num_clients=num_clients,
                batch_size=batch_size,
                sample_threshold=sample_threshold)
        else:
            fed_dataset = load_femnist_dataset(
                data_dir=str((REPO_ROOT / 'data').absolute()),
                num_clients=num_clients,
                batch_size=batch_size,
                sample_threshold=sample_threshold)
        if num_label_limit != -1:
            fed_dataset = scratch_labels(fed_dataset, num_label_limit)
    else:
        raise ValueError(f'dataset "{dataset}" unknown')

    if not hasattr(hc_max_value_criterion, '__iter__'):
        hc_max_value_criterion = [hc_max_value_criterion]
    if not hasattr(hc_lr, '__iter__'):
        hc_lr = [hc_lr]
    input_channels = 3 if use_colored_images else 1
    data_distribution_logged = False
    for cf in hc_client_fraction:
        for lr_i in hc_lr:
            # Initialize experiment context parameters
            fedavg_optimizer_args = OptimizerArgs(optim.SGD, lr=lr_i)
            fedavg_model_args = ModelArgs(CNNLightning,
                                          optimizer_args=fedavg_optimizer_args,
                                          input_channels=input_channels,
                                          only_digits=False)
            fedavg_context = FedAvgExperimentContext(
                name=name,
                client_fraction=cf,
                local_epochs=hc_local_epochs,
                lr=lr_i,
                batch_size=batch_size,
                optimizer_args=fedavg_optimizer_args,
                model_args=fedavg_model_args,
                train_args=hc_train_args,
                dataset_name=dataset)
            reptile_context = ReptileExperimentContext(
                name=name,
                dataset_name=dataset,
                swap_labels=False,
                num_classes_per_client=0,
                num_shots_per_class=0,
                seed=seed,
                model_class=CNNLightning,
                sgd=rp_sgd,
                adam_betas=rp_adam_betas,
                num_clients_train=num_clients,
                num_clients_test=0,
                meta_batch_size=rp_meta_batch_size,
                num_meta_steps=rp_num_meta_steps,
                meta_learning_rate_initial=rp_meta_learning_rate_initial,
                meta_learning_rate_final=rp_meta_learning_rate_final,
                eval_interval=rp_eval_interval,
                num_eval_clients_training=-1,
                do_final_evaluation=True,
                num_eval_clients_final=-1,
                inner_batch_size=batch_size,
                inner_learning_rate=rp_inner_learning_rate,
                num_inner_steps=rp_num_inner_steps,
                num_inner_steps_eval=rp_num_inner_steps_eval)
            experiment_specification = f'{fedavg_context}'
            experiment_logger = create_tensorboard_logger(
                name, experiment_specification)
            if not data_distribution_logged:
                log_dataset_distribution(experiment_logger, 'full dataset',
                                         fed_dataset)
                data_distribution_logged = True

            log_after_round_evaluation_fns = [
                partial(log_after_round_evaluation, experiment_logger,
                        'fedavg'),
                partial(log_after_round_evaluation, experiment_logger,
                        global_tag)
            ]
            server, clients = run_fedavg(
                context=fedavg_context,
                num_rounds=max(hc_cluster_initialization_rounds),
                dataset=fed_dataset,
                save_states=True,
                restore_state=True,
                after_round_evaluation=log_after_round_evaluation_fns)

            for init_rounds, max_value in generate_configuration(
                    hc_cluster_initialization_rounds, hc_max_value_criterion):
                # load the model state
                round_model_state = load_fedavg_state(fedavg_context,
                                                      init_rounds)
                overwrite_participants_models(round_model_state, clients)
                # initialize the cluster configuration
                round_configuration = {
                    'num_rounds_init': init_rounds,
                    'num_rounds_cluster': 0
                }
                cluster_args = ClusterArgs(
                    hc_partitioner_class,
                    linkage_mech=hc_linkage_mech,
                    criterion=hc_criterion,
                    dis_metric=hc_dis_metric,
                    max_value_criterion=max_value,
                    plot_dendrogram=False,
                    reallocate_clients=hc_reallocate_clients,
                    threshold_min_client_cluster=
                    hc_threshold_min_client_cluster,
                    **round_configuration)
                # create new logger for cluster experiment
                experiment_specification = f'{fedavg_context}_{cluster_args}_{reptile_context}'
                experiment_logger = create_tensorboard_logger(
                    name, experiment_specification)
                fedavg_context.experiment_logger = experiment_logger

                initial_train_fn = partial(run_fedavg_train_round,
                                           round_model_state,
                                           training_args=hc_train_cluster_args)
                create_aggregator_fn = partial(FedAvgServer,
                                               model_args=fedavg_model_args,
                                               context=fedavg_context)

                # HIERARCHICAL CLUSTERING
                logger.debug('starting local training before clustering.')
                trained_participants = initial_train_fn(clients)
                if len(trained_participants) != len(clients):
                    raise ValueError(
                        'not all clients successfully participated in the clustering round'
                    )

                # Clustering of participants by model updates
                partitioner = cluster_args()
                cluster_clients_dic = partitioner.cluster(clients, server)
                _cluster_clients_dic = dict()
                for cluster_id, participants in cluster_clients_dic.items():
                    _cluster_clients_dic[cluster_id] = [
                        c._name for c in participants
                    ]
                log_cluster_distribution(experiment_logger,
                                         cluster_clients_dic, 62)

                # Initialize cluster models
                cluster_server_dic = {}
                for cluster_id, participants in cluster_clients_dic.items():
                    intermediate_cluster_server = create_aggregator_fn(
                        'cluster_server' + cluster_id)
                    intermediate_cluster_server.aggregate(participants)
                    cluster_server = ReptileServer(
                        participant_name=f'cluster_server{cluster_id}',
                        model_args=reptile_context.meta_model_args,
                        context=reptile_context,
                        initial_model_state=intermediate_cluster_server.model.
                        state_dict())
                    #create_aggregator_fn('cluster_server' + cluster_id)
                    #cluster_server.aggregate(participants)
                    cluster_server_dic[cluster_id] = cluster_server

                # REPTILE TRAINING INSIDE CLUSTERS
                after_round_evaluation = [log_after_round_evaluation]
                RANDOM = random.Random(seed)

                # Perform training
                for i in range(reptile_context.num_meta_steps):
                    for cluster_id, participants in cluster_clients_dic.items(
                    ):

                        if reptile_context.meta_batch_size == -1:
                            meta_batch = participants
                        else:
                            meta_batch = [
                                participants[k] for k in cyclerange(
                                    start=i * reptile_context.meta_batch_size %
                                    len(participants),
                                    interval=reptile_context.meta_batch_size,
                                    total_len=len(participants))
                            ]
                        # Meta training step
                        reptile_train_step(
                            aggregator=cluster_server_dic[cluster_id],
                            participants=meta_batch,
                            inner_training_args=reptile_context.
                            get_inner_training_args(),
                            meta_training_args=reptile_context.
                            get_meta_training_args(
                                frac_done=i / reptile_context.num_meta_steps))

                    # Evaluation on train and test clients
                    if i % reptile_context.eval_interval == 0:
                        global_step = init_rounds + i
                        global_loss, global_acc = [], []

                        for cluster_id, participants in cluster_clients_dic.items(
                        ):
                            # Test on all clients inside clusters
                            reptile_train_step(
                                aggregator=cluster_server_dic[cluster_id],
                                participants=participants,
                                inner_training_args=reptile_context.
                                get_inner_training_args(eval=True),
                                evaluation_mode=True)
                            result = evaluate_local_models(
                                participants=participants)
                            loss = result.get('test/loss')
                            acc = result.get('test/acc')

                            # Log
                            if after_round_evaluation is not None:
                                for c in after_round_evaluation:
                                    c(experiment_logger,
                                      f'cluster_{cluster_id}', loss, acc,
                                      global_step)
                            loss_list = loss.tolist()
                            acc_list = acc.tolist()
                            global_loss.extend(loss_list if isinstance(
                                loss_list, list) else [loss_list])
                            global_acc.extend(acc_list if isinstance(
                                acc_list, list) else [acc_list])

                        if after_round_evaluation is not None:
                            for c in after_round_evaluation:
                                c(experiment_logger, 'mean_over_all_clients',
                                  Tensor(global_loss), Tensor(global_acc),
                                  global_step)

                    logger.info(f'Finished Reptile training round {i}')

                # Final evaluation at end of training
                if reptile_context.do_final_evaluation:
                    global_loss, global_acc = [], []

                    for cluster_id, participants in cluster_clients_dic.items(
                    ):
                        # Final evaluation on train and test clients
                        # Test on all clients inside clusters
                        reptile_train_step(
                            aggregator=cluster_server_dic[cluster_id],
                            participants=participants,
                            inner_training_args=reptile_context.
                            get_inner_training_args(eval=True),
                            evaluation_mode=True)
                        result = evaluate_local_models(
                            participants=participants)
                        loss = result.get('test/loss')
                        acc = result.get('test/acc')
                        print(
                            f'Cluster {cluster_id} ({len(participants)} part.): loss = {loss}, acc = {acc}'
                        )

                        loss_list = loss.tolist()
                        acc_list = acc.tolist()
                        global_loss.extend(loss_list if isinstance(
                            loss_list, list) else [loss_list])
                        global_acc.extend(acc_list if isinstance(
                            acc_list, list) else [acc_list])

                        # Log
                        if after_round_evaluation is not None:
                            for c in after_round_evaluation:
                                c(experiment_logger, f'cluster_{cluster_id}',
                                  loss, acc, reptile_context.num_meta_steps)

                    log_loss_and_acc('overall_mean', Tensor(global_loss),
                                     Tensor(global_acc), experiment_logger, 0)
Exemple #3
0
def run_reptile(context: ReptileExperimentContext,
                dataset_train: FederatedDatasetData,
                dataset_test: FederatedDatasetData,
                initial_model_state,
                after_round_evaluation):
    RANDOM = random.Random(context.seed)

    # Randomly swap labels
    if context.swap_labels:
        dataset_train = swap_labels(
            fed_dataset=dataset_train,
            max_classes_per_client=62,
            random_seed=context.seed
        )
        if dataset_test is not None:
            dataset_test = swap_labels(
                fed_dataset=dataset_test,
                max_classes_per_client=62,
                random_seed=context.seed
            )

    # Set up clients
    train_clients = initialize_clients(
        dataset=dataset_train,
        model_args=context.inner_model_args,
        context=context.name,
        experiment_logger=context.experiment_logger
    )
    if dataset_test is not None:
        test_clients = initialize_clients(
            dataset=dataset_test,
            model_args=context.inner_model_args,
            context=context.name,
            experiment_logger=context.experiment_logger
        )
    else:
        test_clients = None

    # Set up server
    server = ReptileServer(
        participant_name='initial_server',
        model_args=context.meta_model_args,
        context=context.name,
        initial_model_state=initial_model_state
    )

    # Perform training
    for i in range(context.num_meta_steps):
        if context.meta_batch_size == -1:
            meta_batch = train_clients
        else:
            meta_batch = [
                train_clients[k] for k in cyclerange(
                    start=i*context.meta_batch_size % len(train_clients),
                    interval=context.meta_batch_size,
                    total_len=len(train_clients)
                )
            ]
        # Meta training step
        reptile_train_step(
            aggregator=server,
            participants=meta_batch,
            inner_training_args=context.get_inner_training_args(),
            meta_training_args=context.get_meta_training_args(
                frac_done=i / context.num_meta_steps
            )
        )

        # Evaluation on train and test clients
        if i % context.eval_interval == 0:
            # Pick train / test clients at random and test on them
            losses, accs = [], []
            for client_set in [train_clients, test_clients]:
                if client_set:
                    if context.num_eval_clients_training == -1:
                        clients = client_set
                    else:
                        clients = RANDOM.sample(
                            client_set, context.num_eval_clients_training
                        )
                    reptile_train_step(
                        aggregator=server,
                        participants=clients,
                        inner_training_args=context.get_inner_training_args(eval=True),
                        evaluation_mode=True
                    )
                    result = evaluate_local_models(participants=clients)
                    losses.append(result.get('test/loss'))
                    accs.append(result.get('test/acc'))
                else:
                    losses.append(None)
                    accs.append(None)

            # Log
            if after_round_evaluation is not None:
                for c in after_round_evaluation:
                    c('', losses[0], accs[0], losses[1], accs[1], i)

        logger.info('finished training round')

    if context.do_final_evaluation:
        # Final evaluation on subsample of train / test clients
        losses, accs = [], []
        for client_set in [train_clients, test_clients]:
            if client_set:
                if context.num_eval_clients_final == -1:
                    eval_clients = client_set
                else:
                    eval_clients = RANDOM.sample(
                        client_set, context.num_eval_clients_final
                    )
                reptile_train_step(
                    aggregator=server,
                    participants=eval_clients,
                    inner_training_args=context.get_inner_training_args(eval=True),
                    evaluation_mode=True
                )
                result = evaluate_local_models(participants=eval_clients)
                losses.append(result.get('test/loss'))
                accs.append(result.get('test/acc'))
            else:
                losses.append(None)
                accs.append(None)

        # Log
        if after_round_evaluation is not None:
            for c in after_round_evaluation:
                c('final_', losses[0], accs[0], losses[1], accs[1], 0)