def forward_pass( self, subsampling: List[int] = None) -> Tuple[Tensor, Tensor, Tensor]: """Do a forward pass. Return input, output, and parameters. If sub-sampling is None, the forward pass is calculated on the whole batch. Args: subsampling: Indices of selected samples. Default: ``None`` (all samples). Returns: input, output, and loss of the forward pass """ input = self.input.clone() target = self.target.clone() if subsampling is not None: batch_axis = 0 input = subsample(self.input, dim=batch_axis, subsampling=subsampling) target = subsample(self.target, dim=batch_axis, subsampling=subsampling) output = self.model(input) loss = self.loss_function(output, target) return input, output, loss
def _forward_pass( module: LSTM, mat: Tensor, subsampling: List[int] = None ) -> Tuple[Tensor, Tensor, Tensor]: """This performs an additional forward pass and returns the hidden variables. This is important because the PyTorch implementation does not grant access to some of the hidden variables. Those are computed and returned. See also forward pass in class docstring. Args: module: module mat: matrix, used to extract device and shapes. subsampling: Indices of active samples. Defaults to ``None`` (all samples). Returns: ifgo, c, c_tanh (all in format ``[N, T, ...]``) """ _, N, T, _ = mat.shape H: int = module.hidden_size H0: int = 0 * H H1: int = 1 * H H2: int = 2 * H H3: int = 3 * H H4: int = 4 * H # forward pass and save i, f, g, o, c, c_tanh-> ifgo, c, c_tanh ifgo: Tensor = zeros(N, T, 4 * H, device=mat.device, dtype=mat.dtype) c: Tensor = zeros(N, T, H, device=mat.device, dtype=mat.dtype) c_tanh: Tensor = zeros(N, T, H, device=mat.device, dtype=mat.dtype) input0 = subsample(module.input0, dim=0, subsampling=subsampling) output = subsample(module.output, dim=0, subsampling=subsampling) for t in range(T): ifgo[:, t] = ( einsum("hi,ni->nh", module.weight_ih_l0, input0[:, t]) + module.bias_ih_l0 + module.bias_hh_l0 ) if t != 0: ifgo[:, t] += einsum("hg,ng->nh", module.weight_hh_l0, output[:, t - 1]) ifgo[:, t, H0:H1] = sigmoid(ifgo[:, t, H0:H1]) ifgo[:, t, H1:H2] = sigmoid(ifgo[:, t, H1:H2]) ifgo[:, t, H2:H3] = tanh(ifgo[:, t, H2:H3]) ifgo[:, t, H3:H4] = sigmoid(ifgo[:, t, H3:H4]) c[:, t] = ifgo[:, t, H0:H1] * ifgo[:, t, H2:H3] if t != 0: c[:, t] += ifgo[:, t, H1:H2] * c[:, t - 1] c_tanh[:, t] = tanh(c[:, t]) return ifgo, c, c_tanh
def _weight_ih_l0_jac_t_mat_prod( self, module: RNN, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. weight_ih_l0. Args: module: extended module g_inp: input gradient g_out: output gradient mat: matrix to multiply sum_batch: Whether to sum along batch axis. Defaults to True. subsampling: Indices of active samples. Defaults to ``None`` (all samples). Returns: product """ self._check_parameters(module) return einsum( f"vnth,ntj->v{'' if sum_batch else 'n'}hj", self._a_jac_t_mat_prod(module, module.weight_hh_l0, mat, subsampling), subsample(module.input0, dim=0, subsampling=subsampling), )
def _weight_hh_l0_jac_t_mat_prod( self, module: RNN, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. weight_hh_l0. Args: module: extended module g_inp: input gradient g_out: output gradient mat: matrix to multiply sum_batch: Whether to sum along batch axis. Defaults to True. subsampling: Indices of active samples. Defaults to ``None`` (all samples). Returns: product """ self._check_parameters(module) _, N, _, H = mat.shape output = subsample(module.output, dim=0, subsampling=subsampling) single_step = zeros(N, 1, H, device=mat.device, dtype=mat.dtype) output_shifted = cat([single_step, output[:, :-1]], dim=1) return einsum( f"vnth,ntk->v{'' if sum_batch else 'n'}hk", self._a_jac_t_mat_prod(module, module.weight_hh_l0, mat, subsampling), output_shifted, )
def param_function( ext: BatchGrad, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], bpQuantities: None, ) -> Tensor: """Calculates batch_grad with the help of derivatives object. Args: ext: extension that is used module: module that performed forward pass g_inp: input gradient tensors g_out: output gradient tensors bpQuantities: additional quantities for second order Returns: Scaled individual gradients """ subsampling = ext.get_subsampling() batch_axis = 0 return self._derivatives.param_mjp( param_str, module, g_inp, g_out, subsample(g_out[0], dim=batch_axis, subsampling=subsampling), sum_batch=False, subsampling=subsampling, )
def _weight_jac_t_mat_prod( self, module: Union[ConvTranspose1d, ConvTranspose2d, ConvTranspose3d], g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, subsampling: List[int] = None, ) -> Tensor: V = mat.shape[0] G = module.groups C_in = module.input0.shape[1] N = module.output.shape[0] if subsampling is None else len(subsampling) C_out = module.output.shape[1] mat_reshape = mat.reshape(V, N, G, C_out // G, *module.output.shape[2:]) u = unfold_by_conv_transpose( subsample(module.input0, subsampling=subsampling), module).reshape(N, G, C_in // G, *module.weight.shape[2:], *module.output.shape[2:]) dims_kern = "xyz"[:self.conv_dims] dims_data = "abc"[:self.conv_dims] result_str = ("vgio" if sum_batch else "vngio") + dims_kern equation = f"ngi{dims_kern}{dims_data},vngo{dims_data}->{result_str}" final_shape = ((V, *module.weight.shape) if sum_batch else (V, N, *module.weight.shape)) return einsum(equation, u, mat_reshape).reshape(final_shape)
def _weight_jac_t_mat_prod( self, module: Linear, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor, sum_batch: int = True, subsampling: List[int] = None, ) -> Tensor: """Batch-apply transposed Jacobian of the output w.r.t. the weight. Args: module: Linear layer. g_inp: Gradients w.r.t. module input. Not required by the implementation. g_out: Gradients w.r.t. module output. Not required by the implementation. mat: Batch of ``V`` vectors of same shape as the layer output (``[N, *, out_features]``) to which the transposed output-input Jacobian is applied. Has shape ``[V, N, *, out_features]`` if subsampling is not used, otherwise ``N`` must be ``len(subsampling)`` instead. sum_batch: Sum the result's batch axis. Default: ``True``. subsampling: Indices of samples along the output's batch dimension that should be considered. Defaults to ``None`` (use all samples). Returns: Batched transposed Jacobian vector products. Has shape ``[V, N, *module.weight.shape]`` when ``sum_batch`` is ``False``. With ``sum_batch=True``, has shape ``[V, *module.weight.shape]``. If sub- sampling is used, ``N`` must be ``len(subsampling)`` instead. """ d_weight = subsample(module.input0, subsampling=subsampling) equation = f"vn...o,n...i->v{'' if sum_batch else 'n'}oi" return einsum(equation, mat, d_weight)
def forward_pass(self, input_requires_grad: bool = False, subsampling: List[int] = None ) -> Tuple[Tensor, Tensor, Dict[str, Tensor]]: """Do a forward pass. Return input, output, and parameters.""" input: Tensor = self.input.clone().detach() if subsampling is not None: batch_axis = 0 input = subsample(input, dim=batch_axis, subsampling=subsampling) if input_requires_grad and input.dtype is not long: input.requires_grad = True if self.is_loss(): assert subsampling is None output: Tensor = self.module(input, self.target) else: output: Tensor = self.module(input) if isinstance(output, tuple): # is true for RNN,GRU,LSTM which return tuple (output, ...) output: Tensor = output[0] return input, output, dict(self.module.named_parameters())
def _check_like(mat, module, name, diff=1, *args, **kwargs): if name in ["output", "input0"] and "subsampling" in kwargs.keys(): compare = subsample(getattr(module, name), dim=0, subsampling=kwargs["subsampling"]) else: compare = getattr(module, name) return check_shape(mat, compare, diff=diff)
def df( self, module: Tanh, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], subsampling: List[int] = None, ) -> Tensor: output = subsample(module.output, subsampling=subsampling) return 1.0 - output**2
def df( self, module: LogSigmoid, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], subsampling: List[int] = None, ) -> Tensor: """First Logsigmoid derivative: `logsigmoid'(x) = 1 / (e^x + 1) `.""" input0 = subsample(module.input0, subsampling=subsampling) return 1 / (exp(input0) + 1)
def df( self, module: Sigmoid, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], subsampling: List[int] = None, ) -> Tensor: """First sigmoid derivative: `σ'(x) = σ(x) (1 - σ(x))`.""" output = subsample(module.output, subsampling=subsampling) return output * (1.0 - output)
def df( self, module: ReLU, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], subsampling: List[int] = None, ) -> Tensor: """First ReLU derivative: `ReLU'(x) = 0 if x < 0 else 1`.""" input0 = subsample(module.input0, subsampling=subsampling) return gt(input0, 0).to(input0.dtype)
def df( self, module: LeakyReLU, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], subsampling: List[int] = None, ) -> Tensor: """``LeakyReLU'(x) = negative_slope if x < 0 else 1``.""" input0 = subsample(module.input0, subsampling=subsampling) df_leakyrelu = gt(input0, 0).to(input0.dtype) df_leakyrelu[df_leakyrelu == 0] = module.negative_slope return df_leakyrelu
def jac_t_vec_prod(self, vec: Tensor, subsampling=None) -> Tensor: # noqa: D102 input, output, _ = self.problem.forward_pass(input_requires_grad=True) if subsampling is None: return transposed_jacobian_vector_product(output, input, vec)[0] else: # for each sample, multiply by full input Jacobian, slice out result: # ( (∂ output[n] / ∂ input)ᵀ v[n] )[n] batch_axis = 0 output = subsample(output, dim=batch_axis, subsampling=subsampling) output = output.split(1, dim=batch_axis) vec = vec.split(1, dim=batch_axis) vjps: List[Tensor] = [] for sample_idx, out, v in zip(subsampling, output, vec): vjp = transposed_jacobian_vector_product(out, input, v)[0] vjp = subsample(vjp, dim=batch_axis, subsampling=[sample_idx]) vjps.append(vjp) return cat(vjps, dim=batch_axis)
def _get_probs(module: CrossEntropyLoss, subsampling: List[int] = None) -> Tensor: """Compute the softmax probabilities from the module input. Args: module: cross-entropy loss with I/O. subsampling: Indices of samples to be considered. Default of ``None`` uses the full mini-batch. Returns: Softmax probabilites """ input0 = subsample(module.input0, subsampling=subsampling) return softmax(input0, dim=1)
def _weight_jac_t_mat_prod( self, module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, subsampling: List[int] = None, ) -> Tensor: x_hat, _ = self._get_normalized_input_and_var(module) x_hat = subsample(x_hat, subsampling=subsampling) equation = f"vnc...,nc...->v{'' if sum_batch else 'n'}c" return einsum(equation, mat, x_hat)
def df( self, module: Dropout, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], subsampling: List[int] = None, ) -> Tensor: # noqa: D102 output = subsample(module.output, subsampling=subsampling) if module.training: scaling = 1 / (1 - module.p) mask = 1 - eq(output, 0.0).to(output.dtype) return mask * scaling else: return ones_like(output)
def __same_conv_weight_jac_t( self, module: Union[Conv1d, Conv2d, Conv3d], mat: Tensor, sum_batch: bool, subsampling: List[int] = None, ) -> Tensor: """Uses convolution of same order.""" G = module.groups V = mat.shape[0] C_out = module.output.shape[1] N = module.output.shape[0] if subsampling is None else len(subsampling) C_in = module.input0.shape[1] C_in_axis = 1 N_axis = 0 # treat channel groups like vectorization (v) and batch (n) axes mat = rearrange(mat, "v n (g c) ... -> (v n g) c ...", g=G, c=C_out // G) repeat_pattern = [1, C_in // G] + [1 for _ in range(self.conv_dims)] mat = mat.repeat(*repeat_pattern) mat = rearrange(mat, "a b ... -> (a b) ...") mat = mat.unsqueeze(C_in_axis) input = rearrange( subsample(module.input0, subsampling=subsampling), "n c ... -> (n c) ..." ) input = input.unsqueeze(N_axis) repeat_pattern = [1, V] + [1 for _ in range(self.conv_dims)] input = input.repeat(*repeat_pattern) grad_weight = self.conv_func( input, mat, bias=None, stride=module.dilation, padding=module.padding, dilation=module.stride, groups=C_in * N * V, ).squeeze(0) for dim in range(self.conv_dims): axis = dim + 1 size = module.weight.shape[2 + dim] grad_weight = grad_weight.narrow(axis, 0, size) dim = {"g": G, "v": V, "n": N, "i": C_in // G, "o": C_out // G} if sum_batch: return reduce(grad_weight, "(v n g i o) ... -> v (g o) i ...", "sum", **dim) else: return rearrange(grad_weight, "(v n g i o) ... -> v n (g o) i ...", **dim)
def df( self, module: ELU, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], subsampling: List[int] = None, ): """First ELU derivative: `ELU'(x) = alpha * e^x if x <= 0 else 1`.""" input0 = subsample(module.input0, subsampling=subsampling) non_pos = le(input0, 0) result = ones_like(input0) result[non_pos] = module.alpha * exp(input0[non_pos]) return result
def get_pooling_idx( self, module: Union[MaxPool1d, MaxPool2d, MaxPool3d], subsampling: List[int] = None, ) -> Tensor: _, pool_idx = self.maxpool( subsample(module.input0, subsampling=subsampling), kernel_size=module.kernel_size, stride=module.stride, padding=module.padding, dilation=module.dilation, return_indices=True, ceil_mode=module.ceil_mode, ) return pool_idx
def df( self, module: SELU, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], subsampling: List[int] = None, ) -> Tensor: """First SELU derivative: `SELU'(x) = scale if x > 0 else scale*alpha*e^x`.""" input0 = subsample(module.input0, subsampling=subsampling) non_pos = le(input0, 0) result = self.scale * ones_like(input0) result[non_pos] = self.scale * self.alpha * exp(input0[non_pos]) return result
def _jac_t_mat_prod( self, module: Slicing, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor, subsampling: List[int] = None, ) -> Tensor: self.no_slice_batch_axis(module) input0 = module.input0 result_shape = (mat.shape[0], *subsample(input0, subsampling=subsampling).shape) result = zeros(result_shape, device=input0.device, dtype=input0.dtype) result[(slice(None),) + module.slice_info] = mat return result
def _weight_jac_t_mat_prod( self, module: Embedding, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, subsampling: List[int] = None, ) -> Tensor: self._check_parameters(module) input0 = subsample(module.input0, subsampling=subsampling) delta = zeros(module.num_embeddings, *input0.shape, device=mat.device) for s in range(module.num_embeddings): delta[s] = input0 == s equation = f"sn...,vn...h->v{'' if sum_batch else 'n'}sh" return einsum(equation, delta, mat)
def input_hessian_via_sqrt_hessian( self, mc_samples: int = None, chunks: int = 1, subsampling: List[int] = None) -> Tensor: """Computes the Hessian w.r.t. to the input from its matrix square root. Args: mc_samples: If int, uses an MC approximation with the specified number of samples. If None, uses the exact hessian. Defaults to None. chunks: Maximum sequential split of the computation. Default: ``1``. Only used if mc_samples is specified. subsampling: Indices of active samples. ``None`` uses all samples. Returns: Hessian with respect to the input. Has shape ``[N, A, B, ..., N, A, B, ...]`` where ``N`` is the batch size or number of active samples when sub-sampling is used, and ``[A, B, ...]`` are the input's feature dimensions. """ self.store_forward_io() if mc_samples is not None: chunk_samples = chunk_sizes(mc_samples, chunks) chunk_weights = [samples / mc_samples for samples in chunk_samples] individual_hessians: Tensor = sum( weight * self._sample_hessians_from_sqrt( self.problem.derivative.sqrt_hessian_sampled( self.problem.module, None, None, mc_samples=samples, subsampling=subsampling, )) for weight, samples in zip(chunk_weights, chunk_samples)) else: sqrt_hessian = self.problem.derivative.sqrt_hessian( self.problem.module, None, None, subsampling=subsampling) individual_hessians = self._sample_hessians_from_sqrt(sqrt_hessian) input0 = subsample(self.problem.module.input0, subsampling=subsampling) return self._embed_sample_hessians(individual_hessians, input0)
def _weight_ih_l0_jac_t_mat_prod( self, module: LSTM, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, subsampling: List[int] = None, ) -> Tensor: self._check_parameters(module) IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod( module, mat, subsampling=subsampling ) return einsum( f"vnth,nti->v{'' if sum_batch else 'n'}hi", IFGO_prod, subsample(module.input0, dim=0, subsampling=subsampling), )
def _a_jac_t_mat_prod( cls, module: RNN, weight_hh_l0: Tensor, mat: Tensor, subsampling: List[int] = None, ) -> Tensor: """Calculates jacobian vector product wrt a. Args: module: RNN module weight_hh_l0: weight matrix hidden-to-hidden mat: matrix to multiply subsampling: subsampling Returns: jacobian vector product wrt a """ V, N, T, H = mat.shape output = subsample(module.output, dim=0, subsampling=subsampling) a_jac_t_mat_prod: Tensor = zeros(V, N, T, H, device=mat.device, dtype=mat.dtype) for t in reversed(range(T)): if t == (T - 1): a_jac_t_mat_prod[:, :, t] = einsum("vnh,nh->vnh", mat[:, :, t], 1 - output[:, t]**2) else: a_jac_t_mat_prod[:, :, t] = einsum( "vnh,nh->vnh", mat[:, :, t] + einsum( "vng,gh->vnh", a_jac_t_mat_prod[:, :, t + 1], weight_hh_l0, ), 1 - output[:, t]**2, ) return a_jac_t_mat_prod
def _weight_hh_l0_jac_t_mat_prod( self, module: LSTM, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, subsampling: List[int] = None, ) -> Tensor: self._check_parameters(module) _, N, _, H = mat.shape IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod( module, mat, subsampling=subsampling ) subsampled_output = subsample(module.output, dim=0, subsampling=subsampling) single_step = zeros(N, 1, H, device=mat.device, dtype=mat.dtype) return einsum( f"vnth,ntg->v{'' if sum_batch else 'n'}hg", IFGO_prod, cat([single_step, subsampled_output[:, :-1]], dim=1), )
def test_subsample(): """Test slicing operations for sub-sampling a tensor's batch axis.""" manual_seed(0) tensor = rand(3, 4, 5, 6) # leave tensor untouched when `subsampling = None` assert id(subsample(tensor)) == id(tensor) assert allclose(subsample(tensor), tensor) # slice along correct dimension idx = [2, 0] assert allclose(subsample(tensor, dim=0, subsampling=idx), tensor[idx]) assert allclose(subsample(tensor, dim=1, subsampling=idx), tensor[:, idx]) assert allclose(subsample(tensor, dim=2, subsampling=idx), tensor[:, :, idx]) assert allclose(subsample(tensor, dim=3, subsampling=idx), tensor[:, :, :, idx])
def __higher_conv_weight_jac_t( self, module: Union[Conv1d, Conv2d, Conv3d], mat: Tensor, sum_batch: bool, subsampling: List[int] = None, ) -> Tensor: """Requires higher-order convolution. The algorithm is proposed in: - Rochette, G., Manoel, A., & Tramel, E. W., Efficient per-example gradient computations in convolutional neural networks (2019). """ G = module.groups V = mat.shape[0] C_out = module.output.shape[1] N = module.output.shape[0] if subsampling is None else len(subsampling) C_in = module.input0.shape[1] higher_conv_func = get_conv_function(self.conv_dims + 1) spatial_dim = (C_in // G,) + module.input0.shape[2:] spatial_dim_axis = (1, V) + tuple([1] * (self.conv_dims + 1)) spatial_dim_new = (C_in // G,) + module.weight.shape[2:] # Reshape to extract groups from the convolutional layer # Channels are seen as an extra spatial dimension with kernel size 1 input_conv = ( subsample(module.input0, subsampling=subsampling) .reshape(1, N * G, *spatial_dim) .repeat(*spatial_dim_axis) ) # Compute convolution between input and output; the batchsize is seen # as channels, taking advantage of the `groups` argument mat_conv = rearrange(mat, "v n c ... -> (v n c) ...").unsqueeze(1).unsqueeze(2) stride = (1, *module.stride) dilation = (1, *module.dilation) padding = (0, *module.padding) conv = higher_conv_func( input_conv, mat_conv, groups=V * N * G, stride=dilation, dilation=stride, padding=padding, ).squeeze(0) # Because of rounding shapes when using non-default stride or dilation, # convolution result must be truncated to convolution kernel size for axis in range(2, 2 + self.conv_dims): conv = conv.narrow(axis, 0, module.weight.shape[axis]) new_shape = [V, N, C_out, *spatial_dim_new] weight_grad = conv.reshape(*new_shape) if sum_batch: weight_grad = weight_grad.sum(1) return weight_grad