Пример #1
0
    def _weight_to_image(
            weight: Tensor,
            kernel_channels_last: bool = False) -> Optional[Tensor]:
        """Logs a weight as a TensorBoard image.

        Implementation from TensorFlow codebase, would have invoked theirs directly but they didn't make it a static
        method.
        """
        w_img = squeeze(weight)
        shape = backend.int_shape(w_img)
        if len(shape) == 1:  # Bias case
            w_img = reshape(w_img, [1, shape[0], 1, 1])
        elif len(shape) == 2:  # Dense layer kernel case
            if shape[0] > shape[1]:
                w_img = permute(w_img, [0, 1])
                shape = backend.int_shape(w_img)
            w_img = reshape(w_img, [1, shape[0], shape[1], 1])
        elif len(shape) == 3:  # ConvNet case
            if kernel_channels_last:
                # Switch to channels_first to display every kernel as a separate images
                w_img = permute(w_img, [2, 0, 1])
            w_img = expand_dims(w_img, axis=-1)
        elif len(shape) == 4:  # Conv filter with multiple input channels
            if kernel_channels_last:
                # Switch to channels first to display kernels as separate images
                w_img = permute(w_img, [3, 2, 0, 1])
            w_img = reduce_sum(
                abs(w_img),
                axis=1)  # Sum over the each channel within the kernel
            w_img = expand_dims(w_img, axis=-1)
        shape = backend.int_shape(w_img)
        # Not possible to handle 3D convnets etc.
        if len(shape) == 4 and shape[-1] in [1, 3, 4]:
            return w_img
Пример #2
0
    def _convert_for_visualization(tensor: Tensor,
                                   tile: int = 99) -> np.ndarray:
        """Modify the range of data in a given input `tensor` to be appropriate for visualization.

        Args:
            tensor: Input masks, whose channel values are to be reduced by absolute value summation.
            tile: The percentile [0-100] used to set the max value of the image.

        Returns:
            A (batch X width X height) image after visualization clipping is applied.
        """
        if isinstance(tensor, torch.Tensor):
            channel_axis = 1
        else:
            channel_axis = -1
        flattened_mask = reduce_sum(abs(tensor),
                                    axis=channel_axis,
                                    keepdims=True)

        non_batch_axes = list(range(len(flattened_mask.shape)))[1:]

        vmax = percentile(flattened_mask,
                          tile,
                          axis=non_batch_axes,
                          keepdims=True)
        vmin = reduce_min(flattened_mask, axis=non_batch_axes, keepdims=True)

        return clip_by_value((flattened_mask - vmin) / (vmax - vmin), 0, 1)