Exemple #1
0
def test(model, loader, device):
    model.eval()
    all_labels = []
    all_logits = []
    all_predictions = []
    all_losses = []
    all_seg_preds_interp = []
    all_dices = []
    all_ious = []
    evaluator = Evaluator(ex.current_run.config['model']['num_classes'])
    image_evaluator = Evaluator(ex.current_run.config['model']['num_classes'])

    pbar = tqdm(loader, ncols=80, desc='Test')

    with torch.no_grad():
        for image, segmentation, label in pbar:
            image = image.to(device)

            logits = model(image).cpu()
            pred = model.pooling.predictions(logits=logits).item()
            loss = model.pooling.loss(logits=logits, labels=label)

            if ex.current_run.config['dataset']['name'] == 'caltech_birds':
                segmentation_classes = (segmentation.squeeze() > 0.5)
            else:
                segmentation_classes = (segmentation.squeeze() != 0)
            seg_logits = model.pooling.cam
            seg_logits_interp = F.interpolate(seg_logits,
                                              size=segmentation_classes.shape,
                                              mode='bilinear',
                                              align_corners=True).squeeze(0)

            label = label.item()
            all_labels.append(label)
            all_logits.append(logits)
            all_predictions.append(pred)
            all_losses.append(loss.item())

            if ex.current_run.config['dataset']['name'] == 'glas':
                if ex.current_run.config['model'][
                        'pooling'] == 'deepmil_multi':
                    seg_preds_interp = (seg_logits_interp[label] >
                                        (1 / seg_logits.numel())).cpu()
                else:
                    seg_preds_interp = (
                        seg_logits_interp.argmax(0) == label).cpu()

            else:
                if ex.current_run.config['model']['pooling'] == 'deepmil':
                    seg_preds_interp = (seg_logits_interp.squeeze(0) >
                                        (1 / seg_logits.numel())).cpu()
                elif ex.current_run.config['model'][
                        'pooling'] == 'deepmil_multi':
                    seg_preds_interp = (seg_logits_interp[label] >
                                        (1 / seg_logits.numel())).cpu()
                else:
                    seg_preds_interp = seg_logits_interp.argmax(0).cpu()

            # all_seg_probs_interp.append(seg_probs_interp.numpy())
            all_seg_preds_interp.append(
                seg_preds_interp.numpy().astype('bool'))

            evaluator.add_batch(segmentation_classes, seg_preds_interp)
            image_evaluator.add_batch(segmentation_classes, seg_preds_interp)
            all_dices.append(image_evaluator.dice()[1].item())
            all_ious.append(
                image_evaluator.intersection_over_union()[1].item())
            image_evaluator.reset()

        all_logits = torch.cat(all_logits, 0)
        all_probabilities = model.pooling.probabilities(all_logits)

    metrics = metric_report(np.array(all_labels), all_probabilities.numpy(),
                            np.array(all_predictions))
    metrics['images_path'] = loader.dataset.samples
    metrics['labels'] = np.array(all_labels)
    metrics['logits'] = all_logits.numpy()
    metrics['probabilities'] = all_probabilities.numpy()
    metrics['predictions'] = np.array(all_predictions)
    metrics['losses'] = np.array(all_losses)

    metrics['dice_per_image'] = np.array(all_dices)
    metrics['mean_dice'] = metrics['dice_per_image'].mean()
    metrics['dice'] = evaluator.dice()[1].item()
    metrics['iou_per_image'] = np.array(all_ious)
    metrics['mean_iou'] = metrics['iou_per_image'].mean()
    metrics['iou'] = evaluator.intersection_over_union()[1].item()

    if ex.current_run.config['dataset'][
            'split'] == 0 and ex.current_run.config['dataset']['fold'] == 0:
        metrics['seg_preds'] = all_seg_preds_interp

    return metrics
def test(model, loader, device):
    model.eval()
    all_labels = []
    all_logits = []
    all_predictions = []
    all_losses = []
    all_seg_logits_interp = []
    all_seg_preds_interp = []
    all_dices = []
    all_ious = []
    evaluator = Evaluator(ex.current_run.config['model']['num_classes'])
    image_evaluator = Evaluator(ex.current_run.config['model']['num_classes'])

    pbar = tqdm(loader, ncols=80, desc='Test')

    pooling = ex.current_run.config['model']['pooling']
    if pooling in requires_gradients:
        grad_policy = torch.set_grad_enabled(True)
    else:
        grad_policy = torch.no_grad()

    is_ae = isinstance(model.backbone, ResNet_AE)

    with grad_policy:
        for i, (image, segmentation, label) in enumerate(pbar):
            image, label = image.to(device), label.to(device)

            if pooling in requires_gradients or pooling == 'ablation':
                model.pooling.eval_cams = True

            if is_ae:
                z, x_reconst = model.backbone(image)
                logits = model.pooling(z)
            else:
                logits = model(image)

            pred = model.pooling.predictions(logits=logits).item()

            loss = model.pooling.loss(logits=logits, labels=label)

            if ex.current_run.config['dataset']['name'] == 'caltech_birds':
                segmentation_classes = (segmentation.squeeze() > 0.5)
            else:
                segmentation_classes = (segmentation.squeeze() != 0)

            seg_logits = model.pooling.cam.detach().cpu()
            seg_logits_interp = F.interpolate(seg_logits,
                                              size=segmentation_classes.shape,
                                              mode='bilinear',
                                              align_corners=True).squeeze(0)

            label = label.item()
            all_labels.append(label)
            all_logits.append(logits.cpu())
            all_predictions.append(pred)
            all_losses.append(loss.item())

            if ex.current_run.config['dataset']['name'] == 'glas':
                if ex.current_run.config['model'][
                        'pooling'] == 'deepmil_multi':
                    seg_preds_interp = (seg_logits_interp[label] >
                                        (1 / seg_logits.numel())).cpu()
                else:
                    seg_preds_interp = (
                        seg_logits_interp.argmax(0) == label).cpu()
            else:
                if ex.current_run.config['model']['pooling'] == 'deepmil':
                    seg_preds_interp = (seg_logits_interp.squeeze(0) >
                                        (1 / seg_logits.numel())).cpu()
                elif ex.current_run.config['model'][
                        'pooling'] == 'deepmil_multi':
                    seg_preds_interp = (seg_logits_interp[label] >
                                        (1 / seg_logits.numel())).cpu()
                else:
                    seg_preds_interp = seg_logits_interp.argmax(0).cpu()

            # Save CAMs visualization
            save_dir = 'cams/{}/{}'.format(
                ex.current_run.config['model']['arch'] +
                str(ex.current_run.config['balance']),
                ex.current_run.config['model']['pooling'])
            os.makedirs(save_dir, exist_ok=True)
            file_path = os.path.join(save_dir, 'cam_{}.png'.format(i))
            seg_logits_interp_norm = seg_logits_interp / seg_logits_interp.max(
            )
            saliency_map_0, overlay_0 = visualize_cam(
                seg_logits_interp_norm[0], image)
            saliency_map_1, overlay_1 = visualize_cam(
                seg_logits_interp_norm[1], image)
            overlay = [overlay_0, overlay_1][label]
            save_visualization(image.squeeze().cpu(),
                               segmentation_classes.numpy(), saliency_map_0,
                               saliency_map_1, overlay,
                               seg_preds_interp.numpy() * 255, label,
                               file_path)

            if is_ae:
                x_reconst = x_reconst.detach()
                save_dir = 'reconst/{}/{}'.format(
                    ex.current_run.config['model']['arch'],
                    ex.current_run.config['model']['pooling'])
                os.makedirs(save_dir, exist_ok=True)
                file_path = os.path.join(save_dir, 'reconst_{}.png'.format(i))
                save_reconst(
                    image.squeeze(0).cpu(),
                    x_reconst.squeeze(0).cpu(), file_path)

            all_seg_logits_interp.append(seg_logits_interp.numpy())
            all_seg_preds_interp.append(
                seg_preds_interp.numpy().astype('bool'))

            evaluator.add_batch(segmentation_classes, seg_preds_interp)
            image_evaluator.add_batch(segmentation_classes, seg_preds_interp)
            all_dices.append(image_evaluator.dice()[1].item())
            all_ious.append(
                image_evaluator.intersection_over_union()[1].item())
            image_evaluator.reset()

        if pooling in requires_gradients or pooling == 'ablation':
            model.pooling.eval_cams = False
        all_logits = torch.cat(all_logits, 0)
        all_logits = all_logits.detach()
        all_probabilities = model.pooling.probabilities(all_logits)

    with open('test/gradcampp_seg_preds.pkl', 'wb') as f:
        pkl.dump(all_seg_preds_interp, f)

    results_dir = 'out/{}/{}'.format(
        ex.current_run.config['model']['arch'] +
        str(ex.current_run.config['balance']),
        ex.current_run.config['model']['pooling'])
    save_results(results_dir, loader.dataset.samples, np.array(all_labels),
                 np.array(all_predictions), all_seg_logits_interp,
                 all_seg_preds_interp, np.array(all_dices))

    metrics = metric_report(np.array(all_labels), all_probabilities.numpy(),
                            np.array(all_predictions))
    metrics['images_path'] = loader.dataset.samples
    metrics['labels'] = np.array(all_labels)
    metrics['logits'] = all_logits.numpy()
    metrics['probabilities'] = all_probabilities.numpy()
    metrics['predictions'] = np.array(all_predictions)
    metrics['losses'] = np.array(all_losses)

    metrics['dice_per_image'] = np.array(all_dices)
    metrics['mean_dice'] = metrics['dice_per_image'].mean()
    metrics['dice'] = evaluator.dice()[1].item()
    metrics['iou_per_image'] = np.array(all_ious)
    metrics['mean_iou'] = metrics['iou_per_image'].mean()
    metrics['iou'] = evaluator.intersection_over_union()[1].item()
    metrics['conf_mat'] = evaluator.cm.numpy()

    if ex.current_run.config['dataset'][
            'split'] == 0 and ex.current_run.config['dataset']['fold'] == 0:
        metrics['seg_preds'] = all_seg_preds_interp

    return metrics
Exemple #3
0
    def evaluate(self, loader, test=False):
        self.model.eval()
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        all_labels = []
        all_losses = []
        all_seg_preds = []
        all_dices = []
        all_ious = []
        evaluator = Evaluator(2)
        image_evaluator = Evaluator(2)

        pbar = tqdm(loader, ncols=80, desc='Test' if test else 'Validation')
        with torch.no_grad():
            for image, mask, label, f_name in pbar:
                image, mask = image.to(device), mask.squeeze(1).to(
                    device, non_blocking=True)
                class_masks = (mask != 0).long()

                seg_logits = self.model(image)
                # loss = F.cross_entropy(seg_logits, class_masks).item()
                # # dice_loss = DiceLoss()
                # # loss = dice_loss(seg_logits, class_masks).item()
                loss = self.get_loss(seg_logits, class_masks).item()

                seg_preds = seg_logits.argmax(1)
                evaluator.add_batch(class_masks, seg_preds)
                image_evaluator.add_batch(class_masks, seg_preds)
                dices = image_evaluator.dice()
                ious = image_evaluator.intersection_over_union()
                image_evaluator.reset()

                all_labels.append(label[0])
                all_losses.append(loss)
                all_dices.append(dices.cpu())
                all_ious.append(ious.cpu())
                all_seg_preds.append(
                    seg_preds.squeeze(0).byte().cpu().numpy().astype('bool'))

            all_labels = np.array(all_labels)
            all_losses = np.array(all_losses)
            all_dices = torch.stack(all_dices, 0)
            all_ious = torch.stack(all_ious, 0)

        dices = evaluator.dice()
        ious = evaluator.intersection_over_union()

        metrics = {
            'images_path': loader.dataset.rows,
            'labels': all_labels,
            'losses': all_losses,
            'dice_background_per_image': all_dices[:, 0].numpy(),
            'mean_dice_background': all_dices[:, 0].numpy().mean(),
            'dice_background': dices[0].item(),
            'dice_per_image': all_dices[:, 1].numpy(),
            'mean_dice': all_dices[:, 1].numpy().mean(),
            'dice': dices[1].item(),
            'iou_background_per_image': all_ious[:, 0].numpy(),
            'mean_iou_background': all_ious[:, 0].numpy().mean(),
            'iou_background': ious[0].item(),
            'iou_per_image': all_ious[:, 1].numpy(),
            'mean_iou': all_ious[:, 1].numpy().mean(),
            'iou': ious[1].item(),
        }

        return metrics
Exemple #4
0
    def test_bma(self, dataset, n_predictions):
        self.model.eval()
        loader = DataLoader(dataset, batch_size=1, shuffle=False)
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        with torch.no_grad():
            if n_predictions == 1:
                return self.evaluate(loader=loader, test=True)
            else:

                all_labels = []
                all_losses = []
                all_seg_preds = []
                all_dices = []
                all_ious = []
                evaluator = Evaluator(2)
                image_evaluator = Evaluator(2)

                for image, mask, lbl, name in tqdm(loader,
                                                   ncols=80,
                                                   desc='Test MC predictions'):
                    image, mask = image.to(device), mask.squeeze(1).to(
                        device, non_blocking=True)
                    class_masks = (mask != 0).long()

                    # MC predictions
                    preds = [self.model(image) for _ in range(n_predictions)]
                    stack = torch.stack(preds, dim=-1)
                    seg_logits = stack.mean(dim=-1)

                    # loss = F.cross_entropy(seg_logits, class_masks).item()
                    # # dice_loss = DiceLoss()
                    # # loss = dice_loss(seg_logits, class_masks).item()
                    loss = self.get_loss(seg_logits, class_masks).item()

                    seg_preds = seg_logits.argmax(1)
                    evaluator.add_batch(class_masks, seg_preds)
                    image_evaluator.add_batch(class_masks, seg_preds)
                    dices = image_evaluator.dice()
                    ious = image_evaluator.intersection_over_union()
                    image_evaluator.reset()

                    all_labels.append(lbl[0])
                    all_losses.append(loss)
                    all_dices.append(dices.cpu())
                    all_ious.append(ious.cpu())
                    all_seg_preds.append(
                        seg_preds.squeeze(0).byte().cpu().numpy().astype(
                            'bool'))

                all_labels = np.array(all_labels)
                all_losses = np.array(all_losses)
                all_dices = torch.stack(all_dices, 0)
                all_ious = torch.stack(all_ious, 0)

            dices = evaluator.dice()
            ious = evaluator.intersection_over_union()

            metrics = {
                'images_path': loader.dataset.rows,
                'labels': all_labels,
                'losses': all_losses,
                'dice_background_per_image': all_dices[:, 0].numpy(),
                'mean_dice_background': all_dices[:, 0].numpy().mean(),
                'dice_background': dices[0].item(),
                'dice_per_image': all_dices[:, 1].numpy(),
                'mean_dice': all_dices[:, 1].numpy().mean(),
                'dice': dices[1].item(),
                'iou_background_per_image': all_ious[:, 0].numpy(),
                'mean_iou_background': all_ious[:, 0].numpy().mean(),
                'iou_background': ious[0].item(),
                'iou_per_image': all_ious[:, 1].numpy(),
                'mean_iou': all_ious[:, 1].numpy().mean(),
                'iou': ious[1].item(),
            }

            return metrics
def evaluate(model, loader, device, test=False):
    model.eval()
    all_labels = []
    all_losses = []
    all_seg_preds = []
    all_dices = []
    all_ious = []
    evaluator = Evaluator(2)
    image_evaluator = Evaluator(2)

    pbar = tqdm(loader, ncols=80, desc='Test' if test else 'Validation')
    with torch.no_grad():
        for image, mask, label in pbar:
            image = image.to(device, non_blocking=True)
            segmentation = (mask != 0).squeeze(1)
            t_segmentation = segmentation.to(device, non_blocking=True).long()

            seg_logits = model(image)
            loss = F.cross_entropy(seg_logits, t_segmentation).item()
            seg_probs = torch.softmax(seg_logits, 1)
            seg_preds = seg_logits.argmax(1)
            evaluator.add_batch(t_segmentation, seg_preds)
            image_evaluator.add_batch(t_segmentation, seg_preds)
            dices = image_evaluator.dice()
            ious = image_evaluator.intersection_over_union()
            image_evaluator.reset()

            all_labels.append(label.item())
            all_losses.append(loss)
            all_dices.append(dices.cpu())
            all_ious.append(ious.cpu())
            all_seg_preds.append(
                seg_preds.squeeze(0).byte().cpu().numpy().astype('bool'))

        all_labels = np.array(all_labels)
        all_losses = np.array(all_losses)
        all_dices = torch.stack(all_dices, 0)
        all_ious = torch.stack(all_ious, 0)

    dices = evaluator.dice()
    ious = evaluator.intersection_over_union()

    metrics = {
        'images_path': loader.dataset.samples,
        'labels': all_labels,
        'losses': all_losses,
        'dice_background_per_image': all_dices[:, 0].numpy(),
        'mean_dice_background': all_dices[:, 0].numpy().mean(),
        'dice_background': dices[0].item(),
        'dice_per_image': all_dices[:, 1].numpy(),
        'mean_dice': all_dices[:, 1].numpy().mean(),
        'dice': dices[1].item(),
        'iou_background_per_image': all_ious[:, 0].numpy(),
        'mean_iou_background': all_ious[:, 0].numpy().mean(),
        'iou_background': ious[0].item(),
        'iou_per_image': all_ious[:, 1].numpy(),
        'mean_iou': all_ious[:, 1].numpy().mean(),
        'iou': ious[1].item(),
    }

    if test and ex.current_run.config['dataset'][
            'split'] == 0 and ex.current_run.config['dataset']['fold'] == 0:
        metrics['seg_preds'] = all_seg_preds

    return metrics