def _compute_model_jacobians(self): """Compute Jacobians with respect to each model parameter If last_layer_only is True, this computes the jacobian only with respect to the parameters of the last layer of the model. """ Z = self.logits.split(1, dim=-1) # Store partial Jacobian for each parameter jacobians = {} # dL/dW_i = sum_j (dL/dP_j * dP_j/dW_i) with crypten.no_grad(): # TODO: Async / parallelize this for z in Z: z.backward(torch.ones(z.size()), retain_graph=True) params = self.model.parameters() for param in params: grad = param.grad.flatten().unsqueeze(-1) # Accumulate partial gradients: dL/dZ_j * dP_j/dW_i if param in jacobians.keys(): jacobians[param] = torch.cat([jacobians[param], grad], dim=-1) else: jacobians[param] = grad param.grad = None # Reset grad for next p_j.backward() return jacobians
def _backward_layer_estimation(self, grad_output=None): with crypten.no_grad(): # Find dLdW for last layer weights dLdW = self._compute_last_layer_grad(grad_output=grad_output) # Run backprop in plaintext if self.is_feature_src(): dLdZ = self._solve_dLdZ(dLdW) self.logits.backward(dLdZ)
def _add_dp_if_necessary(self, grad): if self.noise_magnitude is None or self.noise_magnitude == 0.0: return grad # Determine noise generation function generate_noise = (self._generate_noise_from_src if self.noise_src else self._generate_noise_no_src) noise = generate_noise(grad.size()) with crypten.no_grad(): grad += noise return grad
def backward(self, grad_output=None): protocol = cfg.nn.dpsmpc.protocol with crypten.no_grad(): if protocol == "full_jacobian": self._backward_full_jacobian(grad_output=grad_output) raise NotImplementedError( "DPS protocol full_jacobian must be fixed before use.") elif protocol == "layer_estimation": with torch.no_grad(): self._backward_layer_estimation(grad_output=grad_output) else: raise ValueError( f"Unrecognized DPSplitMPC backward protocol: {protocol}")
def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ with crypten.no_grad(): loss = None if closure is not None: with crypten.enable_grad(): loss = closure() for group in self.param_groups: weight_decay = group["weight_decay"] momentum = group["momentum"] dampening = group["dampening"] nesterov = group["nesterov"] for p in group["params"]: if p.grad is None: continue d_p = p.grad if weight_decay != 0: d_p = d_p.add(p.mul(weight_decay)) if momentum != 0: param_state = self.state[id(p)] if "momentum_buffer" not in param_state: buf = param_state["momentum_buffer"] = d_p.clone( ).detach() else: buf = param_state["momentum_buffer"] buf.mul_(momentum).add_(d_p.mul(1 - dampening)) if nesterov: d_p = d_p.add(buf.mul(momentum)) else: d_p = buf p.sub_(d_p.mul(group["lr"])) return loss
def test_batchnorm(self): """ Tests batchnorm forward and backward steps with training on / off. """ tolerance = 0.1 sizes = [(8, 5), (16, 3), (32, 5), (8, 6, 4), (8, 4, 3, 5)] for size in sizes: for is_training in (False, True): # sample input data, weight, and bias: tensor = get_random_test_tensor(size=size, is_float=True) encrypted_input = crypten.cryptensor(tensor) C = size[1] weight = get_random_test_tensor(size=[C], max_value=1, is_float=True) bias = get_random_test_tensor(size=[C], max_value=1, is_float=True) weight.requires_grad = True bias.requires_grad = True # dimensions over which means and variances are computed: stats_dimensions = list(range(tensor.dim())) stats_dimensions.pop(1) # dummy running mean and variance: running_mean = tensor.mean(stats_dimensions).detach() running_var = tensor.var(stats_dimensions).detach() enc_running_mean = crypten.cryptensor(running_mean) enc_running_var = crypten.cryptensor(running_var) # compute reference output: tensor.requires_grad = True reference = torch.nn.functional.batch_norm( tensor, running_mean, running_var, weight=weight, bias=bias, training=is_training, ) # compute CrypTen output: encrypted_input.requires_grad = True ctx = AutogradContext() batch_norm_fn, _ = crypten.gradients.get_grad_fn("batchnorm") with crypten.no_grad(): encrypted_out = batch_norm_fn.forward( ctx, encrypted_input, weight, bias, training=is_training, running_mean=enc_running_mean, running_var=enc_running_var, ) # check forward self._check( encrypted_out, reference, "batchnorm forward failed with training " f"{is_training} on {tensor.dim()}-D", tolerance=tolerance, ) # check backward (input, weight, and bias gradients): reference.backward(reference) with crypten.no_grad(): encrypted_grad = batch_norm_fn.backward(ctx, encrypted_out) TorchGrad = namedtuple("TorchGrad", ["name", "value"]) torch_gradients = [ TorchGrad("input gradient", tensor.grad), TorchGrad("weight gradient", weight.grad), TorchGrad("bias gradient", bias.grad), ] for i, torch_gradient in enumerate(torch_gradients): self._check( encrypted_grad[i], torch_gradient.value, f"batchnorm backward {torch_gradient.name} failed " f"with training {is_training} on {tensor.dim()}-D", tolerance=tolerance, )