def test_cryptensor_registration(self): """Tests the registration mechanism for custom `CrypTensor` types.""" # perform tests: cryptensor_name = "my_cryptensor" self.assertEqual(crypten.get_default_cryptensor_type(), "mpc") with self.assertRaises(ValueError): crypten.set_default_cryptensor_type(cryptensor_name) tensor = crypten.cryptensor(torch.zeros(1, 3)) self.assertEqual(crypten.get_cryptensor_type(tensor), "mpc") # register new tensor type: @crypten.register_cryptensor(cryptensor_name) class MyCrypTensor(crypten.CrypTensor): """Dummy `CrypTensor` type.""" def __init__(self, *args, **kwargs): self.is_custom_type = True # test that registration was successful: self.assertEqual(crypten.get_default_cryptensor_type(), "mpc") crypten.set_default_cryptensor_type(cryptensor_name) self.assertEqual(crypten.get_default_cryptensor_type(), cryptensor_name) tensor = crypten.cryptensor(torch.zeros(1, 3)) self.assertTrue(getattr(tensor, "is_custom_type", False)) self.assertEqual(crypten.get_cryptensor_type(tensor), cryptensor_name)
def setUp(self): super().setUp() if self.rank >= 0: crypten.init() crypten.set_default_cryptensor_type("mpc")