def _weight_to_image( weight: Tensor, kernel_channels_last: bool = False) -> Optional[Tensor]: """Logs a weight as a TensorBoard image. Implementation from TensorFlow codebase, would have invoked theirs directly but they didn't make it a static method. """ w_img = squeeze(weight) shape = backend.int_shape(w_img) if len(shape) == 1: # Bias case w_img = reshape(w_img, [1, shape[0], 1, 1]) elif len(shape) == 2: # Dense layer kernel case if shape[0] > shape[1]: w_img = permute(w_img, [0, 1]) shape = backend.int_shape(w_img) w_img = reshape(w_img, [1, shape[0], shape[1], 1]) elif len(shape) == 3: # ConvNet case if kernel_channels_last: # Switch to channels_first to display every kernel as a separate images w_img = permute(w_img, [2, 0, 1]) w_img = expand_dims(w_img, axis=-1) elif len(shape) == 4: # Conv filter with multiple input channels if kernel_channels_last: # Switch to channels first to display kernels as separate images w_img = permute(w_img, [3, 2, 0, 1]) w_img = reduce_sum( abs(w_img), axis=1) # Sum over the each channel within the kernel w_img = expand_dims(w_img, axis=-1) shape = backend.int_shape(w_img) # Not possible to handle 3D convnets etc. if len(shape) == 4 and shape[-1] in [1, 3, 4]: return w_img
def _convert_for_visualization(tensor: Tensor, tile: int = 99) -> np.ndarray: """Modify the range of data in a given input `tensor` to be appropriate for visualization. Args: tensor: Input masks, whose channel values are to be reduced by absolute value summation. tile: The percentile [0-100] used to set the max value of the image. Returns: A (batch X width X height) image after visualization clipping is applied. """ if isinstance(tensor, torch.Tensor): channel_axis = 1 else: channel_axis = -1 flattened_mask = reduce_sum(abs(tensor), axis=channel_axis, keepdims=True) non_batch_axes = list(range(len(flattened_mask.shape)))[1:] vmax = percentile(flattened_mask, tile, axis=non_batch_axes, keepdims=True) vmin = reduce_min(flattened_mask, axis=non_batch_axes, keepdims=True) return clip_by_value((flattened_mask - vmin) / (vmax - vmin), 0, 1)