Example #1
0
    def recalculate_indexes(self, x_input: Tensor) -> None:
        """Calculate and set the indexes of the analog tile."""
        input_size = x_input.numel() / x_input.size(0)

        # pytorch just always uses NCHW order?
        fold_indices = arange(2, input_size + 2, dtype=float64).detach()
        shape = [1] + list(x_input.shape[1:])
        fold_indices = fold_indices.reshape(*shape)
        unfold = Unfold(kernel_size=self.kernel_size,
                        stride=self.stride,
                        padding=self.padding,
                        dilation=self.dilation)
        fold_indices = unfold(fold_indices).flatten().round().to(dtype=int32)

        if self.use_bias:
            out_image_size = fold_indices.numel() // (self.kernel_size[0] *
                                                      self.kernel_size[1])
            fold_indices = cat(
                (fold_indices, ones(out_image_size, dtype=int32)), 0)

        self.fold_indices = fold_indices.to(x_input.device)

        x_height = x_input.size(2)
        x_width = x_input.size(3)

        d_height = self.get_image_size(x_height, 0)
        d_width = self.get_image_size(x_width, 1)

        image_sizes = [self.in_channels, x_height, x_width, d_height, d_width]
        self.input_size = input_size
        self.analog_tile.set_indexed(self.fold_indices, image_sizes)
Example #2
0
def unfold_func(module):
    return Unfold(
        kernel_size=module.kernel_size,
        dilation=module.dilation,
        padding=module.padding,
        stride=module.stride,
    )
Example #3
0
def im2col(input_tensor: torch.Tensor, filter: torch.Tensor) -> torch.Tensor:
    N, C1, H, W = input_tensor.shape
    M, C2, R, S = filter.shape

    assert C1 == C2, f'Input tensor channels ({C1}) do not match with filter channels ({C2})'
    C = C1

    unfold = Unfold(kernel_size=(R, S))
    im_as_col = unfold(input_tensor)
    
    out_as_col = im_as_col.transpose(1, 2).matmul(filter.view(filter.size(0), -1).t()).transpose(1, 2)
    out = out_as_col.reshape(out_as_col.size(0), out_as_col.size(1), H - R + 1, W - S + 1).contiguous()
    return out
Example #4
0
    def forward(self, x_input: Tensor) -> Tensor:
        """Computes the forward pass."""

        # pylint: disable=arguments-differ

        def get_size(size: int, i: int) -> int:
            """Calculate the output image sizes"""
            nom = (size + 2 * self.padding[i] - self.dilation[i] *
                   (self.kernel_size[i] - 1) - 1)
            return nom // self.stride[i] + 1

        input_size = x_input.numel() / x_input.size(0)
        if not self.fold_indices.numel() or self.input_size != input_size:
            # pytorch just always uses NCHW order?
            fold_indices = arange(2, input_size + 2, dtype=float64).detach()
            shape = [1] + list(x_input.shape[1:])
            fold_indices = fold_indices.reshape(*shape)
            unfold = Unfold(kernel_size=self.kernel_size,
                            stride=self.stride,
                            padding=self.padding,
                            dilation=self.dilation)
            fold_indices = unfold(fold_indices).flatten().round().to(
                dtype=int32)

            if self.use_bias:
                out_image_size = fold_indices.numel() // self.out_channels
                fold_indices = cat(
                    (fold_indices, ones(out_image_size, dtype=int32)), 0)

            self.fold_indices = fold_indices.to(x_input.device)

            x_height = x_input.size(2)
            x_width = x_input.size(3)

            d_height = get_size(x_height, 0)
            d_width = get_size(x_width, 1)

            image_sizes = [
                self.in_channels, x_height, x_width, d_height, d_width
            ]
            self.input_size = input_size
            self.analog_tile.set_indexed(self.fold_indices,
                                         image_sizes)  # type: ignore

        return AnalogIndexedFunction.apply(self.analog_tile, x_input,
                                           self.weight, self.bias,
                                           not self.training)
Example #5
0
    def _calculate_indexes(self, x_input: Tensor,
                           in_channels: int) -> Tuple[Tensor, List[int], int]:
        """Calculate and return the fold indexes and sizes.

        Args:
            x_input: input matrix
            in_channels: number of input channel

        Returns:
            fold_indices: indices for the analog tile
            image_sizes: image sizes for the analog tile
            input_size: size of the current input
        """
        input_size = x_input.numel() / x_input.size(0)

        # pytorch just always uses NCHW order
        fold_indices = arange(2, input_size + 2, dtype=float64).detach()
        shape = [1] + list(x_input.shape[1:])
        fold_indices = fold_indices.reshape(*shape)
        unfold = Unfold(kernel_size=self.kernel_size,
                        stride=self.stride,
                        padding=self.padding,
                        dilation=self.dilation)
        fold_indices = unfold(fold_indices).flatten().round().to(dtype=int32)

        if self.analog_bias:
            out_image_size = fold_indices.numel() // (self.kernel_size[0] *
                                                      self.kernel_size[1])
            fold_indices = cat(
                (fold_indices, ones(out_image_size, dtype=int32)), 0)

        fold_indices = fold_indices.to(x_input.device)

        x_height = x_input.size(2)
        x_width = x_input.size(3)

        d_height = self.get_image_size(x_height, 0)
        d_width = self.get_image_size(x_width, 1)

        image_sizes = [in_channels, x_height, x_width, d_height, d_width]
        return (fold_indices, image_sizes, input_size)
Example #6
0
    def conv2d(input,
               module,
               super_opt='false',
               reduce_sum='false',
               diag='false'):
        f = Unfold(
            kernel_size=module.kernel_size,
            dilation=module.dilation,
            padding=module.padding,
            stride=module.stride,
        )
        I = f(input)
        N = I.shape[0]
        K = I.shape[1]
        L = I.shape[2]
        M = module.out_channels
        module.param_shapes = [N, K, L, M]

        if reduce_sum == 'true':
            I = einsum("nkl->nk", I)
            if diag == 'true':
                I /= L
                II = torch.sum(I * I, dim=1)
            else:
                II = einsum("nk,qk->nq", (I, I))
            module.optimized = True
            return II, I

        flag = False
        if super_opt == 'true':
            flag = N * (L * L) * (K + M) < K * M * L + N * K * M
        else:
            flag = (L * L) * (K + M) < K * M

        if flag == True:
            II = einsum("nkl,qkp->nqlp", (I, I))
            module.optimized = True
            return II, I
        else:
            module.optimized = False
            return None, I
Example #7
0
def evaluate(attribution,
             data_loader,
             tile_size=14,
             baseline="mean",
             perturb_stride=5,
             checkpointer=None,
             reduce_mode="absmean",
             progbar=True):
    child_progbar = False  # progbar

    model, attribute = attribution.forward_func, attribution.attribute
    device = next(model.parameters()).device

    if checkpointer is None:
        morf, lerf = None, None
    else:
        state_dict = checkpointer.state_dict(device=device)
        morf, lerf = state_dict["morf"], state_dict["lerf"]
        if state_dict["sampler_state"] is not None:
            data_loader.sampler.load_state_dict(state_dict["sampler_state"])

    for inputs, targets in tqdm(data_loader,
                                desc=f"{type(attribution).__name__:15s}",
                                disable=not progbar,
                                ncols=70):
        assert inputs.shape[2] % tile_size == 0 and inputs.shape[
            3] % tile_size == 0, "Size mismatch"

        inputs, targets = inputs.to(device=device), targets.to(device=device)

        with warnings.catch_warnings():
            warnings.filterwarnings("ignore")
            if reduce_mode == "absmax":
                raw_heatmaps = attribute(inputs, target=targets).abs().max(
                    dim=1, keepdim=True)[0]
            elif reduce_mode == "absmean":
                raw_heatmaps = attribute(inputs, target=targets).abs()
            elif reduce_mode == "mean":
                raw_heatmaps = attribute(inputs, target=targets)
            else:
                raise ValueError(
                    f"reduce_mode '{reduce_mode}' is not supported.")

        with torch.no_grad():
            heatmaps = F.interpolate(raw_heatmaps,
                                     size=inputs.shape[2:4],
                                     mode="bilinear",
                                     align_corners=False)

            tile_num = int(inputs.shape[2] / tile_size)

            unfold = Unfold(kernel_size=tile_num, dilation=tile_size)
            fold = Fold(output_size=inputs.shape[2:4],
                        kernel_size=tile_num,
                        dilation=tile_size)

            heatmaps_tile = unfold(heatmaps.sum(dim=1, keepdim=True))
            tile_rank = heatmaps_tile.sum(dim=2).argsort(dim=1,
                                                         descending=True)
            inputs_tile = unfold(inputs).view(inputs.shape[0], inputs.shape[1],
                                              tile_num * tile_num, -1)
            covers_tile = _baseline(inputs_tile, baseline)

            morf_b = _perturb(model,
                              inputs_tile.clone(),
                              covers_tile,
                              targets,
                              tile_rank,
                              fold,
                              stride=perturb_stride,
                              mode="morf",
                              progbar=child_progbar)
            morf = morf_b if morf is None else morf + morf_b

            lerf_b = _perturb(model,
                              inputs_tile,
                              covers_tile,
                              targets,
                              tile_rank,
                              fold,
                              stride=perturb_stride,
                              mode="lerf",
                              progbar=child_progbar)
            lerf = lerf_b if lerf is None else lerf + lerf_b

        if checkpointer is not None:
            checkpointer.save(morf, lerf, data_loader.sampler.state_dict())

    norm_morf = (morf - morf[-1]) / (morf[0] - morf[-1])
    norm_lerf = (lerf - lerf[-1]) / (lerf[0] - lerf[-1])

    return norm_morf.cpu(), norm_lerf.cpu()