Exemplo n.º 1
0
    def init_train_loader(self, tl):
        self.train_loader = tl


if __name__ == "__main__":
    args = parse_args()
    torch.manual_seed(args.seed)

    num_users = 100
    num_slices = num_users if args.client_selection else NUM_CLIENTS

    server = CIFAR10ItPrServer(args, config, VGG11())
    list_models, list_indices = server.init_clients()

    sampler = FLSampler(list_indices, MAX_ROUND,
                        NUM_LOCAL_UPDATES * CLIENT_BATCH_SIZE,
                        args.client_selection, num_slices)
    print("Sampler initialized")

    train_loader = get_data_loader(EXP_NAME,
                                   data_type="train",
                                   batch_size=CLIENT_BATCH_SIZE,
                                   shuffle=False,
                                   sampler=sampler,
                                   num_workers=8,
                                   pin_memory=True)

    client_list = [
        CIFAR10ItPrClient(config, list_models[idx])
        for idx in range(NUM_CLIENTS)
    ]
Exemplo n.º 2
0
    num_user_path = os.path.join("datasets", "FEMNIST", "processed",
                                 "num_users.pt")
    if not os.path.isfile(num_user_path):
        get_data_loader(EXP_NAME,
                        data_type="train",
                        batch_size=CLIENT_BATCH_SIZE,
                        shuffle=False,
                        num_workers=8,
                        pin_memory=True)
    num_users = torch.load(num_user_path)

    server = FEMNISTItPrServer(args, config, Conv2())
    list_models, list_users = server.init_clients()

    sampler = FLSampler(get_indices_list(), MAX_ROUND,
                        NUM_LOCAL_UPDATES * CLIENT_BATCH_SIZE,
                        args.client_selection, NUM_CLIENTS)
    print("Sampler initialized")

    train_loader = get_data_loader(EXP_NAME,
                                   data_type="train",
                                   batch_size=CLIENT_BATCH_SIZE,
                                   shuffle=False,
                                   sampler=sampler,
                                   num_workers=8,
                                   pin_memory=True)

    client_list = [
        FEMNISTItPrClient(config, list_models[idx])
        for idx in range(NUM_CLIENTS)
    ]