Ejemplo n.º 1
0
    def process(self, sample):
        augment_p = sample_augmentation_parameters(self.augmentation_params)

        for tag in self.tags:
            pixelspacingtag = tag.split(':')[0] + ":pixelspacing"
            assert pixelspacingtag in sample[
                INPUT], "tag %s not found" % pixelspacingtag
            spacing = sample[INPUT][pixelspacingtag]

            volume = sample[INPUT][tag]
            new_vol = np.zeros((self.max_rois, ) + self.output_shape,
                               volume.dtype)

            patient_id = sample[INPUT][tag.split(':')[0] + ":patient_id"]
            rois = self.rois[patient_id]
            np.random.shuffle(rois)

            for i in range(min(len(rois), self.max_rois)):
                # mm to input space
                center_to_shift = -rois[i] / np.asarray(spacing, np.float)
                # print rois[i], center_to_shift

                new_vol[i] = augment_3d(volume=volume,
                                        pixel_spacing=spacing,
                                        output_shape=self.output_shape,
                                        norm_patch_shape=self.norm_patch_shape,
                                        augment_p=augment_p,
                                        center_to_shift=center_to_shift)

            sample[INPUT][tag] = new_vol  # shape: (max_rois, X, Y, Z)
Ejemplo n.º 2
0
    def process(self, sample):
        augment_p = sample_augmentation_parameters(self.augmentation_params)

        for tag in self.tags:

            pixelspacingtag = tag.split(':')[0] + ":pixelspacing"
            labelstag = tag.split(':')[0] + ":labels"
            origintag = tag.split(':')[0] + ":origin"

            assert pixelspacingtag in sample[
                INPUT], "tag %s not found" % pixelspacingtag
            assert labelstag in sample[INPUT], "tag %s not found" % labelstag
            assert origintag in sample[INPUT], "tag %s not found" % origintag

            spacing = sample[INPUT][pixelspacingtag]
            labels = sample[INPUT][labelstag]
            origin = sample[INPUT][origintag]

            label = random.choice(labels)

            from application.luna import LunaDataLoader
            labelloc = LunaDataLoader.world_to_voxel_coordinates(
                label[:3], origin=origin, spacing=spacing)

            if tag in sample[INPUT]:
                volume = sample[INPUT][tag]

                sample[INPUT][tag] = augment_3d(
                    volume=volume,
                    pixel_spacing=spacing,
                    output_shape=self.output_shape,
                    norm_patch_shape=self.norm_patch_shape,
                    augment_p=augment_p,
                    center_to_shift=-labelloc)
            elif tag in sample[OUTPUT]:
                volume = sample[OUTPUT][tag]

                sample[OUTPUT][tag] = augment_3d(
                    volume=volume,
                    pixel_spacing=spacing,
                    output_shape=self.output_shape,
                    norm_patch_shape=self.norm_patch_shape,
                    augment_p=augment_p,
                    center_to_shift=-labelloc,
                    cval=0.0)
            else:
                pass
Ejemplo n.º 3
0
    def process(self, sample):
        augment_p = sample_augmentation_parameters(self.augmentation_params)

        tag = self.tags[0]
        basetag = tag.split(':')[0]

        pixelspacingtag = basetag + ":pixelspacing"
        patient_idtag = basetag + ":patient_id"
        origintag = basetag + ":origin"

        spacing = sample[INPUT][pixelspacingtag]
        patient_id = sample[INPUT][patient_idtag]
        candidates = self.candidates[patient_id]
        origin = sample[INPUT][origintag]

        if len(candidates) == 1:
            candidate = random.choice(candidates[0])
        elif len(candidates) == 2:
            percentage_chance = 0.5
            if random.random() < percentage_chance:
                candidate = random.choice(candidates[1])
            else:
                candidate = random.choice(candidates[0])
        else:
            raise Exception("candidates is empty")

        #print 'candidate', candidate

        candidateloc = LunaDataLoader.world_to_voxel_coordinates(
            candidate[:3], origin=origin, spacing=spacing)

        volume = sample[INPUT][basetag + ":3d"]

        sample[INPUT][basetag + ":3d"] = augment_3d(
            volume=volume,
            pixel_spacing=spacing,
            output_shape=self.output_shape,
            norm_patch_shape=self.norm_patch_shape,
            augment_p=augment_p,
            center_to_shift=-candidateloc)
        # add candidate label to output tags

        sample[OUTPUT][basetag + ":target"] = np.int32(candidate[3])
Ejemplo n.º 4
0
    def process(self, sample):
        augment_p = sample_augmentation_parameters(self.augmentation_params)

        for tag in self.tags:
            pixelspacingtag = tag.split(':')[0] + ":pixelspacing"
            assert pixelspacingtag in sample[INPUT], "tag %s not found" % pixelspacingtag
            spacing = sample[INPUT][pixelspacingtag]

            volume = sample[INPUT][tag]
            new_vol = np.zeros((self.max_rois,)+self.output_shape, volume.dtype)

            patient_id = sample[INPUT][tag.split(':')[0] + ":patient_id"]
            d_rois = self.rois[patient_id]
            print 'loaded rois from patient', patient_id
            
            rois = d_rois['rois']
            in_mask = d_rois['in_mask']
            fpr_p = d_rois['fpr_p']

            fpr_p[in_mask==0] = 0.

            rank_rois = len(fpr_p)-rankdata(fpr_p).astype(int)
            top = rois[rank_rois<self.max_rois]
            np.random.shuffle(top)

            for idx, roi in enumerate(top):
                # mm to input space
                center_to_shift = -rois[idx]/np.asarray(spacing, np.float)
                # print rois[i], center_to_shift

                new_vol[idx] = augment_3d(
                    volume=volume,
                    pixel_spacing=spacing,
                    output_shape=self.output_shape,
                    norm_patch_shape=self.norm_patch_shape,
                    augment_p=augment_p,
                    center_to_shift=center_to_shift
                )

            sample[INPUT][tag] = new_vol # shape: (max_rois, X, Y, Z)