def undo_reshape(self, shapley_values: torch.Tensor) -> torch.Tensor: r""" This folds the representation back into the canonical image representation Args: shapley_values (): the representation after the Shapley Module Returns: The image representation """ # prepare the output for folding, at the end should be # (batch_size, features_times_channels, patches) shapley_values = shapley_values.align_to(..., NAME_META_CHANNELS, NAME_FEATURES) shapley_values = shapley_values.flatten( [NAME_META_CHANNELS, NAME_FEATURES], NAME_FEATURES_META_CHANNEL) shapley_values = shapley_values.align_to(..., NAME_FEATURES_META_CHANNEL, NAME_PATCHES) # the folding operation shapley_values = self.folder.fold(shapley_values, self.folder.height, self.folder.width) return shapley_values
def _(x: Tensor) -> Tensor: if x.ndim == 3: if any(x.names): return x.align_to("C", "H", "W") return x.permute(2, 0, 1) #.to(memory_format=torch.contiguous_format) if x.ndim == 4: if any(x.names): return x.align_to("N", "C", "H", "W") return x.permute(0, 3, 1, 2).contiguous() return x
def prune(self, shapley_values: torch.Tensor, id_stage: int) -> torch.Tensor: r""" Prune the output of a stage Args: id_stage (): the index of the stage shapley_values (): the current stage's pre-pruning output Returns: the pruned Shapley representation """ # First check if we need pruning in the first place pruning = self.pruning[id_stage] num_pixels = named_tensor_get_dim(shapley_values, [NAME_HEIGHT, NAME_WIDTH]) num_pixels = num_pixels[0] * num_pixels[1] k = int(num_pixels * pruning) # the # to prune if pruning * k == 0: return shapley_values name_size = { name: size for name, size in zip(shapley_values.names, shapley_values.shape) } shapley_values = shapley_values.align_to(..., NAME_HEIGHT, NAME_WIDTH, NAME_META_CHANNELS).flatten( [NAME_HEIGHT, NAME_WIDTH], NAME_FEATURES) # get the norm of the vectors for each of the pixels, based on # which the pruning will be performed abs_values = torch.linalg.norm(shapley_values.rename(None), ord=1, dim=-1) # generate top-k top_k = torch.topk(abs_values.rename(None), k, largest=False)[0].max( 1, keepdim=True)[0] # threshold of the values to prune abs_values = abs_values.rename(NAME_BATCH_SIZE, NAME_FEATURES) > top_k shapley_values = shapley_values * abs_values.align_to( ..., NAME_META_CHANNELS) shapley_values = shapley_values.unflatten( NAME_FEATURES, [[NAME_HEIGHT, name_size[NAME_HEIGHT]], [NAME_WIDTH, name_size[NAME_WIDTH]]]) shapley_values = shapley_values.align_to(*name_size.keys()) return shapley_values
def prune( self, shapley_values: torch.Tensor, id_stage: int ) -> torch.Tensor: r""" Prune the output of a stage Args: id_stage (): the index of the stage shapley_values (): the current stage's pre-pruning output Returns: the pruned Shapley representation """ # First check if we need pruning in the first place pruning = self.pruning[id_stage] num_features = named_tensor_get_dim(shapley_values, NAME_FEATURES) k = int(num_features * pruning) # the number to prune if pruning * k == 0: return shapley_values names = shapley_values.names shapley_values = shapley_values.align_to( ..., NAME_FEATURES, NAME_META_CHANNELS) # get the norm of the vectors for each of the pixels, based on # which the pruning will be performed abs_values = torch.linalg.norm( shapley_values.rename(None), ord=1, dim=-1) # generate top-k top_k = torch.topk( abs_values.rename(None), k, largest=False )[0].max(1, keepdim=True)[0] # threshold of the values to prune shapley_values = shapley_values * ( abs_values > top_k).unsqueeze(-1) shapley_values = shapley_values.align_to(*names) return shapley_values
def _final_process(self, shapley_values: torch.Tensor, id_stage: int, *args, **kwargs) -> torch.Tensor: r""" Depending on the usage, this method is for final processing before output the values Args: shapley_values (): the shapley values computed from the last stage args (): placeholder kwargs (): placeholder Returns: Shapley values final for presenting """ return shapley_values.align_to(NAME_BATCH_SIZE, ..., NAME_META_CHANNELS)
def compute_shapley(self, function_outputs: torch.Tensor) -> torch.Tensor: r""" Args: function_outputs (): of shape (2 ** features, batch_size, output_channels) Returns: Shapley values for each of the variables should be of shape (batch_size, self.m, output_channels) """ shapley_values = torch.matmul( function_outputs.align_to(..., NAME_NUM_PASSES), self.subtraction_matrix ).align_to(NAME_BATCH_SIZE, ..., NAME_FEATURES, NAME_META_CHANNELS) return shapley_values
def reshape(self, shapley_values: torch.Tensor) -> torch.Tensor: r""" In this instantiation, this method tries to unfold the representation and prepare for the Shapley Module. Args: shapley_values (): the input Shapley representation Returns: the prepared Shapley representation for the Shapley module that follows """ # Prepare the input, at the end should be of shape shapley_values_un = self.folder(shapley_values) # (batch_size, patches, features, meta-channels) shapley_values = shapley_values_un.unflatten( NAME_FEATURES_META_CHANNEL, [ (NAME_META_CHANNELS, self.dimensions.in_channel), (NAME_FEATURES, np.prod(self.kernel_size)), ]) shapley_values = shapley_values.align_to(..., NAME_FEATURES, NAME_META_CHANNELS) return shapley_values
def forward(self, emb_inputs: torch.Tensor) -> torch.Tensor: r"""Forward calculation of CompressInteractionNetworkLayer Args: emb_inputs (T), shape = (B, N, E), dtype = torch.float: Embedded features tensors. Returns: T, shape = (B, O), dtype = torch.float: Output of CompressInteractionNetworkLayer. """ # Initialize two lists to store tensors of outputs and next steps temporarily direct_list = list() hidden_list = list() # Transpose emb_inputs # inputs: emb_inputs, shape = (B, N, E) # output: x0, shape = (B, E, N) x0 = emb_inputs.align_to("B", "E", "N") hidden_list.append(x0) # Expand dimension N of x0 to Nx (= N) and H (= 1) # inputs: x0, shape = (B, E, N) # output: x0, shape = (B, E, Nx = N, H = 1) x0 = x0.unflatten("N", [("Nx", x0.size("N")), ("H", 1)]) # Calculate with cin forwardly for i, layer_size in enumerate(self.layer_sizes[:-1]): # Get tensors of previous step and reshape it # inputs: hidden_list[-1], shape = (B, E, N) # output: xi, shape = (B, E, H = 1, Ny = N) xi = hidden_list[-1] xi = xi.unflatten("N", [("H", 1), ("Ny", xi.size("N"))]) # Calculate outer product of x0 and x1 # inputs: x0, shape = (B, E, Nx = N, H = 1) # inputs: x1, shape = (B, E, H = 1, Ny = N) # output: out_prod, shape = (B, E, Nx = N, Ny = N) ## out_prod = torch.matmul(x0, xi) out_prod = torch.einsum("ijkn,ijnh->ijkh", [x0.rename(None), x1.rename(None)]) out_prod.names = ("B", "E", "Nx", "Ny") # Reshape out_prod # inputs: out_prod, shape = (B, E, Nx = N, Ny = N) # output: out_prod, shape = (B, N = Nx * Ny, E) out_prod = out_prod.flatten(["Nx", "Ny"], "N") out_prod = out_prod.align_to("B", "N", "E") # Apply convalution, batchnorm and activation # inputs: out_prod, shape = (B, N = Nx * Ny, E) # output: outputs, shape = (B, N = (Hi * 2 or Hi), E) outputs = self.model[i](out_prod.rename(None)) outputs.names = ("B", "N", "E") if self.is_direct: # Pass to output directly # inputs: outputs, shape = (B, N = Hi, E) # output: direct, shape = (B, N = Hi, E) direct = outputs # Reshape and pass to next step directly # inputs: outputs, shape = (B, Hi, E) # output: hidden, shape = (B, E, N = Hi) hidden = outputs.align_to("B", "E", "N") else: if i != (len(self.layer_sizes) - 1): # Split outputs into two part and pass them to outputs and hidden separately # inputs: outputs, shape = (B, Hi * 2, E) # output: direct, shape = (B, N = Hi, E) # output: hidden, shape = (B, N = Hi, E) direct, hidden = torch.chunk(outputs, 2, dim="N") # Reshape and pass to next step # inputs: hidden, shape = (B, N = Hi, E) # output: hidden, shape = (B, E, N = Hi) hidden = hidden.align_to("B", "E", "N") else: # Pass to output directly # inputs: outputs, shape = (B, N = Hi, E) # output: direct, shape = (B, N = Hi, E) direct = outputs hidden = 0 # Store tensors to lists temporarily direct_list.append(direct) hidden_list.append(hidden) # Concatenate direct_list into a tensor # inputs: direct_list, shape = (B, Hi, E) # output: outputs, shape = (B, sum(Hi), E) outputs = torch.cat(direct_list, dim="N") # Aggregate outputs on dimension E and pass to dense layer # inputs: outputs, shape = (B, sum(Hi), E) # output: outputs, shape = (B, O) outputs = self.fc(outputs.sum("E")) outputs.names = ("B", "O") return outputs
def forward( self, input_image: torch.Tensor, masked_kspace: torch.Tensor, sensitivity_map: torch.Tensor, sampling_mask: torch.Tensor, previous_state: Optional[torch.Tensor] = None, loglikelihood_scaling: Optional[float] = None, **kwargs, ): """ Parameters ---------- input_image : torch.Tensor Initial or intermediate guess of input. masked_kspace : torch.Tensor Kspace masked by the sampling mask. sensitivity_map : torch.Tensor Coil sensitivities. sampling_mask : torch.Tensor Sampling mask. previous_state : torch.Tensor loglikelihood_scaling : torch.Tensor Returns ------- torch.Tensor """ # TODO: This has to be made contiguous input_image = input_image.align_to( "batch", "complex", "height", "width").contiguous() # type: ignore batch_size = input_image.size("batch") spatial_shape = [input_image.size("height"), input_image.size("width")] # Initialize zero state for RIM state_size = ([batch_size, self.num_hidden_channels] + list(spatial_shape) + [self.depth]) if previous_state is None: previous_state = torch.zeros( *state_size, dtype=input_image.dtype).to(input_image.device) cell_outputs = [] intermediate_image = input_image for cell_idx in range(self.length): cell = self.cell_list[ cell_idx] if self.no_sharing else self.cell_list[0] grad_loglikelihood = self.grad_likelihood( intermediate_image, masked_kspace, sensitivity_map, sampling_mask, loglikelihood_scaling, ) if grad_loglikelihood.abs().max() > 150.0: warnings.warn( f"Very large values for the gradient loglikelihood ({grad_loglikelihood.abs().max()}). " f"Might cause difficulties.") cell_input = torch.cat( [ intermediate_image.rename(None), grad_loglikelihood.rename(None) ], dim=1, ) cell_output, previous_state = cell(cell_input, previous_state) if self.skip_connections: intermediate_image = intermediate_image + cell_output if not self.training: # If not training, memory can be significantly reduced by clearing the previous cell. cell_output.set_() grad_loglikelihood.rename(None).set_( ) # TODO: Fix when named tensors have this support. del cell_output, grad_loglikelihood # Only save intermediate reconstructions at training step if self.training or cell_idx == (self.length - 1): cell_outputs.append( intermediate_image.refine_names("batch", "complex", "height", "width")) # type: ignore return cell_outputs, previous_state