Ejemplo n.º 1
0
def run_experiment(xp, xp_count, n_experiments):

    print(xp)
    hp = xp.hyperparameters

    model_fn, optimizer, optimizer_hp = models.get_model(hp["net"])
    optimizer_fn = lambda x: optimizer(
        x, **{k: hp[k] if k in hp else v
              for k, v in optimizer_hp.items()})
    train_data, test_data = data.get_data(hp["dataset"], args.DATA_PATH)
    distill_data = data.get_data(hp["distill_dataset"], args.DATA_PATH)
    distill_data = torch.utils.data.Subset(
        distill_data,
        np.random.permutation(len(distill_data))[:hp["n_distill"]])

    client_loaders, test_loader = data.get_loaders(
        train_data,
        test_data,
        n_clients=hp["n_clients"],
        classes_per_client=hp["classes_per_client"],
        batch_size=hp["batch_size"],
        n_data=None)
    distill_loader = torch.utils.data.DataLoader(distill_data,
                                                 batch_size=128,
                                                 shuffle=False)

    clients = [
        Client(model_fn, optimizer_fn, loader) for loader in client_loaders
    ]
    server = Server(model_fn, lambda x: torch.optim.Adam(x, lr=0.001),
                    test_loader, distill_loader)
    server.load_model(path=args.CHECKPOINT_PATH, name=hp["pretrained"])

    # print model
    models.print_model(server.model)

    # Start Distributed Training Process
    print("Start Distributed Training..\n")
    t1 = time.time()
    for c_round in range(1, hp["communication_rounds"] + 1):

        participating_clients = server.select_clients(clients,
                                                      hp["participation_rate"])

        for client in tqdm(participating_clients):
            client.synchronize_with_server(server)
            train_stats = client.compute_weight_update(hp["local_epochs"])

        if hp["aggregate"]:
            server.aggregate_weight_updates(participating_clients)

        if hp["use_distillation"]:
            server.distill(participating_clients,
                           hp["distill_epochs"],
                           compress=hp["compress"])

        # Logging
        if xp.is_log_round(c_round):
            print("Experiment: {} ({}/{})".format(args.schedule, xp_count + 1,
                                                  n_experiments))

            xp.log({
                'communication_round': c_round,
                'epochs': c_round * hp['local_epochs']
            })
            xp.log({
                key: clients[0].optimizer.__dict__['param_groups'][0][key]
                for key in optimizer_hp
            })

            # Evaluate
            xp.log({
                "client_train_{}".format(key): value
                for key, value in train_stats.items()
            })
            xp.log({
                "server_val_{}".format(key): value
                for key, value in server.evaluate().items()
            })

            # Save results to Disk
            try:
                xp.save_to_disc(path=args.RESULTS_PATH, name=hp['log_path'])
            except:
                print("Saving results Failed!")

            # Timing
            e = int((time.time() - t1) / c_round *
                    (hp['communication_rounds'] - c_round))
            print(
                "Remaining Time (approx.):",
                '{:02d}:{:02d}:{:02d}'.format(e // 3600, (e % 3600 // 60),
                                              e % 60),
                "[{:.2f}%]\n".format(c_round / hp['communication_rounds'] *
                                     100))

    # Save model to disk
    server.save_model(path=args.CHECKPOINT_PATH, name=hp["save_model"])

    # Delete objects to free up GPU memory
    del server
    clients.clear()
    torch.cuda.empty_cache()
Ejemplo n.º 2
0
def run_experiment(xp, xp_count, n_experiments):

    print(xp)
    hp = xp.hyperparameters

    model_fn, optimizer_fn = models.get_model(hp["net"])
    client_data, server_data = data.get_data(hp["dataset"],
                                             n_clients=hp["n_clients"],
                                             alpha=hp["dirichlet_alpha"],
                                             path=DATA_PATH)

    for i, d in enumerate(client_data):
        d.subset_transform = data.get_x_transform(hp["x_transform"], i)
        d.label_transform = data.get_y_transform(hp["y_transform"], i)

    clients = [
        Client(model_fn,
               optimizer_fn,
               subset,
               hp["batch_size"],
               layers=hp["layers"],
               idnum=i) for i, subset in enumerate(client_data)
    ]
    server = Server(model_fn, server_data, layers=hp["layers"])
    server.load_model(path=CHECKPOINT_PATH, name=hp["pretrained"])

    # print model
    models.print_model(server.model)
    xp.log({"shared_layers": list(server.W.keys())})

    # Start Distributed Training Process
    print("Start Distributed Training..\n")
    t1 = time.time()
    for c_round in range(1, hp["communication_rounds"] + 1):

        participating_clients = server.select_clients(clients,
                                                      hp["participation_rate"])

        accs = []
        for i, client in enumerate(clients):
            client.synchronize_with_server(server)
            accs += [client.evaluate()["accuracy"]]
        xp.log({"client_accuracies": accs}, printout=False)
        xp.log({"mean_accuracy": np.mean(accs)})

        for client in tqdm(participating_clients):
            train_stats = client.compute_weight_update(hp["local_epochs"])

        for i, client in enumerate(clients):
            accs += [client.evaluate()["accuracy"]]
        xp.log({"post_client_accuracies": accs}, printout=False)
        xp.log({"post_mean_accuracy": np.mean(accs)})

        server.aggregate_weight_updates(clients)
        for client in participating_clients:
            xp.log(client.compute_server_angle(server), printout=False)

        # Logging
        if xp.is_log_round(c_round):
            print("Experiment: {} ({}/{})".format(args.schedule, xp_count + 1,
                                                  n_experiments))

            xp.log({
                'communication_round': c_round,
                'epochs': c_round * hp['local_epochs']
            })

            # Evaluate
            #xp.log({"client_train_{}".format(key) : value for key, value in train_stats.items()})
            #for i, client in enumerate(clients):
            #    xp.log({"client_{}_val_{}".format(i, key) : value for key, value in client.evaluate().items()})

            # Save results to Disk
            try:
                xp.save_to_disc(path=RESULTS_PATH, name=hp['log_path'])
            except:
                print("Saving results Failed!")

            # Timing
            e = int((time.time() - t1) / c_round *
                    (hp['communication_rounds'] - c_round))
            print(
                "Remaining Time (approx.):",
                '{:02d}:{:02d}:{:02d}'.format(e // 3600, (e % 3600 // 60),
                                              e % 60),
                "[{:.2f}%]\n".format(c_round / hp['communication_rounds'] *
                                     100))

    xp.log(server.compute_pairwise_angles_layerwise(clients), printout=False)
    xp.save_to_disc(path=RESULTS_PATH, name=hp['log_path'])

    # Save model to disk
    server.save_model(path=CHECKPOINT_PATH, name=hp["save_model"])

    # Delete objects to free up GPU memory
    del server
    clients.clear()
    torch.cuda.empty_cache()
Ejemplo n.º 3
0
def run_experiment(xp, xp_count, n_experiments):

    print(xp)
    hp = xp.hyperparameters

    model_fn, optimizer, optimizer_hp = models.get_model(hp["net"])
    optimizer_fn = lambda x: optimizer(
        x, **{k: hp[k] if k in hp else v
              for k, v in optimizer_hp.items()})
    train_data, test_data = data.get_data(hp["dataset"], args.DATA_PATH)
    all_distill_data = data.get_data(hp["distill_dataset"], args.DATA_PATH)
    all_distill_data_indexed = data.IdxSubset(all_distill_data,
                                              np.arange(100000))
    print(len(all_distill_data_indexed))

    np.random.seed(hp["random_seed"])
    # What fraction of the unlabeled data should be used for training the anomaly detector
    distill_data = data.IdxSubset(
        all_distill_data,
        np.random.permutation(len(
            all_distill_data))[:hp["n_distill"]])  # data used for distillation

    client_data, label_counts = data.get_client_data(
        train_data,
        n_clients=hp["n_clients"],
        classes_per_client=hp["classes_per_client"])

    client_loaders = [
        data.DataMerger({'base': local_data},
                        mixture_coefficients={'base': 1},
                        **hp) for local_data in client_data
    ]

    test_loader = data.DataMerger(
        {'base': data.IdxSubset(test_data, list(range(len(test_data))))},
        mixture_coefficients={'base': 1},
        batch_size=256)
    distill_loader = DataLoader(distill_data,
                                batch_size=hp["batch_size"],
                                shuffle=True,
                                num_workers=8)
    distill_dummy_loader = DataLoader(distill_data,
                                      batch_size=2048,
                                      shuffle=False,
                                      num_workers=8)
    all_distill_loader = DataLoader(all_distill_data_indexed,
                                    batch_size=hp["batch_size"],
                                    shuffle=False)

    clients = [
        Client(model_fn,
               optimizer_fn,
               loader,
               idnum=i,
               counts=counts,
               distill_loader=distill_loader)
        for i, (loader, counts) in enumerate(zip(client_loaders, label_counts))
    ]
    server = Server(model_fn, lambda x: torch.optim.Adam(x, lr=1e-3),
                    test_loader, distill_loader)

    models.print_model(server.model)

    # Start Distributed Training Process
    print("Start Distributed Training..\n")
    t1 = time.time()

    xp.log({
        "server_val_{}".format(key): value
        for key, value in server.evaluate().items()
    })
    for c_round in range(1, hp["communication_rounds"] + 1):

        participating_clients = server.select_clients(clients,
                                                      hp["participation_rate"])
        xp.log({
            "participating_clients":
            np.array([client.id for client in participating_clients])
        })

        for client in participating_clients:
            client.synchronize_with_server(server, c_round)

            train_stats = client.compute_weight_update(
                hp["local_epochs"],
                train_oulier_model=hp["aggregation_mode"]
                in ["FAD+S", "FAD+P+S"],
                c_round=c_round,
                max_c_round=hp["communication_rounds"],
                **hp)
            print(train_stats)

            if hp["save_softlabels"] and hp["aggregation_mode"] in [
                    "FDcup", "FDsample", "FDcupdown", "FDer", "FDquant",
                    "FDquantdown"
            ]:
                predictions = client.compute_prediction_matrix(
                    distill_dummy_loader, argmax=True)
                xp.log(
                    {"client_{}_predictions".format(client.id): predictions})

        if hp["aggregation_mode"] in ["FA"]:
            server.aggregate_weight_updates(participating_clients)

        if hp["aggregation_mode"] in [
                "FD", "FDcup", "FDsample", "FDcupdown", "FDer", "FDquant",
                "FDquantdown"
        ]:
            distill_mode = {
                "FD": "mean_probs",
                "FDcup": "pate_up",
                "FDsample": "sample",
                "FDcupdown": "pate",
                "FDer": "mean_logits_er",
                "FDquant": ["quantized", hp["quantization_bits"]],
                "FDquantdown": ["quantized", hp["quantization_bits"]]
            }[hp["aggregation_mode"]]

            reset_model = True if hp["init_mode"] == "random" else False

            distill_stats = server.distill(participating_clients,
                                           hp["distill_iter"],
                                           mode=distill_mode,
                                           reset_model=reset_model)

            if hp["active"]:

                if hp["active"] == "random":
                    idcs = np.random.permutation(
                        len(all_distill_data))[:hp["n_distill"]]
                if hp["active"] == "entropy":
                    mat = server.compute_prediction_matrix(all_distill_loader,
                                                           argmax=False)
                    idcs = np.argsort(
                        -np.sum(-mat * np.log(mat), axis=1))[:hp["n_distill"]]

                if hp["active"] == "certainty":
                    mat = server.compute_prediction_matrix(all_distill_loader,
                                                           argmax=False)
                    idcs = np.argsort(np.max(mat, axis=1))[:hp["n_distill"]]

                if hp["active"] == "margin":
                    mat = server.compute_prediction_matrix(all_distill_loader,
                                                           argmax=False)
                    idcs = np.argsort(
                        np.diff(np.sort(mat, axis=1)[:, -2:],
                                axis=1).flatten())[:hp["n_distill"]]

                distill_data = data.IdxSubset(all_distill_data, idcs)
                distill_loader = DataLoader(distill_data,
                                            batch_size=128,
                                            shuffle=True)
                server.distill_loader = distill_loader

            if hp["init_mode"] == "co_distill":

                if hp["save_softlabels"] and hp["aggregation_mode"] in [
                        "FDcup", "FDsample", "FDcupdown", "FDer", "FDquant",
                        "FDquantdown"
                ]:
                    predictions = server.compute_prediction_matrix(
                        distill_dummy_loader, argmax=True)
                    xp.log({"server_predictions": predictions})

                server.co_distill(
                    hp["co_distill_iter"],
                    quantization_bits=hp["quantization_bits_down"]
                    if hp["aggregation_mode"] == "FDquantdown" else None)

        # Logging
        if xp.is_log_round(c_round):
            print("Experiment: {} ({}/{})".format(args.schedule, xp_count + 1,
                                                  n_experiments))

            xp.log({
                'communication_round': c_round,
                'epochs': c_round * hp['local_epochs']
            })
            xp.log({
                key: clients[0].optimizer.__dict__['param_groups'][0][key]
                for key in optimizer_hp
            })

            # Evaluate
            xp.log({
                "server_val_{}".format(key): value
                for key, value in server.evaluate().items()
            })

            # Save results to Disk
            try:
                xp.save_to_disc(path=args.RESULTS_PATH, name=hp['log_path'])
            except:
                print("Saving results Failed!")

            # Timing
            e = int((time.time() - t1) / c_round *
                    (hp['communication_rounds'] - c_round))
            print(
                "Remaining Time (approx.):",
                '{:02d}:{:02d}:{:02d}'.format(e // 3600, (e % 3600 // 60),
                                              e % 60),
                "[{:.2f}%]\n".format(c_round / hp['communication_rounds'] *
                                     100))

    # Save model to disk
    server.save_model(path=args.CHECKPOINT_PATH, name=hp["save_model"])

    # Delete objects to free up GPU memory
    del server
    clients.clear()
    torch.cuda.empty_cache()
Ejemplo n.º 4
0
def run_experiment(xp, xp_count, n_experiments):

  t0 = time.time()
  print(xp)
  hp = xp.hyperparameters

  num_classes = {"cifar10" : 10, "cifar100" : 100}[hp["dataset"]]

  model_names = [model_name for model_name, k in hp["models"].items() for _ in range(k)]
  
  optimizer, optimizer_hp = getattr(torch.optim, hp["local_optimizer"][0]), hp["local_optimizer"][1]
  optimizer_fn = lambda x : optimizer(x, **{k : hp[k] if k in hp else v for k, v in optimizer_hp.items()})

  distill_optimizer, distill_optimizer_hp = getattr(torch.optim, hp["distill_optimizer"][0]), hp["distill_optimizer"][1]
  distill_optimizer_fn = lambda x : distill_optimizer(x, **{k : hp[k] if k in hp else v for k, v in distill_optimizer_hp.items()})


  train_data, test_data = data.get_data(hp["dataset"], args.DATA_PATH)
  all_distill_data = data.get_data(hp["distill_dataset"], args.DATA_PATH)

  np.random.seed(hp["random_seed"])
  torch.manual_seed(hp["random_seed"])

  n_distill = int(hp["n_distill_frac"] * len(all_distill_data))

  distill_data = data.IdxSubset(all_distill_data, np.random.permutation(len(all_distill_data))[:n_distill], return_index=True)
  public_data = data.IdxSubset(all_distill_data, np.random.permutation(len(all_distill_data))[n_distill:len(all_distill_data)], return_index=False)

  print(len(distill_data), len(public_data))

  client_loaders, test_loader = data.get_loaders(train_data, test_data, n_clients=len(model_names), 
        alpha=hp["alpha"], batch_size=hp["batch_size"], n_data=None, num_workers=0, seed=hp["random_seed"])
  distill_loader = torch.utils.data.DataLoader(distill_data, batch_size=hp["distill_batch_size"], shuffle=True, num_workers=8)
  public_loader = torch.utils.data.DataLoader(public_data, batch_size=128, shuffle=True, num_workers=8)

  clients = [Client(model_name, optimizer_fn, loader, idnum=i, num_classes=num_classes) for i, (loader, model_name) in enumerate(zip(client_loaders, model_names))]
  server = Server(np.unique(model_names), distill_optimizer_fn, test_loader, distill_loader, num_classes=num_classes)

  for client in clients:
    client.public_loader = public_loader
    client.distill_loader = distill_loader
    
  # print model
  models.print_model(clients[0].model)


  if "P" in hp["aggregation_mode"] or hp["aggregation_mode"] == "FedAUX":
    for model_name, model in server.model_dict.items():
      pretrained = hp["pretrained"] if hp["pretrained"] else "{}_{}.pth".format(model_name, hp["distill_dataset"])

      loaded_state = torch.load(args.CHECKPOINT_PATH + pretrained, map_location='cpu')
      loaded_layers = [key for key in loaded_state if key in model.state_dict()]
      model.load_state_dict(loaded_state, strict=False)      

    for client in clients:
      client.synchronize_with_server(server)

    print("Successfully loaded layers {} from".format(loaded_layers), pretrained)


    if hp["aggregation_mode"] == "FedAUX":
      print("Computing Scores...")

      for client in clients:
        client.scores = client.extract_features_and_compute_scores(client.loader, public_loader, distill_loader, lambda_reg=hp["lambda_reg_score"], eps_delt=hp["eps_delt"]) 
        
        if hp["save_scores"]:
          xp.log({"client_{}_scores".format(client.id) : client.scores.detach().cpu().numpy()})
      

  if "L" in hp["aggregation_mode"]:
    for client in clients:
      for name, param in client.model.named_parameters():
        if "classification_layer" not in name:
          param.requires_grad = False

    for model in server.model_dict.values():
      for name, param in model.named_parameters():
        if "classification_layer" not in name:
          param.requires_grad = False



  # Start Distributed Training Process
  print("Start Distributed Training..\n")
  t1 = time.time()
  xp.log({"prep_time" : t1-t0})

  xp.log({"server_val_{}".format(key) : value for key, value in server.evaluate_ensemble().items()})
  for c_round in range(1, hp["communication_rounds"]+1):

    participating_clients = server.select_clients(clients, hp["participation_rate"])
    xp.log({"participating_clients" : np.array([c.id for c in participating_clients])})

    for client in participating_clients:
      client.synchronize_with_server(server)

      train_stats = client.compute_weight_update(hp["local_epochs"], lambda_fedprox=hp["lambda_fedprox"] if "PROX" in hp["aggregation_mode"] else 0.0) 

      print(train_stats)


    # Averaging
    server.aggregate_weight_updates(participating_clients)
    avg_stats = server.evaluate_ensemble()

    xp.log({"averaging_{}".format(key) : value for key, value in avg_stats.items()})


    if hp["aggregation_mode"] in ["FedDF", "FedAUX", "FedDF+P"]:
      distill_mode = "weighted_logits_precomputed" if hp["aggregation_mode"]=="FedAUX" else "mean_logits"

      distill_stats = server.distill(participating_clients, hp["distill_epochs"], mode=distill_mode, num_classes=num_classes)
      xp.log({"distill_{}".format(key) : value for key, value in distill_stats.items()})


    # Logging
    if xp.is_log_round(c_round):
      print("Experiment: {} ({}/{})".format(args.schedule, xp_count+1, n_experiments))   
      
      xp.log({'communication_round' : c_round, 'epochs' : c_round*hp['local_epochs']})
      xp.log({key : clients[0].optimizer.__dict__['param_groups'][0][key] for key in optimizer_hp})
      
      # Evaluate  
      xp.log({"server_val_{}".format(key) : value for key, value in server.evaluate_ensemble().items()})

      xp.log({"epoch_time" : (time.time()-t1)/c_round})
      # Save results to Disk
      try:
        xp.save_to_disc(path=args.RESULTS_PATH, name=hp['log_path'])
      except:
        print("Saving results Failed!")

      # Timing
      e = int((time.time()-t1)/c_round*(hp['communication_rounds']-c_round))
      print("Remaining Time (approx.):", '{:02d}:{:02d}:{:02d}'.format(e // 3600, (e % 3600 // 60), e % 60), 
                "[{:.2f}%]\n".format(c_round/hp['communication_rounds']*100))

  # Save model to disk
  server.save_model(path=args.CHECKPOINT_PATH, name=hp["save_model"])
    
  # Delete objects to free up GPU memory
  del server; clients.clear()
  torch.cuda.empty_cache()
Ejemplo n.º 5
0
def run_experiment(xp, xp_count, n_experiments):

    print(xp)
    hp = xp.hyperparameters

    model_fn, optimizer_fn = models.get_model(hp["net"])
    compression_fn = comp.get_compression(hp["compression"])
    client_data, server_data = data.get_data(hp["dataset"],
                                             n_clients=hp["n_clients"],
                                             alpha=hp["dirichlet_alpha"])

    clients = [
        Client(model_fn, optimizer_fn, subset, hp["batch_size"], idnum=i)
        for i, subset in enumerate(client_data)
    ]
    server = Server(model_fn, server_data)
    server.load_model(path="checkpoints/", name=hp["pretrained"])

    # print model
    models.print_model(server.model)

    # Start Distributed Training Process
    print("Start Distributed Training..\n")
    t1 = time.time()
    for c_round in range(1, hp["communication_rounds"] + 1):

        participating_clients = server.select_clients(clients,
                                                      hp["participation_rate"])

        for client in tqdm(participating_clients):
            client.synchronize_with_server(server)
            train_stats = client.compute_weight_update(hp["local_epochs"])
            client.compress_weight_update(compression_fn,
                                          accumulate=hp["accumulate"])

        server.aggregate_weight_updates(clients)

        # Logging
        if xp.is_log_round(c_round):
            print("Experiment: {} ({}/{})".format(args.schedule, xp_count + 1,
                                                  n_experiments))

            xp.log({
                'communication_round': c_round,
                'epochs': c_round * hp['local_epochs']
            })

            # Evaluate
            xp.log({
                "client_train_{}".format(key): value
                for key, value in train_stats.items()
            })
            xp.log({
                "server_val_{}".format(key): value
                for key, value in server.evaluate().items()
            })

            # Save results to Disk
            try:
                xp.save_to_disc(path="results/", name=hp['log_path'])
            except:
                print("Saving results Failed!")

            # Timing
            e = int((time.time() - t1) / c_round *
                    (hp['communication_rounds'] - c_round))
            print(
                "Remaining Time (approx.):",
                '{:02d}:{:02d}:{:02d}'.format(e // 3600, (e % 3600 // 60),
                                              e % 60),
                "[{:.2f}%]\n".format(c_round / hp['communication_rounds'] *
                                     100))

    # Save model to disk
    server.save_model(path="checkpoints/", name=hp["save_model"])

    # Delete objects to free up GPU memory
    del server
    clients.clear()
    torch.cuda.empty_cache()
Ejemplo n.º 6
0
def run_experiment(xp, xp_count, n_experiments, args, seed=0):

    logger.debug(xp)
    hp = xp.hyperparameters

    model_fn, optimizer, optimizer_hp = models.get_model(hp["net"])
    if "local_optimizer" in hp:
        optimizer = getattr(torch.optim, hp["local_optimizer"][0])
        optimizer_hp = hp["local_optimizer"][1]

    optimizer_fn = lambda x: optimizer(
        x, **{k: hp[k] if k in hp else v
              for k, v in optimizer_hp.items()})
    train_data, test_data = data.get_data(hp["dataset"], args.DATA_PATH)
    all_distill_data = data.get_data(hp["distill_dataset"], args.DATA_PATH)

    np.random.seed(seed)
    # What fraction of the unlabeled data should be used for training the anomaly detector
    distill_data = data.IdxSubset(
        all_distill_data,
        np.random.permutation(len(
            all_distill_data))[:hp["n_distill"]])  # data used for distillation
    distill_loader = torch.utils.data.DataLoader(distill_data,
                                                 batch_size=128,
                                                 shuffle=True)

    client_data, label_counts = data.get_client_data(
        train_data,
        n_clients=hp["n_clients"],
        classes_per_client=hp["classes_per_client"])

    if hp["aggregation_mode"] in ["FD+S", "FAD+S", "FAD+P+S"]:
        public_data = data.IdxSubset(
            all_distill_data,
            np.random.permutation(
                len(all_distill_data))[hp["n_distill"]:len(all_distill_data)]
        )  # data used to train the outlier detector
        public_loader = torch.utils.data.DataLoader(public_data,
                                                    batch_size=128,
                                                    shuffle=True)

        logger.debug(
            "Using {} public data points for distillation and {} public data points for local training.\n"
            .format(len(distill_data), len(public_data)))

        client_loaders = [
            data.DataMerger({
                'base': local_data,
                'public': public_data
            }, **hp) for local_data in client_data
        ]

    else:
        client_loaders = [
            data.DataMerger({'base': local_data}, **hp)
            for local_data in client_data
        ]

    test_loader = data.DataMerger(
        {'base': data.IdxSubset(test_data, list(range(len(test_data))))},
        mixture_coefficients={'base': 1},
        batch_size=100)
    distill_loader = DataLoader(distill_data, batch_size=128, shuffle=True)

    clients = [
        Client(model_fn,
               optimizer_fn,
               loader,
               idnum=i,
               counts=counts,
               distill_loader=distill_loader)
        for i, (loader, counts) in enumerate(zip(client_loaders, label_counts))
    ]
    server = Server(
        model_fn, lambda x: getattr(torch.optim, hp["distill_optimizer"][0])
        (x, **hp["distill_optimizer"][1]), test_loader, distill_loader)

    # Modes that use pretrained representation
    if hp["aggregation_mode"] in [
            "FAD+P", "FAD+P+S"
    ] and os.path.isfile(args.CHECKPOINT_PATH + hp["pretrained"]):
        for device in clients + [server]:
            device.model.load_state_dict(torch.load(args.CHECKPOINT_PATH +
                                                    hp["pretrained"],
                                                    map_location='cpu'),
                                         strict=False)
        logger.debug(f"Successfully loaded model from {hp['pretrained']}")
    """
  # Train shallow Outlier detectors
  if hp["aggregation_mode"] in ["FAD+S", "FAD+P+S"]:
    feature_extractor = model_fn().cuda()
    feature_extractor.load_state_dict(torch.load(args.CHECKPOINT_PATH+hp["pretrained"], map_location='cpu'), strict=False)
    feature_extractor.eval()
    for client in clients:
      client.feature_extractor = feature_extractor

    print("Train Outlier Detectors")
    for client in tqdm(clients):
      client.train_outlier_detector(hp["outlier_model"][0], distill_loader, **hp["outlier_model"][1])
  """

    averaging_stats = {"accuracy": 0.0}
    eval_results = None
    models.print_model(server.model)

    # Start Distributed Training Process
    logger.info("Start Distributed Training..\n")
    t1 = time.time()

    xp.log({
        "server_val_{}".format(key): value
        for key, value in server.evaluate().items()
    })
    for c_round in range(1, hp["communication_rounds"] + 1):

        participating_clients = server.select_clients(clients,
                                                      hp["participation_rate"])
        xp.log({
            "participating_clients":
            np.array([client.id for client in participating_clients])
        })

        for client in tqdm(participating_clients):
            client.synchronize_with_server(server, c_round)

            train_stats = client.compute_weight_update(
                hp["local_epochs"],
                train_oulier_model=hp["aggregation_mode"]
                in ["FAD+S", "FAD+P+S"],
                c_round=c_round,
                max_c_round=hp["communication_rounds"],
                **hp)
            logger.debug(train_stats)

        if hp["aggregation_mode"] in [
                "FA", "FAD", "FAD+P", "FAD+S", "FAD+P+S"
        ]:
            server.aggregate_weight_updates(participating_clients)

            averaging_stats = server.evaluate()
            xp.log({
                "parameter_averaging_{}".format(key): value
                for key, value in averaging_stats.items()
            })

        if hp["aggregation_mode"] in [
                "FD", "FAD", "FAD+P", "FAD+S", "FAD+P+S"
        ]:

            distill_mode = hp["distill_mode"] if hp["aggregation_mode"] in [
                "FD+S", "FAD+S", "FAD+P+S"
            ] else "mean_logits"
            distill_stats = server.distill(participating_clients,
                                           hp["distill_epochs"],
                                           mode=distill_mode,
                                           acc0=averaging_stats["accuracy"],
                                           fallback=hp["fallback"])
            xp.log({
                "distill_{}".format(key): value
                for key, value in distill_stats.items()
            })

        # Logging
        if xp.is_log_round(c_round):
            logger.debug("Experiment: {} ({}/{})".format(
                args.schedule, xp_count + 1, n_experiments))

            xp.log({
                'communication_round': c_round,
                'epochs': c_round * hp['local_epochs']
            })
            xp.log({
                key: clients[0].optimizer.__dict__['param_groups'][0][key]
                for key in optimizer_hp
            })

            # Evaluate
            eval_results = server.evaluate()
            xp.log({
                "server_val_{}".format(key): value
                for key, value in eval_results.items()
            })

            # Save results to Disk
            try:
                xp.save_to_disc(path=args.RESULTS_PATH, name=hp['log_path'])
            except:
                logger.error("Saving results Failed!")

            # Timing
            e = int((time.time() - t1) / c_round *
                    (hp['communication_rounds'] - c_round))
            logger.debug(
                f"Remaining Time (approx.): {e // 3600:02d}:{e % 3600 // 60:02d}:{e % 60:02d} [{c_round/hp['communication_rounds']*100:.2f}%]\n"
            )

    # Save model to disk
    server.save_model(path=args.CHECKPOINT_PATH, name=hp["save_model"])

    try:
        return 1 - eval_results[
            'accuracy'] if eval_results else 1 - server.evaluate()['accuracy']
    finally:
        # Delete objects to free up GPU memory
        del server
        clients.clear()
        del clients
        del xp
        gc.collect()
Ejemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser(description='Simulate FL for noniid clients ')

    parser.add_argument('--num_client', type=int, default=10, 
                        help='number of clients') 
    parser.add_argument('--num_group', type=int, default=2, 
                        help='number of clients groups')
    parser.add_argument('--rounds', type=int, default=50, 
                        help='training rounds')  
    parser.add_argument('--local_epoch', type=int, default=1, 
                        help='local training epochs')  
    parser.add_argument('--noniid_alpha', type=float, default=1.0, 
                        help='DIRICHLET_ALPHA')     
    parser.add_argument('--frac', type=float, default=0.5, 
                        help='fraction of participants ')
    parser.add_argument('--eps1', type=float, default=0.4, 
                        help='fraction of participants ')
    parser.add_argument('--eps2', type=float, default=1.6, 
                        help='fraction of participants ')     
    
    args = parser.parse_args()
    print(args)
    # initialize data, clients, server, logger
    dataset = ClientDataGenerater(args.num_client, args.num_group, args.noniid_alpha)
    dataset.load_emnist()
    print('Finish loading data')

    noniid_train, noniid_test=  dataset.rotate_noniid()
    clients = [Client(ConvNet, lambda x : torch.optim.SGD(x, lr=0.1, momentum=0.9), dat, idnum=i) 
           for i, dat in enumerate(noniid_train)]
    server = Server(ConvNet, noniid_test)
    cfl_stats = ExperimentLogger()
    cluster_indices = [np.arange(len(clients)).astype("int")]
    client_clusters = [[clients[i] for i in idcs] for idcs in cluster_indices]
    cfl_stats.log({'args':args})
    print('Finish setting up')
    
    local_epoch = args.local_epoch
    round_interval = 0
    for c_round in range(1, args.rounds+1):
        if c_round == 1:
            for client in clients:
                client.synchronize_with_server(server)
        participating_clients = server.select_clients(clients,frac = args.frac)
        
        # clients train locally
        for client in participating_clients:
            # average loss for a client
            train_stats = client.compute_weight_update(epochs=local_epoch)
            client.reset()
        similarities = server.compute_pairwise_similarities(clients)
        cluster_indices_new = []
        
        round_interval += 1
        for idc in cluster_indices:
            max_norm = server.compute_max_update_norm([clients[i] for i in idc])
            mean_norm = server.compute_mean_update_norm([clients[i] for i in idc])
            print(max_norm, mean_norm)
            # update avg gradient norm smaller than eps1: start converge
            # max norm of clients' gradient larger than eps2: need to split cluster
            if mean_norm<args.eps1 and max_norm>args.eps2 and len(idc)>2 :  # and round_interval > 20
                
                # server.cache_model(idc, clients[idc[0]].W, acc_clients)
                c1, c2 = server.cluster_clients(similarities[idc][:,idc]) 
                cluster_indices_new += [c1, c2]
                round_interval = 0
                cfl_stats.log({"split" : c_round})
    
            else:
                cluster_indices_new += [idc]
            
        cluster_indices = cluster_indices_new
        
        client_clusters = [[clients[i] for i in idcs] for idcs in cluster_indices]

        server.aggregate_clusterwise(client_clusters)
        acc_clients = [client.evaluate() for client in clients]
        log_dict = {"acc_clients" : acc_clients, "mean_norm" : mean_norm, "max_norm" : max_norm,
                      "rounds" : c_round, "clusters" : cluster_indices}
        cfl_stats.log(log_dict)
        print(log_dict)
        
    display_train_stats(cfl_stats, args.eps1, args.eps2, args.rounds)

    print("Clustering result: "  ,cluster_indices)
    for idc in cluster_indices:    
        server.cache_model(idc, clients[idc[0]].W, acc_clients)