コード例 #1
0
 def _create_singlerow_heatmaps_picture(self, idx: List[int], name: str, inpshp: torch.Size, lbl: int, subdir: str,
                                        res: int, imgs: Tensor, ascores: Tensor, grads: Tensor, gtmaps: Tensor,
                                        labels: List[int]):
     """
     Creates a picture of inputs, heatmaps (either based on ascores or grads, if grads is not None),
     and ground-truth maps (if not None, otherwise omitted).
     Row-wise: (1) inputs (2) heatmaps (3) ground-truth maps.
     Creates one version with local normalization and one with semi_global normalization.
     :param idx: limit the inputs (and corresponding other rows) to these indices.
     :param name: name to be used to store the picture.
     :param inpshp: the input shape (heatmaps will be resized to this).
     :param lbl: label of samples (indices), only used for naming.
     :param subdir: some subdirectory to store the data in.
     :param res: maximum allowed resolution in pixels (images are downsampled if they exceed this threshold).
     :param imgs: the input images.
     :param ascores: anomaly scores.
     :param grads: gradients.
     :param gtmaps: ground-truth maps.
     """
     for norm in ['local', 'global']:
         rows = [self._image_processing(imgs[idx], inpshp, maxres=res, qu=1)]
         if self.objective != 'hsc':
             err = self.objective != 'ae'
             rows.append(
                 self._image_processing(
                     ascores[idx], inpshp, maxres=res, colorize=True,
                     ref=balance_labels(ascores, labels, err) if norm == 'global' else None,
                     norm=norm.replace('semi_', ''),  # semi case is handled in the line above
                 )
             )
         if grads is not None:
             rows.append(
                 self._image_processing(
                     grads[idx], inpshp, self.blur_heatmaps, res, colorize=True,
                     ref=balance_labels(grads, labels) if norm == 'global' else None,
                     norm=norm.replace('semi_', ''),  # semi case is handled in the line above
                 )
             )
         if gtmaps is not None:
             rows.append(self._image_processing(gtmaps[idx], inpshp, maxres=res, norm=None))
         tim = torch.cat(rows)
         imname = '{}_paper_{}_lbl{}'.format(name, norm, lbl)
         self.logger.single_save(imname, torch.stack(rows), subdir=pt.join('tims', subdir))
         self.logger.imsave(imname, tim, nrow=len(idx), scale_mode='none', subdir=subdir)
コード例 #2
0
ファイル: bases.py プロジェクト: liznerski/fcdd
 def _create_heatmaps_picture(self,
                              idx: List[int],
                              name: str,
                              inpshp: torch.Size,
                              subdir: str,
                              nrow: int,
                              imgs: Tensor,
                              ascores: Tensor,
                              grads: Tensor,
                              gtmaps: Tensor,
                              labels: List[int],
                              norm: str = 'global'):
     """
     Creates a picture of inputs, heatmaps (either based on ascores or grads, if grads is not None),
     and ground-truth maps (if not None, otherwise omitted). Each row contains nrow many samples.
     One row contains always only one of {input, heatmaps, ground-truth maps}.
     The order of rows thereby is (1) inputs (2) heatmaps (3) ground-truth maps (4) blank.
     For instance, for 20 samples and nrow=10, the picture would show:
         - 10 inputs
         - 10 corresponding heatmaps
         - 10 corresponding ground-truth maps
         - blank
         - 10 inputs
         - 10 corresponding heatmaps
         - 10 corresponding ground-truth maps
     :param idx: limit the inputs (and corresponding other rows) to these indices.
     :param name: name to be used to store the picture.
     :param inpshp: the input shape (heatmaps will be resized to this).
     :param subdir: some subdirectory to store the data in.
     :param nrow: number of images per row.
     :param imgs: the input images.
     :param ascores: anomaly scores.
     :param grads: gradients.
     :param gtmaps: ground-truth maps.
     :param norm: what type of normalization to apply.
         None: no normalization.
         'local': normalizes each heatmap w.r.t. itself only.
         'global': normalizes each heatmap w.r.t. all heatmaps available (without taking idx into account),
             though it is ensured to consider equally many anomalous and nominal samples (if there are e.g. more
             nominal samples, randomly chosen nominal samples are ignored to match the correct amount).
         'semi-global: normalizes each heatmap w.r.t. all heatmaps chosen in idx.
     """
     number_of_rows = int(np.ceil(len(idx) / nrow))
     rows = []
     for s in range(number_of_rows):
         rows.append(
             self._image_processing(imgs[idx][s * nrow:s * nrow + nrow],
                                    inpshp,
                                    maxres=self.resdown,
                                    qu=1))
         if self.objective != 'hsc':
             rows.append(
                 self._image_processing(
                     ascores[idx][s * nrow:s * nrow + nrow],
                     inpshp,
                     maxres=self.resdown,
                     qu=self.quantile,
                     colorize=True,
                     ref=balance_labels(ascores, labels, False)
                     if norm == 'global' else ascores[idx],
                     norm=norm.replace(
                         'semi_',
                         ''),  # semi case is handled in the line above
                 ))
         if grads is not None:
             rows.append(
                 self._image_processing(
                     grads[idx][s * nrow:s * nrow + nrow],
                     inpshp,
                     self.blur_heatmaps,
                     self.resdown,
                     qu=self.quantile,
                     colorize=True,
                     ref=balance_labels(grads, labels, False)
                     if norm == 'global' else grads[idx],
                     norm=norm.replace(
                         'semi_',
                         ''),  # semi case is handled in the line above
                 ))
         if gtmaps is not None:
             rows.append(
                 self._image_processing(gtmaps[idx][s * nrow:s * nrow +
                                                    nrow],
                                        inpshp,
                                        maxres=self.resdown,
                                        norm=None))
         rows.append(torch.zeros_like(rows[-1]))
     name = '{}_{}'.format(name, norm)
     self.logger.imsave(name,
                        torch.cat(rows),
                        nrow=nrow,
                        scale_mode='none',
                        subdir=subdir)