def _jac_t_mat_prod(self, module, g_inp, g_out, mat): """ Note: ----- The Jacobian is *not independent* among the batch dimension, i.e. D z_i = D z_i(x_1, ..., x_B). This structure breaks the computation of the GGN diagonal, for curvature-matrix products it should still work. References: ----------- https://kevinzakka.github.io/2016/09/14/batch_normalization/ https://chrisyeh96.github.io/2017/08/28/deriving-batchnorm-backprop.html """ assert module.affine is True N = self.get_batch(module) x_hat, var = self.get_normalized_input_and_var(module) ivar = 1.0 / (var + module.eps).sqrt() dx_hat = einsum("vni,i->vni", (mat, module.weight)) jac_t_mat = N * dx_hat jac_t_mat -= dx_hat.sum(1).unsqueeze(1).expand_as(jac_t_mat) jac_t_mat -= einsum("ni,vsi,si->vni", (x_hat, dx_hat, x_hat)) jac_t_mat = einsum("vni,i->vni", (jac_t_mat, ivar / N)) return jac_t_mat
def hessian_mat_prod(mat): Hmat = einsum("bi,cbi->cbi", (probs, mat)) - einsum( "bi,bj,cbj->cbi", (probs, probs, mat) ) if module.reduction == "mean": N = module.input0.shape[0] Hmat /= N return Hmat
def _sqrt_hessian(self, module, g_inp, g_out): probs = self.get_probs(module) tau = torchsqrt(probs) V_dim, C_dim = 0, 2 Id = diag_embed(ones_like(probs), dim1=V_dim, dim2=C_dim) Id_tautau = Id - einsum("nv,nc->vnc", tau, tau) sqrt_H = einsum("nc,vnc->vnc", tau, Id_tautau) if module.reduction == "mean": N = module.input0.shape[0] sqrt_H /= sqrt(N) return sqrt_H
def extract_bias_diagonal(module, sqrt): """ `sqrt` must be the backpropagated quantity for DiagH or DiagGGN(MC) """ V_axis, N_axis = 0, 1 bias_diagonal = (einsum("vnchw->vnc", sqrt)**2).sum([V_axis, N_axis]) return bias_diagonal
def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch): if not sum_batch: warn("BatchNorm batch summation disabled." "This may not compute meaningful quantities") x_hat, _ = self.get_normalized_input_and_var(module) equation = "vni,ni->v{}i".format("" if sum_batch is True else "n") operands = [mat, x_hat] return einsum(equation, operands)
def _sum_hessian(self, module, g_inp, g_out): probs = self.get_probs(module) sum_H = diag(probs.sum(0)) - einsum("bi,bj->ij", (probs, probs)) if module.reduction == "mean": N = module.input0.shape[0] sum_H /= N return sum_H
def extract_weight_diagonal(module, input, grad_output): """ input must be the unfolded input to the convolution (see unfold_func) and grad_output the backpropagated gradient """ grad_output_viewed = separate_channels_and_pixels(module, grad_output) AX = einsum("nkl,vnml->vnkm", (input, grad_output_viewed)) weight_diagonal = (AX**2).sum([0, 1]).transpose(0, 1) return weight_diagonal.view_as(module.weight)
def _factors_from_input(self, ext, module): X = convUtils.unfold_func(module)(module.input0) batch = X.size(0) ea_strategy = ext.get_ea_strategy() if ExpectationApproximation.should_average_param_jac(ea_strategy): raise NotImplementedError("Undefined") else: yield einsum("bik,bjk->ij", (X, X)) / batch
def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): """Apply transposed Jacobian of the output w.r.t. the weight.""" d_weight = module.input0 contract = "vno,ni->voi" if sum_batch else "vno,ni->vnoi" return einsum(contract, (mat, d_weight))
def diag_embed_multi_dim(H): """Convert [N, C_in, H_in, ...] to [N, C_in * H_in * ...,], embed into [N, C_in * H_in * ..., C_in * H_in = V], convert back to [V, N, C_in, H_in, ..., V].""" feature_shapes = H.shape[1:] V, N = prod(feature_shapes), H.shape[0] H_diag = diag_embed(H.view(N, V)) # [V, N, C_in, H_in, ...] shape = (V, N, *feature_shapes) return einsum("nic->cni", H_diag).view(shape)
def two_kfacs_to_mat(A, B): """Given A, B, return A ⊗ B.""" assert is_matrix(A) assert is_matrix(B) mat_shape = ( A.shape[0] * B.shape[0], A.shape[1] * B.shape[1], ) mat = einsum("ij,kl->ikjl", (A, B)).contiguous().view(mat_shape) return mat
def sym_mat_inv(mat, shift, truncate=1e-8): """Inverse of a symmetric matrix A -> (A + 𝜆I)⁻¹. Computed by eigenvalue decomposition. Eigenvalues with small absolute values are truncated. """ eigvals, eigvecs = mat.symeig(eigenvectors=True) eigvals.add_(shift) inv_eigvals = 1.0 / eigvals inv_truncate = 1.0 / truncate inv_eigvals.clamp_(min=-inv_truncate, max=inv_truncate) return einsum("ij,j,kj->ik", (eigvecs, inv_eigvals, eigvecs))
def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1): M = mc_samples C = module.input0.shape[1] probs = self.get_probs(module) V_dim = 0 probs_unsqueezed = probs.unsqueeze(V_dim).repeat(M, 1, 1) multi = multinomial(probs, M, replacement=True) classes = one_hot(multi, num_classes=C) classes = einsum("nvc->vnc", classes).float() sqrt_mc_h = (probs_unsqueezed - classes) / sqrt(M) if module.reduction == "mean": N = module.input0.shape[0] sqrt_mc_h /= sqrt(N) return sqrt_mc_h
def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): _, in_c, in_x, in_y = module.input0.size() in_features = in_c * in_x * in_y _, out_c, out_x, out_y = module.output.size() out_features = out_c * out_x * out_y # 1) apply conv_transpose to multiply with W^T result = mat.view(out_c, out_x, out_y, out_features) result = einsum("cxyf->fcxy", (result, )) # result: W^T mat result = self.__apply_jacobian_t_of(module, result).view( out_features, in_features) # 2) transpose: mat^T W result = result.t() # 3) apply conv_transpose result = result.view(in_features, out_c, out_x, out_y) result = self.__apply_jacobian_t_of(module, result) # 4) transpose to obtain W^T mat W return result.view(in_features, in_features).t()
def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): """Use fact that average pooling can be implemented as conv.""" _, channels, in_x, in_y = module.input0.size() in_features = channels * in_x * in_y _, _, out_x, out_y = module.output.size() out_features = channels * out_x * out_y # 1) apply conv_transpose to multiply with W^T result = mat.view(channels, out_x, out_y, out_features) result = einsum("cxyf->fcxy", (result, )).contiguous() result = result.view(out_features * channels, 1, out_x, out_y) # result: W^T mat result = self.__apply_jacobian_t_of(module, result) result = result.view(out_features, in_features) # 2) transpose: mat^T W result = result.t().contiguous() # 3) apply conv_transpose result = result.view(in_features * channels, 1, out_x, out_y) result = self.__apply_jacobian_t_of(module, result) # 4) transpose to obtain W^T mat W return result.view(in_features, in_features).t()
def _jac_mat_prod(self, module, g_inp, g_out, mat): """Apply Jacobian of the output w.r.t. the input.""" d_input = module.weight.data return einsum("oi,vni->vno", (d_input, mat))
def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): x_hat, _ = self.get_normalized_input_and_var(module) return einsum("ni,vi->vni", (x_hat, mat))
def _jac_t_mat_prod(self, module, g_inp, g_out, mat): df_elementwise = self.df(module, g_inp, g_out) return einsum("...,v...->v...", (df_elementwise, mat))
def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): batch, df_flat = self.batch_flat(self.df(module, g_inp, g_out)) return einsum("ni,nj,ij->ij", (df_flat, df_flat, mat)) / batch
def weight(self, ext, module, g_inp, g_out, backproped): N_axis = 0 X, dE_dY = convUtils.get_weight_gradient_factors( module.input0, g_out[0], module) d1 = einsum("nml,nkl->nmk", (dE_dY, X)) return (d1**2).sum(N_axis).view_as(module.weight)
def extract_bias_diagonal(module, backproped): return einsum("vno->o", backproped**2)
def weight(self, ext, module, g_inp, g_out, backproped): X, dE_dY = convUtils.get_weight_gradient_factors( module.input0, g_out[0], module) return einsum("nml,nkl,nmi,nki->n", (dE_dY, X, dE_dY, X))
def kfacmp(mat): assert is_matrix(mat) _, mat_cols = mat.shape mat_reshaped = mat.view(*(col_dims), mat_cols) return einsum(equation, mat_reshaped, *factors).contiguous().view(-1, mat_cols)
def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): jac = module.weight.data return einsum("ik,ij,jl->kl", (jac, mat, jac))
def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): """Apply Jacobian of the output w.r.t. the weight.""" d_weight = module.input0 return einsum("ni,voi->vno", (d_weight, mat))
def bias(self, ext, module, g_inp, g_out, backproped): N_axis = 0 return (einsum("nchw->nc", g_out[0])**2).sum(N_axis)
def extract_weight_diagonal(module, backproped): return einsum("vno,ni->oi", (backproped**2, module.input0**2))
def _factor_from_sqrt(self, module, backproped): sqrt_ggn = backproped sqrt_ggn = convUtils.separate_channels_and_pixels(module, sqrt_ggn) sqrt_ggn = einsum("cbij->cbi", (sqrt_ggn, )) return einsum("cbi,cbl->il", (sqrt_ggn, sqrt_ggn))
def R_mat_prod(mat): """Multiply with the residual: mat Рєњ [РѕЉ_{k} Hz_k(x) ЮЏ┐z_k] mat. Second term of the module input Hessian backpropagation equation. """ return einsum("n...,vn...->vn...", (R_mod, mat))
def make_quadratic_psd(mat): """Make matrix positive semi-definite: A -> AAᵀ.""" mat_squared = einsum("ij,kj->ik", (mat, mat)) shift = self.PSD_KFAC_MIN_EIGVAL * self.torch_eye_like(mat_squared) return mat_squared + shift