Ejemplo n.º 1
0
def preprocess_mnist(context_manager):
    if context_manager is None:
        context_manager = NoopContextManager()

    with context_manager:
        # each party gets a unique temp directory
        with tempfile.TemporaryDirectory() as data_dir:
            mnist_train = datasets.MNIST(data_dir, download=True, train=True)
            mnist_test = datasets.MNIST(data_dir, download=True, train=False)

    # modify labels so all non-zero digits have class label 1
    mnist_train.targets[mnist_train.targets != 0] = 1
    mnist_test.targets[mnist_test.targets != 0] = 1
    mnist_train.targets[mnist_train.targets == 0] = 0
    mnist_test.targets[mnist_test.targets == 0] = 0

    # compute normalization factors
    data_all = torch.cat([mnist_train.data, mnist_test.data]).float()
    data_mean, data_std = data_all.mean(), data_all.std()
    tensor_mean, tensor_std = data_mean.unsqueeze(0), data_std.unsqueeze(0)

    # normalize data
    data_train_norm = transforms.functional.normalize(mnist_train.data.float(),
                                                      tensor_mean, tensor_std)

    # partition features between Alice and Bob
    data_alice = data_train_norm[:, :, :20]
    data_bob = data_train_norm[:, :, 20:]
    train_labels = mnist_train.targets

    return data_alice, data_bob, train_labels
Ejemplo n.º 2
0
def download_mnist(split="train"):
    """
    Loads split from the MNIST dataset and returns data.
    """
    train = split == "train"

    # If need to downkload MNIST dataset and uncompress,
    # it is necessary to create a separate for each process.
    mnist_exists = os.path.exists(
        os.path.join(
            "/tmp/MNIST/processed", MNIST.training_file if train else MNIST.test_file
        )
    )

    if mnist_exists:
        mnist_root = "/tmp"
    else:
        rank = "0" if "RANK" not in os.environ else os.environ["RANK"]
        mnist_root = os.path.join("tmp", "bandits", rank)
        os.makedirs(mnist_root, exist_ok=True)

    # download the MNIST dataset:
    with NoopContextManager():
        mnist = MNIST(mnist_root, download=not mnist_exists, train=train)
    return mnist
Ejemplo n.º 3
0
def run_mpc_cifar(
    epochs=25,
    start_epoch=0,
    batch_size=1,
    lr=0.001,
    momentum=0.9,
    weight_decay=1e-6,
    print_freq=10,
    model_location="",
    resume=False,
    evaluate=True,
    seed=None,
    skip_plaintext=False,
    context_manager=None,
):
    if seed is not None:
        random.seed(seed)
        torch.manual_seed(seed)

    crypten.init()

    # create model
    model = LeNet()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=momentum,
                                weight_decay=weight_decay)

    # optionally resume from a checkpoint
    best_prec1 = 0
    if resume:
        if os.path.isfile(model_location):
            logging.info("=> loading checkpoint '{}'".format(model_location))
            checkpoint = torch.load(model_location)
            start_epoch = checkpoint["epoch"]
            best_prec1 = checkpoint["best_prec1"]
            model.load_state_dict(checkpoint["state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                model_location, checkpoint["epoch"]))
        else:
            raise IOError(
                "=> no checkpoint found at '{}'".format(model_location))

    # Data loading code
    def preprocess_data(context_manager, data_dirname):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        with context_manager:
            trainset = datasets.CIFAR10(data_dirname,
                                        train=True,
                                        download=True,
                                        transform=transform)
            testset = datasets.CIFAR10(data_dirname,
                                       train=False,
                                       download=True,
                                       transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=4,
                                                  shuffle=True,
                                                  num_workers=2)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=2)
        return trainloader, testloader

    if context_manager is None:
        context_manager = NoopContextManager()

    data_dir = tempfile.TemporaryDirectory()
    train_loader, val_loader = preprocess_data(context_manager, data_dir.name)

    if evaluate:
        if not skip_plaintext:
            logging.info("===== Evaluating plaintext LeNet network =====")
            validate(val_loader, model, criterion, print_freq)
        logging.info("===== Evaluating Private LeNet network =====")
        input_size = get_input_size(val_loader, batch_size)
        private_model = construct_private_model(input_size, model)
        validate(val_loader, private_model, criterion, print_freq)
        # logging.info("===== Validating side-by-side ======")
        # validate_side_by_side(val_loader, model, private_model)
        return

    # define loss function (criterion) and optimizer
    for epoch in range(start_epoch, epochs):
        adjust_learning_rate(optimizer, epoch, lr)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, print_freq)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, print_freq)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "arch": "LeNet",
                "state_dict": model.state_dict(),
                "best_prec1": best_prec1,
                "optimizer": optimizer.state_dict(),
            },
            is_best,
        )
    data_dir.cleanup()
Ejemplo n.º 4
0
def run_tfe_benchmarks(
    network="B",
    epochs=5,
    start_epoch=0,
    batch_size=256,
    lr=0.01,
    momentum=0.9,
    weight_decay=1e-6,
    print_freq=10,
    resume="",
    evaluate=True,
    seed=None,
    skip_plaintext=False,
    save_checkpoint_dir="/tmp/tfe_benchmarks",
    save_modelbest_dir="/tmp/tfe_benchmarks_best",
    context_manager=None,
    mnist_dir=None,
):
    crypten.init()

    if seed is not None:
        random.seed(seed)
        torch.manual_seed(seed)

    # create model
    model = create_benchmark_model(network)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr,
                                momentum=momentum,
                                weight_decay=weight_decay)

    # optionally resume from a checkpoint
    best_prec1 = 0
    if resume:
        if os.path.isfile(resume):
            logging.info("=> loading checkpoint '{}'".format(resume))
            checkpoint = torch.load(resume)
            start_epoch = checkpoint["epoch"]
            best_prec1 = checkpoint["best_prec1"]
            model.load_state_dict(checkpoint["state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                resume, checkpoint["epoch"]))
        else:
            logging.info("=> no checkpoint found at '{}'".format(resume))

    # Loading MNIST. Normalizing per pytorch/examples/blob/master/mnist/main.py
    def preprocess_data(context_manager, data_dirname):
        if mnist_dir is not None:
            process_mnist_files(
                mnist_dir, os.path.join(data_dirname, "MNIST", "processed"))
            download = False
        else:
            download = True

        with context_manager:
            if not evaluate:
                mnist_train = datasets.MNIST(
                    data_dirname,
                    download=download,
                    train=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307, ), (0.3081, )),
                    ]),
                )

            mnist_test = datasets.MNIST(
                data_dirname,
                download=download,
                train=False,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307, ), (0.3081, ))
                ]),
            )
        train_loader = (torch.utils.data.DataLoader(
            mnist_train, batch_size=batch_size, shuffle=True)
                        if not evaluate else None)
        test_loader = torch.utils.data.DataLoader(mnist_test,
                                                  batch_size=batch_size,
                                                  shuffle=False)
        return train_loader, test_loader

    if context_manager is None:
        context_manager = NoopContextManager()

    warnings.filterwarnings("ignore")
    data_dir = tempfile.TemporaryDirectory()
    train_loader, val_loader = preprocess_data(context_manager, data_dir.name)

    flatten = False
    if network == "A":
        flatten = True

    if evaluate:
        if not skip_plaintext:
            logging.info("===== Evaluating plaintext benchmark network =====")
            validate(val_loader, model, criterion, print_freq, flatten=flatten)
        private_model = create_private_benchmark_model(model, flatten=flatten)
        logging.info("===== Evaluating Private benchmark network =====")
        validate(val_loader,
                 private_model,
                 criterion,
                 print_freq,
                 flatten=flatten)
        # validate_side_by_side(val_loader, model, private_model, flatten=flatten)
        return

    os.makedirs(save_checkpoint_dir, exist_ok=True)
    os.makedirs(save_modelbest_dir, exist_ok=True)

    for epoch in range(start_epoch, epochs):
        adjust_learning_rate(optimizer, epoch, lr)

        # train for one epoch
        train(
            train_loader,
            model,
            criterion,
            optimizer,
            epoch,
            print_freq,
            flatten=flatten,
        )

        # evaluate on validation set
        prec1 = validate(val_loader,
                         model,
                         criterion,
                         print_freq,
                         flatten=flatten)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        checkpoint_file = "checkpoint_bn" + network + ".pth.tar"
        model_best_file = "model_best_bn" + network + ".pth.tar"
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "arch": "Benchmark" + network,
                "state_dict": model.state_dict(),
                "best_prec1": best_prec1,
                "optimizer": optimizer.state_dict(),
            },
            is_best,
            filename=os.path.join(save_checkpoint_dir, checkpoint_file),
            model_best=os.path.join(save_modelbest_dir, model_best_file),
        )
    data_dir.cleanup()
    shutil.rmtree(save_checkpoint_dir)
Ejemplo n.º 5
0
def run_experiment(
    model_name,
    imagenet_folder=None,
    tensorboard_folder="/tmp",
    num_samples=None,
    context_manager=None,
):
    """Runs inference using specified vision model on specified dataset."""

    crypten.init()
    # check inputs:
    assert hasattr(models,
                   model_name), ("torchvision does not provide %s model" %
                                 model_name)
    if imagenet_folder is None:
        imagenet_folder = tempfile.gettempdir()
        download = True
    else:
        download = False
    if context_manager is None:
        context_manager = NoopContextManager()

    # load dataset and model:
    with context_manager:
        model = getattr(models, model_name)(pretrained=True)
        model.eval()
        dataset = datasets.ImageNet(imagenet_folder,
                                    split="val",
                                    download=download)

    # define appropriate transforms:
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    to_tensor_transform = transforms.ToTensor()

    # encrypt model:
    dummy_input = to_tensor_transform(dataset[0][0])
    dummy_input.unsqueeze_(0)
    encrypted_model = crypten.nn.from_pytorch(model, dummy_input=dummy_input)
    encrypted_model.encrypt()

    # show encrypted model in tensorboard:
    if SummaryWriter is not None:
        writer = SummaryWriter(log_dir=tensorboard_folder)
        writer.add_graph(encrypted_model)
        writer.close()

    # loop over dataset:
    meter = AccuracyMeter()
    for idx, sample in enumerate(dataset):

        # preprocess sample:
        image, target = sample
        image = transform(image)
        image.unsqueeze_(0)
        target = torch.tensor([target], dtype=torch.long)

        # perform inference using encrypted model on encrypted sample:
        encrypted_image = crypten.cryptensor(image)
        encrypted_output = encrypted_model(encrypted_image)

        # measure accuracy of prediction
        output = encrypted_output.get_plain_text()
        meter.add(output, target)

        # progress:
        logging.info("[sample %d of %d] Accuracy: %f" %
                     (idx + 1, len(dataset), meter.value()[1]))
        if num_samples is not None and idx == num_samples - 1:
            break

    # print final accuracy:
    logging.info("Accuracy on all %d samples: %f" %
                 (len(dataset), meter.value()[1]))