コード例 #1
0
ファイル: eval_utils.py プロジェクト: chbinb/mrs
    def infer(self, model, pred_dir, patch_size, overlap, ext='_mask', file_ext='png', visualize=False,
              densecrf=False, crf_params=None):
        if isinstance(model, list) or isinstance(model, tuple):
            lbl_margin = model[0].lbl_margin
        else:
            lbl_margin = model.lbl_margin
        if crf_params is None and densecrf:
            crf_params = {'sxy': 3, 'srgb': 3, 'compat': 5}

        misc_utils.make_dir_if_not_exist(pred_dir)
        pbar = tqdm(self.rgb_files)
        for rgb_file in pbar:
            file_name = os.path.splitext(os.path.basename(rgb_file))[0].split('_')[0]
            pbar.set_description('Inferring {}'.format(file_name))
            # read data
            rgb = misc_utils.load_file(rgb_file)[:, :, :3]

            # evaluate on tiles
            tile_dim = rgb.shape[:2]
            tile_dim_pad = [tile_dim[0] + 2 * lbl_margin, tile_dim[1] + 2 * lbl_margin]
            grid_list = patch_extractor.make_grid(tile_dim_pad, patch_size, overlap)

            if isinstance(model, list) or isinstance(model, tuple):
                tile_preds = 0
                for m in model:
                    tile_preds = tile_preds + self.infer_tile(m, rgb, grid_list, patch_size, tile_dim, tile_dim_pad,
                                                              lbl_margin)
            else:
                tile_preds = self.infer_tile(model, rgb, grid_list, patch_size, tile_dim, tile_dim_pad, lbl_margin)

            if densecrf:
                d = dcrf.DenseCRF2D(*tile_preds.shape)
                U = unary_from_softmax(np.ascontiguousarray(
                    data_utils.change_channel_order(tile_preds, False)))
                d.setUnaryEnergy(U)
                d.addPairwiseBilateral(rgbim=rgb, **crf_params)
                Q = d.inference(5)
                tile_preds = np.argmax(Q, axis=0).reshape(*tile_preds.shape[:2])
            else:
                tile_preds = np.argmax(tile_preds, -1)

            if self.encode_func:
                pred_img = self.encode_func(tile_preds)
            else:
                pred_img = tile_preds

            if visualize:
                vis_utils.compare_figures([rgb, pred_img], (1, 2), fig_size=(12, 5))

            misc_utils.save_file(os.path.join(pred_dir, '{}{}.{}'.format(file_name, ext, file_ext)), pred_img)
コード例 #2
0
ファイル: network_utils.py プロジェクト: bohaohuang/pgm_torch
    def evaluate(self,
                 model,
                 patch_size,
                 overlap,
                 pred_dir=None,
                 report_dir=None,
                 save_conf=False):
        iou_a, iou_b = 0, 0
        report = []
        if pred_dir:
            misc_utils.make_dir_if_not_exist(pred_dir)
        for lbl_file in self.lbl_files:
            # only do NZ image
            if 'NZ' not in lbl_file:
                continue

            # parse ground truth
            file_name = os.path.splitext(os.path.basename(lbl_file))[0]
            boxes, lines = data_utils.csv_to_annotation(lbl_file)
            tile_id = int(''.join([a for a in file_name if a.isdigit()]))

            if tile_id > 3:
                continue

            # read rgb image
            rgb_file = os.path.join(self.image_dir, '{}.tif'.format(file_name))
            rgb = misc_utils.load_file(rgb_file)[..., :3]
            # make line map
            lbl, _ = data_utils.render_line_graph(rgb.shape[:2], boxes, lines)

            # evaluate on tiles
            tile_dim = rgb.shape[:2]
            grid_list = patch_extractor.make_grid(tile_dim, patch_size,
                                                  overlap)
            tile_preds = []
            for patch in patch_extractor.patch_block(rgb, 0, grid_list,
                                                     patch_size, False):
                for tsfm in self.tsfm:
                    tsfm_image = tsfm(image=patch)
                    patch = tsfm_image['image']
                patch = torch.unsqueeze(patch, 0).to(self.device)
                pred = model.inference(patch).detach().cpu().numpy()
                tile_preds.append(change_channel_order(pred, True)[0, :, :, :])
            # stitch back to tiles
            tile_preds = patch_extractor.unpatch_block(np.array(tile_preds),
                                                       tile_dim,
                                                       patch_size,
                                                       tile_dim,
                                                       patch_size,
                                                       overlap=0)
            if save_conf:
                misc_utils.save_file(
                    os.path.join(pred_dir, '{}.npy'.format(file_name)),
                    tile_preds)
            tile_preds = np.argmax(tile_preds, -1)
            a, b = metric_utils.iou_metric(lbl, tile_preds)
            file_name = os.path.splitext(os.path.basename(lbl_file))[0]
            print('{}: IoU={:.2f}'.format(file_name, a / b * 100))
            report.append('{},{},{},{}\n'.format(file_name, a, b, a / b * 100))
            iou_a += a
            iou_b += b
            if pred_dir:
                misc_utils.save_file(
                    os.path.join(pred_dir, '{}.png'.format(file_name)),
                    tile_preds)
        print('Overall: IoU={:.2f}'.format(iou_a / iou_b * 100))
        report.append('Overall,{},{},{}\n'.format(iou_a, iou_b,
                                                  iou_a / iou_b * 100))
        if report_dir:
            misc_utils.make_dir_if_not_exist(report_dir)
            misc_utils.save_file(os.path.join(report_dir, 'result.txt'),
                                 report)
コード例 #3
0
ファイル: eval_utils.py プロジェクト: xyt556/mrs
    def evaluate(self,
                 model,
                 patch_size,
                 overlap,
                 pred_dir=None,
                 report_dir=None,
                 save_conf=False,
                 delta=1e-6,
                 eval_class=(1, ),
                 visualize=False,
                 densecrf=False,
                 crf_params=None,
                 verbose=True):
        if isinstance(model, list) or isinstance(model, tuple):
            lbl_margin = model[0].lbl_margin
        else:
            lbl_margin = model.lbl_margin
        if crf_params is None and densecrf:
            crf_params = {'sxy': 3, 'srgb': 3, 'compat': 5}

        iou_a, iou_b = np.zeros(len(eval_class)), np.zeros(len(eval_class))
        report = []
        if pred_dir:
            misc_utils.make_dir_if_not_exist(pred_dir)
        for rgb_file, lbl_file in zip(self.rgb_files, self.lbl_files):
            file_name = os.path.splitext(os.path.basename(lbl_file))[0]

            # read data
            rgb = misc_utils.load_file(rgb_file)[:, :, :3]
            lbl = misc_utils.load_file(lbl_file)
            if self.decode_func:
                lbl = self.decode_func(lbl)

            # evaluate on tiles
            tile_dim = rgb.shape[:2]
            tile_dim_pad = [
                tile_dim[0] + 2 * lbl_margin, tile_dim[1] + 2 * lbl_margin
            ]
            grid_list = patch_extractor.make_grid(tile_dim_pad, patch_size,
                                                  overlap)

            if isinstance(model, list) or isinstance(model, tuple):
                tile_preds = 0
                for m in model:
                    tile_preds = tile_preds + self.infer_tile(
                        m, rgb, grid_list, patch_size, tile_dim, tile_dim_pad,
                        lbl_margin)
            else:
                tile_preds = self.infer_tile(model, rgb, grid_list, patch_size,
                                             tile_dim, tile_dim_pad,
                                             lbl_margin)

            if save_conf:
                misc_utils.save_file(
                    os.path.join(pred_dir, '{}.npy'.format(file_name)),
                    scipy.special.softmax(tile_preds, axis=-1)[:, :, 1])

            if densecrf:
                d = dcrf.DenseCRF2D(*tile_preds.shape)
                U = unary_from_softmax(
                    np.ascontiguousarray(
                        data_utils.change_channel_order(
                            scipy.special.softmax(tile_preds, axis=-1),
                            False)))
                d.setUnaryEnergy(U)
                d.addPairwiseBilateral(rgbim=rgb, **crf_params)
                Q = d.inference(5)
                tile_preds = np.argmax(Q,
                                       axis=0).reshape(*tile_preds.shape[:2])
            else:
                tile_preds = np.argmax(tile_preds, -1)
            iou_score = metric_utils.iou_metric(lbl / self.truth_val,
                                                tile_preds,
                                                eval_class=eval_class)
            pstr, rstr = self.get_result_strings(file_name, iou_score, delta)
            tm.misc_utils.verb_print(pstr, verbose)
            report.append(rstr)
            iou_a += iou_score[0, :]
            iou_b += iou_score[1, :]
            if visualize:
                if self.encode_func:
                    vis_utils.compare_figures([
                        rgb,
                        self.encode_func(lbl),
                        self.encode_func(tile_preds)
                    ], (1, 3),
                                              fig_size=(15, 5))
                else:
                    vis_utils.compare_figures([rgb, lbl, tile_preds], (1, 3),
                                              fig_size=(15, 5))
            if pred_dir:
                if self.encode_func:
                    misc_utils.save_file(
                        os.path.join(pred_dir, '{}.png'.format(file_name)),
                        self.encode_func(tile_preds))
                else:
                    misc_utils.save_file(
                        os.path.join(pred_dir, '{}.png'.format(file_name)),
                        tile_preds)
        pstr, rstr = self.get_result_strings('Overall',
                                             np.stack([iou_a, iou_b], axis=0),
                                             delta)
        tm.misc_utils.verb_print(pstr, verbose)
        report.append(rstr)
        if report_dir:
            misc_utils.make_dir_if_not_exist(report_dir)
            misc_utils.save_file(os.path.join(report_dir, 'result.txt'),
                                 report)
        return np.mean(iou_a / (iou_b + delta)) * 100