def test_save_load_module(self): """Test that crypten.save and crypten.load properly save and load modules""" import tempfile comm = crypten.communicator for model_type in [TestModule, NestedTestModule]: # Create models with different parameter values on each rank rank = comm.get().get_rank() test_model = model_type(200, 10) test_model.set_all_parameters(rank) filename = tempfile.NamedTemporaryFile(delete=True).name for src in range(comm.get().get_world_size()): crypten.save(test_model, filename, src=src) dummy_model = model_type(200, 10) result = crypten.load(filename, dummy_model=dummy_model, src=src) if src == rank: for param in result.parameters(recurse=True): self.assertTrue( param.eq(rank).all().item(), "Model load failed") self.assertEqual(result.src, src) failure_dummy_model = model_type(200, 11) with self.assertRaises(AssertionError, msg="Expected load failure not raised"): result = crypten.load(filename, dummy_model=failure_dummy_model, src=src)
def test(): ws = mpc_comm.get().world_size rank = mpc_comm.get().get_rank() print(rank) tens = torch.tensor([x for x in range(ws+1)]) #for rank in range(ws): crypten.save(tens, f"test_{rank}.pth", src=rank)
def encrypt_digits(): """Alice has images. Bob has labels""" digits = torchvision.datasets.MNIST( root="/tmp/data", train=True, transform=torchvision.transforms.ToTensor(), download=True, ) images, labels = take_samples(digits) crypten.save(images, "/tmp/data/alice_images.pth", src=ALICE) crypten.save(labels, "/tmp/data/bob_labels.pth", src=BOB)
def test_save_load(self): """Test that crypten.save and crypten.load properly save and load shares of cryptensors""" import io import pickle def custom_load_function(f): obj = pickle.load(f) return obj def custom_save_function(obj, f): pickle.dump(obj, f) all_save_fns = [torch.save, custom_save_function] all_load_fns = [torch.load, custom_load_function] tensor = get_random_test_tensor() cryptensor1 = crypten.cryptensor(tensor) for i, save_closure in enumerate(all_save_fns): load_closure = all_load_fns[i] f = [ io.BytesIO() for i in range(crypten.communicator.get().get_world_size()) ] crypten.save(cryptensor1, f[self.rank], save_closure=save_closure) f[self.rank].seek(0) cryptensor2 = crypten.load(f[self.rank], load_closure=load_closure) # test whether share matches self.assertTrue(cryptensor1.share.allclose(cryptensor2.share)) # test whether tensor matches self.assertTrue( cryptensor1.get_plain_text().allclose(cryptensor2.get_plain_text()) ) attributes = [ a for a in dir(cryptensor1) if not a.startswith("__") and not callable(getattr(cryptensor1, a)) and a not in ["share", "_tensor", "ctx"] ] for a in attributes: attr1, attr2 = getattr(cryptensor1, a), getattr(cryptensor2, a) if a == "encoder": self.assertTrue(attr1._scale == attr2._scale) self.assertTrue(attr1._precision_bits == attr2._precision_bits) elif torch.is_tensor(attr1): self.assertTrue(attr1.eq(attr2).all()) else: self.assertTrue(attr1 == attr2)
def test_save_load(self): """Test that crypten.save and crypten.load properly save and load tensors""" import tempfile filename = tempfile.NamedTemporaryFile(delete=True).name for dimensions in range(1, 5): # Create tensors with different sizes on each rank size = [self.rank + 1] * dimensions size = tuple(size) tensor = torch.randn(size=size) for src in range(crypten.communicator.get().get_world_size()): crypten.save(tensor, filename, src=src) encrypted_load = crypten.load(filename, src=src) reference_size = tuple([src + 1] * dimensions) self.assertEqual(encrypted_load.size(), reference_size) size_out = [src + 1] * dimensions reference = tensor if self.rank == src else torch.empty( size=size_out) dist.broadcast(reference, src=src) self._check(encrypted_load, reference, "crypten.load() failed")
def test_solo(world_size=world_size): tens = torch.tensor([x for x in range(world_size+1)]) for rank in range(world_size): crypten.save(tens, f"test_{rank}.pth", src=rank)