def init_train_loader(self, tl): self.train_loader = tl if __name__ == "__main__": args = parse_args() torch.manual_seed(args.seed) num_users = 100 num_slices = num_users if args.client_selection else NUM_CLIENTS server = CIFAR10ItPrServer(args, config, VGG11()) list_models, list_indices = server.init_clients() sampler = FLSampler(list_indices, MAX_ROUND, NUM_LOCAL_UPDATES * CLIENT_BATCH_SIZE, args.client_selection, num_slices) print("Sampler initialized") train_loader = get_data_loader(EXP_NAME, data_type="train", batch_size=CLIENT_BATCH_SIZE, shuffle=False, sampler=sampler, num_workers=8, pin_memory=True) client_list = [ CIFAR10ItPrClient(config, list_models[idx]) for idx in range(NUM_CLIENTS) ]
num_user_path = os.path.join("datasets", "FEMNIST", "processed", "num_users.pt") if not os.path.isfile(num_user_path): get_data_loader(EXP_NAME, data_type="train", batch_size=CLIENT_BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=True) num_users = torch.load(num_user_path) server = FEMNISTItPrServer(args, config, Conv2()) list_models, list_users = server.init_clients() sampler = FLSampler(get_indices_list(), MAX_ROUND, NUM_LOCAL_UPDATES * CLIENT_BATCH_SIZE, args.client_selection, NUM_CLIENTS) print("Sampler initialized") train_loader = get_data_loader(EXP_NAME, data_type="train", batch_size=CLIENT_BATCH_SIZE, shuffle=False, sampler=sampler, num_workers=8, pin_memory=True) client_list = [ FEMNISTItPrClient(config, list_models[idx]) for idx in range(NUM_CLIENTS) ]