def undo_reshape(self, shapley_values: torch.Tensor) -> torch.Tensor:
        r"""
        This folds the representation back into the canonical image 
        representation


        Args:
            shapley_values (): the representation after the Shapley Module

        Returns:
            The image representation

        """
        # prepare the output for folding, at the end should be
        # (batch_size, features_times_channels, patches)
        shapley_values = shapley_values.align_to(..., NAME_META_CHANNELS,
                                                 NAME_FEATURES)
        shapley_values = shapley_values.flatten(
            [NAME_META_CHANNELS, NAME_FEATURES], NAME_FEATURES_META_CHANNEL)
        shapley_values = shapley_values.align_to(...,
                                                 NAME_FEATURES_META_CHANNEL,
                                                 NAME_PATCHES)

        # the folding operation
        shapley_values = self.folder.fold(shapley_values, self.folder.height,
                                          self.folder.width)

        return shapley_values
Beispiel #2
0
def _(x: Tensor) -> Tensor:
    if x.ndim == 3:
        if any(x.names):
            return x.align_to("C", "H", "W")
        return x.permute(2, 0, 1)  #.to(memory_format=torch.contiguous_format)
    if x.ndim == 4:
        if any(x.names):
            return x.align_to("N", "C", "H", "W")
        return x.permute(0, 3, 1, 2).contiguous()
    return x
    def prune(self, shapley_values: torch.Tensor,
              id_stage: int) -> torch.Tensor:
        r"""
        Prune the output of a stage

        Args:
            id_stage (): the index of the stage
            shapley_values (): the current stage's pre-pruning output

        Returns:
            the pruned Shapley representation

        """
        # First check if we need pruning in the first place
        pruning = self.pruning[id_stage]
        num_pixels = named_tensor_get_dim(shapley_values,
                                          [NAME_HEIGHT, NAME_WIDTH])
        num_pixels = num_pixels[0] * num_pixels[1]
        k = int(num_pixels * pruning)  # the # to prune
        if pruning * k == 0:
            return shapley_values

        name_size = {
            name: size
            for name, size in zip(shapley_values.names, shapley_values.shape)
        }

        shapley_values = shapley_values.align_to(..., NAME_HEIGHT, NAME_WIDTH,
                                                 NAME_META_CHANNELS).flatten(
                                                     [NAME_HEIGHT, NAME_WIDTH],
                                                     NAME_FEATURES)
        # get the norm of the vectors for each of the pixels, based on
        # which the pruning will be performed
        abs_values = torch.linalg.norm(shapley_values.rename(None),
                                       ord=1,
                                       dim=-1)
        # generate top-k
        top_k = torch.topk(abs_values.rename(None), k, largest=False)[0].max(
            1, keepdim=True)[0]  # threshold of the values to prune
        abs_values = abs_values.rename(NAME_BATCH_SIZE, NAME_FEATURES) > top_k
        shapley_values = shapley_values * abs_values.align_to(
            ..., NAME_META_CHANNELS)
        shapley_values = shapley_values.unflatten(
            NAME_FEATURES, [[NAME_HEIGHT, name_size[NAME_HEIGHT]],
                            [NAME_WIDTH, name_size[NAME_WIDTH]]])

        shapley_values = shapley_values.align_to(*name_size.keys())

        return shapley_values
Beispiel #4
0
    def prune(
            self, shapley_values: torch.Tensor, id_stage: int
    ) -> torch.Tensor:
        r"""
        Prune the output of a stage

        Args:
            id_stage (): the index of the stage
            shapley_values (): the current stage's pre-pruning output

        Returns:
            the pruned Shapley representation

        """
        # First check if we need pruning in the first place
        pruning = self.pruning[id_stage]
        num_features = named_tensor_get_dim(shapley_values, NAME_FEATURES)
        k = int(num_features * pruning)  # the number to prune
        if pruning * k == 0:
            return shapley_values

        names = shapley_values.names
        shapley_values = shapley_values.align_to(
            ..., NAME_FEATURES, NAME_META_CHANNELS)
        # get the norm of the vectors for each of the pixels, based on
        # which the pruning will be performed
        abs_values = torch.linalg.norm(
            shapley_values.rename(None), ord=1, dim=-1)
        # generate top-k
        top_k = torch.topk(
            abs_values.rename(None), k, largest=False
        )[0].max(1, keepdim=True)[0]  # threshold of the values to prune
        shapley_values = shapley_values * (
                abs_values > top_k).unsqueeze(-1)

        shapley_values = shapley_values.align_to(*names)

        return shapley_values
    def _final_process(self, shapley_values: torch.Tensor, id_stage: int,
                       *args, **kwargs) -> torch.Tensor:
        r"""
        Depending on the usage, this method is for final processing before
        output the values

        Args:
            shapley_values (): the shapley values computed from the last stage
            args (): placeholder
            kwargs (): placeholder

        Returns:
            Shapley values final for presenting

        """
        return shapley_values.align_to(NAME_BATCH_SIZE, ...,
                                       NAME_META_CHANNELS)
    def compute_shapley(self, function_outputs: torch.Tensor) -> torch.Tensor:
        r"""

        Args:
            function_outputs (): of shape (2 ** features, batch_size,
            output_channels)

        Returns:
            Shapley values for each of the variables
                should be of shape (batch_size, self.m, output_channels)
        """
        shapley_values = torch.matmul(
            function_outputs.align_to(..., NAME_NUM_PASSES),
            self.subtraction_matrix
        ).align_to(NAME_BATCH_SIZE, ..., NAME_FEATURES, NAME_META_CHANNELS)

        return shapley_values
    def reshape(self, shapley_values: torch.Tensor) -> torch.Tensor:
        r"""
        In this instantiation, this method tries to unfold the representation
        and prepare for the Shapley Module. 

        Args:
            shapley_values (): the input Shapley representation

        Returns:
            the prepared Shapley representation for the Shapley module that 
            follows

        """
        # Prepare the input, at the end should be of shape
        shapley_values_un = self.folder(shapley_values)
        # (batch_size, patches, features, meta-channels)
        shapley_values = shapley_values_un.unflatten(
            NAME_FEATURES_META_CHANNEL, [
                (NAME_META_CHANNELS, self.dimensions.in_channel),
                (NAME_FEATURES, np.prod(self.kernel_size)),
            ])
        shapley_values = shapley_values.align_to(..., NAME_FEATURES,
                                                 NAME_META_CHANNELS)
        return shapley_values
    def forward(self, emb_inputs: torch.Tensor) -> torch.Tensor:
        r"""Forward calculation of CompressInteractionNetworkLayer
        
        Args:
            emb_inputs (T), shape = (B, N, E), dtype = torch.float: Embedded features tensors.
        
        Returns:
            T, shape = (B, O), dtype = torch.float: Output of CompressInteractionNetworkLayer.
        """
        # Initialize two lists to store tensors of outputs and next steps temporarily
        direct_list = list()
        hidden_list = list()

        # Transpose emb_inputs
        # inputs: emb_inputs, shape = (B, N, E)
        # output: x0, shape = (B, E, N)
        x0 = emb_inputs.align_to("B", "E", "N")
        hidden_list.append(x0)
        
        # Expand dimension N of x0 to Nx (= N) and H (= 1)
        # inputs: x0, shape = (B, E, N)
        # output: x0, shape = (B, E, Nx = N, H = 1)
        x0 = x0.unflatten("N", [("Nx", x0.size("N")), ("H", 1)])

        # Calculate with cin forwardly
        for i, layer_size in enumerate(self.layer_sizes[:-1]):
            # Get tensors of previous step and reshape it
            # inputs: hidden_list[-1], shape = (B, E, N)
            # output: xi, shape = (B, E, H = 1, Ny = N)
            xi = hidden_list[-1]
            xi = xi.unflatten("N", [("H", 1), ("Ny", xi.size("N"))])

            # Calculate outer product of x0 and x1
            # inputs: x0, shape = (B, E, Nx = N, H = 1)
            # inputs: x1, shape = (B, E, H = 1, Ny = N) 
            # output: out_prod, shape = (B, E, Nx = N, Ny = N)
            ## out_prod = torch.matmul(x0, xi)
            out_prod = torch.einsum("ijkn,ijnh->ijkh", [x0.rename(None), x1.rename(None)])
            out_prod.names = ("B", "E", "Nx", "Ny")
            
            # Reshape out_prod
            # inputs: out_prod, shape = (B, E, Nx = N, Ny = N)
            # output: out_prod, shape = (B, N = Nx * Ny, E)
            out_prod = out_prod.flatten(["Nx", "Ny"], "N")
            out_prod = out_prod.align_to("B", "N", "E")

            # Apply convalution, batchnorm and activation
            # inputs: out_prod, shape = (B, N = Nx * Ny, E)
            # output: outputs, shape = (B, N = (Hi * 2 or Hi), E)
            outputs = self.model[i](out_prod.rename(None))
            outputs.names = ("B", "N", "E")
            
            if self.is_direct:
                # Pass to output directly
                # inputs: outputs, shape = (B, N = Hi, E)
                # output: direct, shape = (B, N = Hi, E)
                direct = outputs

                # Reshape and pass to next step directly
                # inputs: outputs, shape = (B, Hi, E)
                # output: hidden, shape = (B, E, N = Hi)
                hidden = outputs.align_to("B", "E", "N")
            else:
                if i != (len(self.layer_sizes) - 1):
                    # Split outputs into two part and pass them to outputs and hidden separately
                    # inputs: outputs, shape = (B, Hi * 2, E)
                    # output: direct, shape = (B, N = Hi, E)
                    # output: hidden, shape = (B, N = Hi, E)
                    direct, hidden = torch.chunk(outputs, 2, dim="N")
                    
                    # Reshape and pass to next step
                    # inputs: hidden, shape = (B, N = Hi, E)
                    # output: hidden, shape = (B, E, N = Hi)
                    hidden = hidden.align_to("B", "E", "N")
                else:
                    # Pass to output directly
                    # inputs: outputs, shape = (B, N = Hi, E)
                    # output: direct, shape = (B, N = Hi, E)
                    direct = outputs
                    hidden = 0

            # Store tensors to lists temporarily
            direct_list.append(direct)
            hidden_list.append(hidden)
        
        # Concatenate direct_list into a tensor
        # inputs: direct_list, shape = (B, Hi, E)
        # output: outputs, shape = (B, sum(Hi), E)
        outputs = torch.cat(direct_list, dim="N")

        # Aggregate outputs on dimension E and pass to dense layer
        # inputs: outputs, shape = (B, sum(Hi), E)
        # output: outputs, shape = (B, O)
        outputs = self.fc(outputs.sum("E"))
        outputs.names = ("B", "O")
        
        return outputs
Beispiel #9
0
    def forward(
        self,
        input_image: torch.Tensor,
        masked_kspace: torch.Tensor,
        sensitivity_map: torch.Tensor,
        sampling_mask: torch.Tensor,
        previous_state: Optional[torch.Tensor] = None,
        loglikelihood_scaling: Optional[float] = None,
        **kwargs,
    ):
        """

        Parameters
        ----------
        input_image : torch.Tensor
            Initial or intermediate guess of input.
        masked_kspace : torch.Tensor
            Kspace masked by the sampling mask.
        sensitivity_map : torch.Tensor
            Coil sensitivities.
        sampling_mask : torch.Tensor
            Sampling mask.
        previous_state : torch.Tensor
        loglikelihood_scaling : torch.Tensor

        Returns
        -------
        torch.Tensor
        """

        # TODO: This has to be made contiguous
        input_image = input_image.align_to(
            "batch", "complex", "height", "width").contiguous()  # type: ignore

        batch_size = input_image.size("batch")
        spatial_shape = [input_image.size("height"), input_image.size("width")]
        # Initialize zero state for RIM
        state_size = ([batch_size, self.num_hidden_channels] +
                      list(spatial_shape) + [self.depth])
        if previous_state is None:
            previous_state = torch.zeros(
                *state_size, dtype=input_image.dtype).to(input_image.device)

        cell_outputs = []
        intermediate_image = input_image
        for cell_idx in range(self.length):
            cell = self.cell_list[
                cell_idx] if self.no_sharing else self.cell_list[0]

            grad_loglikelihood = self.grad_likelihood(
                intermediate_image,
                masked_kspace,
                sensitivity_map,
                sampling_mask,
                loglikelihood_scaling,
            )

            if grad_loglikelihood.abs().max() > 150.0:
                warnings.warn(
                    f"Very large values for the gradient loglikelihood ({grad_loglikelihood.abs().max()}). "
                    f"Might cause difficulties.")

            cell_input = torch.cat(
                [
                    intermediate_image.rename(None),
                    grad_loglikelihood.rename(None)
                ],
                dim=1,
            )
            cell_output, previous_state = cell(cell_input, previous_state)
            if self.skip_connections:
                intermediate_image = intermediate_image + cell_output
            if not self.training:
                # If not training, memory can be significantly reduced by clearing the previous cell.
                cell_output.set_()
                grad_loglikelihood.rename(None).set_(
                )  # TODO: Fix when named tensors have this support.
                del cell_output, grad_loglikelihood
            # Only save intermediate reconstructions at training step
            if self.training or cell_idx == (self.length - 1):
                cell_outputs.append(
                    intermediate_image.refine_names("batch", "complex",
                                                    "height",
                                                    "width"))  # type: ignore
        return cell_outputs, previous_state