示例#1
0
class JupyterPlottingContext(PlottingContextBase):
    """ plotting in a jupyter widget using the `inline` backend """

    supports_update = False
    """ flag indicating whether the context supports that plots can be updated
    with out redrawing the entire plot. The jupyter backend (`inline`) requires
    replotting of the entire figure, so an update is not supported."""

    def __enter__(self):
        from IPython.display import display
        from ipywidgets import Output

        if self.initial_plot:
            # close all previous plots
            import matplotlib.pyplot as plt

            plt.close("all")

            # create output widget for capturing all plotting
            self._ipython_out = Output()

            if self.show:
                # only show the widget if necessary
                display(self._ipython_out)

        # capture plots in the output widget
        self._ipython_out.__enter__()

    def __exit__(self, *exc):
        import matplotlib.pyplot as plt

        # finalize plot
        super().__exit__(*exc)

        if self.show:
            # show the plot, but ...
            plt.show()  # show the figure to make sure it can be captured
        # ... also clear it the next time something is done
        self._ipython_out.clear_output(wait=True)

        # stop capturing plots in the output widget
        self._ipython_out.__exit__(*exc)

        # close the figure, so figure windows do not accumulate
        plt.close(self.fig)

    def close(self):
        """ close the plot """
        super().close()
        # close ipython output
        try:
            self._ipython_out.close()
        except Exception:
            pass
示例#2
0
文件: vis.py 项目: w-hc/pcv
class Visualizer():
    def __init__(self, cfg, dset_meta, pcv):
        self.cfg = cfg
        self.output_widget = Output()
        self.dset_meta = dset_meta
        self.pcv = pcv
        self.trainId_2_catName = dset_meta['trainId_2_catName']
        self.category_meta = dset_meta['cats']
        self.catId_2_trainId = dset_meta['catId_2_trainId']
        self.init_state()
        self.pressed = False

        np.set_printoptions(
            formatter={'float': lambda x: "{:.2f}".format(x)}
        )

    def init_state(self):
        self.fig, self.canvas, self.plots = None, None, None

    def __del__(self):
        self.clear_state()
        self.output_widget.close()

    def clear_state(self):
        if self.fig is not None:
            self.disconnect()
            plt.close(self.fig)
            self.init_state()

    def display_stdout_and_err_in_curr_cell(self):
        """
        in JLab, stdout and stderr from widget callbacks
        must be displayed through a specialized output widget
        """
        ipy_display(self.output_widget)

    def connect(self):
        decor = self.output_widget.capture()
        self.cidpress = self.canvas.mpl_connect(
            'button_press_event', decor(self.on_press))
        self.cidrelease = self.canvas.mpl_connect(
            'button_release_event', decor(self.on_release))
        self.cidmotion = self.canvas.mpl_connect(
            'motion_notify_event', decor(self.on_motion))

    def disconnect(self):
        'disconnect all the stored connection ids'
        self.canvas.mpl_disconnect(self.cidpress)
        self.canvas.mpl_disconnect(self.cidrelease)
        self.canvas.mpl_disconnect(self.cidmotion)

    def on_press(self, event):
        self.pressed = True
        ax_in_focus = event.inaxes
        if ax_in_focus is None:
            return
        x, y, button = int(event.xdata), int(event.ydata), event.button
        for k, plot in self.plots.items():
            if ax_in_focus == plot.ax:
                plot.press_coord(x, y, button)
            else:
                plot.query_coord(x, y, button)

    def on_motion(self, event):
        if not self.pressed:
            return
        ax_in_focus = event.inaxes
        if ax_in_focus is None:
            return

        x, y = int(event.xdata), int(event.ydata)
        for k, plot in self.plots.items():
            if ax_in_focus == plot.ax:
                plot.motion_coord(x, y)

    def on_release(self, event):
        self.pressed = False

    @torch.no_grad()
    def vis(
        self, im, pan_mask, segments_info, sem_pred, vote_pred,
        gt_prod_handle, loss_module, h_thresh
    ):
        """Bulk of the logic
        Args: these are possible data to visualize
            im:        [H, W, 3] of PIL Image
            pan_mask:  [H, W, 3] of PIL Image
            segments_info: dict
            sem_pred:  [1, num_classes, H, W] torch gpu tsr
            vote_pred: [1, num_bins, H, W] torch gpu tsr
        """

        ins_mask = MaskFromVote(
            self.cfg.pcv, self.dset_meta, self.pcv, sem_pred.clone(), vote_pred.clone()
        ).infer_panoptic_mask(instance_mask_only=True)[0]

        full_mask, pred_ann = MaskFromVote(
            self.cfg.pcv, self.dset_meta, self.pcv, sem_pred.clone(), vote_pred.clone()
        ).infer_panoptic_mask(instance_mask_only=False)

        # get_each_instance separately
        pairs = []
        tmp_mfv= MaskFromVote(
            self.cfg.pcv, self.dset_meta, self.pcv, sem_pred.clone(), vote_pred.clone()
        )
        peak_regions, _, peak_bbox = \
            tmp_mfv.locate_peak_regions(tmp_mfv.vote_hmap, tmp_mfv.hmap_thresh)
        _, instance_tsr, _ = tmp_mfv.peak_conv_mask_match(
            tmp_mfv.thing_trainIds, tmp_mfv.query_mask,
            tmp_mfv.vote_decision, tmp_mfv.sem_decision, peak_bbox
        )
        if len(np.unique(peak_regions)) - 1 == len(instance_tsr):
            for _i, _ins_mask in enumerate(instance_tsr.cpu().numpy()):
                _reg = peak_regions == (_i+1)
                pairs.append((_reg, _ins_mask))

        self.mfv = MaskFromVote(
            self.cfg.pcv, self.dset_meta, self.pcv, sem_pred.clone(), vote_pred.clone()
        )
        data = self.process_data(
            im, pan_mask, segments_info, sem_pred, vote_pred,
            gt_prod_handle, loss_module, h_thresh
        )
        self.data = data  # store it so that it can be accessed externally
        data['ins_mask'] = ins_mask
        data['full_mask'] = full_mask
        data['pairs'] = pairs
        # plt.imshow(id2rgb(full_mask))
        # plt.show()

        # data['d2_vis'] = d2_vis(self.dset_meta, full_mask, pred_ann, data['im'])

        self.clear_state()
        num_plots = len(plot_device_registry)
        num_per_row = 3
        nrows = (num_plots + num_per_row - 1) // num_per_row
        fig = plt.figure(figsize=(20, 12), constrained_layout=True)
        self.fig = fig
        self.canvas = fig.canvas
        self.plots = dict()
        gs = GridSpec(nrows, num_per_row, figure=fig)
        for i, k in enumerate(plot_device_registry.keys()):
            ax = fig.add_subplot(gs[i // num_per_row, i % num_per_row])
            ax.set_title(k)
            device = plot_device_registry[k]
            self.plots[k] = device(ax, data, self)
        # self.plots['sem_pred'].data['sem_pred'] = id2rgb(full_mask)
        # self.plots['sem_pred'].render_visual()
        self.connect()

    def process_data(
        self, im, pan_img, segments_info, sem_pred, vote_pred,
        gt_prod_handle, loss_module, h_thresh
    ):
        data = {}
        mfv = self.mfv

        # 1. store data derived from gt; sem and vote pred are already softmaxed!
        generator = gt_prod_handle(
            self.dset_meta, self.pcv, pan_img, segments_info
        )
        gts = generator.generate_gt()
        sem_gt, vote_gt = gts[:2]  # the first 2 are always these
        centroids = generator.ins_centroids
        _, vote_tsr = generator.collect_prob_tsr()
        vote_tsr = vote_tsr[:, :-1, :, :]

        data['im'], data['pan_img'] = np.array(im), np.array(pan_img)
        data['pan_mask'] = rgb2id(data['pan_img'])
        data['sem_gt'] = sem_gt
        data['vote_gt_pred'] = vote_tsr.squeeze(axis=0).transpose(1, 2, 0)
        data['vote_gt'], data['ins_centroids'] = vote_gt, centroids
        data['vote_gt_hmap'] = mfv.pixel_consensus_voting(
            torch.as_tensor(vote_tsr).float().cuda()
        )

        # 2. compute and analyze loss
        loss_info = compute_loss(loss_module, gts, sem_pred, vote_pred)
        stats = SegmentLossStats(
            loss_info, data['pan_mask'], segments_info, self.dset_meta['cats']
        )
        stats.summarize()
        data['loss_info'] = loss_info
        data['seg_loss_stats'] = stats

        # 3. store data derived from pred
        data['sem_pred'], data['sem_decision'] = mfv.sem_pred.cpu().numpy(), mfv.sem_decision
        data['vote_pred'], data['vote_decision'] = mfv.vote_pred, mfv.vote_decision
        data['vote_pred_hmap'] = mfv.vote_hmap
        ws_mask, peaks, peak_bbox = mfv.locate_peak_regions(mfv.vote_hmap, h_thresh)
        data['ws_mask'], data['ws_peak_points'] = ws_mask, peaks

        return data