Esempio n. 1
0
def sample_intensity_stats_from_single_dataset(image_dir, labels_dir, labels_list, classes_list=None, max_channel=3,
                                               rescale=True):
    """This function aims at estimating the intensity distributions of K different structure types from a set of images.
    The distribution of each structure type is modelled as a Gaussian, parametrised by a mean and a standard deviation.
    Because the intensity distribution of structures can vary accross images, we additionally use Gausian priors for the
    parameters of each Gaussian distribution. Therefore, the intensity distribution of each structure type is described
    by 4 parameters: a mean/std for the mean intensity, and a mean/std for the std deviation.
    This function uses a set of images along with corresponding segmentations to estimate the 4*K parameters.
    Structures can share the same statistics by being regrouped into classes of similar structure types.
    Images can be multi-modal (n_channels), in which case different statistics are estimated for each modality.
    :param image_dir: path of directory with images to estimate the intensity distribution
    :param labels_dir: path of directory with segmentation of input images.
    They are matched with images by sorting order.
    :param labels_list: list of labels for which to evaluate mean and std intensity.
    Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
    :param classes_list: (optional) enables to regroup structures into classes of similar intensity statistics.
    Intenstites associated to regrouped labels will thus contribute to the same Gaussian during statistics estimation.
    Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
    It should have the same length as labels_list, and contain values between 0 and K-1, where K is the total number of
    classes. Default is all labels have different classes (K=len(labels_list)).
    :param max_channel: (optional) maximum number of channels to consider if the data is multispectral. Default is 3.
    :param rescale: (optional) whether to rescale images between 0 and 255 before intensity estimation
    :return: 2 numpy arrays of size (2*n_channels, K), one with the evaluated means/std for the mean
    intensity, and one for the mean/std for the standard deviation.
    Each block of two rows correspond to a different modality (channel). For each block of two rows, the first row
    represents the mean, and the second represents the std.
    """

    # list files
    path_images = utils.list_images_in_folder(image_dir)
    path_labels = utils.list_images_in_folder(labels_dir)
    assert len(path_images) == len(path_labels), 'image and labels folders do not have the same number of files'

    # reformat list labels and classes
    labels_list = np.array(utils.reformat_to_list(labels_list, load_as_numpy=True, dtype='int'))
    if classes_list is not None:
        classes_list = np.array(utils.reformat_to_list(classes_list, load_as_numpy=True, dtype='int'))
    else:
        classes_list = np.arange(labels_list.shape[0])
    assert len(classes_list) == len(labels_list), 'labels and classes lists should have the same length'

    # get unique classes
    unique_classes, unique_indices = np.unique(classes_list, return_index=True)
    n_classes = len(unique_classes)
    if not np.array_equal(unique_classes, np.arange(n_classes)):
        raise ValueError('classes_list should only contain values between 0 and K-1, '
                         'where K is the total number of classes. Here K = %d' % n_classes)

    # initialise result arrays
    n_dims, n_channels = utils.get_dims(utils.load_volume(path_images[0]).shape, max_channels=max_channel)
    means = np.zeros((len(path_images), n_classes, n_channels))
    stds = np.zeros((len(path_images), n_classes, n_channels))

    # loop over images
    loop_info = utils.LoopInfo(len(path_images), 10, 'estimating', print_time=True)
    for idx, (path_im, path_la) in enumerate(zip(path_images, path_labels)):
        loop_info.update(idx)

        # load image and label map
        image = utils.load_volume(path_im)
        la = utils.load_volume(path_la)
        if n_channels == 1:
            image = utils.add_axis(image, -1)

        # loop over channels
        for channel in range(n_channels):
            im = image[..., channel]
            if rescale:
                im = edit_volumes.rescale_volume(im)
            stats = sample_intensity_stats_from_image(im, la, labels_list, classes_list=classes_list)
            means[idx, :, channel] = stats[0, :]
            stds[idx, :, channel] = stats[1, :]

    # compute prior parameters for mean/std
    mean_means = np.mean(means, axis=0)
    std_means = np.std(means, axis=0)
    mean_stds = np.mean(stds, axis=0)
    std_stds = np.std(stds, axis=0)

    # regroup prior parameters in two different arrays: one for the mean and one for the std
    prior_means = np.zeros((2 * n_channels, n_classes))
    prior_stds = np.zeros((2 * n_channels, n_classes))
    for channel in range(n_channels):
        prior_means[2 * channel, :] = mean_means[:, channel]
        prior_means[2 * channel + 1, :] = std_means[:, channel]
        prior_stds[2 * channel, :] = mean_stds[:, channel]
        prior_stds[2 * channel + 1, :] = std_stds[:, channel]

    return prior_means, prior_stds
Esempio n. 2
0
def preprocess_adni_hippo(path_t1,
                          path_t2,
                          path_aseg,
                          result_dir,
                          target_res,
                          padding_margin=85,
                          remove=False,
                          path_freesurfer='/usr/local/freesurfer/',
                          verbose=True,
                          recompute=True):
    """This function builds a T1+T2 multimodal image from the ADNI dataset.
    It first rescales intensities of each channel between 0 and 255.
    It then resamples the T2 image (which are 0.4*0.4*2.0 resolution) to target resolution.
    The obtained T2 is then padded in all directions by the padding_margin param (typically large 85).
    The T1 and aseg are then resampled like the T2 using mri_convert.
    Now that the T1, T2 and asegs are aligned and at the same resolution, we crop them around the right and left hippo.
    Finally, the T1 and T2 are concatenated into one single multimodal image.
    :param path_t1: path input T1 (typically at 1mm isotropic)
    :param path_t2: path input T2 (typically cropped around the hippo in sagittal axis, 0.4x0.4x2.0)
    :param path_aseg: path input segmentation (typically at 1mm isotropic)
    :param result_dir: path of directory where prepared images and labels will be writen.
    :param target_res: resolution at which to resample the label maps, and the images.
    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
    :param padding_margin: (optional) margin to add around hippocampi when cropping
    :param remove: (optional) whether to delete temporary files. Default is True.
    :param path_freesurfer: (optional) path of FreeSurfer home, to use mri_convert
    :param verbose: (optional) whether to print out mri_convert output when resampling images
    :param recompute: (optional) whether to recompute result files even if they already exists
    """

    # create results dir
    if not os.path.isdir(result_dir):
        os.mkdir(result_dir)

    path_test_im_right = os.path.join(result_dir, 'hippo_right.nii.gz')
    path_test_aseg_right = os.path.join(result_dir, 'hippo_right_aseg.nii.gz')
    path_test_im_left = os.path.join(result_dir, 'hippo_left.nii.gz')
    path_test_aseg_left = os.path.join(result_dir, 'hippo_left_aseg.nii.gz')
    if (not os.path.isfile(path_test_im_right)) | (not os.path.isfile(path_test_aseg_right)) | \
       (not os.path.isfile(path_test_im_left)) | (not os.path.isfile(path_test_aseg_left)) | recompute:

        # set up FreeSurfer
        os.environ['FREESURFER_HOME'] = path_freesurfer
        os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh'))
        mri_convert = os.path.join(path_freesurfer, 'bin/mri_convert.bin')

        # rescale T1
        path_t1_rescaled = os.path.join(result_dir, 't1_rescaled.nii.gz')
        if (not os.path.isfile(path_t1_rescaled)) | recompute:
            im, aff, h = utils.load_volume(path_t1, im_only=False)
            im = edit_volumes.rescale_volume(im)
            utils.save_volume(im, aff, h, path_t1_rescaled)
        # rescale T2
        path_t2_rescaled = os.path.join(result_dir, 't2_rescaled.nii.gz')
        if (not os.path.isfile(path_t2_rescaled)) | recompute:
            im, aff, h = utils.load_volume(path_t2, im_only=False)
            im = edit_volumes.rescale_volume(im)
            utils.save_volume(im, aff, h, path_t2_rescaled)

        # resample T2 to target res
        path_t2_resampled = os.path.join(result_dir, 't2_rescaled_resampled.nii.gz')
        if (not os.path.isfile(path_t2_resampled)) | recompute:
            str_res = ' '.join([str(r) for r in utils.reformat_to_list(target_res, length=3)])
            cmd = mri_convert + ' ' + path_t2_rescaled + ' ' + path_t2_resampled + ' --voxsize ' + str_res
            cmd += ' -odt float'
            if not verbose:
                cmd += ' >/dev/null 2>&1'
            _ = os.system(cmd)

        # pad T2
        path_t2_padded = os.path.join(result_dir, 't2_rescaled_resampled_padded.nii.gz')
        if (not os.path.isfile(path_t2_padded)) | recompute:
            t2, aff, h = utils.load_volume(path_t2_resampled, im_only=False)
            t2_padded = np.pad(t2, padding_margin, 'constant')
            aff[:3, -1] = aff[:3, -1] - (aff[:3, :3] @ (padding_margin * np.ones((3, 1)))).T
            utils.save_volume(t2_padded, aff, h, path_t2_padded)

        # resample T1 and aseg accordingly
        path_t1_resampled = os.path.join(result_dir, 't1_rescaled_resampled.nii.gz')
        if (not os.path.isfile(path_t1_resampled)) | recompute:
            cmd = mri_convert + ' ' + path_t1_rescaled + ' ' + path_t1_resampled + ' -rl ' + path_t2_padded
            cmd += ' -odt float'
            if not verbose:
                cmd += ' >/dev/null 2>&1'
            _ = os.system(cmd)
        path_aseg_resampled = os.path.join(result_dir, 'aseg_resampled.nii.gz')
        if (not os.path.isfile(path_aseg_resampled)) | recompute:
            cmd = mri_convert + ' ' + path_aseg + ' ' + path_aseg_resampled + ' -rl ' + path_t2_padded
            cmd += ' -rt nearest -odt float'
            if not verbose:
                cmd += ' >/dev/null 2>&1'
            _ = os.system(cmd)

        # crop images and concatenate T1 and T2
        for lab, side in zip([17, 53], ['left', 'right']):
            path_test_image = os.path.join(result_dir, 'hippo_{}.nii.gz'.format(side))
            path_test_aseg = os.path.join(result_dir, 'hippo_{}_aseg.nii.gz'.format(side))
            if (not os.path.isfile(path_test_image)) | (not os.path.isfile(path_test_aseg)) | recompute:
                aseg, aff, h = utils.load_volume(path_aseg_resampled, im_only=False)
                tmp_aseg, cropping, tmp_aff = edit_volumes.crop_volume_around_region(aseg,
                                                                                     margin=30,
                                                                                     masking_labels=lab,
                                                                                     aff=aff)
                if side == 'right':
                    tmp_aseg = edit_volumes.flip_volume(tmp_aseg, direction='rl', aff=tmp_aff)
                utils.save_volume(tmp_aseg, tmp_aff, h, path_test_aseg)
                if (not os.path.isfile(path_test_image)) | recompute:
                    t1 = utils.load_volume(path_t1_resampled)
                    t1 = edit_volumes.crop_volume_with_idx(t1, crop_idx=cropping)
                    t1 = edit_volumes.mask_volume(t1, tmp_aseg, dilate=6, erode=5)
                    t2 = utils.load_volume(path_t2_padded)
                    t2 = edit_volumes.crop_volume_with_idx(t2, crop_idx=cropping)
                    t2 = edit_volumes.mask_volume(t2, tmp_aseg, dilate=6, erode=5)
                    if side == 'right':
                        t1 = edit_volumes.flip_volume(t1, direction='rl', aff=tmp_aff)
                        t2 = edit_volumes.flip_volume(t2, direction='rl', aff=tmp_aff)
                    test_image = np.stack([t1, t2], axis=-1)
                    utils.save_volume(test_image, tmp_aff, h, path_test_image)

        # remove unnecessary files
        if remove:
            list_files_to_remove = [path_t1_rescaled,
                                    path_t2_rescaled,
                                    path_t2_resampled,
                                    path_t2_padded,
                                    path_t1_resampled,
                                    path_aseg_resampled]
            for path in list_files_to_remove:
                os.remove(path)
Esempio n. 3
0
def preprocess_image(im_path,
                     n_levels,
                     target_res,
                     crop=None,
                     padding=None,
                     flip=False,
                     path_resample=None):

    # read image and corresponding info
    im, _, aff, n_dims, n_channels, header, im_res = utils.get_volume_info(
        im_path, True)

    # resample image if necessary
    if target_res is not None:
        target_res = np.squeeze(
            utils.reformat_to_n_channels_array(target_res, n_dims))
        if np.any((im_res > target_res + 0.05) | (im_res < target_res - 0.05)):
            im_res = target_res
            im, aff = edit_volumes.resample_volume(im, aff, im_res)
            if path_resample is not None:
                utils.save_volume(im, aff, header, path_resample)

    # align image
    im = edit_volumes.align_volume_to_ref(im,
                                          aff,
                                          aff_ref=np.eye(4),
                                          n_dims=n_dims)
    shape = list(im.shape)

    # pad image if specified
    if padding:
        im = edit_volumes.pad_volume(im, padding_shape=padding)
        pad_shape = im.shape[:n_dims]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop is not None:
        crop = utils.reformat_to_list(crop, length=n_dims, dtype='int')
        if not all([pad_shape[i] >= crop[i] for i in range(len(pad_shape))]):
            crop = [min(pad_shape[i], crop[i]) for i in range(n_dims)]
        if not all([size % (2**n_levels) == 0 for size in crop]):
            crop = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in crop
            ]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in pad_shape
            ]

    # crop image if necessary
    if crop is not None:
        im, crop_idx = edit_volumes.crop_volume(im,
                                                cropping_shape=crop,
                                                return_crop_idx=True)
    else:
        crop_idx = None

    # normalise image
    if n_channels == 1:
        im = edit_volumes.rescale_volume(im,
                                         new_min=0.,
                                         new_max=1.,
                                         min_percentile=0.5,
                                         max_percentile=99.5)
    else:
        for i in range(im.shape[-1]):
            im[..., i] = edit_volumes.rescale_volume(im[..., i],
                                                     new_min=0.,
                                                     new_max=1.,
                                                     min_percentile=0.5,
                                                     max_percentile=99.5)

    # flip image along right/left axis
    if flip & (n_dims > 2):
        im_flipped = edit_volumes.flip_volume(im,
                                              direction='rl',
                                              aff=np.eye(4))
        im_flipped = utils.add_axis(
            im_flipped) if n_channels > 1 else utils.add_axis(im_flipped,
                                                              axis=[0, -1])
    else:
        im_flipped = None

    # add batch and channel axes
    im = utils.add_axis(im) if n_channels > 1 else utils.add_axis(im,
                                                                  axis=[0, -1])

    return im, aff, header, im_res, n_channels, n_dims, shape, pad_shape, crop_idx, im_flipped