コード例 #1
0
    def apply_transform(self, subject: tio.Subject) -> tio.Subject:
        if self.full_dwi_image_name not in subject:
            return subject

        full_dwi_image = subject[self.full_dwi_image_name]
        full_dwi = full_dwi_image.data
        grad = full_dwi_image[self.bvec_name]

        bvals = grad[:, 3]
        bvecs = grad[:, :3]
        mask = (bvals > self.bval_range[0]) & (bvals < self.bval_range[1])

        bvecs = bvecs[mask]
        full_dwi = full_dwi[mask]

        rand_bvec = bvecs[np.random.randint(bvecs.shape[0])]
        dist = torch.sum((bvecs - rand_bvec) ** 2, dim=1)
        closest_indices = np.argsort(dist)[: self.subset_size]

        number_of_selections = np.random.randint(low=1, high=self.subset_size)
        ids = torch.randperm(closest_indices.shape[0])[:number_of_selections]
        selected_indices = closest_indices[ids]
        mean_dwi = torch.mean(full_dwi[selected_indices], dim=0)

        if self.mean_dwi_image_name in subject:
            mean_dwi_image = subject[self.mean_dwi_image_name]
        else:
            mean_dwi_image = copy.deepcopy(full_dwi_image)
            subject.add_image(mean_dwi_image, self.mean_dwi_image_name)

        mean_dwi_image.set_data(mean_dwi.unsqueeze(0))

        return subject
コード例 #2
0
    def apply_transform(self, subject: tio.Subject) -> tio.Subject:
        if self.full_dwi_image_name not in subject:
            return subject

        full_dwi_image = subject[self.full_dwi_image_name]
        full_dwi = full_dwi_image.data.numpy()
        grad = full_dwi_image[self.bvec_name].numpy()

        bvals = grad[:, 3]
        bvecs = grad[:, :3]
        mask = (bvals > self.bval_range[0]) & (bvals < self.bval_range[1])

        bvecs = bvecs[mask]
        full_dwi = full_dwi[mask]

        num_dwis = self.get_num_dwis()
        num_directions = self.get_num_directions()
        directionality = self.get_directionality()

        random_directions = np.random.randn(3, num_directions)
        random_directions = random_directions / np.linalg.norm(random_directions, axis=0, keepdims=True)

        sample_probabilities = np.max(np.abs(bvecs @ random_directions) ** directionality, axis=1)
        sample_probabilities = sample_probabilities / sample_probabilities.sum()

        indices = np.arange(full_dwi.shape[0])
        indices = np.random.choice(indices, size=num_dwis, p=sample_probabilities)

        mean_dwi = np.mean(full_dwi[indices], axis=0, keepdims=True)

        if self.mean_dwi_image_name in subject:
            mean_dwi_image = subject[self.mean_dwi_image_name]
        else:
            mean_dwi_image = copy.deepcopy(full_dwi_image)
            subject.add_image(mean_dwi_image, self.mean_dwi_image_name)

        mean_dwi_image.set_data(mean_dwi)

        return subject
コード例 #3
0
    def __call__(self, subject: tio.Subject):
        # Sampler random parameters
        resection_params = self.get_params(
            self.volumes,
            self.volumes_range,
            self.sigmas_range,
            self.radii_ratio_range,
            self.angles_range,
            self.wm_lesion_p,
            self.clot_p,
        )
        # Convert images to SimpleITK
        with timer('Convert to SITK', self.verbose):
            t1_pre = subject[self.image_name].as_sitk()
            hemisphere = resection_params['hemisphere']
            gm_name = f'resection_gray_matter_{hemisphere}'
            gray_matter_image = subject[gm_name]
            gray_matter_mask = gray_matter_image.as_sitk()
            resectable_name = f'resection_resectable_{hemisphere}'
            resectable_tissue_image = subject[resectable_name]
            resectable_tissue_mask = resectable_tissue_image.as_sitk()

            add_wm = resection_params['add_wm_lesion']
            add_clot = resection_params['add_clot']
            use_csf_image = self.texture == 'csf' or add_wm or add_clot
            if use_csf_image:
                noise_image = subject['resection_noise'].as_sitk()
            else:
                noise_image = None

        # Simulate resection
        with timer('Resection', self.verbose):
            results = resect(
                t1_pre,
                gray_matter_mask,
                resectable_tissue_mask,
                resection_params['sigmas'],
                resection_params['radii'],
                noise_image=noise_image,
                shape=self.shape,
                texture=self.texture,
                angles=resection_params['angles'],
                noise_offset=resection_params['noise_offset'],
                sphere_poly_data=self.sphere_poly_data,
                wm_lesion=add_wm,
                clot=add_clot,
                simplex_path=self.simplex_path,
                center_ras=self.center_ras,
                verbose=self.verbose,
            )
        resected_brain, resection_mask, resection_center, clot_center = results

        # Store centers for visualization purposes
        resection_params['resection_center'] = resection_center
        resection_params['clot_center'] = clot_center

        # Convert from SITK
        with timer('Convert from SITK', self.verbose):
            resected_brain_array = self.sitk_to_array(resected_brain)
            resected_mask_array = self.sitk_to_array(resection_mask)
            image_resected = self.add_channels_axis(resected_brain_array)
            resection_label = self.add_channels_axis(resected_mask_array)
        assert image_resected.ndim == 4
        assert resection_label.ndim == 4

        # Update subject
        if self.delete_resection_keys:
            subject.remove_image('resection_gray_matter_left')
            subject.remove_image('resection_gray_matter_right')
            subject.remove_image('resection_resectable_left')
            subject.remove_image('resection_resectable_right')
            if use_csf_image:
                subject.remove_image('resection_noise')

        # Add resected image and label to subject
        if self.add_params:
            subject['random_resection'] = resection_params
        if self.keep_original:
            subject['image_original'] = copy.deepcopy(subject[self.image_name])
        subject[self.image_name].data = torch.from_numpy(image_resected)
        label = tio.LabelMap(
            tensor=resection_label,
            affine=subject[self.image_name].affine,
        )
        subject.add_image(label, 'label')

        if self.add_resected_structures:
            subject['resected_structures'] = self.get_resected_structures(
                subject, resection_mask)

        return subject