Пример #1
0
def main(net_state_name, net, testset):
    warming_up_cuda()
    # print("CUDA is available:", torch.cuda.is_available())

    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=minibatch,
                                             shuffle=False,
                                             num_workers=4)

    # seed = 100
    # torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    # np.random.seed(seed)

    net.to(device)

    sys.stdout = Logger()

    NumShowInter = 100
    NumEpoch = 200
    IterCounter = -1
    training_start = time.time()

    load_state_dict(net_state_name, net)
    sys.stdout = Logger()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=0.05,
                          momentum=0.9,
                          weight_decay=5e-4)

    correct = 0
    total = 0

    loss_sum = 0.0
    cnt = 0
    inference_start = time.time()
    net.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            loss = criterion(outputs, labels)
            loss_sum += loss.data.cpu().item() * images.size(0)
            correct += (predicted == labels).sum().item()
            cnt += int(images.size()[0])
        print('Accuracy of the network on the 10000 test images: %f %%' %
              (100 * correct / total))
        print("loss=", loss_sum / float(cnt))

    elapsed_time = time.time() - inference_start
    print("Elapsed time for Prediction", elapsed_time)
Пример #2
0
    def test_server():
        rank = Config.server_rank
        sys.stdout = Logger()
        traffic_record = TrafficRecord()
        secure_nn = get_secure_nn()
        secure_nn.set_rank(rank).init_communication(master_address=master_addr,
                                                    master_port=master_port)
        warming_up_cuda()
        secure_nn.fhe_builder_sync()
        load_trunc_params(secure_nn, store_configs)

        net_state = torch.load(net_state_name)
        load_weight_params(secure_nn, store_configs, net_state)

        meta_rg = MetaTruncRandomGenerator()
        meta_rg.reset_seed()

        with NamedTimerInstance("Server Offline"):
            secure_nn.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            secure_nn.online()
            torch_sync()
        traffic_record.reset("server-online")

        secure_nn.check_correctness(check_correctness)
        secure_nn.check_layers(get_plain_net, get_hooking_lst(model_name_base))
        secure_nn.end_communication()
Пример #3
0
                    args=[secure_nn, correctness_func, master_address, master_port])
        p.start()
        processes.append(p)
        for p in processes:
            p.join()

    if party == Config.server_rank:
        run_secure_nn_server_with_random_data(secure_nn, correctness_func, master_address, master_port)
    if party == Config.client_rank:
        run_secure_nn_client_with_random_data(secure_nn, correctness_func, master_address, master_port)

    print(f"\nTest for {test_name}: End")

if __name__ == "__main__":
    input_sid, master_addr, master_port, test_to_run = argparser_distributed()
    sys.stdout = Logger()

    print("====== New Tests ======")
    print("Test To run:", test_to_run)

    num_repeat = 5

    for _ in range(num_repeat):
        if test_to_run in ["small", "all"]:
            marshal_secure_nn_parties(input_sid, master_addr, master_port, generate_small_nn(), correctness_small_nn)
        if test_to_run in ["relu", "all"]:
            marshal_secure_nn_parties(input_sid, master_addr, master_port, generate_relu_only_nn(), correctness_relu_only_nn)
        if test_to_run in ["maxpool", "all"]:
            marshal_secure_nn_parties(input_sid, master_addr, master_port, generate_maxpool2x2(), correctness_maxpool2x2)
        if test_to_run in ["conv2d", "all"]:
            marshal_secure_nn_parties(input_sid, master_addr, master_port, generate_conv2d(), correctness_conv2d)
Пример #4
0
def train(dataset, args):
    # For reproducibility
    torch.manual_seed(1)
    np.random.seed(1)
    random.seed(1)

    logger = Logger(model=args.model_type)

    # build model
    num_feats = NUM_FEATURES if not args.use_refex else NUM_FEATURES + NUM_ROLX_FEATURES
    model = models.GNNStack(
        num_feats,
        args.hidden_dim,
        3,  # dataset.num_classes
        args,
        torch.tensor([1, 0, 15], device=dev).float()  # weights for each class
    )
    if torch.cuda.is_available():
        model = model.cuda(dev)

    scheduler, opt = build_optimizer(args, model.parameters())
    skf, x, y = get_stratified_batches()

    # train
    for epoch in range(args.epochs):
        total_loss = 0
        accs, f1s, aucs, recalls = [], [], [], []
        model.train()
        # No need to loop over batches since we only have one batch
        num_splits = 0
        for train_indices, test_indices in skf.split(x, y):
            train_indices, test_indices = x[train_indices], x[test_indices]
            batch = dataset
            opt.zero_grad()
            pred = model(batch)
            label = batch.y

            pred = pred[train_indices]
            label = label[train_indices]

            loss = model.loss(pred, label)
            loss.backward()
            opt.step()
            total_loss += loss.item()
            num_splits += 1

            acc_score, f1, auc_score, recall = test(dataset, model,
                                                    test_indices)
            accs.append(acc_score)
            f1s.append(f1)
            aucs.append(auc_score)
            recalls.append(recall)

        total_loss /= num_splits
        accs = np.array(accs)
        f1s = np.array(f1s)
        aucs = np.array(aucs)
        recalls = np.array(recalls)
        log_metrics = {
            'total_loss': total_loss,
            'acc': accs,
            'f1': f1s,
            'auc': aucs,
            'recall': recalls
        }

        logger.log(log_metrics, epoch)
        if epoch % 5 == 0:
            logger.display_status(epoch, args.epochs, total_loss, accs, f1s,
                                  aucs, recalls)
    logger.close()
Пример #5
0
    def test_client():
        rank = Config.client_rank
        sys.stdout = Logger()
        traffic_record = TrafficRecord()
        secure_nn = get_secure_nn()
        secure_nn.set_rank(rank).init_communication(master_address=master_addr,
                                                    master_port=master_port)
        warming_up_cuda()
        secure_nn.fhe_builder_sync()

        load_trunc_params(secure_nn, store_configs)

        def input_shift(data):
            first_layer_name = "conv1"
            return data_shift(data,
                              store_configs[first_layer_name + "ForwardX"])

        def testset():
            if model_name_base in ["vgg16_cifar100"]:
                return torchvision.datasets.CIFAR100(
                    root='./data',
                    train=False,
                    download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.4914, 0.4822, 0.4465),
                                             (0.2023, 0.1994, 0.2010)),
                        input_shift
                    ]))
            elif model_name_base in ["vgg16_cifar10", "minionn_maxpool"]:
                return torchvision.datasets.CIFAR10(
                    root='./data',
                    train=False,
                    download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.4914, 0.4822, 0.4465),
                                             (0.2023, 0.1994, 0.2010)),
                        input_shift
                    ]))

        testloader = torch.utils.data.DataLoader(testset(),
                                                 batch_size=batch_size,
                                                 shuffle=True,
                                                 num_workers=2)

        data_iter = iter(testloader)
        image, truth = next(data_iter)
        image = image.reshape(secure_nn.get_input_shape())
        secure_nn.fill_input(image)

        with NamedTimerInstance("Client Offline"):
            secure_nn.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            secure_nn.online()
            torch_sync()
        traffic_record.reset("client-online")

        secure_nn.check_correctness(check_correctness, truth=truth)
        secure_nn.check_layers(get_plain_net, get_hooking_lst(model_name_base))
        secure_nn.end_communication()
Пример #6
0
def prepare_logger(xargs):
    args = copy.deepcopy(xargs)
    from logger_utils import Logger
    logger = Logger(args.save_dir, args.rand_seed, sparse_flag=False)

    logger.log('Main Function with logger : {:}'.format(logger))
    logger.log('Arguments : -------------------------------')
    for name, value in args._get_kwargs():
        logger.log('{:16} : {:}'.format(name, value))
    logger.log("Python  Version  : {:}".format(sys.version.replace('\n', ' ')))
    logger.log("Pillow  Version  : {:}".format(PIL.__version__))
    logger.log("PyTorch Version  : {:}".format(torch.__version__))
    logger.log("cuDNN   Version  : {:}".format(torch.backends.cudnn.version()))
    logger.log("CUDA available   : {:}".format(torch.cuda.is_available()))
    logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
    logger.log("CUDA_VISIBLE_DEVICES : {:}".format(
        os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in
        os.environ else 'None'))
    return logger
Пример #7
0
def main(test_to_run, net, trainset, testset):
    print("CUDA is available:", torch.cuda.is_available())
    # torch.backends.cudnn.deterministic = True
    #n_classes = 10
    withnorm = False
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=minibatch,
                                              shuffle=True,
                                              num_workers=2)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=minibatch,
                                             shuffle=False,
                                             num_workers=2)
    dir = f"./model/checkpoint-minionn-cifar10.pt"
    load_swalp_state_dict(dir, net, withnorm)
    net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=0.05,
                          momentum=0.9,
                          weight_decay=5e-4)

    model_name_base = test_to_run + "_swalp"
    loss_name_base = f"./model/{model_name_base}"
    os.makedirs("./model/", exist_ok=True)
    torch.save(net.state_dict(), loss_name_base + "_net.pth")

    sys.stdout = Logger()

    NumShowInter = 100
    epoch = 300
    # https://github.com/chengyangfu/pytorch-vgg-cifar10
    training_start = time.time()

    lr = schedule(epoch)
    adjust_learning_rate(optimizer, lr)
    running_loss = 0.0

    data_iter = iter(trainloader)
    image, _ = next(data_iter)
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print statistics
        #running_loss += loss.item()
        #signific_acc = "%03d"%int((correct / total) * 1000)
        np.save(loss_name_base + "_exp_configs.npy", store_configs)
        break

    correct = 0
    total = 0
    loss_sum = 0.0
    cnt = 0
    inference_start = time.time()
    net.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            loss = criterion(outputs, labels)
            loss_sum += loss.data.cpu().item() * images.size(0)
            correct += (predicted == labels).sum().item()
            cnt += int(images.size()[0])
        print('Accuracy of the network on the 10000 test images: %f %%' %
              (100 * correct / total))
        print("loss=", loss_sum / float(cnt))

    print('Finished Training')
Пример #8
0
def main(test_to_run, net, trainset, testset):
    print("CUDA is available:", torch.cuda.is_available())
    # torch.backends.cudnn.deterministic = True
    #n_classes = 10

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    #trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.Compose([
    #transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), transforms.ToTensor(), normalize, ]))
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=minibatch,
                                              shuffle=True,
                                              num_workers=2)

    #testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.Compose([
    #transforms.ToTensor(), normalize, ]))
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=minibatch,
                                             shuffle=False,
                                             num_workers=2)

    #classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    #net = NetQ()
    net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=0.05,
                          momentum=0.9,
                          weight_decay=5e-4)
    #optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-5)
    #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
    #optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=60, gamma=0.2)

    # register_linear_layer(net.conv2, "conv2")
    sys.stdout = Logger()

    NumShowInter = 100
    NumEpoch = 300
    IterCounter = -1
    # https://github.com/chengyangfu/pytorch-vgg-cifar10
    training_start = time.time()
    for epoch in range(NumEpoch):  # loop over the dataset multiple times

        #scheduler.step()
        lr = schedule(epoch)
        adjust_learning_rate(optimizer, lr)
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            global training_state
            training_state = Config.training
            IterCounter += 1
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % NumShowInter == NumShowInter - 1 or (
                    epoch == 0 and i == 0):  # print every 2000 mini-batches
                elapsed_time = time.time() - training_start
                print('[%d, %5d, %6d, %6f] loss: %.3f' %
                      (epoch + 1, i + 1, IterCounter, elapsed_time,
                       running_loss / NumShowInter))

                running_loss = 0.032
                correct = 0
                total = 0

                # store_layer_name = 'conv2'
                # store_name = f"quantize_{store_layer_name}_{epoch + 1}_{i + 1}"
                # store_layer(store_layer_name, store_name)
                train_store_configs = store_configs.copy()

                with torch.no_grad():
                    training_state = Config.testing
                    for data in testloader:
                        images, labels = data[0].to(device), data[1].to(device)
                        outputs = net(images)
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        correct += (predicted == labels).sum().item()
                    print(
                        'Accuracy of the network on the 10000 test images: %d %%'
                        % (100 * correct / total))

                model_name_base = test_to_run + "_swalp"
                signific_acc = "%03d" % int((correct / total) * 1000)
                loss_name_base = f"./model/{model_name_base}_{signific_acc}"
                os.makedirs("./model/", exist_ok=True)

                #print("Net's state_dict:")
                #for var_name in net.state_dict():
                #        print(var_name, "\t", net.state_dict()[var_name].size())

                torch.save(net.state_dict(), loss_name_base + "_net.pth")
                # model = TheModelClass(*args, **kwargs)
                # model.load_state_dict(torch.load("./model/vgg_swalp_xxxx_net.pth"))
                # model.eval()
                np.save(loss_name_base + "_exp_configs.npy",
                        train_store_configs)
                # Load
                # read_dictionary = np.load('my_file.npy',allow_pickle='TRUE').item()
                # print(read_dictionary['hello']) # displays "world"

    print('Finished Training')
Пример #9
0
def main(test_to_run, modnet):
    print("CUDA is available:", torch.cuda.is_available())
    # torch.backends.cudnn.deterministic = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    minibatch = 512
    device = torch.device("cuda:0")

    # model_name_base = test_to_run
    # signific_acc = "897"
    # loss_name_base = f"./model/{model_name_base}_{signific_acc}"

    net_state_name, config_name = get_net_config_name(test_to_run)

    net_state = torch.load(net_state_name)
    store_configs = np.load(config_name, allow_pickle="TRUE").item()

    net = modnet(store_configs)
    net.load_weight_bias(net_state)
    net.to(device)

    def modulus_net_input_transform(data):
        bits = 8
        input_exp, _ = store_configs[first_layer_name + "ForwardX"]
        exp = -input_exp + (bits - 2)
        res = shift_by_exp_plain(data, exp)
        res.clamp_(-2**(bits - 1), 2**(bits - 1) - 1)

        return res

    testset = torchvision.datasets.CIFAR10(root='./data',
                                           train=False,
                                           download=True,
                                           transform=transforms.Compose([
                                               transforms.ToTensor(),
                                               normalize,
                                               modulus_net_input_transform
                                           ]))
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=minibatch,
                                             shuffle=False,
                                             num_workers=2)

    sys.stdout = Logger()

    training_start = time.time()

    correct = 0
    total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('Accuracy of the network on the 10000 test images: %2.3f %%' %
              (100 * correct / total))

    elapsed_time = time.time() - training_start
    print("Elapsed time for Prediction", elapsed_time)

    print('Finished Testing for Accuracy')