Beispiel #1
0
def dice3d(base_folder, folder, subfoldername, grp_regex, gt_folder, C):
    if base_folder == '':
        work_folder = Path(folder, subfoldername)
    else:
        work_folder = Path(base_folder,folder, subfoldername)
    filenames = map_(lambda p: str(p.name), work_folder.glob("*.png"))
    grouping_regex: Pattern = re.compile(grp_regex)

    stems: List[str] = [Path(filename).stem for filename in filenames]  # avoid matching the extension
    matches: List[Match] = map_(grouping_regex.match, stems)
    patients: List[str] = [match.group(0) for match in matches]

    unique_patients: List[str] = list(set(patients))
    batch_dice = torch.zeros((len(unique_patients), C))
    for i, patient in enumerate(unique_patients):
        patient_slices = [f for f in stems if f.startswith(patient)]
        w,h = [256,256]
        n = len(patient_slices)
        t_seg = np.ndarray(shape=(w, h, n))
        t_gt = np.ndarray(shape=(w, h, n))
        for slice in patient_slices:
            slice_nb = int(re.split(grp_regex, slice)[1])
            seg = imageio.imread(str(work_folder)+'/'+slice+'.png')
            gt = imageio.imread(str(gt_folder )+'/'+ slice+'.png')
            if seg.shape != (w, h):
                seg = resize_im(seg, 36)
            if gt.shape != (w, h):
                gt = resize_im(gt, 36)
            seg[seg == 255] = 1
            t_seg[:, :, slice_nb] = seg
            t_gt[:, :, slice_nb] = gt
        t_seg = torch.from_numpy(t_seg)
        t_gt = torch.from_numpy(t_gt)
        batch_dice[i,...] = dice_batch(class2one_hot(t_seg,3), class2one_hot(t_gt,3))[0] # do not save the interclasses etcetc
    return batch_dice.mean(dim=0), batch_dice.std(dim=0)
Beispiel #2
0
def runInference(args: argparse.Namespace, pred_folder: str):
    # print('>>> Loading the data')
    device = torch.device("cuda") if torch.cuda.is_available(
    ) and not args.cpu else torch.device("cpu")
    C: int = args.num_classes

    # Let's just reuse some code
    png_transform = transforms.Compose([
        lambda img: np.array(img)[np.newaxis, ...],
        lambda nd: nd / 255,  # max <= 1
        lambda nd: torch.tensor(nd, dtype=torch.float32)
    ])
    gt_transform = transforms.Compose([
        lambda img: np.array(img)[np.newaxis, ...],
        lambda nd: torch.tensor(nd, dtype=torch.int64),
        partial(class2one_hot, C=C),
        itemgetter(0)
    ])

    bounds_gen = [(lambda *a: torch.zeros(C, 1, 2)) for _ in range(2)]

    folders: List[Path] = [
        Path(pred_folder),
        Path(pred_folder),
        Path(args.gt_folder)
    ]  # First one is dummy
    names: List[str] = map_(lambda p: str(p.name), folders[0].glob("*.png"))
    are_hots = [False, True, True]

    dt_set = SliceDataset(
        names,
        folders,
        transforms=[png_transform, gt_transform, gt_transform],
        debug=False,
        C=C,
        are_hots=are_hots,
        in_memory=False,
        bounds_generators=bounds_gen)
    sampler = PatientSampler(dt_set, args.grp_regex)
    loader = DataLoader(dt_set, batch_sampler=sampler, num_workers=11)

    # print('>>> Computing the metrics')
    total_iteration, total_images = len(loader), len(loader.dataset)
    metrics = {
        "all_dices":
        torch.zeros((total_images, C), dtype=torch.float64, device=device),
        "batch_dices":
        torch.zeros((total_iteration, C), dtype=torch.float64, device=device),
        "sizes":
        torch.zeros((total_images, 1), dtype=torch.float64, device=device)
    }

    desc = f">> Computing"
    tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc)
    done: int = 0
    for j, (filenames, _, pred, gt, _) in tq_iter:
        B = len(pred)
        pred = pred.to(device)
        gt = gt.to(device)
        assert simplex(pred) and sset(pred, [0, 1])
        assert simplex(gt) and sset(gt, [0, 1])

        dices: Tensor = dice_coef(pred, gt)
        b_dices: Tensor = dice_batch(pred, gt)
        assert dices.shape == (B, C)
        assert b_dices.shape == (C, ), b_dices.shape

        sm_slice = slice(done, done + B)  # Values only for current batch
        metrics["all_dices"][sm_slice, ...] = dices
        metrics["sizes"][sm_slice, :] = torch.einsum("bwh->b",
                                                     gt[:, 1, ...])[..., None]
        metrics["batch_dices"][j] = b_dices
        done += B

    print(f">>> {pred_folder}")
    for key, v in metrics.items():
        print(key, map_("{:.4f}".format, v.mean(dim=0)))
def do_epoch(mode: str, net: Any, device: Any, loader: DataLoader, epc: int,
             loss_fns: List[Callable], loss_weights: List[float], C: int,
             savedir: str = "", optimizer: Any = None,
             metric_axis: List[int] = [1], compute_haussdorf: bool = False) \
        -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    assert mode in ["train", "val"]
    L: int = len(loss_fns)

    if mode == "train":
        net.train()
        desc = f">> Training   ({epc})"
    elif mode == "val":
        net.eval()
        desc = f">> Validation ({epc})"

    total_iteration, total_images = len(loader), len(loader.dataset)
    all_dices: Tensor = torch.zeros((total_images, C),
                                    dtype=torch.float32,
                                    device=device)
    batch_dices: Tensor = torch.zeros((total_iteration, C),
                                      dtype=torch.float32,
                                      device=device)
    loss_log: Tensor = torch.zeros((total_iteration),
                                   dtype=torch.float32,
                                   device=device)
    haussdorf_log: Tensor = torch.zeros((total_images, C),
                                        dtype=torch.float32,
                                        device=device)

    tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc)
    done: int = 0
    for j, data in tq_iter:
        data[1:] = [e.to(device)
                    for e in data[1:]]  # Move all tensors to device
        filenames, image, target = data[:3]
        labels = data[3:3 + L]
        bounds = data[3 + L:]
        assert len(labels) == len(bounds)

        B = len(image)

        # Reset gradients
        if optimizer:
            optimizer.zero_grad()

        # Forward
        pred_logits: Tensor = net(image)
        pred_probs: Tensor = F.softmax(pred_logits, dim=1)
        predicted_mask: Tensor = probs2one_hot(
            pred_probs.detach())  # Used only for dice computation

        assert len(bounds) == len(loss_fns) == len(loss_weights)
        ziped = zip(loss_fns, labels, loss_weights, bounds)
        losses = [
            w * loss_fn(pred_probs, label, bound)
            for loss_fn, label, w, bound in ziped
        ]
        loss = reduce(add, losses)
        assert loss.shape == (), loss.shape

        # Backward
        if optimizer:
            loss.backward()
            optimizer.step()

        # Compute and log metrics
        loss_log[j] = loss.detach()

        sm_slice = slice(done, done + B)  # Values only for current batch

        dices: Tensor = dice_coef(predicted_mask, target.detach())
        assert dices.shape == (B, C), (dices.shape, B, C)
        all_dices[sm_slice, ...] = dices

        if B > 1 and mode == "val":
            batch_dice: Tensor = dice_batch(predicted_mask, target.detach())
            assert batch_dice.shape == (C, ), (batch_dice.shape, B, C)
            batch_dices[j] = batch_dice

        if compute_haussdorf:
            haussdorf_res: Tensor = haussdorf(predicted_mask.detach(),
                                              target.detach())
            assert haussdorf_res.shape == (B, C)
            haussdorf_log[sm_slice] = haussdorf_res

        # Save images
        if savedir:
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=UserWarning)
                predicted_class: Tensor = probs2class(pred_probs)
                save_images(predicted_class, filenames, savedir, mode, epc)

        # Logging
        big_slice = slice(0,
                          done + B)  # Value for current and previous batches

        dsc_dict = {
            f"DSC{n}": all_dices[big_slice, n].mean()
            for n in metric_axis
        }
        hauss_dict = {
            f"HD{n}": haussdorf_log[big_slice, n].mean()
            for n in metric_axis
        } if compute_haussdorf else {}
        batch_dict = {
            f"bDSC{n}": batch_dices[:j, n].mean()
            for n in metric_axis
        } if B > 1 and mode == "val" else {}

        mean_dict = {
            "DSC": all_dices[big_slice, metric_axis].mean(),
            "HD": haussdorf_log[big_slice, metric_axis].mean()
        } if len(metric_axis) > 1 else {}

        stat_dict = {
            **dsc_dict,
            **hauss_dict,
            **mean_dict,
            **batch_dict, "loss": loss_log[:j].mean()
        }
        nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()}

        tq_iter.set_postfix(nice_dict)
        done += B
    print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items()))

    return loss_log, all_dices, batch_dices, haussdorf_log
Beispiel #4
0
def do_epoch(mode: str, net: Any, device: Any, loaders: List[DataLoader], epc: int,
             list_loss_fns: List[List[Callable]], list_loss_weights: List[List[float]], C: int,
             savedir: str = "", optimizer: Any = None,
             metric_axis: List[int] = [1], compute_haussdorf: bool = False, compute_miou: bool = False,
             temperature: float = 1) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tuple[None, Tensor]]:
    assert mode in ["train", "val"]

    if mode == "train":
        net.train()
        desc = f">> Training   ({epc})"
    elif mode == "val":
        net.eval()
        desc = f">> Validation ({epc})"

    total_iteration: int = sum(len(loader) for loader in loaders)  # U
    total_images: int = sum(len(loader.dataset) for loader in loaders)  # D
    n_loss: int = max(map(len, list_loss_fns))

    all_dices: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device)
    batch_dices: Tensor = torch.zeros((total_iteration, C), dtype=torch.float32, device=device)
    loss_log: Tensor = torch.zeros((total_iteration, n_loss), dtype=torch.float32, device=device)
    haussdorf_log: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device)
    iiou_log: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device)
    intersections: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device)
    unions: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device)

    few_axis: bool = len(metric_axis) <= 3

    done_img: int = 0
    done_batch: int = 0
    tq_iter = tqdm_(total=total_iteration, desc=desc)
    for i, (loader, loss_fns, loss_weights) in enumerate(zip(loaders, list_loss_fns, list_loss_weights)):
        L: int = len(loss_fns)

        for data in loader:
            data[1:] = [e.to(device) for e in data[1:]]  # Move all tensors to device
            filenames, image, target = data[:3]
            assert not target.requires_grad
            labels = data[3:3 + L]
            bounds = data[3 + L:]
            assert len(labels) == len(bounds)

            B = len(image)

            # Reset gradients
            if optimizer:
                optimizer.zero_grad()

            # Forward
            pred_logits: Tensor = net(image)
            pred_probs: Tensor = F.softmax(temperature * pred_logits, dim=1)
            predicted_mask: Tensor = probs2one_hot(pred_probs.detach())  # Used only for dice computation
            assert not predicted_mask.requires_grad

            assert len(bounds) == len(loss_fns) == len(loss_weights) == len(labels)
            ziped = zip(loss_fns, labels, loss_weights, bounds)
            losses = [w * loss_fn(pred_probs, label, bound) for loss_fn, label, w, bound in ziped]
            loss = reduce(add, losses)
            assert loss.shape == (), loss.shape

            # if epc >= 1 and False:
            #     import matplotlib.pyplot as plt
            #     _, axes = plt.subplots(nrows=1, ncols=3)
            #     axes[0].imshow(image[0, 0].cpu().numpy(), cmap='gray')
            #     axes[0].contour(target[0, 1].cpu().numpy(), cmap='rainbow')

            #     pred_np = pred_probs[0, 1].detach().cpu().numpy()
            #     axes[1].imshow(pred_np)

            #     bins = np.linspace(0, 1, 50)
            #     axes[2].hist(pred_np.flatten(), bins)
            #     print(bounds)
            #     print(bounds[2].cpu().numpy())
            #     print(bounds[2][0, 1].cpu().numpy())
            #     print(pred_np.sum())
            #     plt.show()

            # Backward
            if optimizer:
                loss.backward()
                optimizer.step()

            # Compute and log metrics
            # loss_log[done_batch] = loss.detach()
            for j in range(len(loss_fns)):
                loss_log[done_batch, j] = losses[j].detach()

            sm_slice = slice(done_img, done_img + B)  # Values only for current batch

            dices: Tensor = dice_coef(predicted_mask, target)
            assert dices.shape == (B, C), (dices.shape, B, C)
            all_dices[sm_slice, ...] = dices

            if B > 1 and mode == "val":
                batch_dice: Tensor = dice_batch(predicted_mask, target)
                assert batch_dice.shape == (C,), (batch_dice.shape, B, C)
                batch_dices[done_batch] = batch_dice

            if compute_haussdorf:
                haussdorf_res: Tensor = haussdorf(predicted_mask, target)
                assert haussdorf_res.shape == (B, C)
                haussdorf_log[sm_slice] = haussdorf_res
            if compute_miou:
                IoUs: Tensor = iIoU(predicted_mask, target)
                assert IoUs.shape == (B, C), IoUs.shape
                iiou_log[sm_slice] = IoUs
                intersections[sm_slice] = inter_sum(predicted_mask, target)
                unions[sm_slice] = union_sum(predicted_mask, target)

            # Save images
            if savedir:
                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=UserWarning)
                    predicted_class: Tensor = probs2class(pred_probs)
                    save_images(predicted_class, filenames, savedir, mode, epc)

            # Logging
            big_slice = slice(0, done_img + B)  # Value for current and previous batches

            dsc_dict = {f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis} if few_axis else {}

            hauss_dict = {f"HD{n}": haussdorf_log[big_slice, n].mean() for n in metric_axis} \
                if compute_haussdorf and few_axis else {}

            batch_dict = {f"bDSC{n}": batch_dices[:done_batch, n].mean() for n in metric_axis} \
                if B > 1 and mode == "val" and few_axis else {}

            miou_dict = {f"iIoU": iiou_log[big_slice, metric_axis].mean(),
                         f"mIoU": (intersections.sum(dim=0) / (unions.sum(dim=0) + 1e-10)).mean()} \
                if compute_miou else {}

            if len(metric_axis) > 1:
                mean_dict = {"DSC": all_dices[big_slice, metric_axis].mean()}
                if compute_haussdorf:
                    mean_dict["HD"] = haussdorf_log[big_slice, metric_axis].mean()
            else:
                mean_dict = {}

            stat_dict = {**miou_dict, **dsc_dict, **hauss_dict, **mean_dict, **batch_dict,
                         "loss": loss_log[:done_batch].mean()}
            nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()}

            done_img += B
            done_batch += 1
            tq_iter.set_postfix({**nice_dict, "loader": str(i)})
            tq_iter.update(1)
    tq_iter.close()
    print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items()))

    if compute_miou:
        mIoUs: Tensor = (intersections.sum(dim=0) / (unions.sum(dim=0) + 1e-10))
        assert mIoUs.shape == (C,), mIoUs.shape
    else:
        mIoUs = None

    if not few_axis and False:
        print(f"DSC: {[f'{all_dices[:, n].mean():.3f}' for n in metric_axis]}")
        print(f"iIoU: {[f'{iiou_log[:, n].mean():.3f}' for n in metric_axis]}")
        if mIoUs:
            print(f"mIoU: {[f'{mIoUs[n]:.3f}' for n in metric_axis]}")

    return loss_log, all_dices, batch_dices, haussdorf_log, mIoUs
Beispiel #5
0
def do_epoch(
        mode: str,
        net: Any,
        device: Any,
        loaders: list[DataLoader],
        epc: int,
        list_loss_fns: list[list[Callable]],
        list_loss_weights: list[list[float]],
        K: int,
        savedir: str = "",
        optimizer: Any = None,
        metric_axis: list[int] = [1],
        compute_3d_dice: bool = False,
        temperature: float = 1) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
    assert mode in ["train", "val", "dual"]

    if mode == "train":
        net.train()
        desc = f">> Training   ({epc})"
    elif mode == "val":
        net.eval()
        desc = f">> Validation ({epc})"

    total_iteration: int = sum(len(loader) for loader in loaders)  # U
    total_images: int = sum(len(loader.dataset) for loader in loaders)  # D
    n_loss: int = max(map(len, list_loss_fns))

    all_dices: Tensor = torch.zeros((total_images, K),
                                    dtype=torch.float32,
                                    device=device)
    loss_log: Tensor = torch.zeros((total_iteration, n_loss),
                                   dtype=torch.float32,
                                   device=device)

    three_d_dices: Optional[Tensor]
    if compute_3d_dice:
        three_d_dices = torch.zeros((total_iteration, K),
                                    dtype=torch.float32,
                                    device=device)
    else:
        three_d_dices = None

    done_img: int = 0
    done_batch: int = 0
    tq_iter = tqdm_(total=total_iteration, desc=desc)
    for i, (loader, loss_fns, loss_weights) in enumerate(
            zip(loaders, list_loss_fns, list_loss_weights)):
        for data in loader:
            # t0 = time()
            image: Tensor = data["images"].to(device)
            target: Tensor = data["gt"].to(device)
            filenames: list[str] = data["filenames"]
            assert not target.requires_grad
            labels: list[Tensor] = [e.to(device) for e in data["labels"]]
            B, C, *_ = image.shape

            # Reset gradients
            if optimizer:
                optimizer.zero_grad()

            # Forward
            pred_logits: Tensor = net(image)
            pred_probs: Tensor = F.softmax(temperature * pred_logits, dim=1)
            predicted_mask: Tensor = probs2one_hot(
                pred_probs.detach())  # Used only for dice computation
            assert not predicted_mask.requires_grad

            assert len(loss_fns) == len(loss_weights) == len(labels)
            ziped = zip(loss_fns, labels, loss_weights)
            losses = [
                w * loss_fn(pred_probs, label) for loss_fn, label, w in ziped
            ]
            loss = reduce(add, losses)
            assert loss.shape == (), loss.shape

            # Backward
            if optimizer:
                loss.backward()
                optimizer.step()

            # Compute and log metrics
            for j in range(len(loss_fns)):
                loss_log[done_batch, j] = losses[j].detach()

            sm_slice = slice(done_img,
                             done_img + B)  # Values only for current batch

            dices: Tensor = dice_coef(predicted_mask, target)
            assert dices.shape == (B, K), (dices.shape, B, K)
            all_dices[sm_slice, ...] = dices

            if compute_3d_dice:
                three_d_DSC: Tensor = dice_batch(predicted_mask, target)
                assert three_d_DSC.shape == (K, )

                three_d_dices[done_batch] = three_d_DSC  # type: ignore

            # Save images
            if savedir:
                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=UserWarning)
                    predicted_class: Tensor = probs2class(pred_probs)
                    save_images(predicted_class, filenames, savedir, mode, epc)

            # Logging
            big_slice = slice(0, done_img +
                              B)  # Value for current and previous batches

            dsc_dict: dict = {f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis} | \
                ({f"3d_DSC{n}": three_d_dices[:done_batch, n].mean() for n in metric_axis}
                 if three_d_dices is not None else {})

            loss_dict = {
                f"loss_{i}": loss_log[:done_batch].mean(dim=0)[i]
                for i in range(n_loss)
            }

            stat_dict = dsc_dict | loss_dict
            nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()}

            done_img += B
            done_batch += 1
            tq_iter.set_postfix({**nice_dict, "loader": str(i)})
            tq_iter.update(1)
    tq_iter.close()

    print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items()))

    return (loss_log.detach().cpu(), all_dices.detach().cpu(),
            three_d_dices.detach().cpu()
            if three_d_dices is not None else None)
def do_epoch(mode: str, args, net, device, use_cuda, loader, optimizer,
             num_classes, epoch):

    totalImages = len(loader)

    if mode == "train":
        net.train()
        desc = f">> Training   ({epoch})"
    elif mode == "val":
        net.eval()
        desc = f">> Validation ({epoch})"

    total_iteration, total_images = len(loader), len(loader.dataset)
    all_dices: Tensor = torch.zeros((total_images, num_classes),
                                    dtype=eval(args.dtype),
                                    device=device)
    batch_dices: Tensor = torch.zeros((total_iteration, num_classes),
                                      dtype=eval(args.dtype),
                                      device=device)
    loss_log: Tensor = torch.zeros((total_images),
                                   dtype=eval(args.dtype),
                                   device=device)
    entropy_log: Tensor = torch.zeros((total_images),
                                      dtype=eval(args.dtype),
                                      device=device)
    KL_log: Tensor = torch.zeros((total_images),
                                 dtype=eval(args.dtype),
                                 device=device)

    tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc)
    done: int = 0

    for j, data in tq_iter:

        image_f, image_i, image_d, image_o, image_w, image_c, labels, img_names = data
        #image_f=image_f.type(torch.FloatTensor)/65535.
        #image_f = image_f.type(torch.FloatTensor)/65535.
        #image_i = image_i.type(torch.FloatTensor)/65535.
        #image_d = image_d.type(torch.FloatTensor)/65535.
        #image_o = image_o.type(torch.FloatTensor)/65535.
        #image_w = image_w.type(torch.FloatTensor)/65535.
        #image_c = image_c.type(torch.FloatTensor)/65535.
        MRI: Tensor = torch.zeros((1, 6, image_f.size()[2], image_f.size()[3]),
                                  dtype=eval(args.dtype))
        MRI = torch.cat((image_f, image_i, image_d, image_o, image_w, image_c),
                        dim=1)
        MRI = MRI.type(
            torch.FloatTensor
        ) / 65535.0  #.type(eval(args.dtype)) #.type(torch.FloatTensor)
        targets = torch.cat((1 - labels, labels),
                            dim=1)  #.type(torch.LongTensor)
        B = len(image_f)
        #print(type(labels))
        #MRI = torch.cat((image_f,image_i,image_d,image_w),dim=1)
        if use_cuda:
            MRI, targets = MRI.to(device), targets.to(device)

        # forward
        outputs = net(MRI)
        pred_probs = F.softmax(outputs, dim=1)
        predicted_mask = probs2one_hot(pred_probs)

        entropy = crossEntropy_f(pred_probs, targets)

        pred_probs_aver: Tensor = torch.sum(pred_probs, dim=(2, 3))
        pred_probs_aver = pred_probs_aver / torch.sum(targets).float()
        target_aver: Tensor = torch.sum(targets, dim=(2, 3)).float()
        target_aver = target_aver / torch.sum(targets).float()
        KL_loss = args.lam * kl(target_aver, pred_probs_aver)

        loss = entropy + KL_loss

        if mode == "train":
            # zero the parameter gradients8544
            optimizer.zero_grad()
            # backward + optimize
            loss.backward()
            optimizer.step()

        # Compute and log metrics
        dices: Tensor = dice_coef(predicted_mask.detach(),
                                  targets.type(torch.cuda.IntTensor).detach())
        batch_dice: Tensor = dice_batch(
            predicted_mask.detach(),
            targets.type(torch.cuda.IntTensor).detach())
        assert batch_dice.shape == (num_classes, ) and dices.shape == (
            B, num_classes), (batch_dice.shape, dices.shape, B, num_classes)

        sm_slice = slice(done, done + B)  # Values only for current batch
        all_dices[sm_slice, ...] = dices
        entropy_log[sm_slice] = entropy.detach()
        loss_log[sm_slice] = loss.detach()
        KL_log[sm_slice] = KL_loss.detach()
        batch_dices[j] = batch_dice

        # Logging
        big_slice = slice(0,
                          done + B)  # Value for current and previous batches
        stat_dict = {
            "dice": all_dices[big_slice, -1].mean(),
            "total loss": loss_log[big_slice].mean(),
            "entropy loss": entropy_log[big_slice].mean(),
            "KL loss": KL_log[big_slice].mean(),
            "b dice": batch_dices[:j + 1, -1].mean()
        }
        nice_dict = {k: f"{v:.4f}" for (k, v) in stat_dict.items()}

        done += B
        tq_iter.set_postfix(nice_dict)

    return loss_log, entropy_log, KL_log, all_dices, batch_dices
def main() -> None:
    args = get_args()

    iterations_paths: List[Path] = sorted(Path(args.basefolder).glob("iter*"))
    # print(iterations_paths)
    print(f">>> Found {len(iterations_paths)} epoch folders")

    # Handle gracefully if not all folders are there (early stop)
    EPC: int = args.n_epoch if args.n_epoch >= 0 else len(iterations_paths)
    K: int = args.num_classes

    # Get the patient number, and image names, from the GT folder
    gt_path: Path = Path(args.gt_folder)
    names: List[str] = map_(lambda p: str(p.name), gt_path.glob("*"))
    n_img: int = len(names)

    grouping_regex: Pattern = re.compile(args.grp_regex)
    stems: List[str] = [Path(filename).stem
                        for filename in names]  # avoid matching the extension
    matches: List[Match] = map_(grouping_regex.match, stems)  # type: ignore
    patients: List[str] = [match.group(1) for match in matches]

    unique_patients: List[str] = list(set(patients))
    n_patients: int = len(unique_patients)

    print(
        f">>> Found {len(unique_patients)} unique patients out of {n_img} images ; regex: {args.grp_regex}"
    )
    # from pprint import pprint
    # pprint(unique_patients)

    # First, quickly assert all folders have the same numbers of predited images
    n_img_epoc: List[int] = [
        len(list(Path(p, "val").glob("*.png"))) for p in iterations_paths
    ]
    assert len(set(n_img_epoc)) == 1
    assert all(
        len(list(Path(p, "val").glob("*.png"))) == n_img
        for p in iterations_paths)

    metrics: Dict['str', Tensor] = {}
    if '3d_dsc' in args.metrics:
        metrics['3d_dsc'] = torch.zeros((EPC, n_patients, K),
                                        dtype=torch.float32)
        print(f">> Will compute {'3d_dsc'} metric")
    if '3d_hausdorff' in args.metrics:
        metrics['3d_hausdorff'] = torch.zeros((EPC, n_patients, K),
                                              dtype=torch.float32)
        print(f">> Will compute {'3d_hausdorff'} metric")

    gen_dataset = partial(
        SliceDataset,
        transforms=[png_transform, gt_transform, gt_transform],
        are_hots=[False, True, True],
        K=K,
        in_memory=False,
        bounds_generators=[(lambda *a: torch.zeros(K, 1, 2))
                           for _ in range(1)],
        box_prior=False,
        box_priors_arg='{}',
        dimensions=2)
    data_loader = partial(DataLoader,
                          num_workers=cpu_count(),
                          pin_memory=False,
                          collate_fn=custom_collate)

    # Will replace live dataset.folders and call again load_images to update dataset.files
    print(gt_path, gt_path, Path(iterations_paths[0], 'val'))
    dataset: SliceDataset = gen_dataset(
        names,
        [gt_path, gt_path, Path(iterations_paths[0], 'val')])
    sampler: PatientSampler = PatientSampler(dataset,
                                             args.grp_regex,
                                             shuffle=False)
    dataloader: DataLoader = data_loader(dataset, batch_sampler=sampler)

    current_path: Path
    for e, current_path in enumerate(iterations_paths):
        dataset.folders = [gt_path, gt_path, Path(current_path, 'val')]
        dataset.files = SliceDataset.load_images(dataset.folders,
                                                 dataset.filenames, False)

        print(f">>> Doing epoch {str(current_path)}")

        for i, data in enumerate(tqdm(dataloader, leave=None)):
            target: Tensor = data["gt"]
            prediction: Tensor = data["labels"][0]

            assert target.shape == prediction.shape

            if '3d_dsc' in args.metrics:
                dsc: Tensor = dice_batch(target, prediction)
                assert dsc.shape == (K, )

                metrics['3d_dsc'][e, i, :] = dsc
            if '3d_hausdorff' in args.metrics:
                np_pred: np.ndarray = prediction[:, 1, :, :].cpu().numpy()
                np_target: np.ndarray = target[:, 1, :, :].cpu().numpy()

                if np_pred.sum() > 0:
                    hd_: float = hd(np_pred, np_target)

                    metrics["3d_hausdorff"][e, i, 1] = hd_
                else:
                    x, y, z = np_pred.shape
                    metrics["3d_hausdorff"][e, i,
                                            1] = (x**2 + y**2 + z**2)**0.5

        for metric in args.metrics:
            # For now, hardcode the fact we care about class 1 only
            print(f">> {metric}: {metrics[metric][e].mean(dim=0)[1]:.04f}")

    k: str
    el: Tensor
    for k, el in metrics.items():
        np.save(Path(args.basefolder, f"val_{k}.npy"), el.cpu().numpy())
Beispiel #8
0
def do_epoch(
    mode: str,
    net: Any,
    device: Any,
    loaders: List[DataLoader],
    epc: int,
    list_loss_fns: List[List[Callable]],
    list_loss_weights: List[List[float]],
    K: int,
    savedir: str = "",
    optimizer: Any = None,
    metric_axis: List[int] = [1],
    compute_hausdorff: bool = False,
    compute_miou: bool = False,
    compute_3d_dice: bool = False,
    temperature: float = 1
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor],
           Optional[Tensor]]:
    assert mode in ["train", "val"]

    if mode == "train":
        net.train()
        desc = f">> Training   ({epc})"
    elif mode == "val":
        net.eval()
        desc = f">> Validation ({epc})"

    total_iteration: int = sum(len(loader) for loader in loaders)  # U
    total_images: int = sum(len(loader.dataset) for loader in loaders)  # D
    n_loss: int = max(map(len, list_loss_fns))

    all_dices: Tensor = torch.zeros((total_images, K),
                                    dtype=torch.float32,
                                    device=device)
    loss_log: Tensor = torch.zeros((total_iteration, n_loss),
                                   dtype=torch.float32,
                                   device=device)

    iiou_log: Optional[Tensor]
    intersections: Optional[Tensor]
    unions: Optional[Tensor]
    if compute_miou:
        iiou_log = torch.zeros((total_images, K),
                               dtype=torch.float32,
                               device=device)
        intersections = torch.zeros((total_images, K),
                                    dtype=torch.float32,
                                    device=device)
        unions = torch.zeros((total_images, K),
                             dtype=torch.float32,
                             device=device)
    else:
        iiou_log = None
        intersections = None
        unions = None

    three_d_dices: Optional[Tensor]
    if compute_3d_dice:
        three_d_dices = torch.zeros((total_iteration, K),
                                    dtype=torch.float32,
                                    device=device)
    else:
        three_d_dices = None

    hausdorff_log: Optional[Tensor]
    if compute_hausdorff:
        hausdorff_log = torch.zeros((total_images, K),
                                    dtype=torch.float32,
                                    device=device)
    else:
        hausdorff_log = None

    few_axis: bool = len(metric_axis) <= 3

    done_img: int = 0
    done_batch: int = 0
    tq_iter = tqdm_(total=total_iteration, desc=desc)
    for i, (loader, loss_fns, loss_weights) in enumerate(
            zip(loaders, list_loss_fns, list_loss_weights)):
        for data in loader:
            image: Tensor = data["images"].to(device)
            target: Tensor = data["gt"].to(device)
            spacings: Tensor = data["spacings"]  # Keep that one on CPU
            assert not target.requires_grad
            labels: List[Tensor] = [e.to(device) for e in data["labels"]]
            bounds: List[Tensor] = [e.to(device) for e in data["bounds"]]
            box_priors: List[List[Tuple[
                Tensor, Tensor]]]  # one more level for the batch
            box_priors = [[(m.to(device), b.to(device)) for (m, b) in B]
                          for B in data["box_priors"]]
            assert len(labels) == len(bounds)

            B, C, *_ = image.shape

            samplings: List[List[Tuple[slice]]] = data["samplings"]
            assert len(samplings) == B
            assert len(samplings[0][0]) == len(
                image[0, 0].shape), (samplings[0][0], image[0, 0].shape)

            probs_receptacle: Tensor = -torch.ones_like(
                target, dtype=torch.float32)  # -1 for unfilled
            mask_receptacle: Tensor = -torch.ones_like(
                target, dtype=torch.int32)  # -1 for unfilled

            # Use the sampling coordinates of the first batch item
            assert not (len(samplings[0]) > 1 and
                        B > 1), samplings  # No subsampling if batch size > 1
            loss_sub_log: Tensor = torch.zeros(
                (len(samplings[0]), len(loss_fns)),
                dtype=torch.float32,
                device=device)
            for k, sampling in enumerate(samplings[0]):
                img_sampling = [slice(0, B), slice(0, C)] + list(sampling)
                label_sampling = [slice(0, B), slice(0, K)] + list(sampling)
                assert len(img_sampling) == len(image.shape), (img_sampling,
                                                               image.shape)
                sub_img = image[img_sampling]

                # Reset gradients
                if optimizer:
                    optimizer.zero_grad()

                # Forward
                pred_logits: Tensor = net(sub_img)
                pred_probs: Tensor = F.softmax(temperature * pred_logits,
                                               dim=1)
                predicted_mask: Tensor = probs2one_hot(
                    pred_probs.detach())  # Used only for dice computation
                assert not predicted_mask.requires_grad

                probs_receptacle[label_sampling] = pred_probs[...]
                mask_receptacle[label_sampling] = predicted_mask[...]

                assert len(bounds) == len(loss_fns) == len(
                    loss_weights) == len(labels)
                ziped = zip(loss_fns, labels, loss_weights, bounds)
                losses = [
                    w * loss_fn(pred_probs, label[label_sampling], bound,
                                box_priors)
                    for loss_fn, label, w, bound in ziped
                ]
                loss = reduce(add, losses)
                assert loss.shape == (), loss.shape

                # Backward
                if optimizer:
                    loss.backward()
                    optimizer.step()

                # Compute and log metrics
                for j in range(len(loss_fns)):
                    loss_sub_log[k, j] = losses[j].detach()
            reduced_loss_sublog: Tensor = loss_sub_log.sum(dim=0)
            assert reduced_loss_sublog.shape == (len(loss_fns), ), (
                reduced_loss_sublog.shape, len(loss_fns))
            loss_log[done_batch, ...] = reduced_loss_sublog[...]
            del loss_sub_log

            sm_slice = slice(done_img,
                             done_img + B)  # Values only for current batch

            dices: Tensor = dice_coef(mask_receptacle, target)
            assert dices.shape == (B, K), (dices.shape, B, K)
            all_dices[sm_slice, ...] = dices

            if compute_3d_dice:
                three_d_DSC: Tensor = dice_batch(mask_receptacle, target)
                assert three_d_DSC.shape == (K, )

                three_d_dices[done_batch] = three_d_DSC  # type: ignore

            if compute_hausdorff:
                hausdorff_res: Tensor
                try:
                    hausdorff_res = hausdorff(mask_receptacle, target,
                                              spacings)
                except RuntimeError:
                    hausdorff_res = torch.zeros((B, K), device=device)
                assert hausdorff_res.shape == (B, K)
                hausdorff_log[sm_slice] = hausdorff_res  # type: ignore
            if compute_miou:
                IoUs: Tensor = iIoU(mask_receptacle, target)
                assert IoUs.shape == (B, K), IoUs.shape
                iiou_log[sm_slice] = IoUs  # type: ignore
                intersections[sm_slice] = inter_sum(mask_receptacle,
                                                    target)  # type: ignore
                unions[sm_slice] = union_sum(mask_receptacle,
                                             target)  # type: ignore

            # if False and target[0, 1].sum() > 0:  # Useful template for quick and dirty inspection
            #     import matplotlib.pyplot as plt
            #     from pprint import pprint
            #     from mpl_toolkits.axes_grid1 import ImageGrid
            #     from utils import soft_length

            #     print(data["filenames"])
            #     pprint(data["bounds"])
            #     pprint(soft_length(mask_receptacle))

            #     fig = plt.figure()
            #     fig.clear()

            #     grid = ImageGrid(fig, 211, nrows_ncols=(1, 2))

            #     grid[0].imshow(data["images"][0, 0], cmap="gray")
            #     grid[0].contour(data["gt"][0, 1], cmap='jet', alpha=.75, linewidths=2)

            #     grid[1].imshow(data["images"][0, 0], cmap="gray")
            #     grid[1].contour(mask_receptacle[0, 1], cmap='jet', alpha=.75, linewidths=2)
            #     plt.show()

            # Save images
            if savedir:
                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=UserWarning)
                    predicted_class: Tensor = probs2class(pred_probs)
                    save_images(predicted_class, data["filenames"], savedir,
                                mode, epc)

            # Logging
            big_slice = slice(0, done_img +
                              B)  # Value for current and previous batches

            dsc_dict: Dict
            if few_axis:
                dsc_dict = {
                    **{
                        f"DSC{n}": all_dices[big_slice, n].mean()
                        for n in metric_axis
                    },
                    **({
                        f"3d_DSC{n}": three_d_dices[:done_batch, n].mean()
                        for n in metric_axis
                    } if three_d_dices is not None else {})
                }
            else:
                dsc_dict = {}

            # dsc_dict = {f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis} if few_axis else {}

            hauss_dict = {f"HD{n}": hausdorff_log[big_slice, n].mean() for n in metric_axis} \
                if hausdorff_log is not None and few_axis else {}

            miou_dict = {f"iIoU": iiou_log[big_slice, metric_axis].mean(),
                         f"mIoU": (intersections.sum(dim=0) / (unions.sum(dim=0) + 1e-10)).mean()} \
                if iiou_log is not None and intersections is not None and unions is not None else {}

            if len(metric_axis) > 1:
                mean_dict = {"DSC": all_dices[big_slice, metric_axis].mean()}
                if hausdorff_log:
                    mean_dict["HD"] = hausdorff_log[big_slice,
                                                    metric_axis].mean()
            else:
                mean_dict = {}

            stat_dict = {
                **miou_dict,
                **dsc_dict,
                **hauss_dict,
                **mean_dict, "loss": loss_log[:done_batch].mean()
            }
            nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()}

            done_img += B
            done_batch += 1
            tq_iter.set_postfix({**nice_dict, "loader": str(i)})
            tq_iter.update(1)
    tq_iter.close()
    print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items()))

    mIoUs: Optional[Tensor]
    if intersections and unions:
        mIoUs = (intersections.sum(dim=0) / (unions.sum(dim=0) + 1e-10))
        assert mIoUs.shape == (K, ), mIoUs.shape
    else:
        mIoUs = None

    if not few_axis and False:
        print(f"DSC: {[f'{all_dices[:, n].mean():.3f}' for n in metric_axis]}")
        print(f"iIoU: {[f'{iiou_log[:, n].mean():.3f}' for n in metric_axis]}")
        if mIoUs:
            print(f"mIoU: {[f'{mIoUs[n]:.3f}' for n in metric_axis]}")

    return (
        loss_log.detach().cpu(), all_dices.detach().cpu(),
        hausdorff_log.detach().cpu() if hausdorff_log is not None else None,
        mIoUs.detach().cpu() if mIoUs is not None else None,
        three_d_dices.detach().cpu() if three_d_dices is not None else None)
Beispiel #9
0
def do_epoch(mode: str, net: Any, device: Any, loaders: list[DataLoader], epc: int,
             list_loss_fns: list[list[Callable]], list_loss_weights: list[list[float]], K: int,
             savedir: Path = None, optimizer: Any = None,
             metric_axis: list[int] = [1], requested_metrics: list[str] = None,
             temperature: float = 1) -> dict[str, Tensor]:
        assert mode in ["train", "val", "dual"]
        if requested_metrics is None:
                requested_metrics = []

        if mode == "train":
                net.train()
                desc = f">> Training   ({epc})"
        elif mode == "val":
                net.eval()
                desc = f">> Validation ({epc})"
        elif mode == "dual":
                net.eval()
                desc = f">> Dual       ({epc})"

        total_iteration: int = sum(len(loader) for loader in loaders)  # U
        total_images: int = sum(len(loader.dataset) for loader in loaders)  # D
        n_loss: int = max(map(len, list_loss_fns))

        epoch_metrics: dict[str, Tensor]
        epoch_metrics = {"dice": torch.zeros((total_images, K), dtype=torch.float32, device=device),
                         "loss": torch.zeros((total_iteration, n_loss), dtype=torch.float32, device=device)}

        if "3d_dsc" in requested_metrics:
                epoch_metrics["3d_dsc"] = torch.zeros((total_iteration, K), dtype=torch.float32, device=device)

        few_axis: bool = len(metric_axis) <= 4

        # time_log: np.ndarray = np.ndarray(total_iteration, dtype=np.float32)

        done_img: int = 0
        done_batch: int = 0
        tq_iter = tqdm_(total=total_iteration, desc=desc)
        for i, (loader, loss_fns, loss_weights) in enumerate(zip(loaders, list_loss_fns, list_loss_weights)):
                for data in loader:
                        # t0 = time()
                        image: Tensor = data["images"].to(device)
                        target: Tensor = data["gt"].to(device)
                        filenames: list[str] = data["filenames"]
                        assert not target.requires_grad

                        labels: list[Tensor] = [e.to(device) for e in data["labels"]]
                        bounds: list[Tensor] = [e.to(device) for e in data["bounds"]]
                        assert len(labels) == len(bounds)

                        B, C, *_ = image.shape

                        samplings: list[list[Tuple[slice]]] = data["samplings"]
                        assert len(samplings) == B
                        assert len(samplings[0][0]) == len(image[0, 0].shape), (samplings[0][0], image[0, 0].shape)

                        probs_receptacle: Tensor = - torch.ones_like(target, dtype=torch.float32)  # -1 for unfilled
                        mask_receptacle: Tensor = - torch.ones_like(target, dtype=torch.int32)  # -1 for unfilled

                        # Use the sampling coordinates of the first batch item
                        assert not (len(samplings[0]) > 1 and B > 1), samplings  # No subsampling if batch size > 1
                        loss_sub_log: Tensor = torch.zeros((len(samplings[0]), len(loss_fns)), 
                                                           dtype=torch.float32, device=device)
                        for k, sampling in enumerate(samplings[0]):
                                img_sampling = [slice(0, B), slice(0, C)] + list(sampling)
                                label_sampling = [slice(0, B), slice(0, K)] + list(sampling)
                                assert len(img_sampling) == len(image.shape), (img_sampling, image.shape)
                                sub_img = image[img_sampling]

                                # Reset gradients
                                if optimizer:
                                        optimizer.zero_grad()

                                # Forward
                                pred_logits: Tensor = net(sub_img)
                                pred_probs: Tensor = F.softmax(temperature * pred_logits, dim=1)

                                # Used only for dice computation:
                                predicted_mask: Tensor = probs2one_hot(pred_probs.detach())  
                                assert not predicted_mask.requires_grad

                                probs_receptacle[label_sampling] = pred_probs[...]
                                mask_receptacle[label_sampling] = predicted_mask[...]

                                assert len(bounds) == len(loss_fns) == len(loss_weights) == len(labels)
                                ziped = zip(loss_fns, labels, loss_weights, bounds)
                                losses = [w * loss_fn(pred_probs, label[label_sampling], bound, filenames)
                                          for loss_fn, label, w, bound in ziped]
                                loss = reduce(add, losses)
                                assert loss.shape == (), loss.shape

                                # Backward
                                if optimizer:
                                        loss.backward()
                                        optimizer.step()

                                # Compute and log metrics
                                for j in range(len(loss_fns)):
                                        loss_sub_log[k, j] = losses[j].detach()
                        reduced_loss_sublog: Tensor = loss_sub_log.sum(dim=0)
                        assert reduced_loss_sublog.shape == (len(loss_fns),), (reduced_loss_sublog.shape, len(loss_fns))
                        epoch_metrics["loss"][done_batch, ...] = reduced_loss_sublog[...]
                        del loss_sub_log

                        sm_slice = slice(done_img, done_img + B)  # Values only for current batch

                        dices: Tensor = dice_coef(mask_receptacle, target)
                        assert dices.shape == (B, K), (dices.shape, B, K)
                        epoch_metrics["dice"][sm_slice, ...] = dices

                        if "3d_dsc" in requested_metrics:
                                three_d_DSC: Tensor = dice_batch(mask_receptacle, target)
                                assert three_d_DSC.shape == (K,)

                                epoch_metrics["3d_dsc"][done_batch] = three_d_DSC  # type: ignore

                        # Save images
                        if savedir:
                                with warnings.catch_warnings():
                                        warnings.filterwarnings("ignore", category=UserWarning)
                                        predicted_class: Tensor = probs2class(pred_probs)
                                        save_images(predicted_class, 
                                                    data["filenames"], 
                                                    savedir / f"iter{epc:03d}" / mode)

                        # Logging
                        big_slice = slice(0, done_img + B)  # Value for current and previous batches

                        stat_dict: dict[str, Any] = {}
                        # The order matters for the final display -- it is easy to change

                        if few_axis:
                                stat_dict |= {f"DSC{n}": epoch_metrics["dice"][big_slice, n].mean()
                                              for n in metric_axis}

                                if "3d_dsc" in requested_metrics:
                                        stat_dict |= {f"3d_DSC{n}": epoch_metrics["3d_dsc"][:done_batch, n].mean()
                                                      for n in metric_axis}

                        if len(metric_axis) > 1:
                                stat_dict |= {"DSC": epoch_metrics["dice"][big_slice, metric_axis].mean()}

                        stat_dict |= {f"loss_{i}": epoch_metrics["loss"][:done_batch].mean(dim=0)[i] 
                                      for i in range(n_loss)}

                        nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()}

                        # t1 = time()
                        # time_log[done_batch] = (t1 - t0)

                        done_img += B
                        done_batch += 1
                        tq_iter.set_postfix({**nice_dict, "loader": str(i)})
                        tq_iter.update(1)
        tq_iter.close()

        print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items()))

        return {k: v.detach().cpu() for (k, v) in epoch_metrics.items()}