예제 #1
0
class SegmentationInspector:
    """
    For image binnary segmentation tasks, create a HTML report with best segmented
    samples, worst segmented samples, and samples closest to the boundary
    decision.

    Args:
        nb_show (int): how many samples to show
        classes (list of str): classes names
        post_each_batch (bool) whether to generate a new report on each batch
            or only on epoch end.
    """
    def __init__(self, nb_show, classes, post_each_batch=True):
        self.vis = SIVis(nb_show, classes, 0.5)
        self.post_each_batch = post_each_batch

    def on_epoch_start(self, state):
        if 'report' in state['metrics']:
            del state['metrics']['report']
        self.vis.reset()

    @torch.no_grad()
    def on_batch_end(self, state):
        pred, y, x = state['pred'], state['batch'][1], state['batch'][0]
        self.vis.analyze(x, pred, y)
        if self.post_each_batch and state.get('visdom_will_log', False):
            state['metrics']['report'] = self.vis.show()

    def on_epoch_end(self, state):
        state['metrics']['report'] = self.vis.show()
예제 #2
0
 def __init__(self, nb_show, classes, post_each_batch=True):
     self.vis = SIVis(nb_show, classes, 0.5)
     self.post_each_batch = post_each_batch