예제 #1
0
 def f():
     r = transposed_jacobian_vector_product(out,
                                            module.bias,
                                            vin,
                                            detach=False)[0].contiguous()
     if vin.is_cuda:
         torch.cuda.synchronize()
     return r
예제 #2
0
    def param_jac_t_vec_prod(self, name, vec, sum_batch):
        input, output, named_params = self.problem.forward_pass()
        param = named_params[name]

        if sum_batch:
            return transposed_jacobian_vector_product(output, param, vec)[0]
        else:
            N = input.shape[0]

            sample_outputs = [output[n] for n in range(N)]
            sample_vecs = [vec[n] for n in range(N)]

            jac_t_sample_prods = [
                transposed_jacobian_vector_product(n_out, param, n_vec)[0]
                for n_out, n_vec in zip(sample_outputs, sample_vecs)
            ]

            return torch.stack(jac_t_sample_prods)
예제 #3
0
    def jac_t_vec_prod(self,
                       vec: Tensor,
                       subsampling=None) -> Tensor:  # noqa: D102
        input, output, _ = self.problem.forward_pass(input_requires_grad=True)

        if subsampling is None:
            return transposed_jacobian_vector_product(output, input, vec)[0]
        else:
            # for each sample, multiply by full input Jacobian, slice out result:
            # ( (∂ output[n] / ∂ input)ᵀ v[n] )[n]
            batch_axis = 0
            output = subsample(output, dim=batch_axis, subsampling=subsampling)
            output = output.split(1, dim=batch_axis)
            vec = vec.split(1, dim=batch_axis)

            vjps: List[Tensor] = []
            for sample_idx, out, v in zip(subsampling, output, vec):
                vjp = transposed_jacobian_vector_product(out, input, v)[0]
                vjp = subsample(vjp, dim=batch_axis, subsampling=[sample_idx])
                vjps.append(vjp)

            return cat(vjps, dim=batch_axis)
예제 #4
0
        def sample_jac_t_mat_prod(layer, sample, mat):
            result = torch.zeros(sample.numel(), mat.size(1))

            sample.requires_grad = True
            output = layer(sample)

            for col in range(mat.size(1)):
                column = mat[:, col].reshape(output.shape)
                result[:, col] = transposed_jacobian_vector_product(
                    [output], [sample], [column],
                    retain_graph=True)[0].reshape(-1)

            return result
예제 #5
0
            def sample_jac_t_mat_prod(sample_idx, mat):
                sample, output, _ = self.problem.forward_pass(
                    input_requires_grad=True, sample_idx=sample_idx)

                result = torch.zeros(sample.numel(),
                                     mat.size(1),
                                     device=sample.device)

                for col in range(mat.size(1)):
                    column = mat[:, col].reshape(output.shape)
                    result[:, col] = transposed_jacobian_vector_product(
                        [output], [sample], [column],
                        retain_graph=True)[0].reshape(-1)

                return result
예제 #6
0
    def _param_vjp(
        self,
        name: str,
        vec: Tensor,
        sum_batch: bool,
        axis_batch: int = 0,
        subsampling: List[int] = None,
    ) -> Tensor:
        """Compute the product of jac_t and the given vector.

        Args:
            name: name of parameter for derivative
            vec: vectors which to multiply
            sum_batch: whether to sum along batch axis
            axis_batch: index of batch axis. Defaults to 0.
            subsampling: Indices of active samples. Default: ``None`` (all).

        Returns:
            product of jac_t and vec
        """
        input, output, named_params = self.problem.forward_pass()
        param = named_params[name]

        samples = range(
            input.shape[axis_batch]) if subsampling is None else subsampling
        sample_outputs = output.split(1, dim=axis_batch)
        sample_vecs = vec.split(1, dim=axis_batch)

        jac_t_sample_prods = stack([
            transposed_jacobian_vector_product(sample_outputs[n], param,
                                               vec_n)[0]
            for n, vec_n in zip(samples, sample_vecs)
        ], )

        if sum_batch:
            jac_t_sample_prods = jac_t_sample_prods.sum(0)

        return jac_t_sample_prods
예제 #7
0
        def test_input_jacobian_transpose(self):
            """Test multiplication by the transposed input Jacobian.

            Compare with result of L-operator.
            """
            # create input
            cvp_in = self._create_input()
            cvp_in.requires_grad = True
            cvp_layer = self._create_cvp_layer()
            # skip for Sequential:
            if isinstance(cvp_layer, (Flatten, CVPSequential)):
                return
            cvp_out = cvp_layer(cvp_in)
            for _ in range(self.NUM_HVP):
                v = torch.randn(cvp_out.numel(),
                                requires_grad=False).to(self.DEVICE)
                JTv = cvp_layer._input_jacobian_transpose(v)
                # compute via L-operator
                (result, ) = transposed_jacobian_vector_product(
                    cvp_out, cvp_in, v.view(cvp_out.size()))
                assert torch.allclose(JTv,
                                      result.contiguous().view(-1),
                                      atol=self.ATOL,
                                      rtol=self.RTOL)
예제 #8
0
 def jac_t_vec_prod(self, vec):
     input, output, _ = self.problem.forward_pass(input_requires_grad=True)
     return transposed_jacobian_vector_product(output, input, vec)[0]