Exemple #1
0
    def _check_forward_backward(
        self, func_name, input_tensor, *args, torch_func_name=None, msg=None, **kwargs
    ):
        """Checks forward and backward against PyTorch

        Args:
            func_name (str): PyTorch/CrypTen function name
            input_tensor (torch.tensor): primary input
            args (list): contains arguments for function
            msg (str): additional message for mismatch
            kwargs (list): keyword arguments for function
        """

        if msg is None:
            msg = f"{func_name} grad_fn incorrect"

        input = input_tensor.clone()
        input.requires_grad = True
        input_encr = AutogradCrypTensor(crypten.cryptensor(input), requires_grad=True)

        for private in [False, True]:
            input.grad = None
            input_encr.grad = None
            args = self._set_grad_to_zero(args)
            args_encr = self._set_grad_to_zero(list(args), make_private=private)

            # obtain torch function
            if torch_func_name is not None:
                torch_func = self._get_torch_func(torch_func_name)
            else:
                torch_func = self._get_torch_func(func_name)

            reference = torch_func(input, *args, **kwargs)
            encrypted_out = getattr(input_encr, func_name)(*args_encr, **kwargs)

            # extract argmax output for max / min with keepdim=False
            if isinstance(encrypted_out, (list, tuple)):
                reference = reference[0]
                encrypted_out = encrypted_out[0]

            self._check(encrypted_out, reference, msg + " in forward")

            # check backward pass
            grad_output = get_random_test_tensor(
                max_value=2, size=reference.size(), is_float=True
            )
            grad_output_encr = crypten.cryptensor(grad_output)
            reference.backward(grad_output)
            encrypted_out.backward(grad_output_encr)

            self._check(input_encr.grad, input.grad, msg + " in backward")
            for i, arg_encr in enumerate(args_encr):
                if crypten.is_encrypted_tensor(arg_encr):
                    self._check(arg_encr.grad, args[i].grad, msg + " in backward args")
Exemple #2
0
    def _set_grad_to_zero(self, args, make_private=False):
        """Sets gradients for args to zero

        Args:
            args (list of torch.tensors): contains arguments
            make_private (bool): encrypt args using AutogradCrypTensor
        """
        args_zero_grad = []

        for arg in args:
            if is_float_tensor(arg) and make_private:
                arg = AutogradCrypTensor(crypten.cryptensor(arg), requires_grad=True)
            elif is_float_tensor(arg):
                arg.requires_grad = True
                arg.grad = None

            args_zero_grad.append(arg)

        return args_zero_grad
Exemple #3
0
    def _check_forward_backward(self,
                                fn_name,
                                input_tensor,
                                *args,
                                msg=None,
                                **kwargs):
        if msg is None:
            msg = f"{fn_name} grad_fn incorrect"

        for requires_grad in [True]:
            # Setup input
            input = input_tensor.clone()
            input.requires_grad = requires_grad
            input_encr = AutogradCrypTensor(crypten.cryptensor(input),
                                            requires_grad=requires_grad)

            for private in [False, True]:
                input.grad = None
                input_encr.grad = None

                # Setup args
                args_encr = list(args)
                for i, arg in enumerate(args):
                    if private and is_float_tensor(arg):
                        args_encr[i] = AutogradCrypTensor(
                            crypten.cryptensor(arg),
                            requires_grad=requires_grad)
                        args_encr[i].grad = None  # zero grad
                    if is_float_tensor(arg):
                        args[i].requires_grad = requires_grad
                        args[i].grad = None  # zero grad

                # Check forward pass
                if hasattr(input, fn_name):
                    reference = getattr(input, fn_name)(*args, **kwargs)
                elif hasattr(F, fn_name):
                    reference = getattr(F, fn_name)(input, *args, **kwargs)
                elif fn_name == "square":
                    reference = input.pow(2)
                else:
                    raise ValueError("unknown PyTorch function: %s" % fn_name)

                encrypted_out = getattr(input_encr, fn_name)(*args_encr,
                                                             **kwargs)

                # Remove argmax output from max / min
                if isinstance(encrypted_out, (list, tuple)):
                    reference = reference[0]
                    encrypted_out = encrypted_out[0]

                self._check(encrypted_out, reference, msg + " in forward")

                # Check backward pass
                grad_output = get_random_test_tensor(max_value=2,
                                                     size=reference.size(),
                                                     is_float=True)
                grad_output_encr = crypten.cryptensor(grad_output)

                # Do not check backward if pytorch backward fails
                try:
                    reference.backward(grad_output)
                except RuntimeError:
                    logging.info("skipped")
                    continue
                encrypted_out.backward(grad_output_encr)

                self._check(input_encr.grad, input.grad, msg + " in backward")
                for i, arg_encr in enumerate(args_encr):
                    if crypten.is_encrypted_tensor(arg_encr):
                        self._check(arg_encr.grad, args[i].grad,
                                    msg + " in backward args")