Esempio n. 1
0
def apply(image,
          model=None,
          force_cpu=False,
          batch_size=20,
          volume_postprocessing=True,
          show_process=True):
    if model is None:
        model = get_model('unet', 'R231')

    voxvol = np.prod(image.GetSpacing())
    inimg_raw = sitk.GetArrayFromImage(image)
    del image

    if force_cpu:
        device = torch.device('cpu')
    else:
        if torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            logging.info(
                "No GPU support available, will use CPU. Note, that this is significantly slower!"
            )
            batch_size = 1
            device = torch.device('cpu')
    model.to(device)

    tvolslices, xnew_box = utils.preprocess(inimg_raw, resolution=[256, 256])
    tvolslices[tvolslices > 600] = 600
    tvolslices = np.divide((tvolslices + 1024), 1624)
    torch_ds_val = utils.LungLabelsDS_inf(tvolslices)
    dataloader_val = torch.utils.data.DataLoader(torch_ds_val,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=1,
                                                 pin_memory=False)

    timage_res = np.empty((np.append(0, tvolslices[0].shape)), dtype=np.uint8)

    with torch.no_grad():
        for X in tqdm(dataloader_val):
            X = X.float().to(device)
            prediction = model(X)
            pls = torch.max(prediction,
                            1)[1].detach().cpu().numpy().astype(np.uint8)
            timage_res = np.vstack((timage_res, pls))

    # postprocessing includes removal of small connected components, hole filling and mapping of small components to
    # neighbors
    if volume_postprocessing:
        outmask = utils.postrocessing(timage_res)
    else:
        outmask = timage_res

    outmask = np.asarray([
        utils.reshape_mask(outmask[i], xnew_box[i], inimg_raw.shape[1:])
        for i in range(outmask.shape[0])
    ],
                         dtype=np.uint8)

    return outmask
def inference(image, model=None, force_cpu=False):
    xs, ys, zs = image.GetSpacing()
    voxvol = np.prod(image.GetSpacing())
    inimg_raw = sitk.GetArrayFromImage(image)
    del image

    tvolslices, xnew_box = utils.preprocess(inimg_raw, resolution=[512, 512])
    tvolslices[tvolslices > 600] = 600
    tvolslices = np.divide((tvolslices + 1024), 1624)
    outmask = np.empty((np.append(0, tvolslices[0].shape)), dtype=np.uint8)

    stats = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    mean, std = broadcast_stats(1, 4, stats)

    for idx in tqdm.tqdm(range(tvolslices.shape[0])):
        data = tvolslices[idx, :, :]
        with torch.no_grad():

            data = np.expand_dims(np.expand_dims(data, 0), 0)
            data = torch.tensor(data)
            data = (data - mean) / std
            data = data.float() if force_cpu else data.cuda().float()
            pred = model(data)
            pls = torch.max(pred, 1)[1].detach().cpu().numpy().astype(np.uint8)
            outmask = np.vstack((outmask, pls))

    outmask = np.asarray([
        reshape_mask(outmask[i], xnew_box[i], inimg_raw.shape[1:])
        for i in range(outmask.shape[0])
    ],
                         dtype=np.uint8)

    ggo = np.sum(
        outmask.T == 1) * xs * ys * zs * 0.001  # mm * mm * mm * 0.001 (cm)
    consolidation = np.sum(
        outmask.T == 2) * xs * ys * zs * 0.001  # mm * mm * mm * 0.001 (cm)
    pleural_effusion = np.sum(
        outmask.T == 3) * xs * ys * zs * 0.001  # mm * mm * mm * 0.001 (cm)

    return {
        "mask": outmask,
        "ggo": ggo,
        "consolidation": consolidation,
        "pleural_effusion": pleural_effusion
    }
Esempio n. 3
0
def apply(image,
          model=None,
          force_cpu=False,
          batch_size=20,
          volume_postprocessing=True,
          noHU=False,
          verbose=False):
    if model is None:
        model = get_model('unet', 'R231')

    inimg_raw = sitk.GetArrayFromImage(image)
    directions = np.asarray(image.GetDirection())
    if len(directions) == 9:
        inimg_raw = np.flip(inimg_raw,
                            np.where(directions[[0, 4, 8]][::-1] < 0)[0])
    del image

    if force_cpu:
        device = torch.device('cpu')
    else:
        if torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            logging.info(
                "No GPU support available, will use CPU. Note, that this is significantly slower!"
            )
            batch_size = 1
            device = torch.device('cpu')
    model.to(device)

    if not noHU:
        tvolslices, xnew_box = utils.preprocess(inimg_raw,
                                                resolution=[256, 256])
        tvolslices[tvolslices > 600] = 600
        tvolslices = np.divide((tvolslices + 1024), 1624)
    else:
        # support for non HU images. This is just a hack. The models were not trained with this in mind
        tvolslices = skimage.color.rgb2gray(inimg_raw)
        tvolslices = skimage.transform.resize(tvolslices, [256, 256])
        tvolslices = np.asarray(
            [tvolslices * x for x in np.linspace(0.3, 2, 20)])
        tvolslices[tvolslices > 1] = 1
        sanity = [(tvolslices[x] > 0.6).sum() > 25000
                  for x in range(len(tvolslices))]
        tvolslices = tvolslices[sanity]
    torch_ds_val = utils.LungLabelsDS_inf(tvolslices)
    dataloader_val = torch.utils.data.DataLoader(torch_ds_val,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=1,
                                                 pin_memory=False)

    timage_res = np.empty((np.append(0, tvolslices[0].shape)), dtype=np.uint8)

    with torch.no_grad():
        for X in tqdm(dataloader_val, disable=not verbose):
            X = X.float().to(device)
            prediction = model(X)
            pls = torch.max(prediction,
                            1)[1].detach().cpu().numpy().astype(np.uint8)
            timage_res = np.vstack((timage_res, pls))

    # postprocessing includes removal of small connected components, hole filling and mapping of small components to
    # neighbors
    if volume_postprocessing:
        outmask = utils.postrocessing(timage_res)
    else:
        outmask = timage_res

    if noHU:
        outmask = skimage.transform.resize(outmask[np.argmax(
            (outmask == 1).sum(axis=(1, 2)))],
                                           inimg_raw.shape[:2],
                                           order=0,
                                           anti_aliasing=False,
                                           preserve_range=True)[None, :, :]
    else:
        outmask = np.asarray([
            utils.reshape_mask(outmask[i], xnew_box[i], inimg_raw.shape[1:])
            for i in range(outmask.shape[0])
        ],
                             dtype=np.uint8)

    if len(directions) == 9:
        outmask = np.flip(outmask,
                          np.where(directions[[0, 4, 8]][::-1] < 0)[0])

    return outmask.astype(np.uint8)
Esempio n. 4
0
def predict(input_file_path, base_dirname):
    """Load Data"""
    image = sitk.ReadImage(input_file_path)
    inimg_raw = sitk.GetArrayFromImage(image)
    del image

    image = sitk.ReadImage(
        os.path.join(curr_dir, 'training_images', 'tr_mask.nii'))
    gt_raw = sitk.GetArrayFromImage(image)
    del image

    image = sitk.ReadImage(
        os.path.join(curr_dir, 'training_images',
                     'tr_lungmasks_updated.nii.gz'))
    lobe_raw = sitk.GetArrayFromImage(image)
    del image

    lobe_raw[lobe_raw == 1] = 10
    lobe_raw[lobe_raw == 2] = 20

    y_raw = lobe_raw + gt_raw

    X, xnew_box, Y = utils.preprocess(inimg_raw,
                                      label=y_raw,
                                      resolution=[256, 256])

    X = torch.from_numpy(X).unsqueeze(1).float()
    '''Model'''
    n_classes = 5
    model = UNet(n_classes=n_classes,
                 padding=True,
                 depth=5,
                 up_mode='upsample',
                 batch_norm=True,
                 residual=False)
    model = torch.nn.DataParallel(model)
    summary = torch.load(os.path.join(
        curr_dir, 'trained_models',
        'unet_lr0.0001_seed23_losstype0_augTrue_ver1.pth.rar'),
                         map_location=torch.device('cpu'))
    model.load_state_dict(summary["model"])

    model.eval()

    ct_scan_image_paths = []
    pred_image_paths = []
    ref_image_paths = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            ct_slice = X[i].unsqueeze(0)
            pred = model(ct_slice)
            pred = hardlabels(pred).float()

            ct_scan_image_path, pred_image_path, ref_image_path = save_images(
                str(i), ct_slice, pred, ref=None, base_dirname=base_dirname)
            ct_scan_image_paths.append(ct_scan_image_path)
            pred_image_paths.append(pred_image_path)

    print('Memory usage: %s (GB)' %
          (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1e+9))
    return {
        'origUrls': ct_scan_image_paths,
        'predUrls': pred_image_paths,
        'refUrls': ref_image_paths
    }