示例#1
0
def runInference(args: argparse.Namespace):
    print('>>> Loading model')
    net = torch.load(args.model_weights)
    device = torch.device("cuda")
    net.to(device)

    print('>>> Loading the data')
    batch_size: int = args.batch_size
    num_classes: int = args.num_classes

    transform = transforms.Compose([
        lambda img: np.array(img)[np.newaxis, ...],
        lambda nd: nd / 255,  # max <= 1
        lambda nd: torch.tensor(nd, dtype=torch.float32)
    ])

    folders: List[Path] = [Path(args.data_folder)]
    names: List[str] = map_(lambda p: str(p.name), folders[0].glob("*.png"))
    dt_set = SliceDataset(names,
                          folders,
                          transforms=[transform],
                          debug=False,
                          C=num_classes)
    loader = DataLoader(dt_set,
                        batch_size=batch_size,
                        num_workers=batch_size + 2,
                        shuffle=False,
                        drop_last=False)

    print('>>> Starting the inference')
    savedir: str = args.save_folder
    total_iteration = len(loader)
    desc = f">> Inference"
    tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc)
    with torch.no_grad():
        for j, (filenames, image, _) in tq_iter:
            image = image.to(device)

            pred_logits: Tensor = net(image)
            pred_probs: Tensor = F.softmax(pred_logits, dim=1)

            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                predicted_class: Tensor = probs2class(pred_probs)
                save_images(predicted_class, filenames, savedir, "", 0)
def runInference(args: argparse.Namespace):
    print('>>> Loading model')
    net = torch.load(args.model_weights)
    device = torch.device("cuda")
    net.to(device)

    print('>>> Loading the data')
    batch_size: int = args.batch_size
    num_classes: int = args.num_classes

    folders: list[Path] = [Path(args.data_folder)]
    names: list[str] = map_(lambda p: str(p.name), folders[0].glob("*.png"))
    dt_set = SliceDataset(
        names,
        folders * 2,  # Duplicate for compatibility reasons
        are_hots=[False, False],
        transforms=[png_transform, dummy_gt_transform
                    ],  # So it is happy about the target size
        bounds_generators=[],
        debug=args.debug,
        K=num_classes)
    loader = DataLoader(dt_set,
                        batch_size=batch_size,
                        num_workers=batch_size + 2,
                        shuffle=False,
                        drop_last=False,
                        collate_fn=custom_collate)

    print('>>> Starting the inference')
    savedir: Path = Path(args.save_folder)
    savedir.mkdir(parents=True, exist_ok=True)
    total_iteration = len(loader)
    desc = ">> Inference"
    tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc)
    with torch.no_grad():
        for j, data in tq_iter:
            filenames: list[str] = data["filenames"]
            image: Tensor = data["images"].to(device)

            pred_logits: Tensor = net(image)
            pred_probs: Tensor = F.softmax(pred_logits, dim=1)

            with warnings.catch_warnings():
                warnings.simplefilter("ignore")

                predicted_class: Tensor
                if args.mode == 'argmax':
                    predicted_class = probs2class(pred_probs)
                elif args.mode == 'probs':
                    predicted_class = (pred_probs[:, args.probs_class, ...] *
                                       255).type(torch.uint8)
                elif args.mode == 'threshold':
                    thresholded: Tensor = pred_probs[:, ...] > args.threshold
                    predicted_class = thresholded.argmax(dim=1)
                elif args.mode == 'softmax':
                    for i, filename in enumerate(filenames):
                        np.save((savedir / filename).with_suffix(".npy"),
                                pred_probs[i].cpu().numpy())

                if args.mode != 'softmax':
                    save_images(predicted_class, filenames, savedir)
示例#3
0
def runInference(args: argparse.Namespace, pred_folder: str):
    # print('>>> Loading the data')
    device = torch.device("cuda") if torch.cuda.is_available(
    ) and not args.cpu else torch.device("cpu")
    C: int = args.num_classes

    # Let's just reuse some code
    png_transform = transforms.Compose([
        lambda img: np.array(img)[np.newaxis, ...],
        lambda nd: nd / 255,  # max <= 1
        lambda nd: torch.tensor(nd, dtype=torch.float32)
    ])
    gt_transform = transforms.Compose([
        lambda img: np.array(img)[np.newaxis, ...],
        lambda nd: torch.tensor(nd, dtype=torch.int64),
        partial(class2one_hot, C=C),
        itemgetter(0)
    ])

    bounds_gen = [(lambda *a: torch.zeros(C, 1, 2)) for _ in range(2)]

    folders: List[Path] = [
        Path(pred_folder),
        Path(pred_folder),
        Path(args.gt_folder)
    ]  # First one is dummy
    names: List[str] = map_(lambda p: str(p.name), folders[0].glob("*.png"))
    are_hots = [False, True, True]

    dt_set = SliceDataset(
        names,
        folders,
        transforms=[png_transform, gt_transform, gt_transform],
        debug=False,
        C=C,
        are_hots=are_hots,
        in_memory=False,
        bounds_generators=bounds_gen)
    sampler = PatientSampler(dt_set, args.grp_regex)
    loader = DataLoader(dt_set, batch_sampler=sampler, num_workers=11)

    # print('>>> Computing the metrics')
    total_iteration, total_images = len(loader), len(loader.dataset)
    metrics = {
        "all_dices":
        torch.zeros((total_images, C), dtype=torch.float64, device=device),
        "batch_dices":
        torch.zeros((total_iteration, C), dtype=torch.float64, device=device),
        "sizes":
        torch.zeros((total_images, 1), dtype=torch.float64, device=device)
    }

    desc = f">> Computing"
    tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc)
    done: int = 0
    for j, (filenames, _, pred, gt, _) in tq_iter:
        B = len(pred)
        pred = pred.to(device)
        gt = gt.to(device)
        assert simplex(pred) and sset(pred, [0, 1])
        assert simplex(gt) and sset(gt, [0, 1])

        dices: Tensor = dice_coef(pred, gt)
        b_dices: Tensor = dice_batch(pred, gt)
        assert dices.shape == (B, C)
        assert b_dices.shape == (C, ), b_dices.shape

        sm_slice = slice(done, done + B)  # Values only for current batch
        metrics["all_dices"][sm_slice, ...] = dices
        metrics["sizes"][sm_slice, :] = torch.einsum("bwh->b",
                                                     gt[:, 1, ...])[..., None]
        metrics["batch_dices"][j] = b_dices
        done += B

    print(f">>> {pred_folder}")
    for key, v in metrics.items():
        print(key, map_("{:.4f}".format, v.mean(dim=0)))
def runInference(args: argparse.Namespace):
    # print('>>> Loading the data')
    # device = torch.device("cuda") if torch.cuda.is_available() and not args.cpu else torch.device("cpu")
    device = torch.device("cpu")
    C: int = args.num_classes

    # Let's just reuse some code
    png_transform = transforms.Compose([
        lambda img: np.array(img)[np.newaxis, ...],
        lambda nd: nd / 255,  # max <= 1
        lambda nd: torch.tensor(nd, dtype=torch.float32)
    ])
    gt_transform = transforms.Compose([
        lambda img: np.array(img)[np.newaxis, ...],
        lambda nd: torch.tensor(nd, dtype=torch.int64),
        partial(class2one_hot, C=C),
        itemgetter(0)
    ])

    bounds_gen = [(lambda *a: torch.zeros(C, 1, 2)) for _ in range(2)]
    metrics = None

    pred_folders = sorted(list(Path(args.pred_root).glob('iter*')))
    assert len(pred_folders) == args.epochs, (len(pred_folders), args.epochs)
    for epoch, pred_folder in enumerate(pred_folders):
        if args.do_only and epoch not in args.do_only:
            continue

        # First one is dummy:
        folders: List[Path] = [Path(pred_folder, 'val'), Path(pred_folder, 'val'), Path(args.gt_folder)]
        names: List[str] = map_(lambda p: str(p.name), folders[0].glob("*.png"))
        are_hots = [False, True, True]

        # spacing_dict = pickle.load(open(Path(args.gt_folder, "..", "spacing.pkl"), 'rb'))
        spacing_dict = None

        dt_set = SliceDataset(names,
                              folders,
                              transforms=[png_transform, gt_transform, gt_transform],
                              debug=False,
                              C=C,
                              are_hots=are_hots,
                              in_memory=False,
                              spacing_dict=spacing_dict,
                              bounds_generators=bounds_gen,
                              quiet=True)
        loader = DataLoader(dt_set,
                            num_workers=2)

        # print('>>> Computing the metrics')
        total_iteration, total_images = len(loader), len(loader.dataset)
        if not metrics:
            metrics = {"all_dices": torch.zeros((args.epochs, total_images, C), dtype=torch.float64, device=device),
                       "hausdorff": torch.zeros((args.epochs, total_images, C), dtype=torch.float64, device=device)}

        desc = f">> Computing"
        tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc)
        done: int = 0
        for j, (filenames, _, pred, gt, _) in tq_iter:
            B = len(pred)
            pred = pred.to(device)
            gt = gt.to(device)
            assert simplex(pred) and sset(pred, [0, 1])
            assert simplex(gt) and sset(gt, [0, 1])

            dices: Tensor = dice_coef(pred, gt)
            assert dices.shape == (B, C)

            haussdorf_res: Tensor = haussdorf(pred, gt)
            assert haussdorf_res.shape == (B, C)

            sm_slice = slice(done, done + B)  # Values only for current batch
            metrics["all_dices"][epoch, sm_slice, ...] = dices
            metrics["hausdorff"][epoch, sm_slice, ...] = haussdorf_res
            done += B

        for key, v in metrics.items():
            print(epoch, key, map_("{:.4f}".format, v[epoch].mean(dim=0)))

    if metrics:
        savedir: Path = Path(args.save_folder)
        for k, e in metrics.items():
            np.save(Path(savedir, f"{k}.npy"), e.cpu().numpy())
示例#5
0
def runInference(args: argparse.Namespace):
    print('>>> Loading model')
    net = torch.load(args.model_weights)
    device = torch.device("cuda")
    net.to(device)

    print('>>> Loading the data')
    batch_size: int = args.batch_size
    num_classes: int = args.num_classes

    png_transform = transforms.Compose([
        lambda img: img.convert('L'),
        lambda img: np.array(img)[np.newaxis, ...],
        lambda nd: nd / 255,  # max <= 1
        lambda nd: torch.tensor(nd, dtype=torch.float32)
    ])
    dummy_gt = transforms.Compose([
        lambda img: np.array(img), lambda nd: torch.zeros(
            (num_classes, *(nd.shape)), dtype=torch.int64)
    ])

    folders: List[Path] = [Path(args.data_folder)]
    names: List[str] = map_(lambda p: str(p.name), folders[0].glob("*.png"))
    dt_set = SliceDataset(
        names,
        folders * 2,  # Duplicate for compatibility reasons
        are_hots=[False, False],
        transforms=[png_transform,
                    dummy_gt],  # So it is happy about the target size
        bounds_generators=[],
        debug=False,
        C=num_classes)
    loader = DataLoader(dt_set,
                        batch_size=batch_size,
                        num_workers=batch_size + 2,
                        shuffle=False,
                        drop_last=False)

    print('>>> Starting the inference')
    savedir: str = args.save_folder
    total_iteration = len(loader)
    desc = f">> Inference"
    tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc)
    with torch.no_grad():
        for j, (filenames, image, _) in tq_iter:
            image = image.to(device)

            pred_logits: Tensor = net(image)
            pred_probs: Tensor = F.softmax(pred_logits, dim=1)

            with warnings.catch_warnings():
                warnings.simplefilter("ignore")

                predicted_class: Tensor
                if args.mode == "argmax":
                    predicted_class = probs2class(pred_probs)
                elif args.mode == 'probs':
                    predicted_class = (pred_probs[:, args.probs_class, ...] *
                                       255).type(torch.uint8)
                elif args.mode == "threshold":
                    thresholded: Tensor = pred_probs[:, ...] > args.threshold
                    predicted_class = thresholded.argmax(dim=1)

                save_images(predicted_class, filenames, savedir, "", 0)
def main() -> None:
    args = get_args()

    iterations_paths: List[Path] = sorted(Path(args.basefolder).glob("iter*"))
    # print(iterations_paths)
    print(f">>> Found {len(iterations_paths)} epoch folders")

    # Handle gracefully if not all folders are there (early stop)
    EPC: int = args.n_epoch if args.n_epoch >= 0 else len(iterations_paths)
    K: int = args.num_classes

    # Get the patient number, and image names, from the GT folder
    gt_path: Path = Path(args.gt_folder)
    names: List[str] = map_(lambda p: str(p.name), gt_path.glob("*"))
    n_img: int = len(names)

    grouping_regex: Pattern = re.compile(args.grp_regex)
    stems: List[str] = [Path(filename).stem
                        for filename in names]  # avoid matching the extension
    matches: List[Match] = map_(grouping_regex.match, stems)  # type: ignore
    patients: List[str] = [match.group(1) for match in matches]

    unique_patients: List[str] = list(set(patients))
    n_patients: int = len(unique_patients)

    print(
        f">>> Found {len(unique_patients)} unique patients out of {n_img} images ; regex: {args.grp_regex}"
    )
    # from pprint import pprint
    # pprint(unique_patients)

    # First, quickly assert all folders have the same numbers of predited images
    n_img_epoc: List[int] = [
        len(list(Path(p, "val").glob("*.png"))) for p in iterations_paths
    ]
    assert len(set(n_img_epoc)) == 1
    assert all(
        len(list(Path(p, "val").glob("*.png"))) == n_img
        for p in iterations_paths)

    metrics: Dict['str', Tensor] = {}
    if '3d_dsc' in args.metrics:
        metrics['3d_dsc'] = torch.zeros((EPC, n_patients, K),
                                        dtype=torch.float32)
        print(f">> Will compute {'3d_dsc'} metric")
    if '3d_hausdorff' in args.metrics:
        metrics['3d_hausdorff'] = torch.zeros((EPC, n_patients, K),
                                              dtype=torch.float32)
        print(f">> Will compute {'3d_hausdorff'} metric")

    gen_dataset = partial(
        SliceDataset,
        transforms=[png_transform, gt_transform, gt_transform],
        are_hots=[False, True, True],
        K=K,
        in_memory=False,
        bounds_generators=[(lambda *a: torch.zeros(K, 1, 2))
                           for _ in range(1)],
        box_prior=False,
        box_priors_arg='{}',
        dimensions=2)
    data_loader = partial(DataLoader,
                          num_workers=cpu_count(),
                          pin_memory=False,
                          collate_fn=custom_collate)

    # Will replace live dataset.folders and call again load_images to update dataset.files
    print(gt_path, gt_path, Path(iterations_paths[0], 'val'))
    dataset: SliceDataset = gen_dataset(
        names,
        [gt_path, gt_path, Path(iterations_paths[0], 'val')])
    sampler: PatientSampler = PatientSampler(dataset,
                                             args.grp_regex,
                                             shuffle=False)
    dataloader: DataLoader = data_loader(dataset, batch_sampler=sampler)

    current_path: Path
    for e, current_path in enumerate(iterations_paths):
        dataset.folders = [gt_path, gt_path, Path(current_path, 'val')]
        dataset.files = SliceDataset.load_images(dataset.folders,
                                                 dataset.filenames, False)

        print(f">>> Doing epoch {str(current_path)}")

        for i, data in enumerate(tqdm(dataloader, leave=None)):
            target: Tensor = data["gt"]
            prediction: Tensor = data["labels"][0]

            assert target.shape == prediction.shape

            if '3d_dsc' in args.metrics:
                dsc: Tensor = dice_batch(target, prediction)
                assert dsc.shape == (K, )

                metrics['3d_dsc'][e, i, :] = dsc
            if '3d_hausdorff' in args.metrics:
                np_pred: np.ndarray = prediction[:, 1, :, :].cpu().numpy()
                np_target: np.ndarray = target[:, 1, :, :].cpu().numpy()

                if np_pred.sum() > 0:
                    hd_: float = hd(np_pred, np_target)

                    metrics["3d_hausdorff"][e, i, 1] = hd_
                else:
                    x, y, z = np_pred.shape
                    metrics["3d_hausdorff"][e, i,
                                            1] = (x**2 + y**2 + z**2)**0.5

        for metric in args.metrics:
            # For now, hardcode the fact we care about class 1 only
            print(f">> {metric}: {metrics[metric][e].mean(dim=0)[1]:.04f}")

    k: str
    el: Tensor
    for k, el in metrics.items():
        np.save(Path(args.basefolder, f"val_{k}.npy"), el.cpu().numpy())