def test_plaintext_save_load_module_from_party(self): """Test that crypten.save_from_party and crypten.load_from_party properly save and load plaintext modules""" import tempfile comm = crypten.communicator for model_type in [TestModule, NestedTestModule]: # Create models with different parameter values on each rank rank = comm.get().get_rank() test_model = model_type(200, 10) test_model.set_all_parameters(rank) serial.register_safe_class(model_type) filename = tempfile.NamedTemporaryFile(delete=True).name for src in range(comm.get().get_world_size()): crypten.save_from_party(test_model, filename, src=src) result = crypten.load_from_party(filename, src=src) if src == rank: for param in result.parameters(recurse=True): self.assertTrue( param.eq(rank).all().item(), "Model load failed" ) self.assertEqual(result.src, src)
def encrypt_digits(): """Alice has images. Bob has labels""" digits = torchvision.datasets.MNIST( root="/tmp/data", train=True, transform=torchvision.transforms.ToTensor(), download=True, ) images, labels = take_samples(digits) crypten.save_from_party(images, "/tmp/data/alice_images.pth", src=ALICE) crypten.save_from_party(labels, "/tmp/data/bob_labels.pth", src=BOB)
def test_save_load(self): """Test that crypten.save and crypten.load properly save and load tensors""" import tempfile import numpy as np def custom_load_function(f): np_arr = np.load(f) tensor = torch.from_numpy(np_arr) return tensor def custom_save_function(obj, f): np_arr = obj.numpy() np.save(f, np_arr) comm = crypten.communicator filename = tempfile.NamedTemporaryFile(delete=True).name all_save_fns = [torch.save, custom_save_function] all_load_fns = [torch.load, custom_load_function] all_file_completions = [".pth", ".npy"] all_test_load_fns = [torch.load, np.load] for dimensions in range(1, 5): # Create tensors with different sizes on each rank size = [self.rank + 1] * dimensions size = tuple(size) tensor = torch.randn(size=size) for i, save_closure in enumerate(all_save_fns): load_closure = all_load_fns[i] test_load_fn = all_test_load_fns[i] complete_file = filename + all_file_completions[i] for src in range(comm.get().get_world_size()): crypten.save_from_party(tensor, complete_file, src=src, save_closure=save_closure) # the following line will throw an error if an object saved with # torch.save is attempted to be loaded with np.load if self.rank == src: test_load_fn(complete_file) encrypted_load = crypten.load_from_party( complete_file, src=src, load_closure=load_closure) reference_size = tuple([src + 1] * dimensions) self.assertEqual(encrypted_load.size(), reference_size) size_out = [src + 1] * dimensions reference = (tensor if self.rank == src else torch.empty( size=size_out)) comm.get().broadcast(reference, src=src) self._check(encrypted_load, reference, "crypten.load() failed") # test for invalid load_closure with self.assertRaises(TypeError): crypten.load_from_party(complete_file, src=src, load_closure=(lambda f: None)) # test pre-loaded encrypted_preloaded = crypten.load_from_party( src=src, preloaded=tensor) self._check( encrypted_preloaded, reference, "crypten.load() failed using preloaded", )