Esempio n. 1
0
def validate(opt, model, test_transform):
    # list to store the losses and accuracies: [pixel_acc, aji ]
    results = utils.AverageMeter(2)

    # switch to evaluate mode
    model.eval()

    img_dir = '{:s}/val'.format(opt.train['img_dir'])
    label_dir = opt.test['label_dir']

    img_names = os.listdir(img_dir)
    for img_name in img_names:
        img_path = '{:s}/{:s}'.format(img_dir, img_name)
        img = Image.open(img_path)
        name = os.path.splitext(img_name)[0]

        label_path = '{:s}/{:s}_label.png'.format(label_dir, name)
        gt = misc.imread(label_path)

        input = test_transform((img, ))[0].unsqueeze(0)

        if opt.with_uncertainty:
            output, log_var = get_probmaps(input, model, opt)
            output = output.astype(np.float32)
            log_var = log_var.astype(np.float32)
            log_var = np.clip(log_var, a_min=np.min(log_var),
                              a_max=700)  # avoid inf value for float32 type
            prob_maps = np.zeros(output.shape)
            sigma = np.exp(log_var / 2)
            sigma = np.clip(sigma, a_min=0,
                            a_max=700)  # avoid inf value for float32 type
            for t in range(opt.T):
                x_t = output + sigma * np.random.normal(0, 1, output.shape)
                x_t = np.clip(x_t, a_min=0, a_max=700)
                prob_maps += np.exp(x_t) / (np.sum(np.exp(x_t), axis=0) + 1e-8)
            prob_maps /= opt.T
        else:
            prob_maps = get_probmaps(input, model, opt)

        pred = np.argmax(prob_maps, axis=0)  # prediction

        pred_labeled = measure.label(pred)
        pred_labeled = ski_morph.remove_small_objects(pred_labeled,
                                                      opt.post['min_area'])
        pred_labeled = binary_fill_holes(pred_labeled > 0)
        pred_labeled = measure.label(pred_labeled)

        metrics = compute_metrics(pred_labeled, gt, ['acc', 'aji'])
        result = [metrics['acc'], metrics['aji']]

        results.update(result, input.size(0))

    logger.info(
        '\t=> Val Avg:\tAcc {r[0]:.4f}\tAJI {r[1]:.4f}'.format(r=results.avg))

    return results.avg
Esempio n. 2
0
def validate(opt, model, test_transform, mode):
    # list to store the losses and accuracies: [pixel_acc, F1 ]
    results = utils.AverageMeter(2)

    # switch to evaluate mode
    model.eval()

    img_dir = '{:s}/test'.format(opt.train['img_dir'])
    label_dir = opt.test['label_dir']

    img_names = os.listdir(img_dir)
    for img_name in img_names:
        # load test image
        # print('=> Processing image {:s}'.format(img_name))
        img_path = '{:s}/{:s}'.format(img_dir, img_name)
        img = Image.open(img_path)
        name = os.path.splitext(img_name)[0]

        label_path = '{:s}/{:s}_label.png'.format(label_dir, name)
        gt = io.imread(label_path)

        input = test_transform((img, ))[0].unsqueeze(0)
        prob_maps = utils.get_probmaps(input, model, opt)
        pred = np.argmax(prob_maps, axis=0)  # prediction

        # pred = binary_fill_holes(pred)
        pred_label = measure.label(pred)
        pred = ski_morph.remove_small_objects(pred_label, opt.post['min_area'])
        # pred = pred > 0

        metrics = compute_metrics(pred, gt, ['acc', 'aji'])
        result = [metrics['acc'], metrics['aji']]

        results.update(result, input.size(0))

    logger.info('\t=> {:s} Avg:\tAcc {r[0]:.4f}\tAJI {r[1]:.4f}'.format(
        mode.upper(), r=results.avg))

    return results.avg
Esempio n. 3
0
def main():
    opt = Options(isTrain=False)
    opt.parse()
    opt.save_options()

    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
        str(x) for x in opt.test['gpus'])

    img_dir = opt.test['img_dir']
    label_dir = opt.test['label_dir']
    save_dir = opt.test['save_dir']
    model_path = opt.test['model_path']
    save_flag = opt.test['save_flag']

    # data transforms
    test_transform = get_transforms(opt.transform['test'])

    model = ResUNet34(pretrained=opt.model['pretrained'])
    model = torch.nn.DataParallel(model)
    model = model.cuda()
    cudnn.benchmark = True

    # ----- load trained model ----- #
    print("=> loading trained model")
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded model at epoch {}".format(checkpoint['epoch']))
    model = model.module

    # switch to evaluate mode
    model.eval()
    counter = 0
    print("=> Test begins:")

    img_names = os.listdir(img_dir)

    if save_flag:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        strs = img_dir.split('/')
        prob_maps_folder = '{:s}/{:s}_prob_maps'.format(save_dir, strs[-1])
        seg_folder = '{:s}/{:s}_segmentation'.format(save_dir, strs[-1])
        if not os.path.exists(prob_maps_folder):
            os.mkdir(prob_maps_folder)
        if not os.path.exists(seg_folder):
            os.mkdir(seg_folder)

    metric_names = ['acc', 'p_F1', 'p_recall', 'p_precision', 'dice', 'aji']
    test_results = dict()
    all_result = utils.AverageMeter(len(metric_names))

    for img_name in img_names:
        # load test image
        print('=> Processing image {:s}'.format(img_name))
        img_path = '{:s}/{:s}'.format(img_dir, img_name)
        img = Image.open(img_path)
        ori_h = img.size[1]
        ori_w = img.size[0]
        name = os.path.splitext(img_name)[0]
        label_path = '{:s}/{:s}_label.png'.format(label_dir, name)
        gt = misc.imread(label_path)

        input = test_transform((img, ))[0].unsqueeze(0)

        print('\tComputing output probability maps...')
        prob_maps = get_probmaps(input, model, opt)
        pred = np.argmax(prob_maps, axis=0)  # prediction

        pred_labeled = measure.label(pred)
        pred_labeled = morph.remove_small_objects(pred_labeled,
                                                  opt.post['min_area'])
        pred_labeled = ndi_morph.binary_fill_holes(pred_labeled > 0)
        pred_labeled = measure.label(pred_labeled)

        print('\tComputing metrics...')
        metrics = compute_metrics(pred_labeled, gt, metric_names)

        # save result for each image
        test_results[name] = [
            metrics['acc'], metrics['p_F1'], metrics['p_recall'],
            metrics['p_precision'], metrics['dice'], metrics['aji']
        ]

        # update the average result
        all_result.update([
            metrics['acc'], metrics['p_F1'], metrics['p_recall'],
            metrics['p_precision'], metrics['dice'], metrics['aji']
        ])

        # save image
        if save_flag:
            print('\tSaving image results...')
            misc.imsave('{:s}/{:s}_pred.png'.format(prob_maps_folder, name),
                        pred.astype(np.uint8) * 255)
            misc.imsave('{:s}/{:s}_prob.png'.format(prob_maps_folder, name),
                        prob_maps[1, :, :])
            final_pred = Image.fromarray(pred_labeled.astype(np.uint16))
            final_pred.save('{:s}/{:s}_seg.tiff'.format(seg_folder, name))

            # save colored objects
            pred_colored_instance = np.zeros((ori_h, ori_w, 3))
            for k in range(1, pred_labeled.max() + 1):
                pred_colored_instance[pred_labeled == k, :] = np.array(
                    utils.get_random_color())
            filename = '{:s}/{:s}_seg_colored.png'.format(seg_folder, name)
            misc.imsave(filename, pred_colored_instance)

        counter += 1
        if counter % 10 == 0:
            print('\tProcessed {:d} images'.format(counter))

    print('=> Processed all {:d} images'.format(counter))
    print('Average Acc: {r[0]:.4f}\nF1: {r[1]:.4f}\nRecall: {r[2]:.4f}\n'
          'Precision: {r[3]:.4f}\nDice: {r[4]:.4f}\nAJI: {r[5]:.4f}\n'.format(
              r=all_result.avg))

    header = metric_names
    utils.save_results(header, all_result.avg, test_results,
                       '{:s}/test_results.txt'.format(save_dir))
Esempio n. 4
0
def val(img_dir, label_dir, model, transform, opt, tb_writer, epoch):
    model.eval()
    img_names = os.listdir(img_dir)
    metric_names = ['acc', 'p_F1', 'p_recall', 'p_precision', 'dice', 'aji']
    val_results = dict()
    all_results = utils.AverageMeter(len(metric_names))

    plot_num = 10  #len(img_names)
    for img_name in img_names:
        img_path = '{:s}/{:s}'.format(img_dir, img_name)
        img = Image.open(img_path)
        ori_h = img.size[1]
        ori_w = img.size[0]
        name = os.path.splitext(img_name)[0]
        label_path = '{:s}/{:s}_label.png'.format(label_dir, name)
        gt = misc.imread(label_path)

        input = transform((img, ))[0].unsqueeze(0)

        prob_maps = get_probmaps(input, model, opt)
        pred = np.argmax(prob_maps, axis=0)

        pred_labeled = measure.label(pred)
        pred_labeled = morph.remove_small_objects(pred_labeled,
                                                  opt.post['min_area'])
        pred_labeled = ndi_morph.binary_fill_holes(pred_labeled > 0)
        pred_labeled = measure.label(pred_labeled)

        metrics = compute_metrics(pred_labeled, gt, metric_names)

        if plot_num > 0:

            unNorm = get_transforms({
                'unnormalize':
                np.load('{:s}/mean_std.npy'.format(opt.train['data_dir']))
            })
            img_tensor = unNorm(input.squeeze(0))
            img_np = img_tensor.permute(1, 2, 0).numpy()
            font = ImageFont.truetype(
                '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf', 42)
            metrics_text = Image.new("RGB", (512, 512), (255, 255, 255))
            draw = ImageDraw.Draw(metrics_text)
            draw.text(
                (32, 128),
                'Acc: {:.4f}\nF1: {:.4f}\nRecall: {:.4f}\nPrecision: {:.4f}\nDice: {:.4f}\nAJI: {:.4f}'
                .format(metrics['acc'], metrics['p_F1'], metrics['p_recall'],
                        metrics['p_precision'], metrics['dice'],
                        metrics['aji']),
                fill='rgb(0,0,0)',
                font=font)
            #tb_writer.add_scalars('{:s}'.format(name), 'Acc: {:.4f}\nF1: {:.4f}\nRecall: {:.4f}\nPrecision: {:.4f}\nDice: {:.4f}\nAJI: {:.4f}'.format(metrics['acc'], metrics['p_F1'], metrics['p_recall'], metrics['p_precision'], metrics['dice'], metrics['aji']), epoch)
            metrics_text = metrics_text.resize((ori_w, ori_h), Image.ANTIALIAS)
            trans_to_tensor = transforms.Compose([
                transforms.ToTensor(),
            ])
            text_tensor = trans_to_tensor(metrics_text).float()
            colored_gt = np.zeros((ori_h, ori_w, 3))
            colored_pred = np.zeros((ori_h, ori_w, 3))
            img_w_colored_gt = img_np.copy()
            img_w_colored_pred = img_np.copy()
            alpha = 0.5
            for k in range(1, gt.max() + 1):
                colored_gt[gt == k, :] = np.array(
                    utils.get_random_color(seed=k))
                img_w_colored_gt[gt == k, :] = img_w_colored_gt[gt == k, :] * (
                    1 - alpha) + colored_gt[gt == k, :] * alpha
            for k in range(1, pred_labeled.max() + 1):
                colored_pred[pred_labeled == k, :] = np.array(
                    utils.get_random_color(seed=k))
                img_w_colored_pred[
                    pred_labeled ==
                    k, :] = img_w_colored_pred[pred_labeled == k, :] * (
                        1 - alpha) + colored_pred[pred_labeled == k, :] * alpha

            gt_tensor = torch.from_numpy(colored_gt).permute(2, 0, 1).float()
            pred_tensor = torch.from_numpy(colored_pred).permute(2, 0,
                                                                 1).float()
            img_w_gt_tensor = torch.from_numpy(img_w_colored_gt).permute(
                2, 0, 1).float()
            img_w_pred_tensor = torch.from_numpy(img_w_colored_pred).permute(
                2, 0, 1).float()
            tb_writer.add_image(
                '{:s}'.format(name),
                make_grid([
                    img_tensor, img_w_gt_tensor, img_w_pred_tensor,
                    text_tensor, gt_tensor, pred_tensor
                ],
                          nrow=3,
                          padding=10,
                          pad_value=1), epoch)
            plot_num -= 1

        # update the average result
        all_results.update([
            metrics['acc'], metrics['p_F1'], metrics['p_recall'],
            metrics['p_precision'], metrics['dice'], metrics['aji']
        ])
    logger.info('\t=> Val Avg: Acc {r[0]:.4f}'
                '\tF1 {r[1]:.4f}'
                '\tRecall {r[2]:.4f}'
                '\tPrecision {r[3]:.4f}'
                '\tDice {r[4]:.4f}'
                '\tAJI {r[5]:.4f}'.format(r=all_results.avg))

    return all_results.avg