def _check_forward_backward( self, func_name, input_tensor, *args, torch_func_name=None, msg=None, **kwargs ): """Checks forward and backward against PyTorch Args: func_name (str): PyTorch/CrypTen function name input_tensor (torch.tensor): primary input args (list): contains arguments for function msg (str): additional message for mismatch kwargs (list): keyword arguments for function """ if msg is None: msg = f"{func_name} grad_fn incorrect" input = input_tensor.clone() input.requires_grad = True input_encr = AutogradCrypTensor(crypten.cryptensor(input), requires_grad=True) for private in [False, True]: input.grad = None input_encr.grad = None args = self._set_grad_to_zero(args) args_encr = self._set_grad_to_zero(list(args), make_private=private) # obtain torch function if torch_func_name is not None: torch_func = self._get_torch_func(torch_func_name) else: torch_func = self._get_torch_func(func_name) reference = torch_func(input, *args, **kwargs) encrypted_out = getattr(input_encr, func_name)(*args_encr, **kwargs) # extract argmax output for max / min with keepdim=False if isinstance(encrypted_out, (list, tuple)): reference = reference[0] encrypted_out = encrypted_out[0] self._check(encrypted_out, reference, msg + " in forward") # check backward pass grad_output = get_random_test_tensor( max_value=2, size=reference.size(), is_float=True ) grad_output_encr = crypten.cryptensor(grad_output) reference.backward(grad_output) encrypted_out.backward(grad_output_encr) self._check(input_encr.grad, input.grad, msg + " in backward") for i, arg_encr in enumerate(args_encr): if crypten.is_encrypted_tensor(arg_encr): self._check(arg_encr.grad, args[i].grad, msg + " in backward args")
def _set_grad_to_zero(self, args, make_private=False): """Sets gradients for args to zero Args: args (list of torch.tensors): contains arguments make_private (bool): encrypt args using AutogradCrypTensor """ args_zero_grad = [] for arg in args: if is_float_tensor(arg) and make_private: arg = AutogradCrypTensor(crypten.cryptensor(arg), requires_grad=True) elif is_float_tensor(arg): arg.requires_grad = True arg.grad = None args_zero_grad.append(arg) return args_zero_grad
def _check_forward_backward(self, fn_name, input_tensor, *args, msg=None, **kwargs): if msg is None: msg = f"{fn_name} grad_fn incorrect" for requires_grad in [True]: # Setup input input = input_tensor.clone() input.requires_grad = requires_grad input_encr = AutogradCrypTensor(crypten.cryptensor(input), requires_grad=requires_grad) for private in [False, True]: input.grad = None input_encr.grad = None # Setup args args_encr = list(args) for i, arg in enumerate(args): if private and is_float_tensor(arg): args_encr[i] = AutogradCrypTensor( crypten.cryptensor(arg), requires_grad=requires_grad) args_encr[i].grad = None # zero grad if is_float_tensor(arg): args[i].requires_grad = requires_grad args[i].grad = None # zero grad # Check forward pass if hasattr(input, fn_name): reference = getattr(input, fn_name)(*args, **kwargs) elif hasattr(F, fn_name): reference = getattr(F, fn_name)(input, *args, **kwargs) elif fn_name == "square": reference = input.pow(2) else: raise ValueError("unknown PyTorch function: %s" % fn_name) encrypted_out = getattr(input_encr, fn_name)(*args_encr, **kwargs) # Remove argmax output from max / min if isinstance(encrypted_out, (list, tuple)): reference = reference[0] encrypted_out = encrypted_out[0] self._check(encrypted_out, reference, msg + " in forward") # Check backward pass grad_output = get_random_test_tensor(max_value=2, size=reference.size(), is_float=True) grad_output_encr = crypten.cryptensor(grad_output) # Do not check backward if pytorch backward fails try: reference.backward(grad_output) except RuntimeError: logging.info("skipped") continue encrypted_out.backward(grad_output_encr) self._check(input_encr.grad, input.grad, msg + " in backward") for i, arg_encr in enumerate(args_encr): if crypten.is_encrypted_tensor(arg_encr): self._check(arg_encr.grad, args[i].grad, msg + " in backward args")