コード例 #1
0
def aggregate(
    conf,
    fedavg_model,
    client_models,
    criterion,
    metrics,
    data_info,
    flatten_local_models,
    fa_val_perf,
):
    if fa_val_perf["top1"] > conf.fl_aggregate["top1_starting_threshold"]:
        # recover the models.
        _, local_models = agg_utils.recover_models(conf, client_models,
                                                   flatten_local_models)

        # create the virtual labels.
        dataset, labels, eps = create_virtual_labels(conf, fedavg_model,
                                                     local_models, data_info)
        conf.logger.log(f"the used label smoothing={eps}")

        # train the model on the server with the created virtual model
        data_loaders = rebuild_dataset(conf, data_info, dataset, labels)
        fedavg_model = training(conf, fedavg_model, criterion, data_loaders,
                                eps)

        # free the memory.
        del local_models
    else:
        conf.logger.log(f"skip and directly return the model.")
    return fedavg_model
コード例 #2
0
def fedavg(
    conf,
    clientid2arch,
    n_selected_clients,
    flatten_local_models,
    client_models,
    criterion,
    metrics,
    val_data_loader,
):
    if ("server_teaching_scheme" not in conf.fl_aggregate
            or "drop" not in conf.fl_aggregate["server_teaching_scheme"]):
        # directly averaging.
        conf.logger.log(f"No indices to be removed.")
        return _fedavg(clientid2arch, n_selected_clients, flatten_local_models,
                       client_models)
    else:
        # we will first perform the evaluation.
        # recover the models on the computation device.
        _, local_models = agg_utils.recover_models(conf, client_models,
                                                   flatten_local_models)

        # get the weights from the validation performance.
        weights = []
        relationship = {}
        indices_to_remove = []
        random_guess_perf = agg_utils.get_random_guess_perf(conf)
        for idx, (client_id, local_model) in enumerate(local_models.items()):
            relationship[idx] = client_id
            validated_perfs = validate(
                conf,
                model=local_model,
                criterion=criterion,
                metrics=metrics,
                data_loader=val_data_loader,
            )
            perf = validated_perfs["top1"]
            weights.append(perf)

            # check the perf.
            if perf < random_guess_perf:
                indices_to_remove.append(idx)

        # update client_teacher.
        conf.logger.log(
            f"Indices to be removed for FedAvg: {indices_to_remove}; the original performance is: {weights}."
        )
        for index in indices_to_remove[::-1]:
            flatten_local_models.pop(relationship[index])
        return _fedavg(
            clientid2arch,
            n_selected_clients - len(indices_to_remove),
            flatten_local_models,
            client_models,
        )
コード例 #3
0
def aggregate(
    conf,
    fedavg_model,
    client_models,
    criterion,
    metrics,
    flatten_local_models,
    fa_val_perf,
):
    if ("top1_starting_threshold" in conf.fl_aggregate and fa_val_perf["top1"]
            > conf.fl_aggregate["top1_starting_threshold"]):
        # recover the models on the computation device.
        _, local_models = agg_utils.recover_models(conf, client_models,
                                                   flatten_local_models)

        # generate the data for each local models.
        generated_data = {}
        for idx, local_model in local_models.items():
            conf.logger.log(f"distill the knowledge for model_idx={idx}.")
            kt_data_generator = DataGenerator(conf,
                                              model=local_model,
                                              model_idx=idx)
            generated_data[idx] = kt_data_generator.construct_data()

        #
        for out_iter in range(int(conf.fl_aggregate["outer_iters"])):
            conf.logger.log(f"starting the {out_iter}-th knowledge transfer.")
            for idx, dataset in generated_data.items():
                master_model = distill_knowledge(
                    conf,
                    fedavg_model,
                    dataset=dataset,
                    num_epochs=int(conf.fl_aggregate["inner_epochs"]),
                    batch_size=int(
                        conf.fl_aggregate["kt_g_batch_size_per_class"]),
                    teacher_model=local_models[idx]
                    if "softmax_temperature" in conf.fl_aggregate else None,
                    softmax_temperature=1
                    if "softmax_temperature" not in conf.fl_aggregate else
                    conf.fl_aggregate["softmax_temperature"],
                )

        # free the memory.
        del local_models
    else:
        conf.logger.log(f"skip and directly return the model.")

    # a temp hack (only for debug reason).
    client_models = dict((used_client_arch, master_model.cpu())
                         for used_client_arch in conf.used_client_archs)
    return master_model, client_models
コード例 #4
0
def learning2aggregate(conf, fedavg_model, client_models, flatten_local_models,
                       criterion, data_loader):
    # init the local models.
    num_models, local_models = agg_utils.recover_models(
        conf, client_models, flatten_local_models)

    # init the agg_weights
    fedavg_model = fedavg_model.cuda() if conf.graph.on_cuda else fedavg_model
    agg_weights, optimizer, is_layerwise = _get_init_agg_weights(
        conf, fedavg_model, num_models)

    # learning the aggregation weights.
    for _ in range(int(conf.fl_aggregate["epochs"])):
        for _ind, (_input, _target) in enumerate(data_loader):
            # place model and data.
            if conf.graph.on_cuda:
                _input, _target = _input.cuda(), _target.cuda()

            # get mixed model.
            mixed_model = get_mixed_model(
                conf=conf,
                model=fedavg_model,
                local_models=local_models,
                agg_weights=agg_weights,
                is_layerwise=is_layerwise,
            )

            # inference and update alpha
            mixed_model.train()
            optimizer.zero_grad()
            loss = criterion(mixed_model(_input), _target)
            loss.backward()
            optimizer.step()

    # extract the final agg_weights.
    weighted_avg_model = get_mixed_model(
        conf=conf,
        model=fedavg_model,
        local_models=local_models,
        agg_weights=agg_weights,
        is_layerwise=is_layerwise,
        display_agg_weights=True,
    )
    del local_models
    return weighted_avg_model.cpu()
コード例 #5
0
def aggregate(conf, client_models, flatten_local_models):
    # init the local models.
    wasserstein_conf = {
        "exact": True,
        "correction": True,
        "proper_marginals": True,
        "skip_last_layer": False,
        "ensemble_step": 0.5,
        "reg": 1e-2,
        "past_correction": True,
        "unbalanced": False,
        "ground_metric": "euclidean",
        "ground_metric_eff": True,
        "ground_metric_normalize": "none",
        "clip_gm": False,
        "clip_min": 0.0,
        "clip_max": 5,
        "activation_histograms": False,
        "dist_normalize": True,
        "act_num_samples": 100,
        "softmax_temperature": 1,
        "geom_ensemble_type": "wts",
        "normalize_wts": True,
        "importance": None,
    }

    num_models, local_models = agg_utils.recover_models(
        conf, client_models, flatten_local_models)

    local_models = list(local_models.values())
    _model = local_models[0]

    for idx in range(1, num_models):
        avg_aligned_layers = get_wassersteinized_layers_modularized(
            conf, wasserstein_conf, [_model, local_models[idx]])
        _model = get_network_from_param_list(avg_aligned_layers, _model)
    return _model
def aggregate(
    conf,
    fedavg_models,
    client_models,
    criterion,
    metrics,
    flatten_local_models,
    fa_val_perf,
    distillation_sampler,
    distillation_data_loader,
    val_data_loader,
    test_data_loader,
):
    fl_aggregate = conf.fl_aggregate

    # recover the models on the computation device.
    _, local_models = agg_utils.recover_models(conf,
                                               client_models,
                                               flatten_local_models,
                                               use_cuda=conf.graph.on_cuda)

    # evaluate the local model on the test_loader
    if "eval_local" in fl_aggregate and fl_aggregate["eval_local"]:
        perfs = []
        for idx, local_model in enumerate(local_models.values()):
            conf.logger.log(f"Evaluate the local model-{idx}.")
            perf = master_utils.validate(
                conf,
                coordinator=None,
                model=local_model,
                criterion=criterion,
                metrics=metrics,
                data_loader=test_data_loader,
                label=None,
                display=False,
            )
            perfs.append(perf["top1"])
        conf.logger.log(
            f"The averaged test performance of the local models: {sum(perfs) / len(perfs)}; the details of the local performance: {perfs}."
        )

    # evaluate the ensemble of local models on the test_loader
    if "eval_ensemble" in fl_aggregate and fl_aggregate["eval_ensemble"]:
        master_utils.ensembled_validate(
            conf,
            coordinator=None,
            models=list(local_models.values()),
            criterion=criterion,
            metrics=metrics,
            data_loader=test_data_loader,
            label="ensemble_test_loader",
            ensemble_scheme=None if "update_student_scheme" not in fl_aggregate
            else fl_aggregate["update_student_scheme"],
        )

    # distillation.
    _client_models = {}
    for arch, fedavg_model in fedavg_models.items():
        conf.logger.log(
            f"Master: we have {len(local_models)} local models for noise distillation (use {arch} for the distillation)."
        )

        # sample models.
        assert len(
            fedavg_models) == 1, "right now, we only support h**o-arch case."
        # TODO.
        sampled_models = sample_from_swag(conf,
                                          fedavg_model,
                                          local_models,
                                          loader=val_data_loader)

        # initialize knowledge distillation solver.
        kt = SWAKTSolver(
            conf=conf,
            teacher_models=list(sampled_models.values()),
            student_model=fedavg_model,
            criterion=criterion,
            metrics=metrics,
            batch_size=128 if "batch_size" not in fl_aggregate else int(
                fl_aggregate["batch_size"]),
            total_n_server_pseudo_batches=250
            if "total_n_server_pseudo_batches" not in fl_aggregate else int(
                fl_aggregate["total_n_server_pseudo_batches"]),
            val_data_loader=val_data_loader,
            distillation_sampler=distillation_sampler,
            distillation_data_loader=get_unlabeled_data(
                fl_aggregate, distillation_data_loader),
            student_learning_rate=1e-3 if "student_learning_rate"
            not in fl_aggregate else fl_aggregate["student_learning_rate"],
            AT_beta=0
            if "AT_beta" not in fl_aggregate else fl_aggregate["AT_beta"],
            KL_temperature=1 if "temperature" not in fl_aggregate else
            fl_aggregate["temperature"],
            log_fn=conf.logger.log,
            eval_batches_freq=100 if "eval_batches_freq" not in fl_aggregate
            else int(fl_aggregate["eval_batches_freq"]),
            update_student_scheme="avg_logits",
            server_teaching_scheme=None if "server_teaching_scheme"
            not in fl_aggregate else fl_aggregate["server_teaching_scheme"],
            optimizer="sgd"
            if "optimizer" not in fl_aggregate else fl_aggregate["optimizer"],
        )
        getattr(
            kt,
            "distillation" if "noise_kt_scheme" not in fl_aggregate else
            fl_aggregate["noise_kt_scheme"],
        )()
        _client_models[arch] = kt.server_student.cpu()

    # free the memory.
    del local_models, sampled_models, kt
    torch.cuda.empty_cache()
    return _client_models
コード例 #7
0
def aggregate(
    conf,
    fedavg_models,
    client_models,
    criterion,
    metrics,
    flatten_local_models,
    fa_val_perf,
    distillation_sampler,
    distillation_data_loader,
    val_data_loader,
    test_data_loader,
):
    fl_aggregate = conf.fl_aggregate

    # recover the models on the computation device.
    _, local_models = agg_utils.recover_models(conf, client_models,
                                               flatten_local_models)

    # include model from previous comm. round.
    if ("include_previous_models" in fl_aggregate
            and fl_aggregate["include_previous_models"] > 0):
        local_models = agg_utils.include_previous_models(conf, local_models)

    # evaluate the local model on the test_loader
    if "eval_local" in fl_aggregate and fl_aggregate["eval_local"]:
        perfs = []
        for idx, local_model in enumerate(local_models.values()):
            conf.logger.log(f"Evaluate the local model-{idx}.")
            perf = master_utils.validate(
                conf,
                coordinator=None,
                model=local_model,
                criterion=criterion,
                metrics=metrics,
                data_loader=test_data_loader,
                label=None,
                display=False,
            )
            perfs.append(perf["top1"])
        conf.logger.log(
            f"The averaged test performance of the local models: {sum(perfs) / len(perfs)}; the details of the local performance: {perfs}."
        )

    # evaluate the ensemble of local models on the test_loader
    if "eval_ensemble" in fl_aggregate and fl_aggregate["eval_ensemble"]:
        master_utils.ensembled_validate(
            conf,
            coordinator=None,
            models=list(local_models.values()),
            criterion=criterion,
            metrics=metrics,
            data_loader=test_data_loader,
            label="ensemble_test_loader",
            ensemble_scheme=None if "update_student_scheme" not in fl_aggregate
            else fl_aggregate["update_student_scheme"],
        )

    # distillation.
    _client_models = {}
    for arch, fedavg_model in fedavg_models.items():
        conf.logger.log(
            f"Master: we have {len(local_models)} local models for noise distillation (use {arch} for the distillation)."
        )
        kt = NoiseKTSolver(
            conf=conf,
            teacher_models=list(local_models.values()),
            student_model=fedavg_model
            if "use_fedavg_as_start" not in fl_aggregate else
            (fedavg_model if fl_aggregate["use_fedavg_as_start"] else
             copy.deepcopy(client_models[arch])),
            criterion=criterion,
            metrics=metrics,
            batch_size=128 if "batch_size" not in fl_aggregate else int(
                fl_aggregate["batch_size"]),
            total_n_server_pseudo_batches=1000 *
            10 if "total_n_server_pseudo_batches" not in fl_aggregate else int(
                fl_aggregate["total_n_server_pseudo_batches"]),
            server_local_steps=1 if "server_local_steps" not in fl_aggregate
            else int(fl_aggregate["server_local_steps"]),
            val_data_loader=val_data_loader,
            distillation_sampler=distillation_sampler,
            distillation_data_loader=get_unlabeled_data(
                fl_aggregate, distillation_data_loader),
            use_server_model_scheduler=True
            if "use_server_model_scheduler" not in fl_aggregate else
            fl_aggregate["use_server_model_scheduler"],
            same_noise=True if "same_noise" not in fl_aggregate else
            fl_aggregate["same_noise"],
            student_learning_rate=1e-3 if "student_learning_rate"
            not in fl_aggregate else fl_aggregate["student_learning_rate"],
            AT_beta=0
            if "AT_beta" not in fl_aggregate else fl_aggregate["AT_beta"],
            KL_temperature=1 if "temperature" not in fl_aggregate else
            fl_aggregate["temperature"],
            log_fn=conf.logger.log,
            eval_batches_freq=100 if "eval_batches_freq" not in fl_aggregate
            else int(fl_aggregate["eval_batches_freq"]),
            early_stopping_server_batches=2000
            if "early_stopping_server_batches" not in fl_aggregate else int(
                fl_aggregate["early_stopping_server_batches"]),
            update_student_scheme="avg_losses" if "update_student_scheme"
            not in fl_aggregate else fl_aggregate["update_student_scheme"],
            server_teaching_scheme=None if "server_teaching_scheme"
            not in fl_aggregate else fl_aggregate["server_teaching_scheme"],
            return_best_model_on_val=False if "return_best_model_on_val"
            not in fl_aggregate else fl_aggregate["return_best_model_on_val"],
        )
        getattr(
            kt,
            "distillation" if "noise_kt_scheme" not in fl_aggregate else
            fl_aggregate["noise_kt_scheme"],
        )()
        _client_models[arch] = kt.server_student.cpu()

    # update local models from the current comm. round.
    if ("include_previous_models" in fl_aggregate
            and fl_aggregate["include_previous_models"] > 0):
        agg_utils.update_previous_models(conf, _client_models)

    # free the memory.
    del local_models, kt
    torch.cuda.empty_cache()
    return _client_models