def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): if module.groups != 1: raise NotImplementedError( "Groups greater than 1 are not supported yet") V = mat.shape[0] N, C_out = module.output_shape[0], module.output_shape[1] C_in = module.input0_shape[1] C_in_axis = 1 N_axis = 0 dims = self.dim_text repeat_pattern = [1, C_in] + [1 for _ in range(self.conv_dims)] mat = eingroup("v,n,c,{}->vn,c,{}".format(dims, dims), mat) mat = mat.repeat(*repeat_pattern) mat = eingroup("a,b,{}->ab,{}".format(dims, dims), mat) mat = mat.unsqueeze(C_in_axis) repeat_pattern = [1, V] + [1 for _ in range(self.conv_dims)] input = eingroup("n,c,{}->nc,{}".format(dims, dims), module.input0) input = input.unsqueeze(N_axis) input = input.repeat(*repeat_pattern) grad_weight = self.conv_func( input, mat, bias=None, stride=module.dilation, padding=module.padding, dilation=module.stride, groups=C_in * N * V, ).squeeze(0) for dim in range(self.conv_dims): axis = dim + 1 size = module.weight.shape[2 + dim] grad_weight = grad_weight.narrow(axis, 0, size) sum_dim = "" if sum_batch else "n," eingroup_eq = "vnio,{}->v,{}o,i,{}".format(dims, sum_dim, dims) return eingroup(eingroup_eq, grad_weight, dim={ "v": V, "n": N, "i": C_in, "o": C_out })
def get_convtranspose3d_weight_gradient_factors(input, grad_out, module): N, C_in = input.shape[0], input.shape[1] kernel_size = module.kernel_size kernel_size_numel = int(torch.prod(torch.Tensor(kernel_size))) X = unfold_by_conv_transpose(input, module).reshape(N, C_in * kernel_size_numel, -1) dE_dY = eingroup("n,c,d,h,w->n,c,dhw", grad_out) return X, dE_dY
def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): """Unintuitive, but faster due to convolution.""" V = mat.shape[0] N, C_out, _, _ = module.output_shape _, C_in, _, _ = module.input0_shape mat = eingroup("v,n,c,w,h->vn,c,w,h", mat).repeat(1, C_in, 1, 1) C_in_axis = 1 # a,b represent the combined/repeated dimensions mat = eingroup("a,b,w,h->ab,w,h", mat).unsqueeze(C_in_axis) N_axis = 0 input = eingroup("n,c,h,w->nc,h,w", module.input0).unsqueeze(N_axis) input = input.repeat(1, V, 1, 1) grad_weight = conv2d( input, mat, bias=None, stride=module.dilation, padding=module.padding, dilation=module.stride, groups=C_in * N * V, ).squeeze(0) K_H_axis, K_W_axis = 1, 2 _, _, K_H, K_W = module.weight.shape grad_weight = grad_weight.narrow(K_H_axis, 0, K_H).narrow(K_W_axis, 0, K_W) eingroup_eq = "vnio,x,y->v,{}o,i,x,y".format("" if sum_batch else "n,") return eingroup(eingroup_eq, grad_weight, dim={ "v": V, "n": N, "i": C_in, "o": C_out })
def __pool_idx_for_jac(self, module, V): """Manipulated pooling indices ready-to-use in jac(t).""" pool_idx = self.get_pooling_idx(module) V_axis = 0 return ( eingroup("n,c,h,w->n,c,hw", pool_idx) .unsqueeze(V_axis) .expand(V, -1, -1, -1) )
def _jac_mat_prod(self, module, g_inp, g_out, mat): mat_as_conv = eingroup("v,n,c,h,w->vn,c,h,w", mat) jmp_as_conv = conv2d( mat_as_conv, module.weight.data, stride=module.stride, padding=module.padding, dilation=module.dilation, groups=module.groups, ) return self.reshape_like_output(jmp_as_conv, module)
def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): if module.groups != 1: raise NotImplementedError( "Groups greater than 1 are not supported yet") dims = self.dim_text dims_joined = dims.replace(",", "") jac_mat = eingroup("v,o,i,{}->v,o,i{}".format(dims, dims_joined), mat) X = self.get_unfolded_input(module) jac_mat = einsum("nij,vki->vnkj", X, jac_mat) return self.reshape_like_output(jac_mat, module)
def _jac_mat_prod(self, module, g_inp, g_out, mat): dims = self.dim_text mat_as_conv = eingroup("v,n,c,{}->vn,c,{}".format(dims, dims), mat) jmp_as_conv = self.conv_func( mat_as_conv, module.weight.data, stride=module.stride, padding=module.padding, dilation=module.dilation, groups=module.groups, ) return self.reshape_like_output(jmp_as_conv, module)
def __make_single_channel(self, mat, module): """Create fake single-channel images, grouping batch, class and channel dimension.""" result = eingroup("v,n,c,w,h->vnc,w,h", mat) C_axis = 1 return result.unsqueeze(C_axis)
def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): jac_mat = eingroup("v,o,i,h,w->v,o,ihw", mat) X = self.get_unfolded_input(module) jac_mat = einsum("nij,vki->vnkj", (X, jac_mat)) return self.reshape_like_output(jac_mat, module)
def _jac_t_mat_prod(self, module, g_inp, g_out, mat): mat_as_conv = eingroup("v,n,c,h,w->vn,c,h,w", mat) jmp_as_conv = self.__jac_t(module, mat_as_conv) return self.reshape_like_input(jmp_as_conv, module)
def _jac_t_mat_prod(self, module, g_inp, g_out, mat): mat_as_pool = eingroup("v,n,c,h,w->v,n,c,hw", mat) jmp_as_pool = self.__apply_jacobian_t_of(module, mat_as_pool) return self.view_like_input(jmp_as_pool, module)
def _jac_mat_prod(self, module, g_inp, g_out, mat): mat = eingroup("v,n,c,h,w->vn,c,h,w", mat) pad_mat = functional.pad(mat, module.padding, "constant", module.value) return self.reshape_like_output(pad_mat, module)
def _jac_t_mat_prod(self, module, g_inp, g_out, mat): dims = self.dim_text mat_as_conv = eingroup("v,n,c,{}->vn,c,{}".format(dims, dims), mat) jmp_as_conv = self.__jac_t(module, mat_as_conv) return self.reshape_like_input(jmp_as_conv, module)
def get_weight_gradient_factors(input, grad_out, module): # shape [N, C_in * K_x * K_y, H_out * W_out] X = unfold_func(module)(input) dE_dY = eingroup("n,c,h,w->n,c,hw", grad_out) return X, dE_dY
def _jac_mat_prod(self, module, g_inp, g_out, mat): mat_as_conv = eingroup("v,n,c,{0}->vn,c,{0}".format(self.dim_text), mat) jmp_as_conv = self.__jac(module, mat_as_conv) return self.reshape_like_output(jmp_as_conv, module)
def separate_channels_and_pixels(module, tensor): """Reshape (V, N, C, H, W) into (V, N, C, H * W).""" return eingroup("v,n,c,h,w->v,n,c,hw", tensor)
def get_conv3d_weight_gradient_factors(input, grad_out, module): # shape [N, C_in * K_x * K_y * K_z, D_out * H_out * W_out] X = unfold_by_conv(input, module) dE_dY = eingroup("n,c,d,h,w->n,c,dhw", grad_out) return X, dE_dY