Exemple #1
0
    def test_pos_pow(self):
        """Test gradient crypten pos_pow"""
        for power in [3, -2, 1.75]:
            # ensure base is positive for pos_pow
            tensor = get_random_test_tensor(is_float=True, max_value=2) + 4
            tensor.requires_grad = True
            tensor_encr = AutogradCrypTensor(
                crypten.cryptensor(tensor), requires_grad=True
            )

            reference = tensor.pow(power)
            out_encr = tensor_encr.pos_pow(power)
            self._check(
                out_encr, reference, f"pos_pow forward failed with power {power}"
            )

            grad_out = get_random_test_tensor(is_float=True)
            grad_out_encr = crypten.cryptensor(grad_out)
            reference.backward(grad_out)
            out_encr.backward(grad_out_encr)

            self._check(
                tensor_encr.grad,
                tensor.grad,
                f"pos_pow backward failed with power {power}",
            )
Exemple #2
0
def train_encrypted(
    x_encrypted,
    y_encrypted,
    encrypted_model,
    num_epochs,
    learning_rate,
    batch_size,
    print_freq,
):
    rank = comm.get().get_rank()
    loss = crypten.nn.MSELoss()

    num_samples = x_encrypted.size(0)
    label_eye = torch.eye(2)

    for epoch in range(num_epochs):
        last_progress_logged = 0
        # only print from rank 0 to avoid duplicates for readability
        if rank == 0:
            print(f"Epoch {epoch} in progress:")

        for j in range(0, num_samples, batch_size):

            # define the start and end of the training mini-batch
            start, end = j, min(j + batch_size, num_samples)

            # construct AutogradCrypTensors out of training examples
            x_train = AutogradCrypTensor(x_encrypted[start:end])
            y_one_hot = label_eye[y_encrypted[start:end]]
            y_train = AutogradCrypTensor(crypten.cryptensor(y_one_hot))

            # perform forward pass:
            output = encrypted_model(x_train)
            loss_value = loss(output, y_train)

            # backprop
            encrypted_model.zero_grad()
            loss_value.backward()
            encrypted_model.update_parameters(learning_rate)

            # log progress
            if j + batch_size - last_progress_logged >= print_freq:
                last_progress_logged += print_freq
                print(f"Loss {loss_value.get_plain_text().item():.4f}")

        # compute accuracy every epoch
        pred = output.get_plain_text().argmax(1)
        correct = pred.eq(y_encrypted[start:end])
        correct_count = correct.sum(0, keepdim=True).float()
        accuracy = correct_count.mul_(100.0 / output.size(0))

        loss_plaintext = loss_value.get_plain_text().item()
        print(f"Epoch {epoch} completed: "
              f"Loss {loss_plaintext:.4f} Accuracy {accuracy.item():.2f}")
Exemple #3
0
    def _check_forward_backward(
        self, func_name, input_tensor, *args, torch_func_name=None, msg=None, **kwargs
    ):
        """Checks forward and backward against PyTorch

        Args:
            func_name (str): PyTorch/CrypTen function name
            input_tensor (torch.tensor): primary input
            args (list): contains arguments for function
            msg (str): additional message for mismatch
            kwargs (list): keyword arguments for function
        """

        if msg is None:
            msg = f"{func_name} grad_fn incorrect"

        input = input_tensor.clone()
        input.requires_grad = True
        input_encr = AutogradCrypTensor(crypten.cryptensor(input), requires_grad=True)

        for private in [False, True]:
            input.grad = None
            input_encr.grad = None
            args = self._set_grad_to_zero(args)
            args_encr = self._set_grad_to_zero(list(args), make_private=private)

            # obtain torch function
            if torch_func_name is not None:
                torch_func = self._get_torch_func(torch_func_name)
            else:
                torch_func = self._get_torch_func(func_name)

            reference = torch_func(input, *args, **kwargs)
            encrypted_out = getattr(input_encr, func_name)(*args_encr, **kwargs)

            # extract argmax output for max / min with keepdim=False
            if isinstance(encrypted_out, (list, tuple)):
                reference = reference[0]
                encrypted_out = encrypted_out[0]

            self._check(encrypted_out, reference, msg + " in forward")

            # check backward pass
            grad_output = get_random_test_tensor(
                max_value=2, size=reference.size(), is_float=True
            )
            grad_output_encr = crypten.cryptensor(grad_output)
            reference.backward(grad_output)
            encrypted_out.backward(grad_output_encr)

            self._check(input_encr.grad, input.grad, msg + " in backward")
            for i, arg_encr in enumerate(args_encr):
                if crypten.is_encrypted_tensor(arg_encr):
                    self._check(arg_encr.grad, args[i].grad, msg + " in backward args")
Exemple #4
0
    def test_training(self):
        """
        Tests training of simple model in crypten.nn.
        """

        # create MLP with one hidden layer:
        learning_rate = 0.1
        batch_size, num_inputs, num_intermediate, num_outputs = 8, 10, 5, 1
        model = crypten.nn.Sequential([
            crypten.nn.Linear(num_inputs, num_intermediate),
            crypten.nn.ReLU(),
            crypten.nn.Linear(num_intermediate, num_outputs),
        ])
        model.train()
        model.encrypt()
        loss = crypten.nn.MSELoss()

        # perform training iterations:
        for _ in range(10):
            for wrap in [True, False]:

                # get training sample:
                input = get_random_test_tensor(size=(batch_size, num_inputs),
                                               is_float=True)
                target = input.mean(dim=1, keepdim=True)

                # encrypt training sample:
                input = crypten.cryptensor(input)
                target = crypten.cryptensor(target)
                if wrap:
                    input = AutogradCrypTensor(input)
                    target = AutogradCrypTensor(target)

                # perform forward pass:
                output = model(input)
                loss_value = loss(output, target)

                # set gradients to "zero" (setting to None is more efficient):
                model.zero_grad()
                for param in model.parameters():
                    self.assertIsNone(param.grad,
                                      "zero_grad did not reset gradients")

                # perform backward pass:
                loss_value.backward()

                # perform parameter update:
                reference = {}
                reference = self._compute_reference_parameters(
                    "", reference, model, learning_rate)
                model.update_parameters(learning_rate)
                self._check_reference_parameters("", reference, model)
Exemple #5
0
    def test_losses(self):
        """
        Tests all Losses implemented in crypten.nn.
        """

        # create test tensor:
        input = get_random_test_tensor(max_value=0.999,
                                       is_float=True).abs() + 0.001
        target = get_random_test_tensor(max_value=0.999,
                                        is_float=True).abs() + 0.001
        encrypted_input = crypten.cryptensor(input)
        encrypted_target = crypten.cryptensor(target)

        # test forward() function of all simple losses:
        for loss_name in ["BCELoss", "L1Loss", "MSELoss"]:
            enc_loss_object = getattr(torch.nn, loss_name)()
            self.assertEqual(enc_loss_object.reduction, "mean",
                             "Reduction used is not 'mean'")

            loss = getattr(torch.nn, loss_name)()(input, target)
            encrypted_loss = getattr(crypten.nn, loss_name)()(encrypted_input,
                                                              encrypted_target)
            self._check(encrypted_loss, loss, "%s failed" % loss_name)
            encrypted_loss = getattr(crypten.nn, loss_name)()(
                AutogradCrypTensor(encrypted_input),
                AutogradCrypTensor(encrypted_target),
            )
            self._check(encrypted_loss, loss, "%s failed" % loss_name)

        # test forward() function of cross-entropy loss:
        batch_size, num_targets = 16, 5
        input = get_random_test_tensor(size=(batch_size, num_targets),
                                       is_float=True)
        target = get_random_test_tensor(size=(batch_size, ),
                                        max_value=num_targets - 1).abs()
        encrypted_input = crypten.cryptensor(input)
        encrypted_target = crypten.cryptensor(
            onehot(target, num_targets=num_targets))
        enc_loss_object = crypten.nn.CrossEntropyLoss()
        self.assertEqual(enc_loss_object.reduction, "mean",
                         "Reduction used is not 'mean'")

        loss = torch.nn.CrossEntropyLoss()(input, target)
        encrypted_loss = crypten.nn.CrossEntropyLoss()(encrypted_input,
                                                       encrypted_target)
        self._check(encrypted_loss, loss, "cross-entropy loss failed")
        encrypted_loss = crypten.nn.CrossEntropyLoss()(
            AutogradCrypTensor(encrypted_input),
            AutogradCrypTensor(encrypted_target))
        self._check(encrypted_loss, loss, "cross-entropy loss failed")
Exemple #6
0
    def test_graph(self):
        """
        Tests crypten.nn.Graph module.
        """
        for wrap in [True, False]:

            # define test case:
            input_size = (3, 10)
            input = get_random_test_tensor(size=input_size, is_float=True)
            encr_input = crypten.cryptensor(input)
            if wrap:
                encr_input = AutogradCrypTensor(encr_input)

            # test residual block with subsequent linear layer:
            graph = crypten.nn.Graph("input", "output")
            linear1 = get_random_linear(input_size[1], input_size[1])
            linear2 = get_random_linear(input_size[1], input_size[1])
            graph.add_module("linear", crypten.nn.from_pytorch(linear1, input),
                             ["input"])
            graph.add_module("residual", crypten.nn.Add(), ["input", "linear"])
            graph.add_module("output", crypten.nn.from_pytorch(linear2, input),
                             ["residual"])
            graph.encrypt()

            # check container:
            self.assertTrue(graph.encrypted, "nn.Graph not encrypted")
            for module in graph.modules():
                self.assertTrue(module.encrypted, "module not encrypted")
            assert (sum(1 for _ in graph.modules()) == 3
                    ), "nn.Graph contains incorrect number of modules"

            # compare output to reference:
            encr_output = graph(encr_input)
            reference = linear2(linear1(input) + input)
            self._check(encr_output, reference, "nn.Graph forward failed")
Exemple #7
0
    def test_autograd(self):
        """Tests autograd graph construction and backprop."""

        # define test cases:
        tests = [
            (1, ["relu", "neg", "relu", "sum"]),
            (2, ["t", "neg", "add", "sum"]),
            (2, ["relu", "mul", "t", "sum"]),
        ]
        binary_functions = ["add", "sub", "mul", "dot", "matmul"]

        # PyTorch test case:
        for test in tests:

            # get test case:
            number_of_inputs, ops = test
            inputs = [
                get_random_test_tensor(size=(12, 5), is_float=True)
                for _ in range(number_of_inputs)
            ]
            encr_inputs = [crypten.cryptensor(input) for input in inputs]

            # get autograd variables:
            for input in inputs:
                input.requires_grad = True
            encr_inputs = [AutogradCrypTensor(encr_input) for encr_input in encr_inputs]

            # perform forward pass, logging all intermediate outputs:
            outputs, encr_outputs = [inputs], [encr_inputs]
            for op in ops:

                # get inputs for current operation:
                input, output = outputs[-1], []
                encr_input, encr_output = encr_outputs[-1], []

                # apply current operation:
                if op in binary_functions:  # combine outputs via operation
                    output.append(getattr(input[0], op)(input[1]))
                    encr_output.append(getattr(encr_input[0], op)(encr_input[1]))
                else:
                    for idx in range(len(input)):
                        output.append(getattr(input[idx], op)())
                        encr_output.append(getattr(encr_input[idx], op)())

                # keep references to outputs of operation:
                outputs.append(output)
                encr_outputs.append(encr_output)

            # check output of forward pass:
            output, encr_output = outputs[-1][0], encr_outputs[-1][0]
            self._check(encr_output._tensor, output, "forward failed")
            self.assertTrue(encr_output.requires_grad, "requires_grad incorrect")

            # perform backward pass:
            output.backward()
            encr_output.backward()

            # test result of running backward function:
            for idx in range(number_of_inputs):
                self._check(encr_inputs[idx].grad, inputs[idx].grad, "backward failed")
Exemple #8
0
def _to_autograd(args):
    """
    Recursively converts inputs to AutogradCrypTensors.
    """

    # convert tuples to lists to allow changes:
    convert_to_tuple = False
    if isinstance(args, tuple):
        args = list(args)
        convert_to_tuple = True

    # wrap all input tensors in AutogradCrypTensor:
    for idx in range(len(args)):
        if isinstance(args[idx],
                      (list, tuple)):  # input may be list of tensors
            args[idx] = _to_autograd(args[idx])
        elif isinstance(args[idx], AutogradCrypTensor) or args[idx] is None:
            pass
        elif isinstance(args[idx], crypten.CrypTensor):
            args[idx] = AutogradCrypTensor(args[idx])
        else:
            raise ValueError(
                "Cannot convert type {} to AutogradCrypTensor.".format(
                    type(args[idx])))

    # return:
    if convert_to_tuple:
        args = tuple(args)
    return args
Exemple #9
0
    def test_dropout(self):
        """Tests forward and backward passes for dropout"""
        # Create a separate test for dropout since it cannot use the
        # regular forward function
        all_prob_values = [x * 0.2 for x in range(0, 5)]
        for dropout_fn in [
                "dropout", "_feature_dropout", "dropout2d", "dropout3d"
        ]:
            for prob in all_prob_values:
                for size in [(5, 10), (5, 10, 15), (5, 10, 15, 20)]:
                    for use_zeros in [False, True]:
                        tensor = get_random_test_tensor(size=size,
                                                        ex_zero=True,
                                                        min_value=1.0,
                                                        is_float=True)
                        if use_zeros:
                            # turn the first row to all zeros
                            index = [1] + [
                                slice(0, tensor.size(i))
                                for i in range(1, tensor.dim())
                            ]
                            tensor[index] = 0.0

                        encr_tensor = AutogradCrypTensor(
                            crypten.cryptensor(tensor), requires_grad=True)
                        encr_tensor_out = getattr(encr_tensor,
                                                  dropout_fn)(p=prob)
                        dropout_tensor = encr_tensor_out.get_plain_text()
                        # Check the scaling for non-zero elements
                        scaled_tensor = tensor / (1 - prob)
                        reference = dropout_tensor.where(
                            dropout_tensor == 0, scaled_tensor)
                        self._check(
                            encr_tensor_out,
                            reference,
                            "dropout failed with size {}, use_zeros {}, and "
                            "probability {}".format(size, use_zeros, prob),
                        )

                    # Check backward pass
                    grad_output = get_random_test_tensor(max_value=2,
                                                         size=reference.size(),
                                                         is_float=True)
                    grad_output_encr = crypten.cryptensor(grad_output)

                    # Do not check backward if pytorch backward fails
                    try:
                        reference.backward(grad_output)
                    except RuntimeError:
                        logging.info("skipped")
                        continue
                    encr_tensor_out.backward(grad_output_encr)

                    self._check(
                        encr_tensor.grad,
                        input.grad,
                        "dropout failed in backward with size {}, use_zeros {} and "
                        "probability {}".format(size, use_zeros, prob),
                    )
Exemple #10
0
    def _set_grad_to_zero(self, args, make_private=False):
        """Sets gradients for args to zero

        Args:
            args (list of torch.tensors): contains arguments
            make_private (bool): encrypt args using AutogradCrypTensor
        """
        args_zero_grad = []

        for arg in args:
            if is_float_tensor(arg) and make_private:
                arg = AutogradCrypTensor(crypten.cryptensor(arg), requires_grad=True)
            elif is_float_tensor(arg):
                arg.requires_grad = True
                arg.grad = None

            args_zero_grad.append(arg)

        return args_zero_grad
Exemple #11
0
    def test_cross_entropy(self):
        """Tests cross_entropy and binary_cross_entropy"""
        sizes = [(3, 2), (8, 4), (5, 10)]
        losses = [
            "binary_cross_entropy",
            "binary_cross_entropy_with_logits",
            "cross_entropy",
        ]

        for size, loss in itertools.product(sizes, losses):
            for skip_forward in [False, True]:
                batch_size, num_targets = size
                if loss in [
                        "binary_cross_entropy",
                        "binary_cross_entropy_with_logits"
                ]:
                    if loss == "binary_cross_entropy":
                        tensor = get_random_test_tensor(size=(batch_size, ),
                                                        max_value=0.998,
                                                        is_float=True)
                        tensor = tensor.abs().add_(0.001)
                    else:
                        tensor = get_random_test_tensor(size=(batch_size, ),
                                                        is_float=True)

                    target = get_random_test_tensor(size=(batch_size, ),
                                                    is_float=True)
                    target = target.gt(0.0).float()
                    target_encr = crypten.cryptensor(target)
                else:
                    tensor = get_random_test_tensor(size=size, is_float=True)
                    target = get_random_test_tensor(size=(batch_size, ),
                                                    max_value=num_targets - 1)
                    target = onehot(target.abs(), num_targets=num_targets)
                    target_encr = crypten.cryptensor(target)
                    # CrypTen, unlike PyTorch, uses one-hot targets
                    target = target.argmax(1)

                # forward
                tensor.requires_grad = True
                tensor_encr = AutogradCrypTensor(crypten.cryptensor(tensor),
                                                 requires_grad=True)
                reference = getattr(torch.nn.functional, loss)(tensor, target)
                out_encr = getattr(tensor_encr,
                                   loss)(target_encr,
                                         skip_forward=skip_forward)
                if not skip_forward:
                    self._check(out_encr, reference, f"{loss} forward failed")

                # backward
                reference.backward()
                out_encr.backward()
                self._check(tensor_encr.grad, tensor.grad,
                            f"{loss} backward failed with")
Exemple #12
0
 def register_parameter(self, name, param, requires_grad=True):
     """Register parameter in the module."""
     if name in self._parameters or hasattr(self, name):
         raise ValueError("Parameter or field %s already exists." % name)
     if torch.is_tensor(param):  # unencrypted model
         param.requires_grad = requires_grad
         self._parameters[name] = param
     else:  # encryped model
         self._parameters[name] = AutogradCrypTensor(
             param, requires_grad=requires_grad)
     setattr(self, name, param)
Exemple #13
0
    def test_batchnorm_module(self):
        """Test module correctly sets and updates running stats"""
        batchnorm_fn_and_size = (
            ("BatchNorm1d", (500, 10, 3)),
            ("BatchNorm2d", (600, 7, 4, 20)),
            ("BatchNorm3d", (800, 5, 4, 8, 15)),
        )
        for batchnorm_fn, size in batchnorm_fn_and_size:
            for is_trainning in (True, False):
                tensor = get_random_test_tensor(size=size, is_float=True)
                tensor.requires_grad = True
                encrypted_input = AutogradCrypTensor(
                    crypten.cryptensor(tensor))

                C = size[1]
                weight = get_random_test_tensor(size=[C],
                                                max_value=1,
                                                is_float=True)
                bias = get_random_test_tensor(size=[C],
                                              max_value=1,
                                              is_float=True)
                weight.requires_grad = True
                bias.requires_grad = True

                # dimensions for mean and variance
                stats_dimensions = list(range(tensor.dim()))
                # perform on C dimension for tensor of shape (N, C, +)
                stats_dimensions.pop(1)

                # check running stats initial
                enc_model = getattr(crypten.nn.module,
                                    batchnorm_fn)(C).encrypt()
                plain_model = getattr(torch.nn.modules, batchnorm_fn)(C)
                stats = ["running_var", "running_mean"]
                for stat in stats:
                    self._check(
                        enc_model._buffers[stat],
                        plain_model._buffers[stat],
                        f"{stat} initial module value incorrect",
                    )

                # set trainning mode
                plain_model.training = is_trainning
                enc_model.training = is_trainning

                # check running_stats update
                enc_model.forward(encrypted_input)
                plain_model.forward(tensor)
                for stat in stats:
                    self._check(
                        enc_model._buffers[stat],
                        plain_model._buffers[stat],
                        f"{stat} momentum update in module incorrect",
                    )
Exemple #14
0
    def test_autograd_repetition(self):
        """Tests running autograd on the same input repeatedly."""

        # create test case:
        input = get_random_test_tensor(size=(12, 5), is_float=True)
        input.requires_grad = True
        encr_input = AutogradCrypTensor(crypten.cryptensor(input))

        # re-use the same input multiple times:
        for _ in range(7):

            # perform forward pass:
            output = input.exp().sum()
            encr_output = encr_input.exp().sum()
            self._check(encr_output._tensor, output, "forward failed")
            self.assertTrue(encr_output.requires_grad, "requires_grad incorrect")

            # perform backward computation:
            output.backward()
            encr_output.backward()
            self._check(encr_input.grad, input.grad, "backward failed")
Exemple #15
0
    def test_detach(self):
        """Tests that detach() works as expected."""

        for func_name in ["detach", "detach_"]:

            # get test case:
            input_size = (12, 5)
            input1 = get_random_test_tensor(size=input_size, is_float=True)
            input2 = get_random_test_tensor(size=input_size, is_float=True)
            input1 = AutogradCrypTensor(crypten.cryptensor(input1))
            input2 = AutogradCrypTensor(crypten.cryptensor(input2))

            # perform forward computation with detach in the middle:
            intermediate = input1.add(1.0)
            intermediate = getattr(intermediate, func_name)()
            output = intermediate.add(input2).sum()

            # perform backward:
            output.backward()
            msg = "detach() function does not behave as expected"
            self.assertIsNone(output.grad, msg)
            self.assertIsNone(intermediate.grad, msg)
            self.assertIsNone(input1.grad, msg)
            self.assertIsNotNone(input2.grad, msg)
Exemple #16
0
    def test_square(self):
        """Tests square function gradient.
        Note: torch pow(2) is used to verify gradient,
            since PyTorch does not implement square().
        """
        for size in SIZES:
            tensor = get_random_test_tensor(size=size, is_float=True)
            tensor.requires_grad = True
            tensor_encr = AutogradCrypTensor(crypten.cryptensor(tensor),
                                             requires_grad=True)

            out = tensor.pow(2)
            out_encr = tensor_encr.square()
            self._check(out_encr, out,
                        f"square forward failed with size {size}")

            grad_output = get_random_test_tensor(size=out.shape, is_float=True)
            out.backward(grad_output)
            out_encr.backward(crypten.cryptensor(grad_output))
            self._check(
                tensor_encr.grad,
                tensor.grad,
                f"square backward failed with size {size}",
            )
Exemple #17
0
    def encrypt(self, mode=True, src=0):
        """Encrypts the model."""
        if mode != self.encrypted:

            # encrypt / decrypt parameters:
            self.encrypted = mode
            for name, param in self.named_parameters(recurse=False):
                requires_grad = param.requires_grad
                if mode:  # encrypt parameter
                    self.set_parameter(
                        name,
                        AutogradCrypTensor(
                            crypten.cryptensor(param, **{"src": src}),
                            requires_grad=requires_grad,
                        ),
                    )
                else:  # decrypt parameter
                    self.set_parameter(name, param.get_plain_text())
                    self._parameters[name].requires_grad = requires_grad

            # encrypt / decrypt buffers:
            for name, buffer in self.named_buffers(recurse=False):
                if mode:  # encrypt buffer
                    self.set_buffer(
                        name,
                        AutogradCrypTensor(
                            crypten.cryptensor(buffer, **{"src": src}),
                            requires_grad=False,
                        ),
                    )
                else:  # decrypt buffer
                    self.set_buffer(name, buffer.get_plain_text())

            # apply encryption recursively:
            return self._apply(lambda m: m.encrypt(mode=mode, src=src))
        return self
Exemple #18
0
    def test_non_differentiable_marking(self):
        """Tests whether marking of non-differentiability works correctly."""

        # generate random inputs:
        inputs = [get_random_test_tensor(is_float=True) for _ in range(5)]
        inputs = [crypten.cryptensor(input) for input in inputs]
        ctx = AutogradContext()

        # repeat test multiple times:
        for _ in range(10):

            # mark non-differentiable inputs as such:
            differentiable = [
                random.random() > 0.5 for _ in range(len(inputs))
            ]
            for idx, diff in enumerate(differentiable):
                if not diff:
                    ctx.mark_non_differentiable(inputs[idx])

            # check that inputs were correctly marked:
            for idx, input in enumerate(inputs):
                self.assertEqual(
                    ctx.is_differentiable(input),
                    differentiable[idx],
                    "marking of differentiability failed",
                )
            ctx.reset()

        # test behavior of AutogradCrypTensor:
        input = AutogradCrypTensor(inputs[0])
        reference = [True, True, False]
        for func_name in ["min", "max"]:
            outputs = [None] * 3
            outputs[0] = getattr(input, func_name)()
            outputs[1], outputs[2] = getattr(input, func_name)(dim=0)
            for idx, output in enumerate(outputs):
                self.assertEqual(
                    output.requires_grad,
                    reference[idx],
                    "value of requires_grad is incorrect",
                )

        # behavior of max_pool2d in which indices are returned:
        input = get_random_test_tensor(size=(1, 3, 8, 8), is_float=True)
        input = AutogradCrypTensor(crypten.cryptensor(input))
        reference = [True, True, False]
        outputs = [None] * 3
        outputs[0] = input.max_pool2d(2, return_indices=False)
        outputs[1], outputs[2] = input.max_pool2d(2, return_indices=True)
        for idx, output in enumerate(outputs):
            self.assertEqual(
                output.requires_grad,
                reference[idx],
                "value of requires_grad is incorrect",
            )
Exemple #19
0
    def test_sequential(self):
        """
        Tests crypten.nn.Sequential module.
        """

        # try networks of different depth:
        for num_layers in range(1, 6):
            for wrap in [True, False]:

                # construct sequential container:
                input_size = (3, 10)
                output_size = (input_size[0], input_size[1] - num_layers)
                layer_idx = range(input_size[1], output_size[1], -1)
                module_list = [
                    crypten.nn.Linear(num_feat, num_feat - 1)
                    for num_feat in layer_idx
                ]
                sequential = crypten.nn.Sequential(module_list)
                sequential.encrypt()

                # check container:
                self.assertTrue(sequential.encrypted,
                                "nn.Sequential not encrypted")
                for module in sequential.modules():
                    self.assertTrue(module.encrypted, "module not encrypted")
                assert sum(1 for _ in sequential.modules()) == len(
                    module_list
                ), "nn.Sequential contains incorrect number of modules"

                # construct test input and run through sequential container:
                input = get_random_test_tensor(size=input_size, is_float=True)
                encr_input = crypten.cryptensor(input)
                if wrap:
                    encr_input = AutogradCrypTensor(encr_input)
                encr_output = sequential(encr_input)

                # compute reference output:
                encr_reference = encr_input
                for module in sequential.modules():
                    encr_reference = module(encr_reference)
                reference = encr_reference.get_plain_text()

                # compare output to reference:
                self._check(encr_output, reference,
                            "nn.Sequential forward failed")
Exemple #20
0
    def test_cat_stack(self):
        for func in ["cat", "stack"]:
            for dimensions in range(1, 5):
                size = [5] * dimensions
                for num_tensors in range(1, 5):
                    for dim in range(dimensions):
                        tensors = [
                            get_random_test_tensor(size=size, is_float=True)
                            for _ in range(num_tensors)
                        ]
                        encrypted_tensors = [
                            AutogradCrypTensor(crypten.cryptensor(t))
                            for t in tensors
                        ]
                        for i in range(len(tensors)):
                            tensors[i].grad = None
                            tensors[i].requires_grad = True
                            encrypted_tensors[i].grad = None
                            encrypted_tensors[i].requires_grad = True

                        # Forward
                        reference = getattr(torch, func)(tensors, dim=dim)
                        encrypted_out = getattr(crypten,
                                                func)(encrypted_tensors,
                                                      dim=dim)
                        self._check(encrypted_out, reference,
                                    f"{func} forward failed")

                        # Backward
                        grad_output = get_random_test_tensor(
                            size=reference.size(), is_float=True)
                        encrypted_grad_output = crypten.cryptensor(grad_output)

                        reference.backward(grad_output)
                        encrypted_out.backward(encrypted_grad_output)
                        for i in range(len(tensors)):
                            self._check(
                                encrypted_tensors[i].grad,
                                tensors[i].grad,
                                f"{func} backward failed",
                            )
Exemple #21
0
    def test_gather_scatter(self):
        """Tests gather and scatter gradients"""
        sizes = [(2, 2), (3, 5), (3, 5, 10)]
        indices = [[0, 1, 0, 0], [0, 1, 0, 0, 1] * 3, [0, 0, 1] * 50]
        dims = [0, 1]
        funcs = ["scatter", "gather"]

        for dim, func in itertools.product(dims, funcs):
            for size, index in zip(sizes, indices):
                tensor = get_random_test_tensor(size=size, is_float=True)
                index = torch.tensor(index).reshape(tensor.shape)

                tensor.requires_grad = True
                tensor_encr = AutogradCrypTensor(crypten.cryptensor(tensor),
                                                 requires_grad=True)

                if func == "gather":
                    reference = getattr(tensor, func)(dim, index)
                    out_encr = getattr(tensor_encr, func)(dim, index)
                else:
                    src = get_random_test_tensor(size=index.shape,
                                                 is_float=True)
                    reference = getattr(tensor, func)(dim, index, src)
                    out_encr = getattr(tensor_encr, func)(dim, index, src)

                self._check(out_encr, reference,
                            f"{func} forward failed with index {index}")

                grad_out = get_random_test_tensor(size=reference.shape,
                                                  is_float=True)
                grad_out_encr = crypten.cryptensor(grad_out)
                reference.backward(grad_out)
                out_encr.backward(grad_out_encr)

                self._check(
                    tensor_encr.grad,
                    tensor.grad,
                    f"{func} backward failed with index {index}",
                )
Exemple #22
0
    def _check_forward_backward(self,
                                fn_name,
                                input_tensor,
                                *args,
                                msg=None,
                                **kwargs):
        if msg is None:
            msg = f"{fn_name} grad_fn incorrect"

        for requires_grad in [True]:
            # Setup input
            input = input_tensor.clone()
            input.requires_grad = requires_grad
            input_encr = AutogradCrypTensor(crypten.cryptensor(input),
                                            requires_grad=requires_grad)

            for private in [False, True]:
                input.grad = None
                input_encr.grad = None

                # Setup args
                args_encr = list(args)
                for i, arg in enumerate(args):
                    if private and is_float_tensor(arg):
                        args_encr[i] = AutogradCrypTensor(
                            crypten.cryptensor(arg),
                            requires_grad=requires_grad)
                        args_encr[i].grad = None  # zero grad
                    if is_float_tensor(arg):
                        args[i].requires_grad = requires_grad
                        args[i].grad = None  # zero grad

                # Check forward pass
                if hasattr(input, fn_name):
                    reference = getattr(input, fn_name)(*args, **kwargs)
                elif hasattr(F, fn_name):
                    reference = getattr(F, fn_name)(input, *args, **kwargs)
                elif fn_name == "square":
                    reference = input.pow(2)
                else:
                    raise ValueError("unknown PyTorch function: %s" % fn_name)

                encrypted_out = getattr(input_encr, fn_name)(*args_encr,
                                                             **kwargs)

                # Remove argmax output from max / min
                if isinstance(encrypted_out, (list, tuple)):
                    reference = reference[0]
                    encrypted_out = encrypted_out[0]

                self._check(encrypted_out, reference, msg + " in forward")

                # Check backward pass
                grad_output = get_random_test_tensor(max_value=2,
                                                     size=reference.size(),
                                                     is_float=True)
                grad_output_encr = crypten.cryptensor(grad_output)

                # Do not check backward if pytorch backward fails
                try:
                    reference.backward(grad_output)
                except RuntimeError:
                    logging.info("skipped")
                    continue
                encrypted_out.backward(grad_output_encr)

                self._check(input_encr.grad, input.grad, msg + " in backward")
                for i, arg_encr in enumerate(args_encr):
                    if crypten.is_encrypted_tensor(arg_encr):
                        self._check(arg_encr.grad, args[i].grad,
                                    msg + " in backward args")
Exemple #23
0
    def test_autograd_accumulation(self):
        """Tests accumulation in autograd."""

        # define test cases that have nodes with multiple parents:
        def test_case1(input, encr_input):
            output = input.add(1.0).add(input.exp()).sum()
            encr_output = encr_input.add(1.0).add(encr_input.exp()).sum()
            return output, encr_output

        def test_case2(input, encr_input):
            intermediate = input.pow(2.0)  # PyTorch
            output = intermediate.add(1.0).add(intermediate.mul(2.0)).sum()
            encr_intermediate = encr_input.square()  # CrypTen
            encr_output = (
                encr_intermediate.add(1.0).add(encr_intermediate.mul(2.0)).sum()
            )
            return output, encr_output

        def test_case3(input, encr_input):
            intermediate1 = input.pow(2.0)  # PyTorch
            intermediate2 = intermediate1.add(1.0).add(intermediate1.mul(2.0))
            output = intermediate2.pow(2.0).sum()
            encr_intermediate1 = encr_input.square()  # CrypTen
            encr_intermediate2 = encr_intermediate1.add(1.0).add(
                encr_intermediate1.mul(2.0)
            )
            encr_output = encr_intermediate2.square().sum()
            return output, encr_output

        # loop over test cases:
        for idx, test_case in enumerate([test_case1, test_case2, test_case2]):

            # get input tensors:
            input = get_random_test_tensor(size=(12, 5), is_float=True)
            input.requires_grad = True
            encr_input = AutogradCrypTensor(crypten.cryptensor(input))

            # perform multiple forward computations on input that get combined:
            output, encr_output = test_case(input, encr_input)
            self._check(
                encr_output._tensor, output, "forward for test case %d failed" % idx
            )
            self.assertTrue(
                encr_output.requires_grad,
                "requires_grad incorrect for test case %d" % idx,
            )

            # perform backward computation:
            output.backward()
            encr_output.backward()
            self._check(
                encr_input.grad, input.grad, "backward for test case %d failed" % idx
            )

        # test cases in which tensor gets combined with itself:
        for func_name in ["sub", "add", "mul"]:

            # get input tensors:
            input = get_random_test_tensor(size=(12, 5), is_float=True)
            input.requires_grad = True
            encr_input = AutogradCrypTensor(crypten.cryptensor(input))

            # perform forward-backward pass:
            output = getattr(input, func_name)(input).sum()
            encr_output = getattr(encr_input, func_name)(encr_input).sum()
            self._check(encr_output._tensor, output, "forward failed")
            self.assertTrue(encr_output.requires_grad, "requires_grad incorrect")
            output.backward()
            encr_output.backward()
            self._check(encr_input.grad, input.grad, "%s backward failed" % func_name)
Exemple #24
0
    def test_from_pytorch_training(self):
        """Tests the from_pytorch code path for training CrypTen models"""
        import torch.nn as nn
        import torch.nn.functional as F

        class ExampleNet(nn.Module):
            def __init__(self):
                super(ExampleNet, self).__init__()
                self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=1)
                self.fc1 = nn.Linear(16 * 13 * 13, 100)
                self.fc2 = nn.Linear(100, 2)

            def forward(self, x):
                out = self.conv1(x)
                out = F.relu(out)
                out = F.max_pool2d(out, 2)
                out = out.view(out.size(0), -1)
                out = self.fc1(out)
                out = F.relu(out)
                out = self.fc2(out)
                return out

        model_plaintext = ExampleNet()
        batch_size = 5
        x_orig = get_random_test_tensor(size=(batch_size, 1, 28, 28),
                                        is_float=True)
        y_orig = (get_random_test_tensor(size=(batch_size, 1),
                                         is_float=True).gt(0).long())
        y_one_hot = onehot(y_orig, num_targets=2)

        # encrypt training sample:
        x_train = AutogradCrypTensor(crypten.cryptensor(x_orig))
        y_train = crypten.cryptensor(y_one_hot)
        dummy_input = torch.empty((1, 1, 28, 28))

        for loss_name in ["BCELoss", "CrossEntropyLoss", "MSELoss"]:
            # create loss function
            loss = getattr(crypten.nn, loss_name)()

            # create encrypted model
            model = crypten.nn.from_pytorch(model_plaintext, dummy_input)
            model.train()
            model.encrypt()

            num_epochs = 3
            learning_rate = 0.001

            for i in range(num_epochs):
                output = model(x_train)
                if loss_name == "MSELoss":
                    output_norm = output
                else:
                    output_norm = output.softmax(1)
                loss_value = loss(output_norm, y_train)

                # set gradients to "zero"
                model.zero_grad()
                for param in model.parameters():
                    self.assertIsNone(param.grad,
                                      "zero_grad did not reset gradients")

                # perform backward pass:
                loss_value.backward()
                for param in model.parameters():
                    if param.requires_grad:
                        self.assertIsNotNone(
                            param.grad,
                            "required parameter gradient not created")

                # update parameters
                orig_parameters, upd_parameters = {}, {}
                orig_parameters = self._compute_reference_parameters(
                    "", orig_parameters, model, 0)
                model.update_parameters(learning_rate)
                upd_parameters = self._compute_reference_parameters(
                    "", upd_parameters, model, learning_rate)

                # FIX check that any parameter with a non-zero gradient has changed??
                parameter_changed = False
                for name, value in orig_parameters.items():
                    if param.requires_grad and param.grad is not None:
                        unchanged = torch.allclose(upd_parameters[name], value)
                        if unchanged is False:
                            parameter_changed = True
                        self.assertTrue(
                            parameter_changed,
                            "no parameter changed in training step")

                # record initial and current loss
                if i == 0:
                    orig_loss = loss_value.get_plain_text()
                curr_loss = loss_value.get_plain_text()

            # check that the loss has decreased after training
            self.assertTrue(
                curr_loss.item() < orig_loss.item(),
                "loss has not decreased after training",
            )
Exemple #25
0
    def test_custom_module_training(self):
        """Tests training CrypTen models created directly using the crypten.nn.Module"""
        class ExampleNet(crypten.nn.Module):
            def __init__(self):
                super(ExampleNet, self).__init__()
                self.fc1 = crypten.nn.Linear(20, 5)
                self.fc2 = crypten.nn.Linear(5, 2)

            def forward(self, x):
                out = self.fc1(x)
                out = self.fc2(out)
                return out

        model = ExampleNet()

        batch_size = 5
        x_orig = get_random_test_tensor(size=(batch_size, 20), is_float=True)
        y_orig = (get_random_test_tensor(size=(batch_size, 1),
                                         is_float=True).gt(0).long())
        y_one_hot = onehot(y_orig, num_targets=2)

        # encrypt training sample:
        x_train = AutogradCrypTensor(crypten.cryptensor(x_orig))
        y_train = crypten.cryptensor(y_one_hot)

        for loss_name in ["BCELoss", "CrossEntropyLoss", "MSELoss"]:
            # create loss function
            loss = getattr(crypten.nn, loss_name)()

            # create encrypted model
            model.train()
            model.encrypt()

            num_epochs = 3
            learning_rate = 0.001

            for i in range(num_epochs):
                output = model(x_train)
                if loss_name == "MSELoss":
                    output_norm = output
                else:
                    output_norm = output.softmax(1)
                loss_value = loss(output_norm, y_train)

                # set gradients to "zero"
                model.zero_grad()
                for param in model.parameters():
                    self.assertIsNone(param.grad,
                                      "zero_grad did not reset gradients")

                # perform backward pass:
                loss_value.backward()
                for param in model.parameters():
                    if param.requires_grad:
                        self.assertIsNotNone(
                            param.grad,
                            "required parameter gradient not created")

                # update parameters
                orig_parameters, upd_parameters = {}, {}
                orig_parameters = self._compute_reference_parameters(
                    "", orig_parameters, model, 0)
                model.update_parameters(learning_rate)
                upd_parameters = self._compute_reference_parameters(
                    "", upd_parameters, model, learning_rate)

                parameter_changed = False
                for name, value in orig_parameters.items():
                    if param.requires_grad and param.grad is not None:
                        unchanged = torch.allclose(upd_parameters[name], value)
                        if unchanged is False:
                            parameter_changed = True
                        self.assertTrue(
                            parameter_changed,
                            "no parameter changed in training step")

                # record initial and current loss
                if i == 0:
                    orig_loss = loss_value.get_plain_text()
                curr_loss = loss_value.get_plain_text()

            # check that the loss has decreased after training
            self.assertTrue(
                curr_loss.item() < orig_loss.item(),
                "loss has not decreased after training",
            )
Exemple #26
0
    def test_pytorch_modules(self):
        """
        Tests all non-container Modules in crypten.nn that have equivalent
        modules in PyTorch.
        """

        # input arguments for modules and input sizes:
        module_args = {
            "AdaptiveAvgPool2d": (2, ),
            "AvgPool2d": (2, ),
            # "BatchNorm1d": (400,),  # FIXME: Unit tests claim gradients are incorrect.
            # "BatchNorm2d": (3,),
            # "BatchNorm3d": (6,),
            "ConstantPad1d": (3, 1.0),
            "ConstantPad2d": (2, 2.0),
            "ConstantPad3d": (1, 0.0),
            "Conv2d": (3, 6, 5),
            "Linear": (400, 120),
            "MaxPool2d": (2, ),
            "ReLU": (),
            "Softmax": (0, ),
            "LogSoftmax": (0, ),
        }
        input_sizes = {
            "AdaptiveAvgPool2d": (1, 3, 32, 32),
            "AvgPool2d": (1, 3, 32, 32),
            "BatchNorm1d": (8, 400),
            "BatchNorm2d": (8, 3, 32, 32),
            "BatchNorm3d": (8, 6, 32, 32, 4),
            "ConstantPad1d": (9, ),
            "ConstantPad2d": (3, 6),
            "ConstantPad3d": (4, 2, 7),
            "Conv2d": (1, 3, 32, 32),
            "Linear": (1, 400),
            "MaxPool2d": (1, 2, 32, 32),
            "ReLU": (1, 3, 32, 32),
            "Softmax": (5, 5, 5),
            "LogSoftmax": (5, 5, 5),
        }

        # loop over all modules:
        for module_name in module_args.keys():
            for wrap in [True, False]:

                # generate inputs:
                input = get_random_test_tensor(size=input_sizes[module_name],
                                               is_float=True)
                input.requires_grad = True
                encr_input = crypten.cryptensor(input)
                if wrap:
                    encr_input = AutogradCrypTensor(encr_input)

                # create PyTorch module:
                module = getattr(torch.nn,
                                 module_name)(*module_args[module_name])
                module.train()

                # create encrypted CrypTen module:
                encr_module = crypten.nn.from_pytorch(module, input)

                # check that module properly encrypts / decrypts and
                # check that encrypting with current mode properly performs no-op
                for encrypted in [False, True, True, False, True]:
                    encr_module.encrypt(mode=encrypted)
                    if encrypted:
                        self.assertTrue(encr_module.encrypted,
                                        "module not encrypted")
                    else:
                        self.assertFalse(encr_module.encrypted,
                                         "module encrypted")
                    for key in ["weight", "bias"]:
                        if hasattr(module, key):  # if PyTorch model has key
                            encr_param = None

                            # find that key in the crypten.nn.Graph:
                            if isinstance(encr_module, crypten.nn.Graph):
                                for encr_node in encr_module.modules():
                                    if hasattr(encr_node, key):
                                        encr_param = getattr(encr_node, key)
                                        break

                            # or get it from the crypten Module directly:
                            else:
                                encr_param = getattr(encr_module, key)

                            # compare with reference:
                            # NOTE: Because some parameters are initialized randomly
                            # with different values on each process, we only want to
                            # check that they are consistent with source parameter value
                            reference = getattr(module, key)
                            src_reference = comm.get().broadcast(reference,
                                                                 src=0)
                            msg = "parameter %s in %s incorrect" % (
                                key, module_name)
                            if not encrypted:
                                encr_param = crypten.cryptensor(encr_param)
                            self._check(encr_param, src_reference, msg)

                # compare model outputs:
                self.assertTrue(encr_module.training,
                                "training value incorrect")
                reference = module(input)
                encr_output = encr_module(encr_input)
                self._check(encr_output, reference,
                            "%s forward failed" % module_name)

                # test backward pass:
                reference.backward(torch.ones(reference.size()))
                encr_output.backward()
                if wrap:  # you cannot get input gradients on MPCTensor inputs
                    self._check(
                        encr_input.grad,
                        input.grad,
                        "%s backward on input failed" % module_name,
                    )
                else:
                    self.assertFalse(hasattr(encr_input, "grad"))
                for name, param in module.named_parameters():
                    encr_param = getattr(encr_module, name)
                    self._check(
                        encr_param.grad,
                        param.grad,
                        "%s backward on %s failed" % (module_name, name),
                    )
Exemple #27
0
    def test_dropout_module(self):
        """Tests the dropout module"""
        input_size = [3, 3, 3]
        prob_list = [0.2 * x for x in range(1, 5)]
        for module_name in ["Dropout", "Dropout2d", "Dropout3d"]:
            for prob in prob_list:
                for wrap in [True, False]:
                    # generate inputs:
                    input = get_random_test_tensor(size=input_size,
                                                   is_float=True,
                                                   ex_zero=True)
                    input.requires_grad = True
                    encr_input = crypten.cryptensor(input)
                    if wrap:
                        encr_input = AutogradCrypTensor(encr_input)

                    # create PyTorch module:
                    for inplace in [False, True]:
                        module = getattr(torch.nn,
                                         module_name)(prob, inplace=inplace)
                        module.train()

                        # create encrypted CrypTen module:
                        encr_module = crypten.nn.from_pytorch(module, input)

                        # check that module properly encrypts / decrypts and
                        # check that encrypting with current mode properly
                        # performs no-op
                        for encrypted in [False, True, True, False, True]:
                            encr_module.encrypt(mode=encrypted)
                            if encrypted:
                                self.assertTrue(encr_module.encrypted,
                                                "module not encrypted")
                            else:
                                self.assertFalse(encr_module.encrypted,
                                                 "module encrypted")

                        # compare model outputs:
                        # compare the zero and non-zero entries of the encrypted tensor
                        # with a directly constructed plaintext tensor, since we cannot
                        # ensure that the randomization produces the same output
                        # for both encrypted and plaintext tensors
                        self.assertTrue(encr_module.training,
                                        "training value incorrect")
                        encr_output = encr_module(encr_input)
                        plaintext_output = encr_output.get_plain_text()
                        scaled_tensor = input / (1 - prob)
                        reference = plaintext_output.where(
                            plaintext_output == 0, scaled_tensor)
                        self._check(encr_output, reference,
                                    "Dropout forward failed")
                        if inplace:
                            self._check(encr_input, reference,
                                        "In-place dropout failed")
                        else:
                            self._check(encr_input, input,
                                        "Out-of-place dropout failed")
                            # check backward
                            # compare the zero and non-zero entries of the grad in
                            # the encrypted tensor with a directly constructed plaintext
                            # tensor: we do this because we cannot ensure that the
                            # randomization produces the same output for the input
                            # encrypted and plaintext tensors and so we cannot ensure
                            # that the grad in the input tensor is populated identically
                            all_ones = torch.ones(reference.size())
                            ref_grad = plaintext_output.where(
                                plaintext_output == 0, all_ones)
                            ref_grad_input = ref_grad / (1 - prob)
                            encr_output.backward()
                            if (
                                    wrap
                            ):  # you cannot get input gradients on MPCTensor inputs
                                self._check(
                                    encr_input.grad,
                                    ref_grad_input,
                                    "dropout backward on input failed",
                                )

                        # check testing mode for Dropout module
                        encr_module.train(mode=False)
                        encr_output = encr_module(encr_input)
                        result = encr_input.eq(encr_output)
                        result_plaintext = result.get_plain_text().bool()
                        self.assertTrue(result_plaintext.all(),
                                        "dropout failed in test mode")
Exemple #28
0
    def test_batchnorm(self):
        """
        Tests batchnorm forward and backward steps with training on / off.
        """
        # sizes for 1D, 2D, and 3D batchnorm
        # batch_size (dim=0) > 500 and increase tolerance to avoid flaky precision
        # errors in inv_var, which involves sqrt and reciprocal
        sizes = [(800, 5), (500, 8, 15), (600, 10, 3, 15)]
        tolerance = 0.5

        for size in sizes:
            for is_trainning in (False, True):
                tensor = get_random_test_tensor(size=size, is_float=True)
                tensor.requires_grad = True
                encrypted_input = crypten.cryptensor(tensor)

                C = size[1]
                weight = get_random_test_tensor(size=[C],
                                                max_value=1,
                                                is_float=True)
                bias = get_random_test_tensor(size=[C],
                                              max_value=1,
                                              is_float=True)
                weight.requires_grad = True
                bias.requires_grad = True

                # dimensions for mean and variance
                stats_dimensions = list(range(tensor.dim()))
                # perform on C dimension for tensor of shape (N, C, +)
                stats_dimensions.pop(1)

                running_mean = tensor.mean(stats_dimensions).detach()
                running_var = tensor.var(stats_dimensions).detach()
                enc_running_mean = encrypted_input.mean(stats_dimensions)
                enc_running_var = encrypted_input.var(stats_dimensions)

                reference = torch.nn.functional.batch_norm(tensor,
                                                           running_mean,
                                                           running_var,
                                                           weight=weight,
                                                           bias=bias)

                encrypted_input = AutogradCrypTensor(encrypted_input)
                ctx = AutogradContext()
                batch_norm_fn = crypten.gradients.get_grad_fn("batchnorm")
                encrypted_out = batch_norm_fn.forward(
                    ctx,
                    (encrypted_input, weight, bias),
                    training=is_trainning,
                    running_mean=enc_running_mean,
                    running_var=enc_running_var,
                )

                # check forward
                self._check(
                    encrypted_out,
                    reference,
                    "batchnorm forward failed with trainning "
                    f"{is_trainning} on {tensor.dim()}-D",
                    tolerance=tolerance,
                )

                # check backward (input, weight, and bias gradients)
                reference.backward(reference)
                encrypted_grad = batch_norm_fn.backward(ctx, encrypted_out)
                TorchGrad = namedtuple("TorchGrad", ["name", "value"])
                torch_gradients = [
                    TorchGrad("input gradient", tensor.grad),
                    TorchGrad("weight gradient", weight.grad),
                    TorchGrad("bias gradient", bias.grad),
                ]

                for i, torch_gradient in enumerate(torch_gradients):
                    self._check(
                        encrypted_grad[i],
                        torch_gradient.value,
                        f"batchnorm backward {torch_gradient.name} failed"
                        f"with training {is_trainning} on {tensor.dim()}-D",
                        tolerance=tolerance,
                    )