def recalculate_indexes(self, x_input: Tensor) -> None: """Calculate and set the indexes of the analog tile.""" input_size = x_input.numel() / x_input.size(0) # pytorch just always uses NCHW order? fold_indices = arange(2, input_size + 2, dtype=float64).detach() shape = [1] + list(x_input.shape[1:]) fold_indices = fold_indices.reshape(*shape) unfold = Unfold(kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation) fold_indices = unfold(fold_indices).flatten().round().to(dtype=int32) if self.use_bias: out_image_size = fold_indices.numel() // (self.kernel_size[0] * self.kernel_size[1]) fold_indices = cat( (fold_indices, ones(out_image_size, dtype=int32)), 0) self.fold_indices = fold_indices.to(x_input.device) x_height = x_input.size(2) x_width = x_input.size(3) d_height = self.get_image_size(x_height, 0) d_width = self.get_image_size(x_width, 1) image_sizes = [self.in_channels, x_height, x_width, d_height, d_width] self.input_size = input_size self.analog_tile.set_indexed(self.fold_indices, image_sizes)
def unfold_func(module): return Unfold( kernel_size=module.kernel_size, dilation=module.dilation, padding=module.padding, stride=module.stride, )
def im2col(input_tensor: torch.Tensor, filter: torch.Tensor) -> torch.Tensor: N, C1, H, W = input_tensor.shape M, C2, R, S = filter.shape assert C1 == C2, f'Input tensor channels ({C1}) do not match with filter channels ({C2})' C = C1 unfold = Unfold(kernel_size=(R, S)) im_as_col = unfold(input_tensor) out_as_col = im_as_col.transpose(1, 2).matmul(filter.view(filter.size(0), -1).t()).transpose(1, 2) out = out_as_col.reshape(out_as_col.size(0), out_as_col.size(1), H - R + 1, W - S + 1).contiguous() return out
def forward(self, x_input: Tensor) -> Tensor: """Computes the forward pass.""" # pylint: disable=arguments-differ def get_size(size: int, i: int) -> int: """Calculate the output image sizes""" nom = (size + 2 * self.padding[i] - self.dilation[i] * (self.kernel_size[i] - 1) - 1) return nom // self.stride[i] + 1 input_size = x_input.numel() / x_input.size(0) if not self.fold_indices.numel() or self.input_size != input_size: # pytorch just always uses NCHW order? fold_indices = arange(2, input_size + 2, dtype=float64).detach() shape = [1] + list(x_input.shape[1:]) fold_indices = fold_indices.reshape(*shape) unfold = Unfold(kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation) fold_indices = unfold(fold_indices).flatten().round().to( dtype=int32) if self.use_bias: out_image_size = fold_indices.numel() // self.out_channels fold_indices = cat( (fold_indices, ones(out_image_size, dtype=int32)), 0) self.fold_indices = fold_indices.to(x_input.device) x_height = x_input.size(2) x_width = x_input.size(3) d_height = get_size(x_height, 0) d_width = get_size(x_width, 1) image_sizes = [ self.in_channels, x_height, x_width, d_height, d_width ] self.input_size = input_size self.analog_tile.set_indexed(self.fold_indices, image_sizes) # type: ignore return AnalogIndexedFunction.apply(self.analog_tile, x_input, self.weight, self.bias, not self.training)
def _calculate_indexes(self, x_input: Tensor, in_channels: int) -> Tuple[Tensor, List[int], int]: """Calculate and return the fold indexes and sizes. Args: x_input: input matrix in_channels: number of input channel Returns: fold_indices: indices for the analog tile image_sizes: image sizes for the analog tile input_size: size of the current input """ input_size = x_input.numel() / x_input.size(0) # pytorch just always uses NCHW order fold_indices = arange(2, input_size + 2, dtype=float64).detach() shape = [1] + list(x_input.shape[1:]) fold_indices = fold_indices.reshape(*shape) unfold = Unfold(kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation) fold_indices = unfold(fold_indices).flatten().round().to(dtype=int32) if self.analog_bias: out_image_size = fold_indices.numel() // (self.kernel_size[0] * self.kernel_size[1]) fold_indices = cat( (fold_indices, ones(out_image_size, dtype=int32)), 0) fold_indices = fold_indices.to(x_input.device) x_height = x_input.size(2) x_width = x_input.size(3) d_height = self.get_image_size(x_height, 0) d_width = self.get_image_size(x_width, 1) image_sizes = [in_channels, x_height, x_width, d_height, d_width] return (fold_indices, image_sizes, input_size)
def conv2d(input, module, super_opt='false', reduce_sum='false', diag='false'): f = Unfold( kernel_size=module.kernel_size, dilation=module.dilation, padding=module.padding, stride=module.stride, ) I = f(input) N = I.shape[0] K = I.shape[1] L = I.shape[2] M = module.out_channels module.param_shapes = [N, K, L, M] if reduce_sum == 'true': I = einsum("nkl->nk", I) if diag == 'true': I /= L II = torch.sum(I * I, dim=1) else: II = einsum("nk,qk->nq", (I, I)) module.optimized = True return II, I flag = False if super_opt == 'true': flag = N * (L * L) * (K + M) < K * M * L + N * K * M else: flag = (L * L) * (K + M) < K * M if flag == True: II = einsum("nkl,qkp->nqlp", (I, I)) module.optimized = True return II, I else: module.optimized = False return None, I
def evaluate(attribution, data_loader, tile_size=14, baseline="mean", perturb_stride=5, checkpointer=None, reduce_mode="absmean", progbar=True): child_progbar = False # progbar model, attribute = attribution.forward_func, attribution.attribute device = next(model.parameters()).device if checkpointer is None: morf, lerf = None, None else: state_dict = checkpointer.state_dict(device=device) morf, lerf = state_dict["morf"], state_dict["lerf"] if state_dict["sampler_state"] is not None: data_loader.sampler.load_state_dict(state_dict["sampler_state"]) for inputs, targets in tqdm(data_loader, desc=f"{type(attribution).__name__:15s}", disable=not progbar, ncols=70): assert inputs.shape[2] % tile_size == 0 and inputs.shape[ 3] % tile_size == 0, "Size mismatch" inputs, targets = inputs.to(device=device), targets.to(device=device) with warnings.catch_warnings(): warnings.filterwarnings("ignore") if reduce_mode == "absmax": raw_heatmaps = attribute(inputs, target=targets).abs().max( dim=1, keepdim=True)[0] elif reduce_mode == "absmean": raw_heatmaps = attribute(inputs, target=targets).abs() elif reduce_mode == "mean": raw_heatmaps = attribute(inputs, target=targets) else: raise ValueError( f"reduce_mode '{reduce_mode}' is not supported.") with torch.no_grad(): heatmaps = F.interpolate(raw_heatmaps, size=inputs.shape[2:4], mode="bilinear", align_corners=False) tile_num = int(inputs.shape[2] / tile_size) unfold = Unfold(kernel_size=tile_num, dilation=tile_size) fold = Fold(output_size=inputs.shape[2:4], kernel_size=tile_num, dilation=tile_size) heatmaps_tile = unfold(heatmaps.sum(dim=1, keepdim=True)) tile_rank = heatmaps_tile.sum(dim=2).argsort(dim=1, descending=True) inputs_tile = unfold(inputs).view(inputs.shape[0], inputs.shape[1], tile_num * tile_num, -1) covers_tile = _baseline(inputs_tile, baseline) morf_b = _perturb(model, inputs_tile.clone(), covers_tile, targets, tile_rank, fold, stride=perturb_stride, mode="morf", progbar=child_progbar) morf = morf_b if morf is None else morf + morf_b lerf_b = _perturb(model, inputs_tile, covers_tile, targets, tile_rank, fold, stride=perturb_stride, mode="lerf", progbar=child_progbar) lerf = lerf_b if lerf is None else lerf + lerf_b if checkpointer is not None: checkpointer.save(morf, lerf, data_loader.sampler.state_dict()) norm_morf = (morf - morf[-1]) / (morf[0] - morf[-1]) norm_lerf = (lerf - lerf[-1]) / (lerf[0] - lerf[-1]) return norm_morf.cpu(), norm_lerf.cpu()