Example #1
0
def FederatedTrain(args):

    if args.dataset == 'MNIST':
        dataset = data.load_mnist()
        dataloaders_train, dataloader_test = data.create_split_dataloaders(
            dataset = dataset,
            args=args
        )
        dataiters_train = [iter(loader) for loader in dataloaders_train]
        dataiters_test = iter(dataloader_test)
        n_channels = 1
    else:
        print 'Dataset Is Not Supported'
        exit(1)


    n_clients = args.n_clients

    global_net = net.LeNet(n_channels = n_channels)
    print global_net

    learner = FederatedLearner(net = global_net, args = args)
    learner.gpu = args.gpu

    model_dir = args.model_dir
    global_model_name = args.global_model_name
    global_optim_name = args.global_optimizor_name
    global_model_suffix = global_model_suffix = '_init_.pth'
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    torch.save(learner.net.state_dict(), model_dir + global_model_name + global_model_suffix)
    print "Model saved"

    for t in range(args.epochs):
        if t == 0:
            global_model_suffix = '_init_.pth'
        else:
            global_model_suffix = '_{cur}.pth'.format(cur=t-1)

        learner.load_model(model_path = model_dir + global_model_name + global_model_suffix)

        for i in range(n_clients):
            print 't=', t, 'client model idx=', i
            try:
                batchX, batchY = next(dataiters_train[i])
            except StopIteration:
                dataiters_train[i] = iter(dataloaders_train[i])
                batchX, batchY = next(dataiters_train[i])

            learner.comp_grad(i, batchX, batchY)
        learner._update_model()

        global_model_suffix = '_{cur}.pth'.format(cur=t)
        torch.save(learner.net.state_dict(), model_dir + global_model_name + global_model_suffix)

        if (t+1) % args.n_eval_iters == 0:
            learner._evalTest(test_loader = dataloader_test)
Example #2
0
# 标准化
data_tf = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize([0.5], [0.5])])

train_dataset = datasets.MNIST(root='F:/shujuji/',
                               train=True,
                               transform=data_tf,
                               download=True)
test_dataset = datasets.MNIST(root='F:/shujuji/',
                              train=False,
                              transform=data_tf)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

model = net.LeNet()
if torch.cuda.is_available():
    model = model.cuda()

# 定义loss函数和优化方法
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

for epoch in range(num_epoches):
    model.train()
    for data in train_loader:  # 每次取一个batch_size张图片
        img, label = data  # img.size:128*1*28*28
        # img = img.view(img.size(0), -1)  # 展开成128 *784(28*28)
        if torch.cuda.is_available():
            img = img.cuda()
            label = label.cuda()