Пример #1
0
    def test_save_load_module(self):
        """Test that crypten.save and crypten.load properly save and load modules"""
        import tempfile

        comm = crypten.communicator
        for model_type in [TestModule, NestedTestModule]:
            # Create models with different parameter values on each rank
            rank = comm.get().get_rank()

            test_model = model_type(200, 10)
            test_model.set_all_parameters(rank)

            filename = tempfile.NamedTemporaryFile(delete=True).name
            for src in range(comm.get().get_world_size()):
                crypten.save(test_model, filename, src=src)

                dummy_model = model_type(200, 10)

                result = crypten.load(filename,
                                      dummy_model=dummy_model,
                                      src=src)
                if src == rank:
                    for param in result.parameters(recurse=True):
                        self.assertTrue(
                            param.eq(rank).all().item(), "Model load failed")
                self.assertEqual(result.src, src)

                failure_dummy_model = model_type(200, 11)
                with self.assertRaises(AssertionError,
                                       msg="Expected load failure not raised"):
                    result = crypten.load(filename,
                                          dummy_model=failure_dummy_model,
                                          src=src)
Пример #2
0
def test():
    ws = mpc_comm.get().world_size
    rank = mpc_comm.get().get_rank()
    print(rank)
    tens = torch.tensor([x for x in range(ws+1)])
    #for rank in range(ws):
    crypten.save(tens, f"test_{rank}.pth", src=rank)
Пример #3
0
def encrypt_digits():
    """Alice has images. Bob has labels"""
    digits = torchvision.datasets.MNIST(
        root="/tmp/data",
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=True,
    )
    images, labels = take_samples(digits)
    crypten.save(images, "/tmp/data/alice_images.pth", src=ALICE)
    crypten.save(labels, "/tmp/data/bob_labels.pth", src=BOB)
Пример #4
0
    def test_save_load(self):
        """Test that crypten.save and crypten.load properly save and load
        shares of cryptensors"""
        import io
        import pickle

        def custom_load_function(f):
            obj = pickle.load(f)
            return obj

        def custom_save_function(obj, f):
            pickle.dump(obj, f)

        all_save_fns = [torch.save, custom_save_function]
        all_load_fns = [torch.load, custom_load_function]

        tensor = get_random_test_tensor()
        cryptensor1 = crypten.cryptensor(tensor)

        for i, save_closure in enumerate(all_save_fns):
            load_closure = all_load_fns[i]
            f = [
                io.BytesIO() for i in range(crypten.communicator.get().get_world_size())
            ]
            crypten.save(cryptensor1, f[self.rank], save_closure=save_closure)
            f[self.rank].seek(0)
            cryptensor2 = crypten.load(f[self.rank], load_closure=load_closure)
            # test whether share matches
            self.assertTrue(cryptensor1.share.allclose(cryptensor2.share))
            # test whether tensor matches
            self.assertTrue(
                cryptensor1.get_plain_text().allclose(cryptensor2.get_plain_text())
            )
            attributes = [
                a
                for a in dir(cryptensor1)
                if not a.startswith("__")
                and not callable(getattr(cryptensor1, a))
                and a not in ["share", "_tensor", "ctx"]
            ]
            for a in attributes:
                attr1, attr2 = getattr(cryptensor1, a), getattr(cryptensor2, a)
                if a == "encoder":
                    self.assertTrue(attr1._scale == attr2._scale)
                    self.assertTrue(attr1._precision_bits == attr2._precision_bits)
                elif torch.is_tensor(attr1):
                    self.assertTrue(attr1.eq(attr2).all())
                else:
                    self.assertTrue(attr1 == attr2)
Пример #5
0
    def test_save_load(self):
        """Test that crypten.save and crypten.load properly save and load tensors"""
        import tempfile

        filename = tempfile.NamedTemporaryFile(delete=True).name
        for dimensions in range(1, 5):
            # Create tensors with different sizes on each rank
            size = [self.rank + 1] * dimensions
            size = tuple(size)
            tensor = torch.randn(size=size)

            for src in range(crypten.communicator.get().get_world_size()):
                crypten.save(tensor, filename, src=src)
                encrypted_load = crypten.load(filename, src=src)
                reference_size = tuple([src + 1] * dimensions)
                self.assertEqual(encrypted_load.size(), reference_size)

                size_out = [src + 1] * dimensions
                reference = tensor if self.rank == src else torch.empty(
                    size=size_out)
                dist.broadcast(reference, src=src)
                self._check(encrypted_load, reference, "crypten.load() failed")
Пример #6
0
def test_solo(world_size=world_size):
    tens = torch.tensor([x for x in range(world_size+1)])
    for rank in range(world_size):
        crypten.save(tens, f"test_{rank}.pth", src=rank)