def _replace_output(self, output, attention_map, data_shape): """Replaces the model output with the current attention map.""" if self.medcam_dict['_replace_output']: if len(attention_map.keys()) == 1: output = torch.tensor( self.medcam_dict['current_attention_map']).to( str(self.medcam_dict['device'])) if data_shape is not None: # If data_shape is None then the task is classification -> return unchanged attention map output = medcam_utils.interpolate(output, data_shape) else: raise ValueError( "Not possible to replace output when layer is 'full', only with 'auto' or a manually set layer" ) return output
def generate( self ): # TODO: Redo ggcam, find a solution for normalize_per_channel """Generates an attention map.""" for layer_name in self.attention_map_GCAM.keys(): if self.attention_map_GBP.shape == self.attention_map_GCAM[ layer_name].shape: self.attention_map_GCAM[layer_name] = np.multiply( self.attention_map_GCAM[layer_name], self.attention_map_GBP) else: attention_map_GCAM_tmp = medcam_utils.interpolate( self.attention_map_GCAM[layer_name], self.attention_map_GBP.shape[2:]) self.attention_map_GCAM[layer_name] = np.multiply( attention_map_GCAM_tmp, self.attention_map_GBP) self.attention_map_GCAM[layer_name] = self._normalize_per_channel( self.attention_map_GCAM[layer_name]) return self.attention_map_GCAM
def _generate_helper(self, fmaps, grads, layer): B, C, *data_shape = grads.size() alpha_num = grads.pow(2) tmp = fmaps.mul(grads.pow(3)) tmp = tmp.view(B, C, prod(data_shape)) tmp = tmp.sum(-1, keepdim=True) if self.input_dim == 2: tmp = tmp.view(B, C, 1, 1) else: tmp = tmp.view(B, C, 1, 1, 1) alpha_denom = grads.pow(2).mul(2) + tmp alpha_denom = torch.where(alpha_denom != 0.0, alpha_denom, torch.ones_like(alpha_denom)) alpha = alpha_num.div(alpha_denom + 1e-7) if self.mask is not None: mask = self.mask.squeeze() if self.mask is None: # Classification prob_weights = torch.tensor(1.0) elif len(mask.shape) == 1: # Classification best/index prob_weights = self.logits.squeeze()[torch.argmax(mask)] else: # Segmentation masked_logits = self.logits * self.mask prob_weights = medcam_utils.interpolate(masked_logits, grads.shape[2:]) # TODO: Still removes channels... positive_gradients = F.relu(torch.mul(prob_weights.exp(), grads)) weights = (alpha * positive_gradients).view(B, C, -1).sum(-1) if self.input_dim == 2: weights = weights.view(B, C, 1, 1) else: weights = weights.view(B, C, 1, 1, 1) attention_map = (weights * fmaps) try: attention_map = attention_map.view(B, self.output_channels, -1, *data_shape) except RuntimeError: raise RuntimeError("Number of set channels ({}) is not a multiple of the feature map channels ({}) in layer: {}".format(self.output_channels, fmaps.shape[1], layer)) attention_map = torch.sum(attention_map, dim=2) attention_map = F.relu(attention_map).detach() attention_map = self._normalize_per_channel(attention_map) return attention_map
def _preprocessing(attention_map, mask, attention_threshold): """Interpolates, normalizes and binarizes the attention map.""" if not np.isfinite(attention_map).all(): raise ValueError("Attention map contains non finite elements") if not np.isfinite(mask).all(): raise ValueError("Mask contains non finite elements") if np.sum( attention_map < 0 ) > 0: # For gbp and ggcam as they contain negative values, which would otherwise falsify the evaluation attention_map = np.abs(attention_map) attention_map = medcam_utils.interpolate(attention_map, mask.shape, squeeze=True) attention_map = medcam_utils.normalize(attention_map.astype(np.float)) weights = copy.deepcopy(attention_map) mask = np.array(mask, dtype=int) if np.min(attention_map) == np.max(attention_map): attention_threshold = 1 elif attention_threshold == 'otsu': attention_threshold = threshold_otsu(attention_map.flatten()) attention_map[attention_map < attention_threshold] = 0 attention_map[attention_map >= attention_threshold] = 1 attention_map = np.array(attention_map, dtype=int) return attention_map, mask, weights