class Client:
    def __init__(self, name, train_data_dir, test_data_dir):
        self.name = name

        transform = transforms.ToTensor()

        trainset = datasets.ImageFolder(train_data_dir, transform=transform)
        self.trainloader = torch.utils.data.DataLoader(
            trainset,
            batch_size=BATCH_SIZE,
            shuffle=True
        )

        testset = datasets.ImageFolder(test_data_dir, transform=transform)
        self.testloader = torch.utils.data.DataLoader(
            testset,
            batch_size=BATCH_SIZE,
            shuffle=False
        )

        dataset_list = list(self.trainloader)
        self.dataset_len = len(dataset_list)

        self.net = LeNet().to(device)

        self.criterion = nn.CrossEntropyLoss()

    def update(self, net_dict, center_params_dict):
        self.net.load_state_dict(net_dict)

        for i in range(LOCAL_EPOCH_NUM):
            data_iter = iter(self.trainloader)
            for b in range(self.dataset_len):
                inputs, labels = next(data_iter)
                inputs = torch.index_select(inputs, 1, torch.LongTensor([0]))
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = self.net(inputs)
                loss = self.criterion(outputs, labels)
                optimizer = optim.SGD(self.net.parameters(), lr=LR, momentum=0.9)
                optimizer.zero_grad()
                loss.backward()

                params_modules = list(self.net.named_parameters())
                for params_module in params_modules:
                    name, params = params_module
                    params.grad += MU * (params.data - center_params_dict[name])

                optimizer.step()

        return self.net.state_dict()
class Client:
    def __init__(self, name, train_data_dir, test_data_dir, pk, sk):
        self.name = name
        self.pk = pk
        self.sk = sk

        transform = transforms.ToTensor()

        trainset = datasets.ImageFolder(train_data_dir, transform=transform)
        self.trainloader = torch.utils.data.DataLoader(trainset,
                                                       batch_size=BATCH_SIZE,
                                                       shuffle=True)

        testset = datasets.ImageFolder(test_data_dir, transform=transform)
        self.testloader = torch.utils.data.DataLoader(testset,
                                                      batch_size=BATCH_SIZE,
                                                      shuffle=False)

        dataset_list = list(self.trainloader)
        self.dataset_len = len(dataset_list)

        self.net = LeNet().to(device)

        self.criterion = nn.CrossEntropyLoss()

    def get_encrypted_grad(self, client_inputs, client_labels, net_dict):
        self.net.load_state_dict(net_dict)
        client_outputs = self.net(client_inputs)
        client_loss = self.criterion(client_outputs, client_labels)
        client_optimizer = optim.SGD(self.net.parameters(),
                                     lr=LR,
                                     momentum=0.9)
        client_optimizer.zero_grad()
        client_loss.backward()

        params_modules = list(self.net.named_parameters())
        params_grad_list = []
        for params_module in params_modules:
            name, params = params_module
            params_grad_list.append(copy.deepcopy(params.grad).view(-1))

        params_grad = ((torch.cat(params_grad_list, 0) + bound) *
                       2**prec).long().cuda()
        client_encrypted_grad = Enc(self.pk, params_grad)

        client_optimizer.zero_grad()

        return client_encrypted_grad
Beispiel #3
0
class Client:
    def __init__(self, name, train_data_dir, test_data_dir):
        self.name = name

        transform = transforms.ToTensor()

        trainset = datasets.ImageFolder(train_data_dir, transform=transform)
        self.trainloader = torch.utils.data.DataLoader(trainset,
                                                       batch_size=BATCH_SIZE,
                                                       shuffle=True)

        testset = datasets.ImageFolder(test_data_dir, transform=transform)
        self.testloader = torch.utils.data.DataLoader(testset,
                                                      batch_size=BATCH_SIZE,
                                                      shuffle=False)

        dataset_list = list(self.trainloader)
        self.dataset_len = len(dataset_list)

        self.net = LeNet().to(device)

        self.criterion = nn.CrossEntropyLoss()

    def get_grad(self, client_inputs, client_labels, net_dict):
        self.net.load_state_dict(net_dict)
        client_outputs = self.net(client_inputs)
        client_loss = self.criterion(client_outputs, client_labels)
        client_optimizer = optim.SGD(self.net.parameters(),
                                     lr=LR,
                                     momentum=0.9)
        client_optimizer.zero_grad()
        client_loss.backward()

        client_grad_dict = dict()
        params_modules = list(self.net.named_parameters())
        for params_module in params_modules:
            name, params = params_module
            params_grad = copy.deepcopy(params.grad)
            client_grad_dict[name] = params_grad
        client_optimizer.zero_grad()
        return client_grad_dict
            ], net_dict, attacker_net)
        # client_attacker_grad_dict = get_client_grad_model(client_attacker_inputs, client_attacker_labels, attack_module)
        # 取各client参数梯度均值
        client_average_grad_dict = dict()
        # attacker_average_grad_dict = dict()
        # for key in client_attacker_grad_dict:
        #     attacker_average_grad_dict[key] = client_attacker_grad_dict[key]
        for key in client_0_grad_dict:
            client_average_grad_dict[key] = client_0_grad_dict[key] * (
                1 / Num_client) + client_1_grad_dict[key] * (
                    1 / Num_client) + client_2_grad_dict[key] * (
                        1 / Num_client) + client_attacker_grad_dict[key] * (
                            1 / Num_client)

        # 加载梯度
        params_modules_server = net.named_parameters()
        # params_modules_attacker = attacker_net.named_parameters()
        # for params_module in params_modules_attacker:
        #     (name_attacker, params) = params_module
        #     params.grad = attacker_average_grad_dict[name_attacker]
        # optimizer_backdoor.step()
        for params_module in params_modules_server:
            (name, params) = params_module
            params.grad = client_average_grad_dict[
                name]  # 用字典中存储的子模型的梯度覆盖server中的参数梯度
        optimizer_server.step()

    # 每跑完一次epoch测试一下准确率
    with torch.no_grad():
        correct = 0
        total = 0
                client_attacker_net)
            # 取各client参数梯度均值
            client_average_grad_dict = dict()
            attacker_average_grad_dict = dict()
            for key in client_attacker_grad_dict:
                attacker_average_grad_dict[key] = client_attacker_grad_dict[
                    key]
            for key in client_0_grad_dict:
                client_average_grad_dict[key] = client_0_grad_dict[key] * (
                    1 / Num_client) + client_1_grad_dict[key] * (
                        1 / Num_client) + client_2_grad_dict[key] * (
                            1 / Num_client
                        ) + client_attacker_grad_dict[key] * (1 / Num_client)

            # 加载梯度
            params_modules_server = net.named_parameters()
            params_modules_attacker = attacker_net.named_parameters()
            for params_module in params_modules_attacker:
                (name_attacker, params) = params_module
                params.grad = client_average_grad_dict[name_attacker]
            optimizer_backdoor.step()
            for params_module in params_modules_server:
                (name, params) = params_module
                params.grad = client_average_grad_dict[
                    name]  # 用字典中存储的子模型的梯度覆盖server中的参数梯度
            optimizer_server.step()
            # # 每训练100个batch打印一次平均loss
            # sum_loss += loss_c0.item()
            # if i % 100 == 99:
            #     print('[%d, %d] loss: %.03f'
            #           % (epoch + 1, i + 1, sum_loss / 100))