def log_jaccard(im_id: str, cls: int, true_mask: np.ndarray, mask: np.ndarray, poly_mask: np.ndarray, true_poly: MultiPolygon, poly: MultiPolygon, valid_polygons=False): assert len(mask.shape) == 2 pixel_jc = utils.mask_tp_fp_fn(mask, true_mask, 0.5) if valid_polygons: if not true_poly.is_valid: true_poly = utils.to_multipolygon(true_poly.buffer(0)) if not poly.is_valid: poly = utils.to_multipolygon(poly.buffer(0)) tp = true_poly.intersection(poly).area fn = true_poly.difference(poly).area fp = poly.difference(true_poly).area poly_jc = tp, fp, fn else: poly_jc = utils.mask_tp_fp_fn(poly_mask, true_mask, 0.5) logger.info( '{} cls-{} pixel jaccard: {:.5f}, polygon jaccard: {:.5f}'.format( im_id, cls, jaccard(pixel_jc), jaccard(poly_jc))) return pixel_jc, poly_jc
def _update_jaccard(self, stats, mask, pred): assert mask.shape == pred.shape assert len(mask.shape) in {3, 4} for cls, tp_fp_fn in stats.items(): cls_idx = self.hps.classes.index(cls) if len(mask.shape) == 3: assert mask.shape[0] == self.hps.n_classes p, y = pred[cls_idx], mask[cls_idx] else: assert mask.shape[1] == self.hps.n_classes p, y = pred[:, cls_idx], mask[:, cls_idx] for threshold, (tp, fp, fn) in tp_fp_fn.items(): _tp, _fp, _fn = utils.mask_tp_fp_fn(p, y, threshold) tp.append(_tp) fp.append(_fp) fn.append(_fn)