Beispiel #1
0
def predict(dicom_path, centroids):
    """ Predicts nodule boundaries.

    Given a pth to a DICOM image and a list of centroids
        (1) load the segmentation model from its serialized state
        (2) pre-process the dicom data into whatever format the segmentation
            model expects
        (3) for each pixel create an indicator 0 or 1 of if the pixel is
            cancerous
        (4) write this binary mask to disk, and return the path to the mask

    Args:
        dicom_path (str): a path to a DICOM directory
        centroids (list[dict]): A list of centroids of the form::
            {'x': int,
             'y': int,
             'z': int}

    Returns:
        dict: Dictionary containing path to serialized binary masks and
            volumes per centroid with form::
            {'binary_mask_path': str,
             'volumes': list[float]}
    """
    load_ct(dicom_path)
    segment_path = os.path.join(os.path.dirname(__file__),
                                'assets', 'test_mask.npy')
    volumes = calculate_volume(segment_path, centroids)
    return_value = {
        'binary_mask_path': segment_path,
        'volumes': volumes
    }

    return return_value
def predict(dicom_path):
    """ Predicts centroids of nodules in a DICOM image.

    Given an iterator of DICOM objects, this method will:
        (1) load the identification model from its serialized state
        (2) pre-process the dicom into whatever format the identification model
            expects
        (3) return centroids with a probability that each centroid
            is a nodule (as opposed to not a nodule)

    Note:
        This model doesn't detect whether or not a nodule is cancerous, that
        is done in the ``classify`` model.

    Args:
        dicom_path (str): a path to a DICOM image

    Returns:
        list(dict): a list of centroids in the form::
            {'x': int,
             'y': int,
             'z': int,
             'p_nodule': float}
    """
    load_ct(dicom_path)
    return [{'x': 0, 'y': 0, 'z': 0, 'p_nodule': 0.5}]
def test_load_meta(metaimage_path, dicom_path):
    meta = load_ct.load_ct(dicom_path, voxel=False)
    assert isinstance(meta, list)
    assert len(meta) > 0
    assert all(
        [isinstance(_slice, dicom.dataset.FileDataset) for _slice in meta])

    meta = load_ct.load_ct(metaimage_path, voxel=False)
    assert isinstance(meta, SimpleITK.SimpleITK.Image)
def test_load_ct(metaimage_path, dicom_path):
    ct_array, meta = load_ct.load_ct(dicom_path)
    assert isinstance(ct_array, np.ndarray)
    assert ct_array.shape[0] == len(meta)

    ct_array, meta = load_ct.load_ct(metaimage_path)
    assert isinstance(ct_array, np.ndarray)
    assert isinstance(meta, SimpleITK.SimpleITK.Image)

    try:
        load_ct.load_ct('.')
    except ValueError as e:
        assert 'contain any .mhd or .dcm files' in str(e)
Beispiel #5
0
def test_metadata(metaimage_path, dicom_path):
    meta = load_ct.load_ct(dicom_path, voxel=False)
    meta = load_ct.MetaData(meta)
    zipped = zip(meta.spacing, (0.703125, 0.703125, 2.5))
    assert all([m_axis == o_axis for m_axis, o_axis in zipped])

    meta = load_ct.load_ct(metaimage_path, voxel=False)
    spacing = list(reversed(meta.GetSpacing()))
    meta = load_ct.MetaData(meta)
    assert meta.spacing == spacing

    try:
        load_ct.MetaData([1, 2, 3])
    except ValueError as e:
        assert 'either list[dicom.dataset.FileDataset] or SimpleITK' in str(e)
def test_preprocess(metaimage_path):
    nodule_list = [{"z": 556, "y": 100, "x": 0}]
    image_itk = sitk.ReadImage(metaimage_path)

    image = sitk.GetArrayFromImage(image_itk)
    spacing = np.array(image_itk.GetSpacing())[::-1]
    origin = np.array(image_itk.GetOrigin())[::-1]
    image = lum_trans(image)
    image = resample(image, spacing, np.array([1, 1, 1]), order=1)[0]

    crop = SimpleCrop()

    for nodule in nodule_list:
        nod_location = np.array([np.float32(nodule[s]) for s in ["z", "y", "x"]])
        nod_location = np.ceil((nod_location - origin) / 1.)
        cropped_image, coords = crop(image[np.newaxis], nod_location)

    # New style
    ct_array, meta = load_ct.load_ct(metaimage_path)

    preprocess = preprocess_ct.PreprocessCT(clip_lower=-1200., clip_upper=600.,
                                            min_max_normalize=True, scale=255, dtype='uint8')

    ct_array, meta = preprocess(ct_array, meta)
    preprocess = preprocess_ct.PreprocessCT(spacing=1., order=1)
    ct_array, meta = preprocess(ct_array, meta)

    cropped_image_new, coords_new = crop_patches.patches_from_ct(ct_array, meta, 96, nodule_list,
                                                                 stride=4, pad_value=160)[0]

    assert np.abs(cropped_image_new - cropped_image).sum() == 0
    assert np.abs(coords_new - coords).sum() == 0
def test_metadata(metaimage_path, dicom_path):
    meta = load_ct.load_ct(dicom_path, voxel=False)
    meta = load_ct.MetaData(meta)
    zipped = zip(meta.spacing, (2.5, 0.703125, 0.703125))
    assert all([m_axis == o_axis for m_axis, o_axis in zipped])

    meta = load_ct.load_ct(metaimage_path, voxel=False)
    # the default axes order which is used is: (z, y, x)
    spacing = meta.GetSpacing()[::-1]
    meta = load_ct.MetaData(meta)
    assert meta.spacing == spacing

    try:
        load_ct.MetaData([1, 2, 3])
    except ValueError as e:
        assert 'either list[dicom.dataset.FileDataset] or SimpleITK' in str(e)
Beispiel #8
0
def calculate_volume(segment_path, centroids, ct_path=None):
    """ Calculates tumor volume in cubic mm if a dicom_path has been provided.

    Given the path to the serialized mask and a list of centroids
        (1) For each centroid, calculate the volume of the tumor.
        (2) DICOM has voxels' sizes in mm therefore the volume should be in real
        measurements (not pixels).
    Args:
        segment_path (str): a path to a mask file
        centroids (list[dict]): A list of centroids of the form::
            {'x': int,
             'y': int,
             'z': int}
        dicom_path (str): contains the path to the folder containing the dcm-files of a series.
            If None then volume will be returned in voxels.

    Returns:
        list[float]: a list of volumes in cubic mm (if a dicom_path has been provided)
            of a connected component for each centroid.
    """

    mask = np.load(segment_path)
    mask, _ = scipy.ndimage.label(mask)
    labels = [mask[centroid['x'], centroid['y'], centroid['z']] for centroid in centroids]
    volumes = np.bincount(mask.flatten())
    volumes = volumes[labels].tolist()

    if ct_path:
        meta = load_ct(ct_path, voxel=False)
        meta = MetaData(meta)
        spacing = np.prod(meta.spacing)
        volumes = [volume * spacing for volume in volumes]

    return volumes
def test_lum_trans(metaimage_path):
    ct_array, meta = load_ct.load_ct(metaimage_path)
    lumed = lum_trans(ct_array)
    functional = preprocess_ct.PreprocessCT(clip_lower=-1200., clip_upper=600.,
                                            min_max_normalize=True, scale=255, dtype='uint8')

    processed, _ = functional(ct_array, meta)
    assert np.abs(lumed - processed).sum() == 0
def test_preprocess_dicom_min_max_scale(dicom_path):
    preprocess = preprocess_ct.PreprocessCT(clip_lower=-1000,
                                            clip_upper=400,
                                            min_max_normalize=True)
    dicom_array, _ = preprocess(*load_ct.load_ct(dicom_path))
    assert isinstance(dicom_array, np.ndarray)
    assert dicom_array.max() <= 1
    assert dicom_array.min() >= 0
Beispiel #11
0
 def _ct_preprocess(self, ct_path):
     preprocess = preprocess_ct.PreprocessCT(to_hu=True,
                                             clip_lower=self.clip_lower,
                                             clip_upper=self.clip_upper,
                                             spacing=[.9, .7, .7],
                                             min_max_normalize=False)
     ct_array, meta = preprocess(*load_ct.load_ct(ct_path))
     return ct_array, meta
def test_segmentation_over_LIDC(full_dicom_path):
    """
    Function is needed for fast loading and seeing DICOM or LUNA.
    Segmentation and separation of the lungs are provided with function "improved_lung_segmentation".
    """
    preprocess = preprocess_ct.PreprocessCT(to_hu=True)
    patient, _ = preprocess(*load_ct.load_ct(full_dicom_path))
    lung, lung_left, lung_right, trachea = improved_lung_segmentation(patient)
Beispiel #13
0
def predict(dicom_path,
            centroids,
            model_path=None,
            preprocess_ct=None,
            preprocess_model_input=None):
    """ Predicts if centroids are concerning or not.

    Given path to a DICOM image and an iterator of centroids:
        (1) load the classification model from its serialized state
        (2) pre-process the dicom into whatever format the classification
            model expects
        (3) for each centroid (which represents a nodule), yield a probability
            that the nodule is concerning

    Args:
        dicom_path (str): A path to the DICOM image
        centroids (list[dict]): A list of centroids of the form::
            {'x': int,
             'y': int,
             'z': int}
        model_path (str): A path to the serialized model
        preprocess_ct (preprocess.preprocess_dicom.PreprocessDicom): A preprocess
            method which aimed at brining the input data to the desired view.
        preprocess_model_input (callable[ndarray, list[dict]]): preprocess for a model
            input.

    Returns:
        list[dict]: a list of centroids with the probability they are
        concerning of the form::
            {'x': int,
             'y': int,
             'z': int,
             'p_concerning': float}
    """
    reader = sitk.ImageSeriesReader()
    filenames = reader.GetGDCMSeriesFileNames(dicom_path)

    if not filenames:
        raise ValueError(
            "The path doesn't contain neither .mhd nor .dcm files")

    reader.SetFileNames(filenames)
    image = reader.Execute()

    if preprocess_ct:
        meta = load_ct(dicom_path)[1]
        voxel_data = preprocess_ct(image, MetaData(meta))
    else:
        voxel_data = image

    if preprocess_model_input:
        preprocessed = preprocess_model_input(voxel_data, centroids)
    else:
        preprocessed = voxel_data

    model_path = model_path or "src/algorithms/classify/assets/gtr123_model.ckpt"

    return gtr123_model.predict(preprocessed, centroids, model_path)
Beispiel #14
0
def test_resample(metaimage_path):
    ct_array, meta = load_ct.load_ct(metaimage_path)
    resampled, _ = resample(ct_array,
                            np.array(load_ct.MetaData(meta).spacing),
                            np.array([1, 1, 1]),
                            order=1)
    preprocess = preprocess_ct.PreprocessCT(spacing=True, order=1)
    processed, _ = preprocess(ct_array, meta)
    assert np.abs(resampled - processed).sum() == 0
Beispiel #15
0
def predict(ct_path,
            nodule_list,
            model_path="src/algorithms/classify/assets/gtr123_model.ckpt"):
    """

    Args:
      ct_path (str): path to a MetaImage or DICOM data.
      nodule_list: List of nodules
      model_path: Path to the torch model (Default value = "src/algorithms/classify/assets/gtr123_model.ckpt")

    Returns:
      List of nodules, and probabilities

    """
    if not nodule_list:
        return []
    casenet = CaseNet()

    casenet.load_state_dict(torch.load(model_path))
    casenet.eval()

    if torch.cuda.is_available():
        casenet = torch.nn.DataParallel(casenet).cuda()
    # else:
    # casenet = torch.nn.parallel.DistributedDataParallel(casenet)

    preprocess = preprocess_ct.PreprocessCT(clip_lower=-1200.,
                                            clip_upper=600.,
                                            spacing=1.,
                                            order=1,
                                            min_max_normalize=True,
                                            scale=255,
                                            dtype='uint8')
    ct_array, meta = preprocess(*load_ct.load_ct(ct_path))
    patches = crop_patches.patches_from_ct(ct_array,
                                           meta,
                                           config['crop_size'],
                                           nodule_list,
                                           stride=config['stride'],
                                           pad_value=config['filling_value'])
    results = []
    for nodule, (cropped_image, coords) in zip(nodule_list, patches):
        cropped_image = Variable(
            torch.from_numpy(cropped_image[np.newaxis, np.newaxis]).float())
        cropped_image.volatile = True
        coords = Variable(torch.from_numpy(coords[np.newaxis]).float())
        coords.volatile = True
        _, pred, _ = casenet(cropped_image, coords)
        results.append({
            "x": nodule["x"],
            "y": nodule["y"],
            "z": nodule["z"],
            "p_concerning": float(pred.data.cpu().numpy())
        })

    return results
def test_preprocess_dicom_min_max_scale(dicom_path):
    params = preprocess_ct.Params(clip_lower=-1000, clip_upper=400, min_max_normalize=True)
    preprocess = preprocess_ct.PreprocessCT(params)

    dicom_array, meta = load_ct.load_ct(dicom_path)
    meta = load_ct.MetaData(meta)
    dicom_array = preprocess(dicom_array, meta)
    assert isinstance(dicom_array, np.ndarray)
    assert dicom_array.max() <= 1
    assert dicom_array.min() >= 0
def test_preprocess_dicom_clips(dicom_path):
    params = preprocess_ct.Params(clip_lower=-1, clip_upper=40)
    preprocess = preprocess_ct.PreprocessCT(params)

    dicom_array, meta = load_ct.load_ct(dicom_path)
    meta = load_ct.MetaData(meta)
    dicom_array = preprocess(dicom_array, meta)
    assert isinstance(dicom_array, np.ndarray)
    assert dicom_array.max() <= 40
    assert dicom_array.min() >= -1
Beispiel #18
0
def test_generators_shift_in_flow(ct_path):
    ct_array, meta = load_ct.load_ct(ct_path)
    batch = np.expand_dims(np.expand_dims(ct_array[45:90], 0), 0)
    dg = generators.DataGenerator(shift_range=[.1, .4, .4],
                                  data_format='channels_first')
    augmented = next(dg.flow(batch, batch_size=1, seed=21))
    assert len(augmented.shape) == 5
    assert augmented.shape == batch.shape
    np.random.seed(21)
    shifted = generators.random_shift(batch[0], (.1, .4, .4))
    assert not np.abs(shifted - augmented).sum()
Beispiel #19
0
def test_patches_from_ct(ct_path):
    centroids = [[556, 101, -70], [556, 121, -20], [556, 221, -77]]
    centroids = [{
        'z': centroid[0],
        'y': centroid[1],
        'x': centroid[2]
    } for centroid in centroids]
    patches = crop_patches.patches_from_ct(*load_ct.load_ct(ct_path),
                                           patch_shape=12,
                                           centroids=centroids)
    assert isinstance(patches, list)
    assert len(patches) == 3
    assert all([patch.shape == (12, 12, 12) for patch in patches])
Beispiel #20
0
def test_generators_zoom_in_flow(ct_path):
    ct_array, meta = load_ct.load_ct(ct_path)
    batch = np.expand_dims(np.expand_dims(ct_array[45:90], 0), 0)
    dg = generators.DataGenerator(zoom_lower=[1, .8, .8],
                                  zoom_upper=[1, 1.2, 1.2],
                                  zoom_independent=True,
                                  data_format='channels_first')
    augmented = next(dg.flow(batch, batch_size=1, seed=21))
    assert len(augmented.shape) == 5
    assert augmented.shape == batch.shape
    np.random.seed(21)
    zoomed = generators.random_zoom(batch[0],
                                    zoom_lower=[1, .8, .8],
                                    zoom_upper=[1, 1.2, 1.2],
                                    independent=True)
    assert not np.abs(zoomed - augmented).sum()
Beispiel #21
0
def test_patches_from_ct(ct_path):
    ct_array, meta = load_ct.load_ct(ct_path)
    meta = load_ct.MetaData(meta)
    centroids = [[507, -21, -177], [547, -121, -220], [530, -221, -277]]
    centroids = [{
        'x': centroid[0],
        'y': centroid[1],
        'z': centroid[2]
    } for centroid in centroids]
    patches = crop_patches.patches_from_ct(ct_array,
                                           patch_shape=12,
                                           centroids=centroids,
                                           meta=meta)
    assert isinstance(patches, list)
    assert len(patches) == 3
    assert all([patch.shape == (12, 12, 12) for patch in patches])
Beispiel #22
0
def test_generators_zoom(ct_path):
    ct_array, meta = load_ct.load_ct(ct_path)
    patch = np.expand_dims(ct_array[45:90], 0)
    zoomed = scipy.ndimage.interpolation.zoom(patch[0], [1, 1.2, 1.2],
                                              order=0,
                                              mode='nearest')
    augmented = generators.random_zoom(patch, [1, 1.2, 1.2], [1, 1.2, 1.2],
                                       True)
    offsets = [(i - j) for i, j in zip(zoomed.shape, patch[0].shape)]
    offsets = [[i // 2, j - (i - (i // 2))]
               for i, j in zip(offsets, zoomed.shape)]
    assert len(augmented.shape) == 4
    assert augmented.shape == patch.shape
    assert not np.abs(zoomed[offsets[0][0]:offsets[0][1],
                             offsets[1][0]:offsets[1][1],
                             offsets[2][0]:offsets[2][1]] - augmented).sum()
Beispiel #23
0
def test_generators_rotate(ct_path):
    ct_array, meta = load_ct.load_ct(ct_path)
    patch = np.expand_dims(ct_array[45:90], 0)
    np.random.seed(1)
    theta = np.random.uniform(-45, 45)
    rotated = scipy.ndimage.rotate(patch[0],
                                   -theta,
                                   axes=(1, 2),
                                   order=0,
                                   mode='nearest',
                                   reshape=False)

    np.random.seed(1)
    augmented = generators.random_rotation(patch, (45, 0, 0))
    assert len(augmented.shape) == 4
    assert augmented.shape == patch.shape
    assert not np.abs(rotated - augmented).sum()
Beispiel #24
0
def test_generators_standardization(ct_path):
    ct_array, meta = load_ct.load_ct(ct_path)
    batch = np.expand_dims(np.expand_dims(ct_array[45:90], 0), 0)
    dg = generators.DataGenerator(featurewise_center=True,
                                  featurewise_std_normalization=True,
                                  samplewise_center=True,
                                  samplewise_std_normalization=True)
    with warnings.catch_warnings():
        warnings.filterwarnings('error')
        try:
            next(dg.flow(batch, batch_size=1))
        except Warning as w:
            assert "Fit it first by calling `.fit(numpy_data)`" in str(w)
    dg.fit(batch)
    normalized = next(dg.flow(batch, batch_size=1))
    assert len(normalized.shape) == 5
    assert normalized.shape == batch.shape
Beispiel #25
0
def test_generators_shift(ct_path):
    ct_array, meta = load_ct.load_ct(ct_path)
    patch = np.expand_dims(ct_array[45:90], 0)
    np.random.seed(1)
    translations = [
        -np.random.uniform(-rg, rg) * side
        for rg, side in zip([.1, .4, .4], patch.shape[1:])
    ]
    shifted = scipy.ndimage.interpolation.shift(patch[0],
                                                translations,
                                                order=0,
                                                mode='nearest')

    np.random.seed(1)
    augmented = generators.random_shift(patch, (.1, .4, .4))
    assert len(augmented.shape) == 4
    assert augmented.shape == patch.shape
    assert not np.abs(shifted - augmented).sum()
Beispiel #26
0
def test_generators_save_to_dir(ct_path, tmpdir):
    ct_array, meta = load_ct.load_ct(ct_path)
    batch = np.expand_dims(np.expand_dims(ct_array[45:90], 0), 0)
    dg = generators.DataGenerator()
    np.random.seed(1)
    save_prefix = 'test'
    fname = '{prefix}_{index}_{hash}.{format}'.format(
        prefix=save_prefix, index=0, hash=np.random.randint(1e4), format='npy')
    fname = str(tmpdir.mkdir("processed").join(fname))
    batch = next(
        dg.flow(batch,
                seed=1,
                batch_size=1,
                save_to_dir=os.path.dirname(fname),
                save_prefix=save_prefix))
    saved = np.load(fname)
    assert len(saved.shape) == 4
    assert saved.shape == batch[0].shape
    assert not np.abs(saved - batch[0]).sum()
Beispiel #27
0
def test_preprocess(metaimage_path):
    nodule_list = [{"z": 556, "y": 100, "x": 0}]
    image_itk = sitk.ReadImage(metaimage_path)

    image = sitk.GetArrayFromImage(image_itk)
    spacing = np.array(image_itk.GetSpacing())[::-1]
    origin = np.array(image_itk.GetOrigin())[::-1]
    image = lum_trans(image)
    image = resample(image, spacing, np.array([1, 1, 1]), order=1)[0]
    spacing = np.array([1, 1, 1])
    image = image.astype('uint8')

    crop = SimpleCrop()

    for nodule in nodule_list:
        nod_location = np.array(
            [np.float32(nodule[s]) for s in ["z", "y", "x"]])
        # N-dimensional array coordinates for the point in real world should be computed in the way below:
        nod_location = (nod_location - origin) / spacing
        cropped_image, coords = crop(image, nod_location)

    preprocess = preprocess_ct.PreprocessCT(clip_lower=-1200.,
                                            clip_upper=600.,
                                            min_max_normalize=True,
                                            scale=255,
                                            spacing=True,
                                            order=1,
                                            dtype='uint8')

    ct_array, meta = load_ct.load_ct(metaimage_path)
    ct_array, meta = preprocess(ct_array, meta)

    cropped_image_new, coords_new = crop_patches.patches_from_ct(
        ct_array, meta, 96, nodule_list, stride=4, pad_value=160)[0]

    assert np.abs(cropped_image_new - cropped_image).sum() == 0
    assert np.abs(coords_new - coords).sum() == 0
def predict(ct_path, model_path=None):
    """

    Args:
      image_itk: ITK Image in Hu units
      model_path: Path to the file containing the model state
                 (Default value = "src/algorithms/identify/assets/dsb2017_detector.ckpt")

    Returns:
      List of Nodule locations and probabilities

    """
    if not model_path:
        INDENTIFY_DIR = path.join(Config.ALGOS_DIR, 'identify')
        model_path = path.join(INDENTIFY_DIR, 'assets',
                               'dsb2017_detector.ckpt')

    ct_array, meta = load_ct.load_ct(ct_path)
    meta = load_ct.MetaImage(meta)
    spacing = np.array(meta.spacing)
    masked_image, mask = filter_lungs(ct_array)

    # masked_image = image
    net = Net()
    net.load_state_dict(torch.load(model_path)["state_dict"])

    if torch.cuda.is_available():
        net = torch.nn.DataParallel(net).cuda()

    split_comber = SplitComb(side_len=int(144),
                             margin=32,
                             max_stride=16,
                             stride=4,
                             pad_value=170)

    # We have to use small batches until the next release of PyTorch, as bigger ones will segfault for CPU
    # split_comber = SplitComb(side_len=int(32), margin=16, max_stride=16, stride=4, pad_value=170)
    # Transform image to the 0-255 range and resample to 1x1x1mm
    preprocess = preprocess_ct.PreprocessCT(clip_lower=-1200.,
                                            clip_upper=600.,
                                            spacing=1.,
                                            order=1,
                                            min_max_normalize=True,
                                            scale=255,
                                            dtype='uint8')

    ct_array, meta = preprocess(ct_array, meta)
    ct_array = ct_array[np.newaxis, ...]

    imgT, coords, nzhw = split_data(ct_array, split_comber=split_comber)
    results = []

    # Loop over the image chunks
    for img, coord in zip(imgT, coords):
        var = Variable(img[np.newaxis])
        var.volatile = True
        coord = Variable(coord[np.newaxis])
        coord.volatile = True
        resvar = net(var, coord)
        res = resvar.data.cpu().numpy()
        results.append(res)

    results = np.concatenate(results, 0)
    results = split_comber.combine(results, nzhw=nzhw)
    pbb = GetPBB()
    # First index of proposals is the propabillity. Then x, y z, and radius
    proposals, _ = pbb(results, ismask=True)

    # proposals = proposals[proposals[:,4] < 40]
    proposals = nms(proposals)
    # Filter out proposals outside the actual lung
    # prop_int = proposals[:, 1:4].astype(np.int32)
    # wrong = [imgs[0, x[0], x[1], x[2]] > 180 for x in prop_int]
    # proposals = proposals[np.logical_not(wrong)]

    # Do sigmoid to get propabillities
    proposals[:, 0] = expit(proposals[:, 0])
    # Remove really weak proposals?
    # proposals = proposals[proposals[:,0] > 0.5]

    # Rescale back to image space coordinates
    proposals[:, 1:4] /= spacing[np.newaxis]
    return [{
        "x": int(p[3]),
        "y": int(p[2]),
        "z": int(p[1]),
        "p_nodule": float(p[0])
    } for p in proposals]
Beispiel #29
0
def predict(ct_path, nodule_list, model_path=None):
    """

    Args:
      ct_path (str): path to a MetaImage or DICOM data.
      nodule_list: List of nodules
      model_path: Path to the torch model (Default value = "src/algorithms/classify/assets/gtr123_model.ckpt")

    Returns:
      List of nodules, and probabilities

    """
    if not model_path:
        CLASSIFY_DIR = os.path.join(Config.ALGOS_DIR, 'classify')
        model_path = os.path.join(CLASSIFY_DIR, 'assets', 'gtr123_model.ckpt')

    if not nodule_list:
        return []

    casenet = CaseNet()
    casenet.load_state_dict(torch.load(model_path))
    casenet.eval()

    if torch.cuda.is_available():
        casenet = torch.nn.DataParallel(casenet).cuda()

    preprocess = PreprocessCT(clip_lower=-1200.,
                              clip_upper=600.,
                              spacing=True,
                              order=1,
                              min_max_normalize=True,
                              scale=255,
                              dtype='uint8')

    # convert the image to voxels(apply the real spacing between pixels)
    ct_array, meta = preprocess(*load_ct(ct_path))

    patches = patches_from_ct(ct_array,
                              meta,
                              config['crop_size'],
                              nodule_list,
                              stride=config['stride'],
                              pad_value=config['filling_value'])

    results = []

    for nodule, (cropped_image, coords) in zip(nodule_list, patches):
        cropped_image = Variable(
            torch.from_numpy(cropped_image[np.newaxis, np.newaxis]).float())
        cropped_image.volatile = True
        coords = Variable(torch.from_numpy(coords[np.newaxis]).float())
        coords.volatile = True
        _, pred, _ = casenet(cropped_image, coords)
        results.append({
            'x': nodule['x'],
            'y': nodule['y'],
            'z': nodule['z'],
            'p_concerning': float(pred.data.cpu().numpy()),
        })

    return results
def test_preprocess_dicom_clips(dicom_path):
    preprocess = preprocess_ct.PreprocessCT(clip_lower=-1, clip_upper=40)
    dicom_array, _ = preprocess(*load_ct.load_ct(dicom_path))
    assert isinstance(dicom_array, np.ndarray)
    assert dicom_array.max() <= 40
    assert dicom_array.min() >= -1