示例#1
0
    def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
        """
        Note:
        -----
        The Jacobian is *not independent* among the batch dimension, i.e.
        D z_i = D z_i(x_1, ..., x_B).

        This structure breaks the computation of the GGN diagonal,
        for curvature-matrix products it should still work.

        References:
        -----------
        https://kevinzakka.github.io/2016/09/14/batch_normalization/
        https://chrisyeh96.github.io/2017/08/28/deriving-batchnorm-backprop.html
        """
        assert module.affine is True

        N = self.get_batch(module)
        x_hat, var = self.get_normalized_input_and_var(module)
        ivar = 1.0 / (var + module.eps).sqrt()

        dx_hat = einsum("vni,i->vni", (mat, module.weight))

        jac_t_mat = N * dx_hat
        jac_t_mat -= dx_hat.sum(1).unsqueeze(1).expand_as(jac_t_mat)
        jac_t_mat -= einsum("ni,vsi,si->vni", (x_hat, dx_hat, x_hat))
        jac_t_mat = einsum("vni,i->vni", (jac_t_mat, ivar / N))

        return jac_t_mat
示例#2
0
        def hessian_mat_prod(mat):
            Hmat = einsum("bi,cbi->cbi", (probs, mat)) - einsum(
                "bi,bj,cbj->cbi", (probs, probs, mat)
            )

            if module.reduction == "mean":
                N = module.input0.shape[0]
                Hmat /= N

            return Hmat
示例#3
0
    def _sqrt_hessian(self, module, g_inp, g_out):
        probs = self.get_probs(module)
        tau = torchsqrt(probs)
        V_dim, C_dim = 0, 2
        Id = diag_embed(ones_like(probs), dim1=V_dim, dim2=C_dim)
        Id_tautau = Id - einsum("nv,nc->vnc", tau, tau)
        sqrt_H = einsum("nc,vnc->vnc", tau, Id_tautau)

        if module.reduction == "mean":
            N = module.input0.shape[0]
            sqrt_H /= sqrt(N)

        return sqrt_H
示例#4
0
def extract_bias_diagonal(module, sqrt):
    """
    `sqrt` must be the backpropagated quantity for DiagH or DiagGGN(MC)
    """
    V_axis, N_axis = 0, 1
    bias_diagonal = (einsum("vnchw->vnc", sqrt)**2).sum([V_axis, N_axis])
    return bias_diagonal
示例#5
0
 def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch):
     if not sum_batch:
         warn("BatchNorm batch summation disabled."
              "This may not compute meaningful quantities")
     x_hat, _ = self.get_normalized_input_and_var(module)
     equation = "vni,ni->v{}i".format("" if sum_batch is True else "n")
     operands = [mat, x_hat]
     return einsum(equation, operands)
示例#6
0
    def _sum_hessian(self, module, g_inp, g_out):
        probs = self.get_probs(module)
        sum_H = diag(probs.sum(0)) - einsum("bi,bj->ij", (probs, probs))

        if module.reduction == "mean":
            N = module.input0.shape[0]
            sum_H /= N

        return sum_H
示例#7
0
def extract_weight_diagonal(module, input, grad_output):
    """
    input must be the unfolded input to the convolution (see unfold_func)
    and grad_output the backpropagated gradient
    """
    grad_output_viewed = separate_channels_and_pixels(module, grad_output)
    AX = einsum("nkl,vnml->vnkm", (input, grad_output_viewed))
    weight_diagonal = (AX**2).sum([0, 1]).transpose(0, 1)
    return weight_diagonal.view_as(module.weight)
示例#8
0
    def _factors_from_input(self, ext, module):
        X = convUtils.unfold_func(module)(module.input0)
        batch = X.size(0)

        ea_strategy = ext.get_ea_strategy()

        if ExpectationApproximation.should_average_param_jac(ea_strategy):
            raise NotImplementedError("Undefined")
        else:
            yield einsum("bik,bjk->ij", (X, X)) / batch
示例#9
0
 def _weight_jac_t_mat_prod(self,
                            module,
                            g_inp,
                            g_out,
                            mat,
                            sum_batch=True):
     """Apply transposed Jacobian of the output w.r.t. the weight."""
     d_weight = module.input0
     contract = "vno,ni->voi" if sum_batch else "vno,ni->vnoi"
     return einsum(contract, (mat, d_weight))
示例#10
0
        def diag_embed_multi_dim(H):
            """Convert [N, C_in, H_in, ...] to [N, C_in * H_in * ...,],
            embed into [N, C_in * H_in * ..., C_in * H_in = V], convert back
            to [V, N, C_in, H_in, ...,  V]."""
            feature_shapes = H.shape[1:]
            V, N = prod(feature_shapes), H.shape[0]

            H_diag = diag_embed(H.view(N, V))
            # [V, N, C_in, H_in, ...]
            shape = (V, N, *feature_shapes)
            return einsum("nic->cni", H_diag).view(shape)
示例#11
0
def two_kfacs_to_mat(A, B):
    """Given A, B, return A ⊗ B."""
    assert is_matrix(A)
    assert is_matrix(B)

    mat_shape = (
        A.shape[0] * B.shape[0],
        A.shape[1] * B.shape[1],
    )
    mat = einsum("ij,kl->ikjl", (A, B)).contiguous().view(mat_shape)
    return mat
示例#12
0
    def sym_mat_inv(mat, shift, truncate=1e-8):
        """Inverse of a symmetric matrix A -> (A + 𝜆I)⁻¹.

        Computed by eigenvalue decomposition. Eigenvalues with small
        absolute values are truncated.
        """
        eigvals, eigvecs = mat.symeig(eigenvectors=True)
        eigvals.add_(shift)
        inv_eigvals = 1.0 / eigvals
        inv_truncate = 1.0 / truncate
        inv_eigvals.clamp_(min=-inv_truncate, max=inv_truncate)
        return einsum("ij,j,kj->ik", (eigvecs, inv_eigvals, eigvecs))
示例#13
0
    def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1):
        M = mc_samples
        C = module.input0.shape[1]

        probs = self.get_probs(module)
        V_dim = 0
        probs_unsqueezed = probs.unsqueeze(V_dim).repeat(M, 1, 1)

        multi = multinomial(probs, M, replacement=True)
        classes = one_hot(multi, num_classes=C)
        classes = einsum("nvc->vnc", classes).float()

        sqrt_mc_h = (probs_unsqueezed - classes) / sqrt(M)

        if module.reduction == "mean":
            N = module.input0.shape[0]
            sqrt_mc_h /= sqrt(N)

        return sqrt_mc_h
示例#14
0
    def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
        _, in_c, in_x, in_y = module.input0.size()
        in_features = in_c * in_x * in_y
        _, out_c, out_x, out_y = module.output.size()
        out_features = out_c * out_x * out_y

        # 1) apply conv_transpose to multiply with W^T
        result = mat.view(out_c, out_x, out_y, out_features)
        result = einsum("cxyf->fcxy", (result, ))
        # result: W^T mat
        result = self.__apply_jacobian_t_of(module, result).view(
            out_features, in_features)

        # 2) transpose: mat^T W
        result = result.t()

        # 3) apply conv_transpose
        result = result.view(in_features, out_c, out_x, out_y)
        result = self.__apply_jacobian_t_of(module, result)

        # 4) transpose to obtain W^T mat W
        return result.view(in_features, in_features).t()
示例#15
0
    def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
        """Use fact that average pooling can be implemented as conv."""
        _, channels, in_x, in_y = module.input0.size()
        in_features = channels * in_x * in_y
        _, _, out_x, out_y = module.output.size()
        out_features = channels * out_x * out_y

        # 1) apply conv_transpose to multiply with W^T
        result = mat.view(channels, out_x, out_y, out_features)
        result = einsum("cxyf->fcxy", (result, )).contiguous()
        result = result.view(out_features * channels, 1, out_x, out_y)
        # result: W^T mat
        result = self.__apply_jacobian_t_of(module, result)
        result = result.view(out_features, in_features)

        # 2) transpose: mat^T W
        result = result.t().contiguous()

        # 3) apply conv_transpose
        result = result.view(in_features * channels, 1, out_x, out_y)
        result = self.__apply_jacobian_t_of(module, result)

        # 4) transpose to obtain W^T mat W
        return result.view(in_features, in_features).t()
示例#16
0
 def _jac_mat_prod(self, module, g_inp, g_out, mat):
     """Apply Jacobian of the output w.r.t. the input."""
     d_input = module.weight.data
     return einsum("oi,vni->vno", (d_input, mat))
示例#17
0
 def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
     x_hat, _ = self.get_normalized_input_and_var(module)
     return einsum("ni,vi->vni", (x_hat, mat))
示例#18
0
 def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
     df_elementwise = self.df(module, g_inp, g_out)
     return einsum("...,v...->v...", (df_elementwise, mat))
示例#19
0
 def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
     batch, df_flat = self.batch_flat(self.df(module, g_inp, g_out))
     return einsum("ni,nj,ij->ij", (df_flat, df_flat, mat)) / batch
示例#20
0
 def weight(self, ext, module, g_inp, g_out, backproped):
     N_axis = 0
     X, dE_dY = convUtils.get_weight_gradient_factors(
         module.input0, g_out[0], module)
     d1 = einsum("nml,nkl->nmk", (dE_dY, X))
     return (d1**2).sum(N_axis).view_as(module.weight)
示例#21
0
def extract_bias_diagonal(module, backproped):
    return einsum("vno->o", backproped**2)
示例#22
0
 def weight(self, ext, module, g_inp, g_out, backproped):
     X, dE_dY = convUtils.get_weight_gradient_factors(
         module.input0, g_out[0], module)
     return einsum("nml,nkl,nmi,nki->n", (dE_dY, X, dE_dY, X))
示例#23
0
 def kfacmp(mat):
     assert is_matrix(mat)
     _, mat_cols = mat.shape
     mat_reshaped = mat.view(*(col_dims), mat_cols)
     return einsum(equation, mat_reshaped,
                   *factors).contiguous().view(-1, mat_cols)
示例#24
0
 def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
     jac = module.weight.data
     return einsum("ik,ij,jl->kl", (jac, mat, jac))
示例#25
0
 def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
     """Apply Jacobian of the output w.r.t. the weight."""
     d_weight = module.input0
     return einsum("ni,voi->vno", (d_weight, mat))
示例#26
0
 def bias(self, ext, module, g_inp, g_out, backproped):
     N_axis = 0
     return (einsum("nchw->nc", g_out[0])**2).sum(N_axis)
示例#27
0
def extract_weight_diagonal(module, backproped):
    return einsum("vno,ni->oi", (backproped**2, module.input0**2))
示例#28
0
    def _factor_from_sqrt(self, module, backproped):
        sqrt_ggn = backproped

        sqrt_ggn = convUtils.separate_channels_and_pixels(module, sqrt_ggn)
        sqrt_ggn = einsum("cbij->cbi", (sqrt_ggn, ))
        return einsum("cbi,cbl->il", (sqrt_ggn, sqrt_ggn))
示例#29
0
            def R_mat_prod(mat):
                """Multiply with the residual: mat Рєњ [РѕЉ_{k} Hz_k(x) ­ЮЏ┐z_k] mat.

                Second term of the module input Hessian backpropagation equation.
                """
                return einsum("n...,vn...->vn...", (R_mod, mat))
示例#30
0
 def make_quadratic_psd(mat):
     """Make matrix positive semi-definite: A -> AAᵀ."""
     mat_squared = einsum("ij,kj->ik", (mat, mat))
     shift = self.PSD_KFAC_MIN_EIGVAL * self.torch_eye_like(mat_squared)
     return mat_squared + shift