class SegmentationInspector: """ For image binnary segmentation tasks, create a HTML report with best segmented samples, worst segmented samples, and samples closest to the boundary decision. Args: nb_show (int): how many samples to show classes (list of str): classes names post_each_batch (bool) whether to generate a new report on each batch or only on epoch end. """ def __init__(self, nb_show, classes, post_each_batch=True): self.vis = SIVis(nb_show, classes, 0.5) self.post_each_batch = post_each_batch def on_epoch_start(self, state): if 'report' in state['metrics']: del state['metrics']['report'] self.vis.reset() @torch.no_grad() def on_batch_end(self, state): pred, y, x = state['pred'], state['batch'][1], state['batch'][0] self.vis.analyze(x, pred, y) if self.post_each_batch and state.get('visdom_will_log', False): state['metrics']['report'] = self.vis.show() def on_epoch_end(self, state): state['metrics']['report'] = self.vis.show()
def __init__(self, nb_show, classes, post_each_batch=True): self.vis = SIVis(nb_show, classes, 0.5) self.post_each_batch = post_each_batch