예제 #1
0
    def visualize(self, attribution, data, contribution_frac) -> FeatureOutput:
        attribution = attribution.squeeze()
        data = data.squeeze()
        data_t = np.transpose(data.cpu().detach().numpy(), (1, 2, 0))
        attribution_t = np.transpose(
            attribution.squeeze().cpu().detach().numpy(), (1, 2, 0))

        orig_fig, _ = viz.visualize_image_attr(attribution_t,
                                               data_t,
                                               method="original_image",
                                               use_pyplot=False)
        attr_fig, _ = viz.visualize_image_attr(
            attribution_t,
            data_t,
            method="heat_map",
            sign="absolute_value",
            use_pyplot=False,
        )

        img_64 = _convert_figure_base64(orig_fig)
        attr_img_64 = _convert_figure_base64(attr_fig)

        return FeatureOutput(
            name=self.name,
            base=img_64,
            modified=attr_img_64,
            type=self.visualization_type(),
            contribution=contribution_frac,
        )
예제 #2
0
    def visualize(self, attribution, data, contribution_frac) -> FeatureOutput:
        if self.visualization_transform:
            data = self.visualization_transform(data)

        data_t, attribution_t = [
            t.detach().squeeze().permute((1, 2, 0)).cpu().numpy()
            for t in (data, attribution)
        ]

        orig_fig, _ = viz.visualize_image_attr(attribution_t,
                                               data_t,
                                               method="original_image",
                                               use_pyplot=False)
        attr_fig, _ = viz.visualize_image_attr(
            attribution_t,
            data_t,
            method="heat_map",
            sign="absolute_value",
            use_pyplot=False,
        )

        img_64 = _convert_figure_base64(orig_fig)
        attr_img_64 = _convert_figure_base64(attr_fig)

        return FeatureOutput(
            name=self.name,
            base=img_64,
            modified=attr_img_64,
            type=self.visualization_type(),
            contribution=contribution_frac,
        )