def f(): r = transposed_jacobian_vector_product(out, module.bias, vin, detach=False)[0].contiguous() if vin.is_cuda: torch.cuda.synchronize() return r
def param_jac_t_vec_prod(self, name, vec, sum_batch): input, output, named_params = self.problem.forward_pass() param = named_params[name] if sum_batch: return transposed_jacobian_vector_product(output, param, vec)[0] else: N = input.shape[0] sample_outputs = [output[n] for n in range(N)] sample_vecs = [vec[n] for n in range(N)] jac_t_sample_prods = [ transposed_jacobian_vector_product(n_out, param, n_vec)[0] for n_out, n_vec in zip(sample_outputs, sample_vecs) ] return torch.stack(jac_t_sample_prods)
def jac_t_vec_prod(self, vec: Tensor, subsampling=None) -> Tensor: # noqa: D102 input, output, _ = self.problem.forward_pass(input_requires_grad=True) if subsampling is None: return transposed_jacobian_vector_product(output, input, vec)[0] else: # for each sample, multiply by full input Jacobian, slice out result: # ( (∂ output[n] / ∂ input)ᵀ v[n] )[n] batch_axis = 0 output = subsample(output, dim=batch_axis, subsampling=subsampling) output = output.split(1, dim=batch_axis) vec = vec.split(1, dim=batch_axis) vjps: List[Tensor] = [] for sample_idx, out, v in zip(subsampling, output, vec): vjp = transposed_jacobian_vector_product(out, input, v)[0] vjp = subsample(vjp, dim=batch_axis, subsampling=[sample_idx]) vjps.append(vjp) return cat(vjps, dim=batch_axis)
def sample_jac_t_mat_prod(layer, sample, mat): result = torch.zeros(sample.numel(), mat.size(1)) sample.requires_grad = True output = layer(sample) for col in range(mat.size(1)): column = mat[:, col].reshape(output.shape) result[:, col] = transposed_jacobian_vector_product( [output], [sample], [column], retain_graph=True)[0].reshape(-1) return result
def sample_jac_t_mat_prod(sample_idx, mat): sample, output, _ = self.problem.forward_pass( input_requires_grad=True, sample_idx=sample_idx) result = torch.zeros(sample.numel(), mat.size(1), device=sample.device) for col in range(mat.size(1)): column = mat[:, col].reshape(output.shape) result[:, col] = transposed_jacobian_vector_product( [output], [sample], [column], retain_graph=True)[0].reshape(-1) return result
def _param_vjp( self, name: str, vec: Tensor, sum_batch: bool, axis_batch: int = 0, subsampling: List[int] = None, ) -> Tensor: """Compute the product of jac_t and the given vector. Args: name: name of parameter for derivative vec: vectors which to multiply sum_batch: whether to sum along batch axis axis_batch: index of batch axis. Defaults to 0. subsampling: Indices of active samples. Default: ``None`` (all). Returns: product of jac_t and vec """ input, output, named_params = self.problem.forward_pass() param = named_params[name] samples = range( input.shape[axis_batch]) if subsampling is None else subsampling sample_outputs = output.split(1, dim=axis_batch) sample_vecs = vec.split(1, dim=axis_batch) jac_t_sample_prods = stack([ transposed_jacobian_vector_product(sample_outputs[n], param, vec_n)[0] for n, vec_n in zip(samples, sample_vecs) ], ) if sum_batch: jac_t_sample_prods = jac_t_sample_prods.sum(0) return jac_t_sample_prods
def test_input_jacobian_transpose(self): """Test multiplication by the transposed input Jacobian. Compare with result of L-operator. """ # create input cvp_in = self._create_input() cvp_in.requires_grad = True cvp_layer = self._create_cvp_layer() # skip for Sequential: if isinstance(cvp_layer, (Flatten, CVPSequential)): return cvp_out = cvp_layer(cvp_in) for _ in range(self.NUM_HVP): v = torch.randn(cvp_out.numel(), requires_grad=False).to(self.DEVICE) JTv = cvp_layer._input_jacobian_transpose(v) # compute via L-operator (result, ) = transposed_jacobian_vector_product( cvp_out, cvp_in, v.view(cvp_out.size())) assert torch.allclose(JTv, result.contiguous().view(-1), atol=self.ATOL, rtol=self.RTOL)
def jac_t_vec_prod(self, vec): input, output, _ = self.problem.forward_pass(input_requires_grad=True) return transposed_jacobian_vector_product(output, input, vec)[0]