def test_preprocess(metaimage_path):
    nodule_list = [{"z": 556, "y": 100, "x": 0}]
    image_itk = sitk.ReadImage(metaimage_path)

    image = sitk.GetArrayFromImage(image_itk)
    spacing = np.array(image_itk.GetSpacing())[::-1]
    origin = np.array(image_itk.GetOrigin())[::-1]
    image = lum_trans(image)
    image = resample(image, spacing, np.array([1, 1, 1]), order=1)[0]

    crop = SimpleCrop()

    for nodule in nodule_list:
        nod_location = np.array([np.float32(nodule[s]) for s in ["z", "y", "x"]])
        nod_location = np.ceil((nod_location - origin) / 1.)
        cropped_image, coords = crop(image[np.newaxis], nod_location)

    # New style
    ct_array, meta = load_ct.load_ct(metaimage_path)

    preprocess = preprocess_ct.PreprocessCT(clip_lower=-1200., clip_upper=600.,
                                            min_max_normalize=True, scale=255, dtype='uint8')

    ct_array, meta = preprocess(ct_array, meta)
    preprocess = preprocess_ct.PreprocessCT(spacing=1., order=1)
    ct_array, meta = preprocess(ct_array, meta)

    cropped_image_new, coords_new = crop_patches.patches_from_ct(ct_array, meta, 96, nodule_list,
                                                                 stride=4, pad_value=160)[0]

    assert np.abs(cropped_image_new - cropped_image).sum() == 0
    assert np.abs(coords_new - coords).sum() == 0
예제 #2
0
def predict(ct_path,
            nodule_list,
            model_path="src/algorithms/classify/assets/gtr123_model.ckpt"):
    """

    Args:
      ct_path (str): path to a MetaImage or DICOM data.
      nodule_list: List of nodules
      model_path: Path to the torch model (Default value = "src/algorithms/classify/assets/gtr123_model.ckpt")

    Returns:
      List of nodules, and probabilities

    """
    if not nodule_list:
        return []
    casenet = CaseNet()

    casenet.load_state_dict(torch.load(model_path))
    casenet.eval()

    if torch.cuda.is_available():
        casenet = torch.nn.DataParallel(casenet).cuda()
    # else:
    # casenet = torch.nn.parallel.DistributedDataParallel(casenet)

    preprocess = preprocess_ct.PreprocessCT(clip_lower=-1200.,
                                            clip_upper=600.,
                                            spacing=1.,
                                            order=1,
                                            min_max_normalize=True,
                                            scale=255,
                                            dtype='uint8')
    ct_array, meta = preprocess(*load_ct.load_ct(ct_path))
    patches = crop_patches.patches_from_ct(ct_array,
                                           meta,
                                           config['crop_size'],
                                           nodule_list,
                                           stride=config['stride'],
                                           pad_value=config['filling_value'])
    results = []
    for nodule, (cropped_image, coords) in zip(nodule_list, patches):
        cropped_image = Variable(
            torch.from_numpy(cropped_image[np.newaxis, np.newaxis]).float())
        cropped_image.volatile = True
        coords = Variable(torch.from_numpy(coords[np.newaxis]).float())
        coords.volatile = True
        _, pred, _ = casenet(cropped_image, coords)
        results.append({
            "x": nodule["x"],
            "y": nodule["y"],
            "z": nodule["z"],
            "p_concerning": float(pred.data.cpu().numpy())
        })

    return results
예제 #3
0
def test_patches_from_ct(ct_path):
    centroids = [[556, 101, -70], [556, 121, -20], [556, 221, -77]]
    centroids = [{
        'z': centroid[0],
        'y': centroid[1],
        'x': centroid[2]
    } for centroid in centroids]
    patches = crop_patches.patches_from_ct(*load_ct.load_ct(ct_path),
                                           patch_shape=12,
                                           centroids=centroids)
    assert isinstance(patches, list)
    assert len(patches) == 3
    assert all([patch.shape == (12, 12, 12) for patch in patches])
예제 #4
0
def test_patches_from_ct(ct_path):
    ct_array, meta = load_ct.load_ct(ct_path)
    meta = load_ct.MetaData(meta)
    centroids = [[507, -21, -177], [547, -121, -220], [530, -221, -277]]
    centroids = [{
        'x': centroid[0],
        'y': centroid[1],
        'z': centroid[2]
    } for centroid in centroids]
    patches = crop_patches.patches_from_ct(ct_array,
                                           patch_shape=12,
                                           centroids=centroids,
                                           meta=meta)
    assert isinstance(patches, list)
    assert len(patches) == 3
    assert all([patch.shape == (12, 12, 12) for patch in patches])
예제 #5
0
def test_preprocess(metaimage_path):
    nodule_list = [{"z": 556, "y": 100, "x": 0}]
    image_itk = sitk.ReadImage(metaimage_path)

    image = sitk.GetArrayFromImage(image_itk)
    spacing = np.array(image_itk.GetSpacing())[::-1]
    origin = np.array(image_itk.GetOrigin())[::-1]
    image = lum_trans(image)
    image = resample(image, spacing, np.array([1, 1, 1]), order=1)[0]
    spacing = np.array([1, 1, 1])
    image = image.astype('uint8')

    crop = SimpleCrop()

    for nodule in nodule_list:
        nod_location = np.array(
            [np.float32(nodule[s]) for s in ["z", "y", "x"]])
        # N-dimensional array coordinates for the point in real world should be computed in the way below:
        nod_location = (nod_location - origin) / spacing
        cropped_image, coords = crop(image, nod_location)

    preprocess = preprocess_ct.PreprocessCT(clip_lower=-1200.,
                                            clip_upper=600.,
                                            min_max_normalize=True,
                                            scale=255,
                                            spacing=True,
                                            order=1,
                                            dtype='uint8')

    ct_array, meta = load_ct.load_ct(metaimage_path)
    ct_array, meta = preprocess(ct_array, meta)

    cropped_image_new, coords_new = crop_patches.patches_from_ct(
        ct_array, meta, 96, nodule_list, stride=4, pad_value=160)[0]

    assert np.abs(cropped_image_new - cropped_image).sum() == 0
    assert np.abs(coords_new - coords).sum() == 0
예제 #6
0
    def feed(self,
             annotations,
             sampling_pure=1.,
             sampling_cancerous=1.,
             train_mode=True):  # noqa: C901
        """Train the model through the annotated CT scans

                Args:
                    annotations (list[dict]): A list of centroids of the form::
                         {'file_path': str,
                          'centroids': [{'x': int,
                                         'y': int,
                                         'z': int,
                                         'cancerous': bool}, ..]}.
                    sampling_pure (float): coefficient of .
                    sampling_cancerous (float): .
                    train_mode (bool): Whether to use data augmentation and shuffling.

                Yields:
                    list[np.ndarray]: list of patches.
        """
        while True:
            sampled = annotations
            if train_mode:
                sampled = self.on_epoch_start(annotations, sampling_pure,
                                              sampling_cancerous)
            findings_amount = sum(
                [len(patient['centroids']) for patient in sampled])
            iterations = int(np.ceil(findings_amount / self.pull_ct.maxlen))
            for counter in range(iterations):
                batch = sampled[counter * self.pull_ct.maxlen:(counter + 1) *
                                self.pull_ct.maxlen]
                for patient in batch:
                    ct_array, meta = self._ct_preprocess(patient['file_path'])
                    self.pull_ct.append((patient['centroids'], ct_array, meta))
                while self.pull_ct:
                    centroids, ct_array, meta = self.pull_ct.pop()
                    patches = crop_patches.patches_from_ct(ct_array,
                                                           centroids=centroids,
                                                           patch_shape=(42, 42,
                                                                        42),
                                                           meta=meta)

                    patches = [(centroid['cancerous'] if train_mode else -1,
                                patch)
                               for centroid, patch in zip(centroids, patches)]
                    self.pull_patches.extend(patches)

                if train_mode:
                    np.random.shuffle(self.pull_patches)
                allowed_iterations = len(self.pull_patches) // self.batch_size

                for _ in range(allowed_iterations):
                    batch = list()
                    for _ in range(self.batch_size):
                        batch.append(self.pull_patches.pop())

                    labels = [label for label, _ in batch]
                    batch = np.expand_dims(
                        np.asarray([patch for _, patch in batch]),
                        self.channel_axis)

                    generator = self.test_data_generator
                    if train_mode:
                        generator = self.train_data_generator

                    batch, labels = self._batch_process(
                        generator, batch, labels)

                    if train_mode:
                        yield batch, labels
                    else:
                        yield batch
예제 #7
0
def predict(ct_path, nodule_list, model_path=None):
    """

    Args:
      ct_path (str): path to a MetaImage or DICOM data.
      nodule_list: List of nodules
      model_path: Path to the torch model (Default value = "src/algorithms/classify/assets/gtr123_model.ckpt")

    Returns:
      List of nodules, and probabilities

    """
    if not model_path:
        CLASSIFY_DIR = os.path.join(Config.ALGOS_DIR, 'classify')
        model_path = os.path.join(CLASSIFY_DIR, 'assets', 'gtr123_model.ckpt')

    if not nodule_list:
        return []

    casenet = CaseNet()
    casenet.load_state_dict(torch.load(model_path))
    casenet.eval()

    if torch.cuda.is_available():
        casenet = torch.nn.DataParallel(casenet).cuda()

    preprocess = PreprocessCT(clip_lower=-1200.,
                              clip_upper=600.,
                              spacing=True,
                              order=1,
                              min_max_normalize=True,
                              scale=255,
                              dtype='uint8')

    # convert the image to voxels(apply the real spacing between pixels)
    ct_array, meta = preprocess(*load_ct(ct_path))

    patches = patches_from_ct(ct_array,
                              meta,
                              config['crop_size'],
                              nodule_list,
                              stride=config['stride'],
                              pad_value=config['filling_value'])

    results = []

    for nodule, (cropped_image, coords) in zip(nodule_list, patches):
        cropped_image = Variable(
            torch.from_numpy(cropped_image[np.newaxis, np.newaxis]).float())
        cropped_image.volatile = True
        coords = Variable(torch.from_numpy(coords[np.newaxis]).float())
        coords.volatile = True
        _, pred, _ = casenet(cropped_image, coords)
        results.append({
            'x': nodule['x'],
            'y': nodule['y'],
            'z': nodule['z'],
            'p_concerning': float(pred.data.cpu().numpy()),
        })

    return results