device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# number of workers
N_w = 20
# number of training samples
# Cifar10  50,000
# Fashin MNIST 60,000
N_s = 50000

batch = 64
tau = 4
runs = int(24000 / tau)

trainloaders, testloader = data_loader.CIFAR_data(batch, N_w, N_s)

w_index = 0
results = np.empty([1, int(runs / int(120 / tau))])
res_ind = 0
nets = [nn_classes.ResNet18().to(device) for n in range(N_w)]

ps_model = nn_classes.ResNet18().to(device)
avg_model = nn_classes.ResNet18().to(device)

lr = 1e-1
momentum = 0.9
weight_decay = 1e-4
alpha = 0.45

criterions = [nn.CrossEntropyLoss() for n in range(N_w)]
    for param, param_ps in zip(model.parameters(), agg_model.parameters()):
        param.grad.data = param_ps.grad.data
        param.data.add_(-scale, param.grad.data)

    return None

for a in range(2):
    # select gpu
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    # load data
    print('loadinf data...')
    mini_batch = 64
    trainloader, testloader = data_loader.CIFAR_data(mini_batch)

    for d in trainloader:
        # used in initilizing the gradiients
        x_init, y_init = d
        x_init, y_init = x_init.to(device), y_init.to(device)

    # number of clusters
    num_cl = 7
    # number of workers per cluster
    num_w_per_cluster = 5
    nets = [[
        nn_classes.ResNet18().to(device) for n in range(num_w_per_cluster)
    ] for c in range(num_cl)]
    for c in range(num_cl):