Example #1
0
    def test_cryptensor_registration(self):
        """Tests the registration mechanism for custom `CrypTensor` types."""

        # perform tests:
        cryptensor_name = "my_cryptensor"
        self.assertEqual(crypten.get_default_cryptensor_type(), "mpc")
        with self.assertRaises(ValueError):
            crypten.set_default_cryptensor_type(cryptensor_name)
        tensor = crypten.cryptensor(torch.zeros(1, 3))
        self.assertEqual(crypten.get_cryptensor_type(tensor), "mpc")

        # register new tensor type:
        @crypten.register_cryptensor(cryptensor_name)
        class MyCrypTensor(crypten.CrypTensor):
            """Dummy `CrypTensor` type."""
            def __init__(self, *args, **kwargs):
                self.is_custom_type = True

        # test that registration was successful:
        self.assertEqual(crypten.get_default_cryptensor_type(), "mpc")
        crypten.set_default_cryptensor_type(cryptensor_name)
        self.assertEqual(crypten.get_default_cryptensor_type(),
                         cryptensor_name)
        tensor = crypten.cryptensor(torch.zeros(1, 3))
        self.assertTrue(getattr(tensor, "is_custom_type", False))
        self.assertEqual(crypten.get_cryptensor_type(tensor), cryptensor_name)
Example #2
0
 def setUp(self):
     super().setUp()
     if self.rank >= 0:
         crypten.init()
         crypten.set_default_cryptensor_type("mpc")