예제 #1
0
def prepare_hippo_training_atlases(labels_dir,
                                   result_dir,
                                   image_dir=None,
                                   image_result_dir=None,
                                   smooth=True,
                                   crop_margin=50,
                                   recompute=True):
    """This function prepares training label maps from CobraLab. It first crops each atlas around the right and left
    hippocampi, with a margin. It then equalises the shape of these atlases by croppping them to the size of the
    smallest hippocampus. Finally it realigns the obtained atlases to FS orientation axes.
    :param labels_dir: path of directory with label maps to prepare
    :param result_dir: path of directory where prepared atlases will be writen
    :param image_dir: (optional) path of directory with images corresponding to the label maps to prepare.
    This can be sued to prepare a dataset of real images for supervised training.
    :param image_result_dir: (optional) path of directory where images corresponding to prepared atlases will be writen
    :param smooth: (optional) whether to smooth the final cropped label maps
    :param crop_margin: (optional) margin to add around hippocampi when cropping
    :param recompute: (optional) whether to recompute result files even if they already exists"""

    # create results dir
    if not os.path.exists(result_dir):
        os.mkdir(result_dir)
    tmp_result_dir = os.path.join(result_dir, 'first_cropping')
    if not os.path.exists(tmp_result_dir):
        os.mkdir(tmp_result_dir)
    if image_dir is not None:
        assert image_result_dir is not None, 'image_result_dir should not be None if image_dir is specified'
        if not os.path.exists(image_result_dir):
            os.mkdir(image_result_dir)
        tmp_image_result_dir = os.path.join(image_result_dir, 'first_cropping')
        if not os.path.exists(tmp_image_result_dir):
            os.mkdir(tmp_image_result_dir)
    else:
        tmp_image_result_dir = None

    # list labels and images
    labels_paths = utils.list_images_in_folder(labels_dir)
    if image_dir is not None:
        path_images = utils.list_images_in_folder(image_dir)
    else:
        path_images = [None] * len(labels_paths)

    # crop all atlases around hippo
    print('\ncropping around hippo')
    shape_array = np.zeros((len(labels_paths)*2, 3))
    for idx, (path_label, path_image) in enumerate(zip(labels_paths, path_images)):
        utils.print_loop_info(idx, len(labels_paths), 1)

        # crop left hippo first
        path_label_first_crop_l = os.path.join(tmp_result_dir,
                                               os.path.basename(path_label).replace('.nii', '_left.nii'))
        lab, aff, h = utils.load_volume(path_label, im_only=False)
        lab_l, croppping_idx, aff_l = edit_volumes.crop_volume_around_region(lab, crop_margin,
                                                                             list(range(20101, 20109)), aff=aff)
        if (not os.path.exists(path_label_first_crop_l)) | recompute:
            utils.save_volume(lab_l, aff_l, h, path_label_first_crop_l)
        else:
            lab_l = utils.load_volume(path_label_first_crop_l)
        if path_image is not None:
            path_image_first_crop_l = os.path.join(tmp_image_result_dir,
                                                   os.path.basename(path_image).replace('.nii', '_left.nii'))
            if (not os.path.exists(path_image_first_crop_l)) | recompute:
                im, aff, h = utils.load_volume(path_image, im_only=False)
                im, aff = edit_volumes.crop_volume_with_idx(im, croppping_idx, aff)
                utils.save_volume(im, aff, h, path_image_first_crop_l)
        shape_array[2*idx, :] = np.array(lab_l.shape)

        # crop right hippo and flip them
        path_label_first_crop_r = os.path.join(tmp_result_dir,
                                               os.path.basename(path_label).replace('.nii', '_right_flipped.nii'))
        lab, aff, h = utils.load_volume(path_label, im_only=False)
        lab_r, croppping_idx, aff_r = edit_volumes.crop_volume_around_region(lab, crop_margin,
                                                                             list(range(20001, 20009)), aff=aff)
        if (not os.path.exists(path_label_first_crop_r)) | recompute:
            lab_r = edit_volumes.flip_volume(lab_r, direction='rl', aff=aff_r)
            utils.save_volume(lab_r, aff_r, h, path_label_first_crop_r)
        else:
            lab_r = utils.load_volume(path_label_first_crop_r)
        if path_image is not None:
            path_image_first_crop_r = os.path.join(tmp_image_result_dir,
                                                   os.path.basename(path_image).replace('.nii', '_right.nii'))
            if (not os.path.exists(path_image_first_crop_r)) | recompute:
                im, aff, h = utils.load_volume(path_image, im_only=False)
                im, aff = edit_volumes.crop_volume_with_idx(im, croppping_idx, aff)
                im = edit_volumes.flip_volume(im, direction='rl', aff=aff)
                utils.save_volume(im, aff, h, path_image_first_crop_r)
        shape_array[2*idx+1, :] = np.array(lab_r.shape)

    # list croppped files
    path_labels_first_cropped = utils.list_images_in_folder(tmp_result_dir)
    if tmp_image_result_dir is not None:
        path_images_first_cropped = utils.list_images_in_folder(tmp_image_result_dir)
    else:
        path_images_first_cropped = [None] * len(path_labels_first_cropped)

    # crop all label maps to same size
    print('\nequalising shapes')
    new_shape = np.min(shape_array, axis=0).astype('int32')
    for i, (path_label, path_image) in enumerate(zip(path_labels_first_cropped, path_images_first_cropped)):
        utils.print_loop_info(i, len(path_labels_first_cropped), 1)

        # get cropping indices
        path_lab_cropped = os.path.join(result_dir, os.path.basename(path_label))
        lab, aff, h = utils.load_volume(path_label, im_only=False)
        lab_shape = lab.shape
        min_cropping = np.array([np.maximum(int((lab_shape[i]-new_shape[i])/2), 0) for i in range(3)])
        max_cropping = np.array([min_cropping[i] + new_shape[i] for i in range(3)])

        # crop labels and realign on adni format
        if (not os.path.exists(path_lab_cropped)) | recompute:
            lab, aff = edit_volumes.crop_volume_with_idx(lab, np.concatenate([min_cropping, max_cropping]), aff)
            # realign on adni format
            lab = np.flip(lab, axis=2)
            aff[0:3, 0:3] = np.array([[-0.6, 0, 0], [0, 0, -0.6], [0, -0.6, 0]])
            utils.save_volume(lab, aff, h, path_lab_cropped)

        # crop image and realign on adni format
        if path_image is not None:
            path_im_cropped = os.path.join(image_result_dir, os.path.basename(path_image))
            if (not os.path.exists(path_im_cropped)) | recompute:
                im, aff, h = utils.load_volume(path_image, im_only=False)
                im, aff = edit_volumes.crop_volume_with_idx(im, np.concatenate([min_cropping, max_cropping]), aff)
                im = np.flip(im, axis=2)
                aff[0:3, 0:3] = np.array([[-0.6, 0, 0], [0, 0, -0.6], [0, -0.6, 0]])
                im = edit_volumes.mask_volume(im, lab)
                utils.save_volume(im, aff, h, path_im_cropped)

    # correct all labels to left values
    print('\ncorrecting labels')
    list_incorrect_labels = [77, 80, 251, 252, 253, 254, 255, 29, 41, 42, 43, 44, 46, 47, 49, 50, 51, 52, 54, 58, 60,
                             61, 62, 63, 7012, 20001, 20002, 20004, 20005, 20006, 20007, 20008]
    list_correct_labels = [2, 3, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5, 7, 8, 10, 11, 12, 13, 18, 26, 28, 2, 30, 31, 20108,
                           20101, 20102, 20104, 20105, 20106, 20107, 20108]
    edit_volumes.correct_labels_in_dir(result_dir, list_incorrect_labels, list_correct_labels, result_dir)

    # smooth labels
    if smooth:
        print('\nsmoothing labels')
        edit_volumes.smooth_labels_in_dir(result_dir, result_dir)
예제 #2
0
파일: predict.py 프로젝트: BBillot/SynthSeg
def postprocess(post_patch,
                pad_shape,
                im_shape,
                crop,
                n_dims,
                segmentation_labels,
                lr_indices,
                keep_biggest_component,
                aff,
                topology_classes=True,
                post_patch_flip=None):

    # get posteriors and segmentation
    post_patch = np.squeeze(post_patch)
    if post_patch_flip is not None:
        post_patch_flip = edit_volumes.flip_volume(np.squeeze(post_patch_flip),
                                                   direction='rl',
                                                   aff=np.eye(4))
        if lr_indices is not None:
            post_patch_flip[..., lr_indices.flatten()] = post_patch_flip[
                ..., lr_indices[::-1].flatten()]
        post_patch = 0.5 * (post_patch + post_patch_flip)

    # keep biggest connected component (use it with smoothing!)
    if keep_biggest_component:
        tmp_post_patch = post_patch[..., 1:]
        post_patch_mask = np.sum(tmp_post_patch, axis=-1) > 0.25
        post_patch_mask = edit_volumes.get_largest_connected_component(
            post_patch_mask)
        post_patch_mask = np.stack([post_patch_mask] *
                                   tmp_post_patch.shape[-1],
                                   axis=-1)
        tmp_post_patch = edit_volumes.mask_volume(tmp_post_patch,
                                                  mask=post_patch_mask)
        post_patch[..., 1:] = tmp_post_patch

    # reset posteriors to zero outside the largest connected component of each topological class
    if topology_classes is not None:
        post_patch_mask = post_patch > 0.25
        for topology_class in np.unique(topology_classes)[1:]:
            tmp_topology_indices = np.where(
                topology_classes == topology_class)[0]
            tmp_mask = np.any(post_patch_mask[..., tmp_topology_indices],
                              axis=-1)
            tmp_mask = edit_volumes.get_largest_connected_component(tmp_mask)
            for idx in tmp_topology_indices:
                post_patch[..., idx] *= tmp_mask

    # renormalise posteriors and get hard segmentation
    if (post_patch_flip is not None) | keep_biggest_component | (
            topology_classes is not None):
        post_patch /= np.sum(post_patch, axis=-1)[..., np.newaxis]
    seg_patch = post_patch.argmax(-1)

    # paste patches back to matrix of original image size
    if crop is not None:
        seg = np.zeros(shape=pad_shape, dtype='int32')
        posteriors = np.zeros(shape=[*pad_shape, segmentation_labels.shape[0]])
        posteriors[...,
                   0] = np.ones(pad_shape)  # place background around patch
        if n_dims == 2:
            seg[crop[0]:crop[2], crop[1]:crop[3]] = seg_patch
            posteriors[crop[0]:crop[2], crop[1]:crop[3], :] = post_patch
        elif n_dims == 3:
            seg[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5]] = seg_patch
            posteriors[crop[0]:crop[3], crop[1]:crop[4],
                       crop[2]:crop[5], :] = post_patch
    else:
        seg = seg_patch
        posteriors = post_patch
    seg = segmentation_labels[seg.astype('int')].astype('int')

    if im_shape != pad_shape:
        bounds = [int((p - i) / 2) for (p, i) in zip(pad_shape, im_shape)]
        bounds += [p + i for (p, i) in zip(bounds, im_shape)]
        seg = edit_volumes.crop_volume_with_idx(seg, bounds)
        posteriors = edit_volumes.crop_volume_with_idx(posteriors,
                                                       bounds,
                                                       n_dims=n_dims)

    # align prediction back to first orientation
    if n_dims > 2:
        seg = edit_volumes.align_volume_to_ref(seg,
                                               aff=np.eye(4),
                                               aff_ref=aff,
                                               n_dims=n_dims,
                                               return_aff=False)
        posteriors = edit_volumes.align_volume_to_ref(posteriors,
                                                      aff=np.eye(4),
                                                      aff_ref=aff,
                                                      n_dims=n_dims)

    return seg, posteriors
예제 #3
0
def preprocess_adni_hippo(path_t1,
                          path_t2,
                          path_aseg,
                          result_dir,
                          target_res,
                          padding_margin=85,
                          remove=False,
                          path_freesurfer='/usr/local/freesurfer/',
                          verbose=True,
                          recompute=True):
    """This function builds a T1+T2 multimodal image from the ADNI dataset.
    It first rescales intensities of each channel between 0 and 255.
    It then resamples the T2 image (which are 0.4*0.4*2.0 resolution) to target resolution.
    The obtained T2 is then padded in all directions by the padding_margin param (typically large 85).
    The T1 and aseg are then resampled like the T2 using mri_convert.
    Now that the T1, T2 and asegs are aligned and at the same resolution, we crop them around the right and left hippo.
    Finally, the T1 and T2 are concatenated into one single multimodal image.
    :param path_t1: path input T1 (typically at 1mm isotropic)
    :param path_t2: path input T2 (typically cropped around the hippo in sagittal axis, 0.4x0.4x2.0)
    :param path_aseg: path input segmentation (typically at 1mm isotropic)
    :param result_dir: path of directory where prepared images and labels will be writen.
    :param target_res: resolution at which to resample the label maps, and the images.
    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
    :param padding_margin: (optional) margin to add around hippocampi when cropping
    :param remove: (optional) whether to delete temporary files. Default is True.
    :param path_freesurfer: (optional) path of FreeSurfer home, to use mri_convert
    :param verbose: (optional) whether to print out mri_convert output when resampling images
    :param recompute: (optional) whether to recompute result files even if they already exists
    """

    # create results dir
    if not os.path.isdir(result_dir):
        os.mkdir(result_dir)

    path_test_im_right = os.path.join(result_dir, 'hippo_right.nii.gz')
    path_test_aseg_right = os.path.join(result_dir, 'hippo_right_aseg.nii.gz')
    path_test_im_left = os.path.join(result_dir, 'hippo_left.nii.gz')
    path_test_aseg_left = os.path.join(result_dir, 'hippo_left_aseg.nii.gz')
    if (not os.path.isfile(path_test_im_right)) | (not os.path.isfile(path_test_aseg_right)) | \
       (not os.path.isfile(path_test_im_left)) | (not os.path.isfile(path_test_aseg_left)) | recompute:

        # set up FreeSurfer
        os.environ['FREESURFER_HOME'] = path_freesurfer
        os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh'))
        mri_convert = os.path.join(path_freesurfer, 'bin/mri_convert.bin')

        # rescale T1
        path_t1_rescaled = os.path.join(result_dir, 't1_rescaled.nii.gz')
        if (not os.path.isfile(path_t1_rescaled)) | recompute:
            im, aff, h = utils.load_volume(path_t1, im_only=False)
            im = edit_volumes.rescale_volume(im)
            utils.save_volume(im, aff, h, path_t1_rescaled)
        # rescale T2
        path_t2_rescaled = os.path.join(result_dir, 't2_rescaled.nii.gz')
        if (not os.path.isfile(path_t2_rescaled)) | recompute:
            im, aff, h = utils.load_volume(path_t2, im_only=False)
            im = edit_volumes.rescale_volume(im)
            utils.save_volume(im, aff, h, path_t2_rescaled)

        # resample T2 to target res
        path_t2_resampled = os.path.join(result_dir, 't2_rescaled_resampled.nii.gz')
        if (not os.path.isfile(path_t2_resampled)) | recompute:
            str_res = ' '.join([str(r) for r in utils.reformat_to_list(target_res, length=3)])
            cmd = mri_convert + ' ' + path_t2_rescaled + ' ' + path_t2_resampled + ' --voxsize ' + str_res
            cmd += ' -odt float'
            if not verbose:
                cmd += ' >/dev/null 2>&1'
            _ = os.system(cmd)

        # pad T2
        path_t2_padded = os.path.join(result_dir, 't2_rescaled_resampled_padded.nii.gz')
        if (not os.path.isfile(path_t2_padded)) | recompute:
            t2, aff, h = utils.load_volume(path_t2_resampled, im_only=False)
            t2_padded = np.pad(t2, padding_margin, 'constant')
            aff[:3, -1] = aff[:3, -1] - (aff[:3, :3] @ (padding_margin * np.ones((3, 1)))).T
            utils.save_volume(t2_padded, aff, h, path_t2_padded)

        # resample T1 and aseg accordingly
        path_t1_resampled = os.path.join(result_dir, 't1_rescaled_resampled.nii.gz')
        if (not os.path.isfile(path_t1_resampled)) | recompute:
            cmd = mri_convert + ' ' + path_t1_rescaled + ' ' + path_t1_resampled + ' -rl ' + path_t2_padded
            cmd += ' -odt float'
            if not verbose:
                cmd += ' >/dev/null 2>&1'
            _ = os.system(cmd)
        path_aseg_resampled = os.path.join(result_dir, 'aseg_resampled.nii.gz')
        if (not os.path.isfile(path_aseg_resampled)) | recompute:
            cmd = mri_convert + ' ' + path_aseg + ' ' + path_aseg_resampled + ' -rl ' + path_t2_padded
            cmd += ' -rt nearest -odt float'
            if not verbose:
                cmd += ' >/dev/null 2>&1'
            _ = os.system(cmd)

        # crop images and concatenate T1 and T2
        for lab, side in zip([17, 53], ['left', 'right']):
            path_test_image = os.path.join(result_dir, 'hippo_{}.nii.gz'.format(side))
            path_test_aseg = os.path.join(result_dir, 'hippo_{}_aseg.nii.gz'.format(side))
            if (not os.path.isfile(path_test_image)) | (not os.path.isfile(path_test_aseg)) | recompute:
                aseg, aff, h = utils.load_volume(path_aseg_resampled, im_only=False)
                tmp_aseg, cropping, tmp_aff = edit_volumes.crop_volume_around_region(aseg,
                                                                                     margin=30,
                                                                                     masking_labels=lab,
                                                                                     aff=aff)
                if side == 'right':
                    tmp_aseg = edit_volumes.flip_volume(tmp_aseg, direction='rl', aff=tmp_aff)
                utils.save_volume(tmp_aseg, tmp_aff, h, path_test_aseg)
                if (not os.path.isfile(path_test_image)) | recompute:
                    t1 = utils.load_volume(path_t1_resampled)
                    t1 = edit_volumes.crop_volume_with_idx(t1, crop_idx=cropping)
                    t1 = edit_volumes.mask_volume(t1, tmp_aseg, dilate=6, erode=5)
                    t2 = utils.load_volume(path_t2_padded)
                    t2 = edit_volumes.crop_volume_with_idx(t2, crop_idx=cropping)
                    t2 = edit_volumes.mask_volume(t2, tmp_aseg, dilate=6, erode=5)
                    if side == 'right':
                        t1 = edit_volumes.flip_volume(t1, direction='rl', aff=tmp_aff)
                        t2 = edit_volumes.flip_volume(t2, direction='rl', aff=tmp_aff)
                    test_image = np.stack([t1, t2], axis=-1)
                    utils.save_volume(test_image, tmp_aff, h, path_test_image)

        # remove unnecessary files
        if remove:
            list_files_to_remove = [path_t1_rescaled,
                                    path_t2_rescaled,
                                    path_t2_resampled,
                                    path_t2_padded,
                                    path_t1_resampled,
                                    path_aseg_resampled]
            for path in list_files_to_remove:
                os.remove(path)
예제 #4
0
파일: predict.py 프로젝트: BBillot/SynthSeg
def preprocess_image(im_path,
                     n_levels,
                     target_res,
                     crop=None,
                     padding=None,
                     flip=False,
                     path_resample=None):

    # read image and corresponding info
    im, _, aff, n_dims, n_channels, header, im_res = utils.get_volume_info(
        im_path, True)

    # resample image if necessary
    if target_res is not None:
        target_res = np.squeeze(
            utils.reformat_to_n_channels_array(target_res, n_dims))
        if np.any((im_res > target_res + 0.05) | (im_res < target_res - 0.05)):
            im_res = target_res
            im, aff = edit_volumes.resample_volume(im, aff, im_res)
            if path_resample is not None:
                utils.save_volume(im, aff, header, path_resample)

    # align image
    im = edit_volumes.align_volume_to_ref(im,
                                          aff,
                                          aff_ref=np.eye(4),
                                          n_dims=n_dims)
    shape = list(im.shape)

    # pad image if specified
    if padding:
        im = edit_volumes.pad_volume(im, padding_shape=padding)
        pad_shape = im.shape[:n_dims]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop is not None:
        crop = utils.reformat_to_list(crop, length=n_dims, dtype='int')
        if not all([pad_shape[i] >= crop[i] for i in range(len(pad_shape))]):
            crop = [min(pad_shape[i], crop[i]) for i in range(n_dims)]
        if not all([size % (2**n_levels) == 0 for size in crop]):
            crop = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in crop
            ]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in pad_shape
            ]

    # crop image if necessary
    if crop is not None:
        im, crop_idx = edit_volumes.crop_volume(im,
                                                cropping_shape=crop,
                                                return_crop_idx=True)
    else:
        crop_idx = None

    # normalise image
    if n_channels == 1:
        im = edit_volumes.rescale_volume(im,
                                         new_min=0.,
                                         new_max=1.,
                                         min_percentile=0.5,
                                         max_percentile=99.5)
    else:
        for i in range(im.shape[-1]):
            im[..., i] = edit_volumes.rescale_volume(im[..., i],
                                                     new_min=0.,
                                                     new_max=1.,
                                                     min_percentile=0.5,
                                                     max_percentile=99.5)

    # flip image along right/left axis
    if flip & (n_dims > 2):
        im_flipped = edit_volumes.flip_volume(im,
                                              direction='rl',
                                              aff=np.eye(4))
        im_flipped = utils.add_axis(
            im_flipped) if n_channels > 1 else utils.add_axis(im_flipped,
                                                              axis=[0, -1])
    else:
        im_flipped = None

    # add batch and channel axes
    im = utils.add_axis(im) if n_channels > 1 else utils.add_axis(im,
                                                                  axis=[0, -1])

    return im, aff, header, im_res, n_channels, n_dims, shape, pad_shape, crop_idx, im_flipped