示例#1
0
def labeled_data_test_params(path_in='sct_testing_data/t2/t2.nii.gz',
                             path_seg='sct_testing_data/t2/labels.nii.gz'):
    """Generate image/label pairs for various test cases of
    test_sagittal_slice_get_center_spit."""
    im_in = Image(path_in)  # Base anatomical image
    im_seg_labeled = Image(path_seg)  # Base labeled segmentation
    assert np.count_nonzero(
        im_seg_labeled.data
    ) >= 2, "Labeled segmentation image has fewer than 2 labels"

    # Create image with all but one label removed
    im_seg_one_label = im_seg_labeled.copy()
    for x, y, z in np.argwhere(im_seg_one_label.data)[1:]:
        im_seg_one_label.data[x, y, z] = 0

    # Create image with no labels
    im_seg_no_labels = im_seg_labeled.copy()
    for x, y, z in np.argwhere(im_seg_no_labels.data):
        im_seg_no_labels.data[x, y, z] = 0

    return [
        pytest.param(im_in, im_seg_labeled, id='multiple_labels'),
        pytest.param(im_in, im_seg_one_label, id='one_label'),
        pytest.param(im_in, im_seg_no_labels, id='no_labels')
    ]
示例#2
0
def detect_c2c3_from_file(fname_im,
                          fname_seg,
                          contrast,
                          fname_c2c3=None,
                          verbose=1):
    """
    Detect the posterior edge of C2-C3 disc.
    :param fname_im:
    :param fname_seg:
    :param contrast:
    :param fname_c2c3:
    :param verbose:
    :return: fname_c2c3
    """
    # load data
    sct.printv('Load data...', verbose)
    nii_im = Image(fname_im)
    nii_seg = Image(fname_seg)

    # detect C2-C3
    nii_c2c3 = detect_c2c3(nii_im.copy(), nii_seg, contrast, verbose=verbose)

    # Output C2-C3 disc label
    # by default, output in the same directory as the input images
    sct.printv('Generate output file...', verbose)
    if fname_c2c3 is None:
        fname_c2c3 = os.path.join(os.path.dirname(nii_im.absolutepath),
                                  "label_c2c3.nii.gz")
    nii_c2c3.save(fname_c2c3)

    return fname_c2c3
def detect_c2c3_from_file(fname_im, fname_seg, contrast, fname_c2c3=None, verbose=1):
    """
    Detect the posterior edge of C2-C3 disc.
    :param fname_im:
    :param fname_seg:
    :param contrast:
    :param fname_c2c3:
    :param verbose:
    :return: fname_c2c3
    """
    # load data
    logger.info('Load data...')
    nii_im = Image(fname_im)
    nii_seg = Image(fname_seg)

    # detect C2-C3
    nii_c2c3 = detect_c2c3(nii_im.copy(), nii_seg, contrast, verbose=verbose)

    # Output C2-C3 disc label
    # by default, output in the same directory as the input images
    logger.info('Generate output file...')
    if fname_c2c3 is None:
        fname_c2c3 = os.path.join(os.path.dirname(nii_im.absolutepath), "label_c2c3.nii.gz")
    nii_c2c3.save(fname_c2c3)

    return fname_c2c3
def test_crop_image_around_centerline():
    input_shape = (100, 100, 100)
    crop_size = 20

    data = np.random.rand(input_shape[0], input_shape[1], input_shape[2])
    affine = np.eye(4)
    nii = nib.nifti1.Nifti1Image(data, affine)
    img = Image(data, hdr=nii.header, dim=nii.header.get_data_shape())

    ctr, _, _ = dummy_centerline(size_arr=input_shape)

    _, _, _, img_out = sct.deepseg_sc.core.crop_image_around_centerline(
        im_in=img.copy(), ctr_in=ctr.copy(), crop_size=crop_size)

    img_in_z0 = img.data[:, :, 0]
    x_ctr_z0, y_ctr_z0 = np.where(ctr.data[:, :, 0])[0][0], np.where(
        ctr.data[:, :, 0])[1][0]
    x_start, x_end = sct.deepseg_sc.core._find_crop_start_end(
        x_ctr_z0, crop_size, img.dim[0])
    y_start, y_end = sct.deepseg_sc.core._find_crop_start_end(
        y_ctr_z0, crop_size, img.dim[1])
    img_in_z0_crop = img_in_z0[x_start:x_end, y_start:y_end]

    assert img_out.data.shape == (crop_size, crop_size, input_shape[2])
    assert np.allclose(img_in_z0_crop, img_out.data[:, :, 0])
示例#5
0
def multicomponent_merge(fname_list):
    from numpy import zeros
    # WARNING: output multicomponent is not optimal yet, some issues may be related to the use of this function

    im_0 = Image(fname_list[0])
    new_shape = list(im_0.data.shape)
    if len(new_shape) == 3:
        new_shape.append(1)
    new_shape.append(len(fname_list))
    new_shape = tuple(new_shape)

    data_out = zeros(new_shape)
    for i, fname in enumerate(fname_list):
        im = Image(fname)
        dat = im.data
        if len(dat.shape) == 2:
            data_out[:, :, 0, 0, i] = dat.astype('float32')
        elif len(dat.shape) == 3:
            data_out[:, :, :, 0, i] = dat.astype('float32')
        elif len(dat.shape) == 4:
            data_out[:, :, :, :, i] = dat.astype('float32')
        del im
        del dat
    im_out = im_0.copy()
    im_out.data = data_out.astype('float32')
    im_out.hdr.set_intent('vector', (), '')
    im_out.absolutepath = sct.add_suffix(im_out.absolutepath,
                                         '_multicomponent')
    return im_out
def clean_straight(dataset_info, contrast='t1', ratio=0.9):
    path_data = dataset_info['path_data']
    path_template = dataset_info['path_template']
    list_subjects = dataset_info['subjects']

    fname_in = contrast + '_straight.nii.gz'
    fname_out = contrast + '_straight_clean.nii.gz'

    tqdm_bar = tqdm(total=len(list_subjects), unit='B', unit_scale=True, desc="Status", ascii=True)
    for subject_name in list_subjects:
        path_data_subject = path_data + subject_name + '/' + contrast + '/'

        im = Image(os.path.join(path_data_subject, fname_in))
        im_out = im.copy()

        data = im.data[100:150, 100:150]
        max_max_val = np.percentile(data[data > 100], 80)
        thr = ratio * max_max_val

        z2clean = []
        for i in range(data.shape[2]):
            max_val = np.percentile(data[:, :, i], 90)
            if max_val < thr:
                z2clean.append(i)

        im_out.data[:, :, z2clean] *= 0.0
        im_out.save(os.path.join(path_data_subject, fname_out))
        del im, im_out, data

        tqdm_bar.update(1)
    tqdm_bar.close()
示例#7
0
def multicomponent_merge(fname_list):
    from numpy import zeros
    # WARNING: output multicomponent is not optimal yet, some issues may be related to the use of this function

    im_0 = Image(fname_list[0])
    new_shape = list(im_0.data.shape)
    if len(new_shape) == 3:
        new_shape.append(1)
    new_shape.append(len(fname_list))
    new_shape = tuple(new_shape)

    data_out = zeros(new_shape)
    for i, fname in enumerate(fname_list):
        im = Image(fname)
        dat = im.data
        if len(dat.shape) == 2:
            data_out[:, :, 0, 0, i] = dat.astype('float32')
        elif len(dat.shape) == 3:
            data_out[:, :, :, 0, i] = dat.astype('float32')
        elif len(dat.shape) == 4:
            data_out[:, :, :, :, i] = dat.astype('float32')
        del im
        del dat
    im_out = im_0.copy()
    im_out.data = data_out.astype('float32')
    im_out.hdr.set_intent('vector', (), '')
    im_out.absolutepath = sct.add_suffix(im_out.absolutepath, '_multicomponent')
    return im_out
示例#8
0
def clean_labeled_segmentation(fname_labeled_seg, fname_seg,
                               fname_labeled_seg_new):
    """
    Clean labeled segmentation by:
      (i)  removing voxels in segmentation_labeled that are not in segmentation and
      (ii) adding voxels in segmentation that are not in segmentation_labeled
    :param fname_labeled_seg:
    :param fname_seg:
    :param fname_labeled_seg_new: output
    :return: none
    """
    # remove voxels in segmentation_labeled that are not in segmentation
    img_labeled_seg = Image(fname_labeled_seg)
    img_seg = Image(fname_seg)
    data_labeled_seg_mul = img_labeled_seg.data * img_seg.data
    # dilate to add voxels in segmentation that are not in segmentation_labeled
    data_labeled_seg_dil = dilate(img_labeled_seg.data, [2])
    data_labeled_seg_mul_bin = data_labeled_seg_mul > 0
    data_diff = img_seg.data - data_labeled_seg_mul_bin
    ind_nonzero = np.where(data_diff)
    img_labeled_seg_corr = img_labeled_seg.copy()
    img_labeled_seg_corr.data = data_labeled_seg_mul
    for i_vox in range(len(ind_nonzero[0])):
        # assign closest label value for this voxel
        ix, iy, iz = ind_nonzero[0][i_vox], ind_nonzero[1][i_vox], ind_nonzero[
            2][i_vox]
        img_labeled_seg_corr.data[ix, iy, iz] = data_labeled_seg_dil[ix, iy,
                                                                     iz]
    # save new label file (overwrite)
    img_labeled_seg_corr.absolutepath = fname_labeled_seg_new
    img_labeled_seg_corr.save()
def dummy_centerline(size_arr=(9, 9, 9), subsampling=1, dilate_ctl=0, hasnan=False, zeroslice=[], outlier=[],
                     orientation='RPI', debug=False):
    """
    Create a dummy Image centerline of small size. Return the full and sub-sampled version along z.
    :param size_arr: tuple: (nx, ny, nz)
    :param subsampling: int >=1. Subsampling factor along z. 1: no subsampling. 2: centerline defined every other z.
    :param dilate_ctl: Dilation of centerline. E.g., if dilate_ctl=1, result will be a square of 3x3 per slice.
                         if dilate_ctl=0, result will be a single pixel per slice.
    :param hasnan: Bool: Image has non-numerical values: nan, inf. In this case, do not subsample.
    :param zeroslice: list int: zero all slices listed in this param
    :param outlier: list int: replace the current point with an outlier at the corner of the image for the slices listed
    :param orientation:
    :param debug: Bool: Write temp files
    :return:
    """
    from numpy import poly1d, polyfit
    nx, ny, nz = size_arr
    # define array based on a polynomial function, within X-Z plane, located at y=ny/4, based on the following points:
    x = np.array([round(nx/4.), round(nx/2.), round(3*nx/4.)])
    z = np.array([0, round(nz/2.), nz-1])
    p = poly1d(polyfit(z, x, deg=3))
    data = np.zeros((nx, ny, nz))
    arr_ctl = np.array([p(range(nz)).astype(np.int),
                        [round(ny / 4.)] * len(range(nz)),
                        range(nz)], dtype=np.uint16)
    # Loop across dilation of centerline. E.g., if dilate_ctl=1, result will be a square of 3x3 per slice.
    for ixiy_ctl in itertools.product(range(-dilate_ctl, dilate_ctl+1, 1), range(-dilate_ctl, dilate_ctl+1, 1)):
        data[(arr_ctl[0] + ixiy_ctl[0]).tolist(),
             (arr_ctl[1] + ixiy_ctl[1]).tolist(),
             arr_ctl[2].tolist()] = 1
    # Zero specified slices
    if zeroslice is not []:
        data[:, :, zeroslice] = 0
    # Add outlier
    if outlier is not []:
        # First, zero all the slice
        data[:, :, outlier] = 0
        # Then, add point in the corner
        data[0, 0, outlier] = 1
    # Create image with default orientation LPI
    affine = np.eye(4)
    nii = nib.nifti1.Nifti1Image(data, affine)
    img = Image(data, hdr=nii.header, dim=nii.header.get_data_shape())
    # subsample data
    img_sub = img.copy()
    img_sub.data = np.zeros((nx, ny, nz))
    for iz in range(0, nz, subsampling):
        img_sub.data[..., iz] = data[..., iz]
    # Add non-numerical values at the top corner of the image
    if hasnan:
        img.data[0, 0, 0] = np.nan
        img.data[1, 0, 0] = np.inf
    # Update orientation
    img.change_orientation(orientation)
    img_sub.change_orientation(orientation)
    if debug:
        img_sub.save('tmp_dummy_seg_'+datetime.now().strftime("%Y%m%d%H%M%S%f")+'.nii.gz')
    return img, img_sub, arr_ctl
    def label_lesion(self):
        printv('\nLabel connected regions of the masked image...', self.verbose, 'normal')
        im = Image(self.fname_mask)
        im_2save = im.copy()
        im_2save.data = label(im.data, connectivity=2)
        im_2save.save(self.fname_label)

        self.measure_pd['label'] = [l for l in np.unique(im_2save.data) if l]
        printv('Lesion count = ' + str(len(self.measure_pd['label'])), self.verbose, 'info')
    def label_lesion(self):
        printv('\nLabel connected regions of the masked image...', self.verbose, 'normal')
        im = Image(self.fname_mask)
        im_2save = im.copy()
        im_2save.data = label(im.data, connectivity=2)
        im_2save.save(self.fname_label)

        self.measure_pd['label'] = [l for l in np.unique(im_2save.data) if l]
        printv('Lesion count = ' + str(len(self.measure_pd['label'])), self.verbose, 'info')
    def measure(self):
        im_lesion = Image(self.fname_label)
        im_lesion_data = im_lesion.data
        p_lst = im_lesion.dim[4:7] # voxel size

        label_lst = [l for l in np.unique(im_lesion_data) if l]  # lesion label IDs list

        if self.path_template is not None:
            if os.path.isfile(self.path_levels):
                img_vert = Image(self.path_levels)
                im_vert_data = img_vert.data
                self.vert_lst = [v for v in np.unique(im_vert_data) if v]  # list of vertebral levels available in the input image

            else:
                im_vert_data = None
                printv('ERROR: the file ' + self.path_levels + ' does not exist. Please make sure the template was correctly registered and warped (sct_register_to_template or sct_register_multimodal and sct_warp_template)', type='error')

            # In order to open atlas images only one time
            atlas_data_dct = {}  # dict containing the np.array of the registrated atlas
            for fname_atlas_roi in self.atlas_roi_lst:
                tract_id = int(fname_atlas_roi.split('_')[-1].split('.nii.gz')[0])
                img_cur = Image(fname_atlas_roi)
                img_cur_copy = img_cur.copy()
                atlas_data_dct[tract_id] = img_cur_copy.data
                del img_cur

        self.volumes = np.zeros((im_lesion.dim[2], len(label_lst)))

        # iteration across each lesion to measure statistics
        for lesion_label in label_lst:
            im_lesion_data_cur = np.copy(im_lesion_data == lesion_label)
            printv('\nMeasures on lesion #' + str(lesion_label) + '...', self.verbose, 'normal')

            label_idx = self.measure_pd[self.measure_pd.label == lesion_label].index
            self._measure_volume(im_lesion_data_cur, p_lst, label_idx)
            self._measure_length(im_lesion_data_cur, p_lst, label_idx)
            self._measure_diameter(im_lesion_data_cur, p_lst, label_idx)

            # compute lesion distribution for each lesion
            if self.path_template is not None:
                self._measure_eachLesion_distribution(lesion_id=lesion_label,
                                                      atlas_data=atlas_data_dct,
                                                      im_vert=im_vert_data,
                                                      im_lesion=im_lesion_data_cur,
                                                      p_lst=p_lst)

        if self.path_template is not None:
            # compute total lesion distribution
            self._measure_totLesion_distribution(im_lesion=np.copy(im_lesion_data > 0),
                                                 atlas_data=atlas_data_dct,
                                                 im_vert=im_vert_data,
                                                 p_lst=p_lst)

        if self.fname_ref is not None:
            # Compute mean and std value in each labeled lesion
            self._measure_within_im(im_lesion=im_lesion_data, im_ref=Image(self.fname_ref).data, label_lst=label_lst)
    def measure(self):
        im_lesion = Image(self.fname_label)
        im_lesion_data = im_lesion.data
        p_lst = im_lesion.dim[4:7]  # voxel size

        label_lst = [l for l in np.unique(im_lesion_data) if l]  # lesion label IDs list

        if self.path_template is not None:
            if os.path.isfile(self.path_levels):
                img_vert = Image(self.path_levels)
                im_vert_data = img_vert.data
                self.vert_lst = [v for v in np.unique(im_vert_data) if v]  # list of vertebral levels available in the input image

            else:
                im_vert_data = None
                printv('ERROR: the file ' + self.path_levels + ' does not exist. Please make sure the template was correctly registered and warped (sct_register_to_template or sct_register_multimodal and sct_warp_template)', type='error')

            # In order to open atlas images only one time
            atlas_data_dct = {}  # dict containing the np.array of the registrated atlas
            for fname_atlas_roi in self.atlas_roi_lst:
                tract_id = int(fname_atlas_roi.split('_')[-1].split('.nii.gz')[0])
                img_cur = Image(fname_atlas_roi)
                img_cur_copy = img_cur.copy()
                atlas_data_dct[tract_id] = img_cur_copy.data
                del img_cur

        self.volumes = np.zeros((im_lesion.dim[2], len(label_lst)))

        # iteration across each lesion to measure statistics
        for lesion_label in label_lst:
            im_lesion_data_cur = np.copy(im_lesion_data == lesion_label)
            printv('\nMeasures on lesion #' + str(lesion_label) + '...', self.verbose, 'normal')

            label_idx = self.measure_pd[self.measure_pd.label == lesion_label].index
            self._measure_volume(im_lesion_data_cur, p_lst, label_idx)
            self._measure_length(im_lesion_data_cur, p_lst, label_idx)
            self._measure_diameter(im_lesion_data_cur, p_lst, label_idx)

            # compute lesion distribution for each lesion
            if self.path_template is not None:
                self._measure_eachLesion_distribution(lesion_id=lesion_label,
                                                      atlas_data=atlas_data_dct,
                                                      im_vert=im_vert_data,
                                                      im_lesion=im_lesion_data_cur,
                                                      p_lst=p_lst)

        if self.path_template is not None:
            # compute total lesion distribution
            self._measure_totLesion_distribution(im_lesion=np.copy(im_lesion_data > 0),
                                                 atlas_data=atlas_data_dct,
                                                 im_vert=im_vert_data,
                                                 p_lst=p_lst)

        if self.fname_ref is not None:
            # Compute mean and std value in each labeled lesion
            self._measure_within_im(im_lesion=im_lesion_data, im_ref=Image(self.fname_ref).data, label_lst=label_lst)
示例#14
0
def add(img: Image, value: int) -> Image:
    """
    This function adds a specified value to all non-zero voxels.
    :param img: source image
    :param value: numeric value to add
    :returns new image with value added
    """
    out = img.copy()
    out.data[np.where(out.data != 0)] += value

    return out
示例#15
0
def create_labels(img: Image, coordinates: Sequence[Coordinate]) -> Image:
    """
    Add labels provided by a user to the image.
    This method works only if the user inserted correct coordinates.
    If only one label is to be added, coordinates must be completed with '[]'
    :param img: source image
    :param coordinates: list of Coordinate objects (see spinalcordtoolbox.types)
    :returns: labeled source image
    """
    out = _add_labels(img.copy(), coordinates)

    return out
示例#16
0
def register_to_template(img_path, sc_path, contrast, label_path, label_flag, ofolder, qc_folder):

    registration_status = 1

    try:
        sct.run(['sct_register_to_template', '-i', img_path,
                                            '-s', sc_path,
                                            '-c', contrast,
                                            label_flag, label_path,
                                            '-param', PARAM_REG,
                                            '-ofolder', ofolder,
                                            '-qc', qc_folder])
    except:
        im_ana, im_seg = Image(img_path), Image(sc_path)
        im_seg_new = im_ana.copy() # copy hdr --> segmentation
        im_seg_new.data = im_seg.data
        im_seg_new.save(sc_path)

        im_labels = Image(label_path)
        im_labels_new = im_ana.copy() # copy hdr --> labels
        im_labels_new.data = im_labels.data 
        im_labels_new.change_type(type='uint8')
        im_labels_new.save(label_path)

        try: # re-run
            sct.run(['sct_register_to_template', '-i', img_path,
                                                '-s', sc_path,
                                                '-c', contrast,
                                                label_flag, label_path,
                                                '-param', PARAM_REG,
                                                '-ofolder', ofolder,
                                                '-qc', qc_folder])
        except:
            registration_status = 0
            sct.printv('ERROR: Could not complete registration for anat. --> template! Path: %s' % img_path)

    return registration_status
示例#17
0
def remove_missing_labels(img: Image, ref: Image):
    """
    Compare an input image and a reference image. Remove any label from the input image that doesn't exist in the reference image.
    :param img: source image
    :param ref: reference image
    :returns: image with labels missing from reference removed
    """
    out = img.copy()

    input_coords = img.getNonZeroCoordinates(coordValue=True)
    ref_coords = ref.getNonZeroCoordinates(coordValue=True)

    for c in input_coords:
        if c not in ref_coords:
            out.data[int(c.x), int(c.y), int(c.z)] = 0

    return out
示例#18
0
def remove_labels_from_image(img: Image, labels: Sequence[int]) -> Image:
    """
    Remove specified labels (set to 0) from an image.
    :param img: source image
    :param labels: list of specified labels to remove
    :returns: image with labels specified removed
    """
    out = img.copy()

    for l in labels:
        for x, y, z, v in img.getNonZeroCoordinates():
            if l == v:
                out.data[int(x), int(y), int(z)] = 0.0
                break
        else:
            logger.warning(f"Label {l} not found in input image!")

    return out
def test_segment():
    contrast_test = 't2'
    model_path = os.path.join(sct.__sct_dir__, 'data', 'deepseg_lesion_models', '{}_lesion.h5'.format(contrast_test))

    # create fake data
    data = np.zeros((48,48,96))
    xx, yy = np.mgrid[:48, :48]
    circle = (xx - 24) ** 2 + (yy - 24) ** 2
    for zz in range(data.shape[2]):
        data[:,:,zz] += np.logical_and(circle < 400, circle >= 200) * 2400 # CSF
        data[:,:,zz] += (circle < 200) * 500 # SC
    data[16:22, 16:22, 64:90] = 1000 # fake lesion

    affine = np.eye(4)
    nii = nib.nifti1.Nifti1Image(data, affine)
    img = Image(data, hdr=nii.header, dim=nii.header.get_data_shape())

    seg = deepseg_lesion.segment_3d(model_path, contrast_test, img.copy())

    assert np.any(seg.data[16:22, 16:22, 64:90]) == True  # check if lesion detected
    assert np.any(seg.data[img.data != 1000]) == False  # check if no FP
def dummy_centerline_small(size_arr=(9, 9, 9), subsampling=1, dilate_ctl=0, hasnan=False, orientation='RPI'):
    """
    Create a dummy Image centerline of small size. Return the full and sub-sampled version along z.
    :param size_arr: tuple: (nx, ny, nz)
    :param subsampling: int >=1. Subsampling factor along z. 1: no subsampling. 2: centerline defined every other z.
    :param dilate_ctl: Dilation of centerline. E.g., if dilate_ctl=1, result will be a square of 3x3 per slice.
                         if dilate_ctl=0, result will be a single pixel per slice.
    :param hasnan: Bool: Image has non-numerical values: nan, inf. In this case, do not subsample.
    :param orientation:
    :return:
    """
    from numpy import poly1d, polyfit
    nx, ny, nz = size_arr
    # define polynomial-based centerline within X-Z plane, located at y=ny/4
    x = np.array([round(nx/4.), round(nx/2.), round(3*nx/4.)])
    z = np.array([0, round(nz/2.), nz-1])
    p = poly1d(polyfit(z, x, deg=3))
    data = np.zeros((nx, ny, nz))
    # Loop across dilation of centerline. E.g., if dilate_ctl=1, result will be a square of 3x3 per slice.
    for ixiy_ctl in itertools.product(range(-dilate_ctl, dilate_ctl+1, 1), range(-dilate_ctl, dilate_ctl+1, 1)):
        data[p(range(nz)).astype(np.int) + ixiy_ctl[0], round(ny / 4.) + ixiy_ctl[1], range(nz)] = 1
    # generate Image object with RPI orientation
    affine = np.eye(4)
    nii = nib.nifti1.Nifti1Image(data, affine)
    img = Image(data, hdr=nii.header, dim=nii.header.get_data_shape())
    # subsample data
    img_sub = img.copy()
    img_sub.data = np.zeros((nx, ny, nz))
    for iz in range(0, nz, subsampling):
        img_sub.data[..., iz] = data[..., iz]
    # Add non-numerical values at the top corner of the image
    if hasnan:
        img.data[0, 0, 0] = np.nan
        img.data[1, 0, 0] = np.inf
    # Update orientation
    img.change_orientation(orientation)
    img_sub.change_orientation(orientation)
    return img, img_sub
示例#21
0
def test_segment():
    contrast_test = 't2'
    model_path = os.path.join(sct.__sct_dir__, 'data', 'deepseg_lesion_models',
                              '{}_lesion.h5'.format(contrast_test))

    # create fake data
    data = np.zeros((48, 48, 96))
    xx, yy = np.mgrid[:48, :48]
    circle = (xx - 24)**2 + (yy - 24)**2
    for zz in range(data.shape[2]):
        data[:, :,
             zz] += np.logical_and(circle < 400, circle >= 200) * 2400  # CSF
        data[:, :, zz] += (circle < 200) * 500  # SC
    data[16:22, 16:22, 64:90] = 1000  # fake lesion

    affine = np.eye(4)
    nii = nib.nifti1.Nifti1Image(data, affine)
    img = Image(data, hdr=nii.header, dim=nii.header.get_data_shape())

    seg = deepseg_lesion.segment_3d(model_path, contrast_test, img.copy())

    assert np.any(seg.data[16:22, 16:22,
                           64:90]) == True  # check if lesion detected
    assert np.any(seg.data[img.data != 1000]) == False  # check if no FP
def test_crop_image_around_centerline():
    input_shape = (100, 100, 100)
    crop_size = 20
    crop_size_half = crop_size // 2

    data = np.random.rand(input_shape[0], input_shape[1], input_shape[2])
    affine = np.eye(4)
    nii = nib.nifti1.Nifti1Image(data, affine)
    img = Image(data, hdr=nii.header, dim=nii.header.get_data_shape())

    ctr, _, _ = dummy_centerline(size_arr=input_shape)

    _, _, _, img_out = deepseg_sc.crop_image_around_centerline(im_in=img.copy(),
                                                        ctr_in=ctr.copy(),
                                                        crop_size=crop_size)

    img_in_z0 = img.data[:,:,0]
    x_ctr_z0, y_ctr_z0 = np.where(ctr.data[:,:,0])[0][0], np.where(ctr.data[:,:,0])[1][0]
    x_start, x_end = deepseg_sc._find_crop_start_end(x_ctr_z0, crop_size, img.dim[0])
    y_start, y_end = deepseg_sc._find_crop_start_end(y_ctr_z0, crop_size, img.dim[1])
    img_in_z0_crop = img_in_z0[x_start:x_end, y_start:y_end]

    assert img_out.data.shape == (crop_size, crop_size, input_shape[2])
    assert np.allclose(img_in_z0_crop, img_out.data[:,:,0])
示例#23
0
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    # Initialization
    param = Param()
    start_time = time.time()

    fname_anat = arguments.i
    fname_centerline = arguments.s
    param.algo_fitting = arguments.algo_fitting

    if arguments.smooth is not None:
        sigmas = arguments.smooth
    remove_temp_files = arguments.r
    if arguments.o is not None:
        fname_out = arguments.o
    else:
        fname_out = extract_fname(fname_anat)[1] + '_smooth.nii'

    # Display arguments
    printv('\nCheck input arguments...')
    printv('  Volume to smooth .................. ' + fname_anat)
    printv('  Centerline ........................ ' + fname_centerline)
    printv('  Sigma (mm) ........................ ' + str(sigmas))
    printv('  Verbose ........................... ' + str(verbose))

    # Check that input is 3D:
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_anat).dim
    dim = 4  # by default, will be adjusted later
    if nt == 1:
        dim = 3
    if nz == 1:
        dim = 2
    if dim == 4:
        printv(
            'WARNING: the input image is 4D, please split your image to 3D before smoothing spinalcord using :\n'
            'sct_image -i ' + fname_anat + ' -split t -o ' + fname_anat,
            verbose, 'warning')
        printv('4D images not supported, aborting ...', verbose, 'error')

    # Extract path/file/extension
    path_anat, file_anat, ext_anat = extract_fname(fname_anat)
    path_centerline, file_centerline, ext_centerline = extract_fname(
        fname_centerline)

    path_tmp = tmp_create(basename="smooth_spinalcord")

    # Copying input data to tmp folder
    printv('\nCopying input data to tmp folder and convert to nii...', verbose)
    copy(fname_anat, os.path.join(path_tmp, "anat" + ext_anat))
    copy(fname_centerline, os.path.join(path_tmp,
                                        "centerline" + ext_centerline))

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # convert to nii format
    im_anat = convert(Image('anat' + ext_anat))
    im_anat.save('anat.nii', mutable=True, verbose=verbose)
    im_centerline = convert(Image('centerline' + ext_centerline))
    im_centerline.save('centerline.nii', mutable=True, verbose=verbose)

    # Change orientation of the input image into RPI
    printv('\nOrient input volume to RPI orientation...')

    img_anat_rpi = Image("anat.nii").change_orientation("RPI")
    fname_anat_rpi = add_suffix(img_anat_rpi.absolutepath, "_rpi")
    img_anat_rpi.save(path=fname_anat_rpi, mutable=True)

    # Change orientation of the input image into RPI
    printv('\nOrient centerline to RPI orientation...')

    img_centerline_rpi = Image("centerline.nii").change_orientation("RPI")
    fname_centerline_rpi = add_suffix(img_centerline_rpi.absolutepath, "_rpi")
    img_centerline_rpi.save(path=fname_centerline_rpi, mutable=True)

    # Straighten the spinal cord
    # straighten segmentation
    printv('\nStraighten the spinal cord using centerline/segmentation...',
           verbose)
    cache_sig = cache_signature(
        input_files=[fname_anat_rpi, fname_centerline_rpi],
        input_params={"x": "spline"})
    cachefile = os.path.join(curdir, "straightening.cache")
    if cache_valid(cachefile, cache_sig) and os.path.isfile(
            os.path.join(
                curdir, 'warp_curve2straight.nii.gz')) and os.path.isfile(
                    os.path.join(
                        curdir,
                        'warp_straight2curve.nii.gz')) and os.path.isfile(
                            os.path.join(curdir, 'straight_ref.nii.gz')):
        # if they exist, copy them into current folder
        printv('Reusing existing warping field which seems to be valid',
               verbose, 'warning')
        copy(os.path.join(curdir, 'warp_curve2straight.nii.gz'),
             'warp_curve2straight.nii.gz')
        copy(os.path.join(curdir, 'warp_straight2curve.nii.gz'),
             'warp_straight2curve.nii.gz')
        copy(os.path.join(curdir, 'straight_ref.nii.gz'),
             'straight_ref.nii.gz')
        # apply straightening
        run_proc([
            'sct_apply_transfo', '-i', fname_anat_rpi, '-w',
            'warp_curve2straight.nii.gz', '-d', 'straight_ref.nii.gz', '-o',
            'anat_rpi_straight.nii', '-x', 'spline'
        ], verbose)
    else:
        run_proc([
            'sct_straighten_spinalcord', '-i', fname_anat_rpi, '-o',
            'anat_rpi_straight.nii', '-s', fname_centerline_rpi, '-x',
            'spline', '-param', 'algo_fitting=' + param.algo_fitting
        ], verbose)
        cache_save(cachefile, cache_sig)
        # move warping fields locally (to use caching next time)
        copy('warp_curve2straight.nii.gz',
             os.path.join(curdir, 'warp_curve2straight.nii.gz'))
        copy('warp_straight2curve.nii.gz',
             os.path.join(curdir, 'warp_straight2curve.nii.gz'))

    # Smooth the straightened image along z
    printv('\nSmooth the straightened image...')

    img = Image("anat_rpi_straight.nii")
    out = img.copy()

    if len(sigmas) == 1:
        sigmas = [sigmas[0] for i in range(len(img.data.shape))]
    elif len(sigmas) != len(img.data.shape):
        raise ValueError(
            "-smooth need the same number of inputs as the number of image dimension OR only one input"
        )

    sigmas = [sigmas[i] / img.dim[i + 4] for i in range(3)]
    out.data = smooth(out.data, sigmas)
    out.save(path="anat_rpi_straight_smooth.nii")

    # Apply the reversed warping field to get back the curved spinal cord
    printv(
        '\nApply the reversed warping field to get back the curved spinal cord...'
    )
    run_proc([
        'sct_apply_transfo', '-i', 'anat_rpi_straight_smooth.nii', '-o',
        'anat_rpi_straight_smooth_curved.nii', '-d', 'anat.nii', '-w',
        'warp_straight2curve.nii.gz', '-x', 'spline'
    ], verbose)

    # replace zeroed voxels by original image (issue #937)
    printv('\nReplace zeroed voxels by original image...', verbose)
    nii_smooth = Image('anat_rpi_straight_smooth_curved.nii')
    data_smooth = nii_smooth.data
    data_input = Image('anat.nii').data
    indzero = np.where(data_smooth == 0)
    data_smooth[indzero] = data_input[indzero]
    nii_smooth.data = data_smooth
    nii_smooth.save('anat_rpi_straight_smooth_curved_nonzero.nii')

    # come back
    os.chdir(curdir)

    # Generate output file
    printv('\nGenerate output file...')
    generate_output_file(
        os.path.join(path_tmp, "anat_rpi_straight_smooth_curved_nonzero.nii"),
        fname_out)

    # Remove temporary files
    if remove_temp_files == 1:
        printv('\nRemove temporary files...')
        rmtree(path_tmp)

    # Display elapsed time
    elapsed_time = time.time() - start_time
    printv('\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) +
           's\n')

    display_viewer_syntax([fname_anat, fname_out], verbose=verbose)
示例#24
0
def propseg(img_input, options_dict):
    """
    :param img_input: source image, to be segmented
    :param options_dict: arguments as dictionary
    :return: segmented Image
    """
    arguments = options_dict
    fname_input_data = img_input.absolutepath
    fname_data = os.path.abspath(fname_input_data)
    contrast_type = arguments.c
    contrast_type_conversion = {
        't1': 't1',
        't2': 't2',
        't2s': 't2',
        'dwi': 't1'
    }
    contrast_type_propseg = contrast_type_conversion[contrast_type]

    # Starting building the command
    cmd = ['isct_propseg', '-t', contrast_type_propseg]

    if arguments.o is not None:
        fname_out = arguments.o
    else:
        fname_out = os.path.basename(add_suffix(fname_data, "_seg"))

    folder_output = os.path.dirname(fname_out)
    cmd += ['-o', folder_output]
    if not os.path.isdir(folder_output) and os.path.exists(folder_output):
        logger.error("output directory %s is not a valid directory" %
                     folder_output)
    if not os.path.exists(folder_output):
        os.makedirs(folder_output)

    if arguments.down is not None:
        cmd += ["-down", str(arguments.down)]
    if arguments.up is not None:
        cmd += ["-up", str(arguments.up)]

    remove_temp_files = arguments.r

    verbose = int(arguments.v)
    # Update for propseg binary
    if verbose > 0:
        cmd += ["-verbose"]

    # Output options
    if arguments.mesh is not None:
        cmd += ["-mesh"]
    if arguments.centerline_binary is not None:
        cmd += ["-centerline-binary"]
    if arguments.CSF is not None:
        cmd += ["-CSF"]
    if arguments.centerline_coord is not None:
        cmd += ["-centerline-coord"]
    if arguments.cross is not None:
        cmd += ["-cross"]
    if arguments.init_tube is not None:
        cmd += ["-init-tube"]
    if arguments.low_resolution_mesh is not None:
        cmd += ["-low-resolution-mesh"]
    # TODO: Not present. Why is this here? Was this renamed?
    # if arguments.detect_nii is not None:
    #     cmd += ["-detect-nii"]
    # TODO: Not present. Why is this here? Was this renamed?
    # if arguments.detect_png is not None:
    #     cmd += ["-detect-png"]

    # Helping options
    use_viewer = None
    use_optic = True  # enabled by default
    init_option = None
    rescale_header = arguments.rescale
    if arguments.init is not None:
        init_option = float(arguments.init)
        if init_option < 0:
            printv(
                'Command-line usage error: ' + str(init_option) +
                " is not a valid value for '-init'", 1, 'error')
            sys.exit(1)
    if arguments.init_centerline is not None:
        if str(arguments.init_centerline) == "viewer":
            use_viewer = "centerline"
        elif str(arguments.init_centerline) == "hough":
            use_optic = False
        else:
            if rescale_header is not 1:
                fname_labels_viewer = func_rescale_header(str(
                    arguments.init_centerline),
                                                          rescale_header,
                                                          verbose=verbose)
            else:
                fname_labels_viewer = str(arguments.init_centerline)
            cmd += ["-init-centerline", fname_labels_viewer]
            use_optic = False
    if arguments.init_mask is not None:
        if str(arguments.init_mask) == "viewer":
            use_viewer = "mask"
        else:
            if rescale_header is not 1:
                fname_labels_viewer = func_rescale_header(
                    str(arguments.init_mask), rescale_header)
            else:
                fname_labels_viewer = str(arguments.init_mask)
            cmd += ["-init-mask", fname_labels_viewer]
            use_optic = False
    if arguments.mask_correction is not None:
        cmd += ["-mask-correction", str(arguments.mask_correction)]
    if arguments.radius is not None:
        cmd += ["-radius", str(arguments.radius)]
    # TODO: Not present. Why is this here? Was this renamed?
    # if arguments.detect_n is not None:
    #     cmd += ["-detect-n", str(arguments.detect_n)]
    # TODO: Not present. Why is this here? Was this renamed?
    # if arguments.detect_gap is not None:
    #     cmd += ["-detect-gap", str(arguments.detect_gap)]
    # TODO: Not present. Why is this here? Was this renamed?
    # if arguments.init_validation is not None:
    #     cmd += ["-init-validation"]
    if arguments.nbiter is not None:
        cmd += ["-nbiter", str(arguments.nbiter)]
    if arguments.max_area is not None:
        cmd += ["-max-area", str(arguments.max_area)]
    if arguments.max_deformation is not None:
        cmd += ["-max-deformation", str(arguments.max_deformation)]
    if arguments.min_contrast is not None:
        cmd += ["-min-contrast", str(arguments.min_contrast)]
    if arguments.d is not None:
        cmd += ["-d", str(arguments["-d"])]
    if arguments.distance_search is not None:
        cmd += ["-dsearch", str(arguments.distance_search)]
    if arguments.alpha is not None:
        cmd += ["-alpha", str(arguments.alpha)]

    # check if input image is in 3D. Otherwise itk image reader will cut the 4D image in 3D volumes and only take the first one.
    image_input = Image(fname_data)
    image_input_rpi = image_input.copy().change_orientation('RPI')
    nx, ny, nz, nt, px, py, pz, pt = image_input_rpi.dim
    if nt > 1:
        printv(
            'ERROR: your input image needs to be 3D in order to be segmented.',
            1, 'error')

    path_data, file_data, ext_data = extract_fname(fname_data)
    path_tmp = tmp_create(basename="label_vertebrae")

    # rescale header (see issue #1406)
    if rescale_header is not 1:
        fname_data_propseg = func_rescale_header(fname_data, rescale_header)
    else:
        fname_data_propseg = fname_data

    # add to command
    cmd += ['-i', fname_data_propseg]

    # if centerline or mask is asked using viewer
    if use_viewer:
        from spinalcordtoolbox.gui.base import AnatomicalParams
        from spinalcordtoolbox.gui.centerline import launch_centerline_dialog

        params = AnatomicalParams()
        if use_viewer == 'mask':
            params.num_points = 3
            params.interval_in_mm = 15  # superior-inferior interval between two consecutive labels
            params.starting_slice = 'midfovminusinterval'
        if use_viewer == 'centerline':
            # setting maximum number of points to a reasonable value
            params.num_points = 20
            params.interval_in_mm = 30
            params.starting_slice = 'top'
        im_data = Image(fname_data_propseg)

        im_mask_viewer = zeros_like(im_data)
        # im_mask_viewer.absolutepath = add_suffix(fname_data_propseg, '_labels_viewer')
        controller = launch_centerline_dialog(im_data, im_mask_viewer, params)
        fname_labels_viewer = add_suffix(fname_data_propseg, '_labels_viewer')

        if not controller.saved:
            printv(
                'The viewer has been closed before entering all manual points. Please try again.',
                1, 'error')
            sys.exit(1)
        # save labels
        controller.as_niftii(fname_labels_viewer)

        # add mask filename to parameters string
        if use_viewer == "centerline":
            cmd += ["-init-centerline", fname_labels_viewer]
        elif use_viewer == "mask":
            cmd += ["-init-mask", fname_labels_viewer]

    # If using OptiC
    elif use_optic:
        image_centerline = optic.detect_centerline(image_input, contrast_type,
                                                   verbose)
        fname_centerline_optic = os.path.join(path_tmp,
                                              'centerline_optic.nii.gz')
        image_centerline.save(fname_centerline_optic)
        cmd += ["-init-centerline", fname_centerline_optic]

    if init_option is not None:
        if init_option > 1:
            init_option /= (nz - 1)
        cmd += ['-init', str(init_option)]

    # enabling centerline extraction by default (needed by check_and_correct_segmentation() )
    cmd += ['-centerline-binary']

    # run propseg
    status, output = run_proc(cmd,
                              verbose,
                              raise_exception=False,
                              is_sct_binary=True)

    # check status is not 0
    if not status == 0:
        printv(
            'Automatic cord detection failed. Please initialize using -init-centerline or -init-mask (see help)',
            1, 'error')
        sys.exit(1)

    # build output filename
    fname_seg = os.path.join(folder_output, fname_out)
    fname_centerline = os.path.join(
        folder_output, os.path.basename(add_suffix(fname_data, "_centerline")))
    # in case header was rescaled, we need to update the output file names by removing the "_rescaled"
    if rescale_header is not 1:
        mv(
            os.path.join(
                folder_output,
                add_suffix(os.path.basename(fname_data_propseg), "_seg")),
            fname_seg)
        mv(
            os.path.join(
                folder_output,
                add_suffix(os.path.basename(fname_data_propseg),
                           "_centerline")), fname_centerline)
        # if user was used, copy the labelled points to the output folder (they will then be scaled back)
        if use_viewer:
            fname_labels_viewer_new = os.path.join(
                folder_output,
                os.path.basename(add_suffix(fname_data, "_labels_viewer")))
            copy(fname_labels_viewer, fname_labels_viewer_new)
            # update variable (used later)
            fname_labels_viewer = fname_labels_viewer_new

    # check consistency of segmentation
    if arguments.correct_seg:
        check_and_correct_segmentation(fname_seg,
                                       fname_centerline,
                                       folder_output=folder_output,
                                       threshold_distance=3.0,
                                       remove_temp_files=remove_temp_files,
                                       verbose=verbose)

    # copy header from input to segmentation to make sure qform is the same
    printv("Copy header input --> output(s) to make sure qform is the same.",
           verbose)
    list_fname = [fname_seg, fname_centerline]
    if use_viewer:
        list_fname.append(fname_labels_viewer)
    for fname in list_fname:
        im = Image(fname)
        im.header = image_input.header
        im.save(dtype='int8'
                )  # they are all binary masks hence fine to save as int8

    return Image(fname_seg)
def interpolate_im_to_ref(im_input,
                          im_input_sc,
                          new_res=0.3,
                          sq_size_size_mm=22.5,
                          interpolation_mode=3):
    nx, ny, nz, nt, px, py, pz, pt = im_input.dim

    im_input_sc = im_input_sc.copy()
    im_input = im_input.copy()

    # keep only spacing and origin in qform to avoid rotation issues
    input_qform = im_input.hdr.get_qform()
    for i in range(4):
        for j in range(4):
            if i != j and j != 3:
                input_qform[i, j] = 0

    im_input.hdr.set_qform(input_qform)
    im_input.hdr.set_sform(input_qform)
    im_input_sc.hdr = im_input.hdr

    sq_size = int(sq_size_size_mm / new_res)
    # create a reference image : square of ones
    im_ref = Image(np.ones((sq_size, sq_size, 1), dtype=np.int),
                   dim=(sq_size, sq_size, 1, 0, new_res, new_res, pz, 0),
                   orientation='RPI')

    # copy input qform matrix to reference image
    im_ref.hdr.set_qform(im_input.hdr.get_qform())
    im_ref.hdr.set_sform(im_input.hdr.get_sform())

    # set correct header to reference image
    im_ref.hdr.set_data_shape((sq_size, sq_size, 1))
    im_ref.hdr.set_zooms((new_res, new_res, pz))

    # save image to set orientation to RPI (not properly done at the creation of the image)
    fname_ref = 'im_ref.nii.gz'
    im_ref.save(fname_ref).change_orientation("RPI")

    # set header origin to zero to get physical coordinates of the center of the square
    im_ref.hdr.as_analyze_map()['qoffset_x'] = 0
    im_ref.hdr.as_analyze_map()['qoffset_y'] = 0
    im_ref.hdr.as_analyze_map()['qoffset_z'] = 0
    im_ref.hdr.set_sform(im_ref.hdr.get_qform())
    im_ref.hdr.set_qform(im_ref.hdr.get_qform())
    [[x_square_center_phys, y_square_center_phys,
      z_square_center_phys]] = im_ref.transfo_pix2phys(
          coordi=[[int(sq_size / 2), int(sq_size / 2), 0]])

    list_interpolate_images = []
    # iterate on z dimension of input image
    for iz in range(nz):
        # copy reference image: one reference image per slice
        im_ref_slice_iz = im_ref.copy()

        # get center of mass of SC for slice iz
        x_seg, y_seg = (im_input_sc.data[:, :, iz] > 0).nonzero()
        x_center, y_center = np.mean(x_seg), np.mean(y_seg)
        [[x_center_phys, y_center_phys, z_center_phys]
         ] = im_input_sc.transfo_pix2phys(coordi=[[x_center, y_center, iz]])

        # center reference image on SC for slice iz
        im_ref_slice_iz.hdr.as_analyze_map(
        )['qoffset_x'] = x_center_phys - x_square_center_phys
        im_ref_slice_iz.hdr.as_analyze_map(
        )['qoffset_y'] = y_center_phys - y_square_center_phys
        im_ref_slice_iz.hdr.as_analyze_map()['qoffset_z'] = z_center_phys
        im_ref_slice_iz.hdr.set_sform(im_ref_slice_iz.hdr.get_qform())
        im_ref_slice_iz.hdr.set_qform(im_ref_slice_iz.hdr.get_qform())

        # interpolate input image to reference image
        im_input_interpolate_iz = im_input.interpolate_from_image(
            im_ref_slice_iz,
            interpolation_mode=interpolation_mode,
            border='nearest')
        # reshape data to 2D if needed
        if len(im_input_interpolate_iz.data.shape) == 3:
            im_input_interpolate_iz.data = im_input_interpolate_iz.data.reshape(
                im_input_interpolate_iz.data.shape[:-1])
        # add slice to list
        list_interpolate_images.append(im_input_interpolate_iz)

    return list_interpolate_images
def pre_processing(fname_target,
                   fname_sc_seg,
                   fname_level=None,
                   fname_manual_gmseg=None,
                   new_res=0.3,
                   square_size_size_mm=22.5,
                   denoising=True,
                   verbose=1,
                   rm_tmp=True,
                   for_model=False):
    printv('\nPre-process data...', verbose, 'normal')

    tmp_dir = tmp_create()

    copy(fname_target, tmp_dir)
    fname_target = ''.join(extract_fname(fname_target)[1:])
    copy(fname_sc_seg, tmp_dir)
    fname_sc_seg = ''.join(extract_fname(fname_sc_seg)[1:])

    curdir = os.getcwd()
    os.chdir(tmp_dir)

    original_info = {
        'orientation': None,
        'im_sc_seg_rpi': None,
        'interpolated_images': []
    }

    im_target = Image(fname_target).copy()
    im_sc_seg = Image(fname_sc_seg).copy()

    # get original orientation
    printv('  Reorient...', verbose, 'normal')
    original_info['orientation'] = im_target.orientation

    # assert images are in the same orientation
    assert im_target.orientation == im_sc_seg.orientation, "ERROR: the image to segment and it's SC segmentation are not in the same orientation"

    im_target_rpi = im_target.copy().change_orientation(
        'RPI', generate_path=True).save()
    im_sc_seg_rpi = im_sc_seg.copy().change_orientation(
        'RPI', generate_path=True).save()
    original_info['im_sc_seg_rpi'] = im_sc_seg_rpi.copy(
    )  # target image in RPI will be used to post-process segmentations

    # denoise using P. Coupe non local means algorithm (see [Manjon et al. JMRI 2010]) implemented in dipy
    if denoising:
        printv('  Denoise...', verbose, 'normal')
        # crop image before denoising to fasten denoising
        nx, ny, nz, nt, px, py, pz, pt = im_target_rpi.dim
        size_x, size_y = (square_size_size_mm + 1) / px, (square_size_size_mm +
                                                          1) / py
        size = int(np.ceil(max(size_x, size_y)))
        # create mask
        fname_mask = 'mask_pre_crop.nii.gz'
        sct_create_mask.main([
            '-i', im_target_rpi.absolutepath, '-p',
            'centerline,' + im_sc_seg_rpi.absolutepath, '-f', 'box', '-size',
            str(size), '-o', fname_mask
        ])
        # crop image
        cropper = ImageCropper(im_target_rpi)
        cropper.get_bbox_from_mask(Image(fname_mask))
        im_target_rpi_crop = cropper.crop()
        # crop segmentation
        cropper = ImageCropper(im_sc_seg_rpi)
        cropper.get_bbox_from_mask(Image(fname_mask))
        im_sc_seg_rpi_crop = cropper.crop()
        # denoising
        from spinalcordtoolbox.math import denoise_nlmeans
        block_radius = 3
        block_radius = int(
            im_target_rpi_crop.data.shape[2] /
            2) if im_target_rpi_crop.data.shape[2] < (block_radius *
                                                      2) else block_radius
        patch_radius = block_radius - 1
        data_denoised = denoise_nlmeans(im_target_rpi_crop.data,
                                        block_radius=block_radius,
                                        patch_radius=patch_radius)
        im_target_rpi_crop.data = data_denoised

        im_target_rpi = im_target_rpi_crop
        im_sc_seg_rpi = im_sc_seg_rpi_crop
    else:
        fname_mask = None

    # interpolate image to reference square image (resample and square crop centered on SC)
    printv('  Interpolate data to the model space...', verbose, 'normal')
    list_im_slices = interpolate_im_to_ref(im_target_rpi,
                                           im_sc_seg_rpi,
                                           new_res=new_res,
                                           sq_size_size_mm=square_size_size_mm)
    original_info[
        'interpolated_images'] = list_im_slices  # list of images (not Slice() objects)

    printv('  Mask data using the spinal cord segmentation...', verbose,
           'normal')
    list_sc_seg_slices = interpolate_im_to_ref(
        im_sc_seg_rpi,
        im_sc_seg_rpi,
        new_res=new_res,
        sq_size_size_mm=square_size_size_mm,
        interpolation_mode=1)
    for i in range(len(list_im_slices)):
        # list_im_slices[i].data[list_sc_seg_slices[i].data == 0] = 0
        list_sc_seg_slices[i] = binarize(list_sc_seg_slices[i],
                                         thr_min=0.5,
                                         thr_max=1)
        list_im_slices[
            i].data = list_im_slices[i].data * list_sc_seg_slices[i].data

    printv('  Split along rostro-caudal direction...', verbose, 'normal')
    list_slices_target = [
        Slice(slice_id=i, im=im_slice.data, gm_seg=[], wm_seg=[])
        for i, im_slice in enumerate(list_im_slices)
    ]

    # load vertebral levels
    if fname_level is not None:
        printv('  Load vertebral levels...', verbose, 'normal')
        # copy level file to tmp dir
        os.chdir(curdir)
        copy(fname_level, tmp_dir)
        os.chdir(tmp_dir)
        # change fname level to only file name (path = tmp dir now)
        fname_level = ''.join(extract_fname(fname_level)[1:])
        # load levels
        list_slices_target = load_level(list_slices_target, fname_level)

    os.chdir(curdir)

    # load manual gmseg if there is one (model data)
    if fname_manual_gmseg is not None:
        printv('\n\tLoad manual GM segmentation(s) ...', verbose, 'normal')
        list_slices_target = load_manual_gmseg(list_slices_target,
                                               fname_manual_gmseg,
                                               tmp_dir,
                                               im_sc_seg_rpi,
                                               new_res,
                                               square_size_size_mm,
                                               for_model=for_model,
                                               fname_mask=fname_mask)

    if rm_tmp:
        # remove tmp folder
        rmtree(tmp_dir)
    return list_slices_target, original_info
def main():
    """Main function."""
    sct.init_sct()
    parser = get_parser()
    args = sys.argv[1:]
    arguments = parser.parse(args)

    fname_image = os.path.abspath(arguments['-i'])
    contrast_type = arguments['-c']

    ctr_algo = arguments["-centerline"]

    if "-brain" not in args:
        if contrast_type in ['t2s', 'dwi']:
            brain_bool = False
        if contrast_type in ['t1', 't2']:
            brain_bool = True
    else:
        brain_bool = bool(int(arguments["-brain"]))

    kernel_size = arguments["-kernel"]
    if kernel_size == '3d' and contrast_type == 'dwi':
        kernel_size = '2d'
        sct.printv('3D kernel model for dwi contrast is not available. 2D kernel model is used instead.', type="warning")

    if '-ofolder' not in args:
        output_folder = os.getcwd()
    else:
        output_folder = arguments["-ofolder"]

    if ctr_algo == 'file' and "-file_centerline" not in args:
        logger.warning('Please use the flag -file_centerline to indicate the centerline filename.')
        sys.exit(1)
    
    if "-file_centerline" in args:
        manual_centerline_fname = arguments["-file_centerline"]
        ctr_algo = 'file'
    else:
        manual_centerline_fname = None

    remove_temp_files = int(arguments['-r'])

    verbose = int(arguments.get('-v'))
    sct.init_sct(log_level=verbose, update=True)  # Update log level

    path_qc = arguments.get("-qc", None)
    qc_dataset = arguments.get("-qc-dataset", None)
    qc_subject = arguments.get("-qc-subject", None)

    algo_config_stg = '\nMethod:'
    algo_config_stg += '\n\tCenterline algorithm: ' + str(ctr_algo)
    algo_config_stg += '\n\tAssumes brain section included in the image: ' + str(brain_bool)
    algo_config_stg += '\n\tDimension of the segmentation kernel convolutions: ' + kernel_size + '\n'
    sct.printv(algo_config_stg)

    im_image = Image(fname_image)
    # note: below we pass im_image.copy() otherwise the field absolutepath becomes None after execution of this function
    im_seg, im_image_RPI_upsamp, im_seg_RPI_upsamp, im_labels_viewer, im_ctr = deep_segmentation_spinalcord(
        im_image.copy(), contrast_type, ctr_algo=ctr_algo, ctr_file=manual_centerline_fname,
        brain_bool=brain_bool, kernel_size=kernel_size, remove_temp_files=remove_temp_files, verbose=verbose)

    # Save segmentation
    fname_seg = os.path.abspath(os.path.join(output_folder, sct.extract_fname(fname_image)[1] + '_seg' +
                                             sct.extract_fname(fname_image)[2]))
    im_seg.save(fname_seg)

    if ctr_algo == 'viewer':
        # Save labels
        fname_labels = os.path.abspath(os.path.join(output_folder, sct.extract_fname(fname_image)[1] + '_labels-centerline' +
                                               sct.extract_fname(fname_image)[2]))
        im_labels_viewer.save(fname_labels)

    if verbose == 2:
        # Save ctr
        fname_ctr = os.path.abspath(os.path.join(output_folder, sct.extract_fname(fname_image)[1] + '_centerline' +
                                               sct.extract_fname(fname_image)[2]))
        im_ctr.save(fname_ctr)

    if path_qc is not None:
        generate_qc(fname_image, fname_seg=fname_seg, args=args, path_qc=os.path.abspath(path_qc),
                    dataset=qc_dataset, subject=qc_subject, process='sct_deepseg_sc')
    sct.display_viewer_syntax([fname_image, fname_seg], colormaps=['gray', 'red'], opacities=['', '0.7'])
def main():
    """Main function."""
    sct.init_sct()
    parser = get_parser()
    args = sys.argv[1:]
    arguments = parser.parse(args)

    fname_image = os.path.abspath(arguments['-i'])
    contrast_type = arguments['-c']

    ctr_algo = arguments["-centerline"]

    if "-brain" not in args:
        if contrast_type in ['t2s', 'dwi']:
            brain_bool = False
        if contrast_type in ['t1', 't2']:
            brain_bool = True
    else:
        brain_bool = bool(int(arguments["-brain"]))

    kernel_size = arguments["-kernel"]
    if kernel_size == '3d' and contrast_type == 'dwi':
        kernel_size = '2d'
        sct.printv('3D kernel model for dwi contrast is not available. 2D kernel model is used instead.', type="warning")

    if '-ofolder' not in args:
        output_folder = os.getcwd()
    else:
        output_folder = arguments["-ofolder"]

    if ctr_algo == 'file' and "-file_centerline" not in args:
        sct.log.warning('Please use the flag -file_centerline to indicate the centerline filename.')
        sys.exit(1)
    
    if "-file_centerline" in args:
        manual_centerline_fname = arguments["-file_centerline"]
        ctr_algo = 'file'
    else:
        manual_centerline_fname = None

    remove_temp_files = int(arguments['-r'])

    verbose = int(arguments['-v'])

    path_qc = arguments.get("-qc", None)

    algo_config_stg = '\nMethod:'
    algo_config_stg += '\n\tCenterline algorithm: ' + str(ctr_algo)
    algo_config_stg += '\n\tAssumes brain section included in the image: ' + str(brain_bool)
    algo_config_stg += '\n\tDimension of the segmentation kernel convolutions: ' + kernel_size + '\n'
    sct.printv(algo_config_stg)

    im_image = Image(fname_image)
    # note: below we pass im_image.copy() otherwise the field absolutepath becomes None after execution of this function
    im_seg, im_image_RPI_upsamp, im_seg_RPI_upsamp, im_labels_viewer, im_ctr = deep_segmentation_spinalcord(
        im_image.copy(), contrast_type, ctr_algo=ctr_algo, ctr_file=manual_centerline_fname,
        brain_bool=brain_bool, kernel_size=kernel_size, remove_temp_files=remove_temp_files, verbose=verbose)

    # Save segmentation
    fname_seg = os.path.abspath(os.path.join(output_folder, sct.extract_fname(fname_image)[1] + '_seg' +
                                             sct.extract_fname(fname_image)[2]))
    im_seg.save(fname_seg)

    if ctr_algo == 'viewer':
        # Save labels
        fname_labels = os.path.abspath(os.path.join(output_folder, sct.extract_fname(fname_image)[1] + '_labels-centerline' +
                                               sct.extract_fname(fname_image)[2]))
        im_labels_viewer.save(fname_labels)

    if verbose == 2:
        # Save ctr
        fname_ctr = os.path.abspath(os.path.join(output_folder, sct.extract_fname(fname_image)[1] + '_centerline' +
                                               sct.extract_fname(fname_image)[2]))
        im_ctr.save(fname_ctr)

    if path_qc is not None:
        generate_qc(fname_image, fname_seg=fname_seg, args=args, path_qc=os.path.abspath(path_qc),
                    process='sct_deepseg_sc')
    sct.display_viewer_syntax([fname_image, fname_seg], colormaps=['gray', 'red'], opacities=['', '0.7'])
def main():
    """Main function."""
    parser = get_parser()
    args = parser.parse_args(args=None if sys.argv[1:] else ['--help'])

    fname_image = os.path.abspath(args.i)
    contrast_type = args.c

    ctr_algo = args.centerline

    if args.brain is None:
        if contrast_type in ['t2s', 'dwi']:
            brain_bool = False
        if contrast_type in ['t1', 't2']:
            brain_bool = True
    else:
        brain_bool = bool(args.brain)

    kernel_size = args.kernel
    if kernel_size == '3d' and contrast_type == 'dwi':
        kernel_size = '2d'
        sct.printv('3D kernel model for dwi contrast is not available. 2D kernel model is used instead.',
                   type="warning")


    if ctr_algo == 'file' and args.file_centerline is None:
        sct.printv('Please use the flag -file_centerline to indicate the centerline filename.', 1, 'warning')
        sys.exit(1)

    if args.file_centerline is not None:
        manual_centerline_fname = args.file_centerline
        ctr_algo = 'file'
    else:
        manual_centerline_fname = None

    remove_temp_files = args.r
    verbose = args.v
    sct.init_sct(log_level=verbose, update=True)  # Update log level

    path_qc = args.qc
    qc_dataset = args.qc_dataset
    qc_subject = args.qc_subject
    output_folder = args.ofolder

    algo_config_stg = '\nMethod:'
    algo_config_stg += '\n\tCenterline algorithm: ' + str(ctr_algo)
    algo_config_stg += '\n\tAssumes brain section included in the image: ' + str(brain_bool)
    algo_config_stg += '\n\tDimension of the segmentation kernel convolutions: ' + kernel_size + '\n'
    sct.printv(algo_config_stg)

    # Segment image
    from spinalcordtoolbox.image import Image
    from spinalcordtoolbox.deepseg_sc.core import deep_segmentation_spinalcord
    from spinalcordtoolbox.reports.qc import generate_qc

    im_image = Image(fname_image)
    # note: below we pass im_image.copy() otherwise the field absolutepath becomes None after execution of this function
    im_seg, im_image_RPI_upsamp, im_seg_RPI_upsamp, im_labels_viewer, im_ctr = \
        deep_segmentation_spinalcord(im_image.copy(), contrast_type, ctr_algo=ctr_algo,
                                     ctr_file=manual_centerline_fname, brain_bool=brain_bool, kernel_size=kernel_size,
                                     remove_temp_files=remove_temp_files, verbose=verbose)

    # Save segmentation
    fname_seg = os.path.abspath(os.path.join(output_folder, sct.extract_fname(fname_image)[1] + '_seg' +
                                             sct.extract_fname(fname_image)[2]))

    # copy q/sform from input image to output segmentation
    im_seg.copy_qform_from_ref(im_image)
    im_seg.save(fname_seg)

    if ctr_algo == 'viewer':
        # Save labels
        fname_labels = os.path.abspath(os.path.join(output_folder, sct.extract_fname(fname_image)[1] + '_labels-centerline' +
                                               sct.extract_fname(fname_image)[2]))
        im_labels_viewer.save(fname_labels)

    if verbose == 2:
        # Save ctr
        fname_ctr = os.path.abspath(os.path.join(output_folder, sct.extract_fname(fname_image)[1] + '_centerline' +
                                               sct.extract_fname(fname_image)[2]))
        im_ctr.save(fname_ctr)

    if path_qc is not None:
        generate_qc(fname_image, fname_seg=fname_seg, args=sys.argv[1:], path_qc=os.path.abspath(path_qc),
    dataset=qc_dataset, subject=qc_subject, process='sct_deepseg_sc')
    sct.display_viewer_syntax([fname_image, fname_seg], colormaps=['gray', 'red'], opacities=['', '0.7'])
示例#30
0
def propseg(img_input, options_dict):
    """
    :param img_input: source image, to be segmented
    :param options_dict: arguments as dictionary
    :return: segmented Image
    """
    arguments = options_dict
    fname_input_data = img_input.absolutepath
    fname_data = os.path.abspath(fname_input_data)
    contrast_type = arguments["-c"]
    contrast_type_conversion = {'t1': 't1', 't2': 't2', 't2s': 't2', 'dwi': 't1'}
    contrast_type_propseg = contrast_type_conversion[contrast_type]

    # Starting building the command
    cmd = ['isct_propseg', '-t', contrast_type_propseg]

    if "-ofolder" in arguments:
        folder_output = arguments["-ofolder"]
    else:
        folder_output = './'
    cmd += ['-o', folder_output]
    if not os.path.isdir(folder_output) and os.path.exists(folder_output):
        logger.error("output directory %s is not a valid directory" % folder_output)
    if not os.path.exists(folder_output):
        os.makedirs(folder_output)

    if "-down" in arguments:
        cmd += ["-down", str(arguments["-down"])]
    if "-up" in arguments:
        cmd += ["-up", str(arguments["-up"])]

    remove_temp_files = 1
    if "-r" in arguments:
        remove_temp_files = int(arguments["-r"])

    verbose = int(arguments.get('-v'))
    sct.init_sct(log_level=verbose, update=True)  # Update log level
    # Update for propseg binary
    if verbose > 0:
        cmd += ["-verbose"]

    # Output options
    if "-mesh" in arguments:
        cmd += ["-mesh"]
    if "-centerline-binary" in arguments:
        cmd += ["-centerline-binary"]
    if "-CSF" in arguments:
        cmd += ["-CSF"]
    if "-centerline-coord" in arguments:
        cmd += ["-centerline-coord"]
    if "-cross" in arguments:
        cmd += ["-cross"]
    if "-init-tube" in arguments:
        cmd += ["-init-tube"]
    if "-low-resolution-mesh" in arguments:
        cmd += ["-low-resolution-mesh"]
    if "-detect-nii" in arguments:
        cmd += ["-detect-nii"]
    if "-detect-png" in arguments:
        cmd += ["-detect-png"]

    # Helping options
    use_viewer = None
    use_optic = True  # enabled by default
    init_option = None
    rescale_header = arguments["-rescale"]
    if "-init" in arguments:
        init_option = float(arguments["-init"])
        if init_option < 0:
            sct.printv('Command-line usage error: ' + str(init_option) + " is not a valid value for '-init'", 1, 'error')
            sys.exit(1)
    if "-init-centerline" in arguments:
        if str(arguments["-init-centerline"]) == "viewer":
            use_viewer = "centerline"
        elif str(arguments["-init-centerline"]) == "hough":
            use_optic = False
        else:
            if rescale_header is not 1:
                fname_labels_viewer = func_rescale_header(str(arguments["-init-centerline"]), rescale_header, verbose=verbose)
            else:
                fname_labels_viewer = str(arguments["-init-centerline"])
            cmd += ["-init-centerline", fname_labels_viewer]
            use_optic = False
    if "-init-mask" in arguments:
        if str(arguments["-init-mask"]) == "viewer":
            use_viewer = "mask"
        else:
            if rescale_header is not 1:
                fname_labels_viewer = func_rescale_header(str(arguments["-init-mask"]), rescale_header)
            else:
                fname_labels_viewer = str(arguments["-init-mask"])
            cmd += ["-init-mask", fname_labels_viewer]
            use_optic = False
    if "-mask-correction" in arguments:
        cmd += ["-mask-correction", str(arguments["-mask-correction"])]
    if "-radius" in arguments:
        cmd += ["-radius", str(arguments["-radius"])]
    if "-detect-n" in arguments:
        cmd += ["-detect-n", str(arguments["-detect-n"])]
    if "-detect-gap" in arguments:
        cmd += ["-detect-gap", str(arguments["-detect-gap"])]
    if "-init-validation" in arguments:
        cmd += ["-init-validation"]
    if "-nbiter" in arguments:
        cmd += ["-nbiter", str(arguments["-nbiter"])]
    if "-max-area" in arguments:
        cmd += ["-max-area", str(arguments["-max-area"])]
    if "-max-deformation" in arguments:
        cmd += ["-max-deformation", str(arguments["-max-deformation"])]
    if "-min-contrast" in arguments:
        cmd += ["-min-contrast", str(arguments["-min-contrast"])]
    if "-d" in arguments:
        cmd += ["-d", str(arguments["-d"])]
    if "-distance-search" in arguments:
        cmd += ["-dsearch", str(arguments["-distance-search"])]
    if "-alpha" in arguments:
        cmd += ["-alpha", str(arguments["-alpha"])]

    # check if input image is in 3D. Otherwise itk image reader will cut the 4D image in 3D volumes and only take the first one.
    image_input = Image(fname_data)
    image_input_rpi = image_input.copy().change_orientation('RPI')
    nx, ny, nz, nt, px, py, pz, pt = image_input_rpi.dim
    if nt > 1:
        sct.printv('ERROR: your input image needs to be 3D in order to be segmented.', 1, 'error')

    path_data, file_data, ext_data = sct.extract_fname(fname_data)
    path_tmp = sct.tmp_create(basename="label_vertebrae", verbose=verbose)

    # rescale header (see issue #1406)
    if rescale_header is not 1:
        fname_data_propseg = func_rescale_header(fname_data, rescale_header)
    else:
        fname_data_propseg = fname_data

    # add to command
    cmd += ['-i', fname_data_propseg]

    # if centerline or mask is asked using viewer
    if use_viewer:
        from spinalcordtoolbox.gui.base import AnatomicalParams
        from spinalcordtoolbox.gui.centerline import launch_centerline_dialog

        params = AnatomicalParams()
        if use_viewer == 'mask':
            params.num_points = 3
            params.interval_in_mm = 15  # superior-inferior interval between two consecutive labels
            params.starting_slice = 'midfovminusinterval'
        if use_viewer == 'centerline':
            # setting maximum number of points to a reasonable value
            params.num_points = 20
            params.interval_in_mm = 30
            params.starting_slice = 'top'
        im_data = Image(fname_data_propseg)

        im_mask_viewer = msct_image.zeros_like(im_data)
        # im_mask_viewer.absolutepath = sct.add_suffix(fname_data_propseg, '_labels_viewer')
        controller = launch_centerline_dialog(im_data, im_mask_viewer, params)
        fname_labels_viewer = sct.add_suffix(fname_data_propseg, '_labels_viewer')

        if not controller.saved:
            sct.printv('The viewer has been closed before entering all manual points. Please try again.', 1, 'error')
            sys.exit(1)
        # save labels
        controller.as_niftii(fname_labels_viewer)

        # add mask filename to parameters string
        if use_viewer == "centerline":
            cmd += ["-init-centerline", fname_labels_viewer]
        elif use_viewer == "mask":
            cmd += ["-init-mask", fname_labels_viewer]

    # If using OptiC
    elif use_optic:
        image_centerline = optic.detect_centerline(image_input, contrast_type, verbose)
        fname_centerline_optic = os.path.join(path_tmp, 'centerline_optic.nii.gz')
        image_centerline.save(fname_centerline_optic)
        cmd += ["-init-centerline", fname_centerline_optic]

    if init_option is not None:
        if init_option > 1:
            init_option /= (nz - 1)
        cmd += ['-init', str(init_option)]

    # enabling centerline extraction by default (needed by check_and_correct_segmentation() )
    cmd += ['-centerline-binary']

    # run propseg
    status, output = sct.run(cmd, verbose, raise_exception=False, is_sct_binary=True)

    # check status is not 0
    if not status == 0:
        sct.printv('Automatic cord detection failed. Please initialize using -init-centerline or -init-mask (see help)',
                   1, 'error')
        sys.exit(1)

    # build output filename
    fname_seg = os.path.join(folder_output, os.path.basename(sct.add_suffix(fname_data, "_seg")))
    fname_centerline = os.path.join(folder_output, os.path.basename(sct.add_suffix(fname_data, "_centerline")))
    # in case header was rescaled, we need to update the output file names by removing the "_rescaled"
    if rescale_header is not 1:
        sct.mv(os.path.join(folder_output, sct.add_suffix(os.path.basename(fname_data_propseg), "_seg")),
                  fname_seg)
        sct.mv(os.path.join(folder_output, sct.add_suffix(os.path.basename(fname_data_propseg), "_centerline")),
                  fname_centerline)
        # if user was used, copy the labelled points to the output folder (they will then be scaled back)
        if use_viewer:
            fname_labels_viewer_new = os.path.join(folder_output, os.path.basename(sct.add_suffix(fname_data,
                                                                                                  "_labels_viewer")))
            sct.copy(fname_labels_viewer, fname_labels_viewer_new)
            # update variable (used later)
            fname_labels_viewer = fname_labels_viewer_new

    # check consistency of segmentation
    if arguments["-correct-seg"] == "1":
        check_and_correct_segmentation(fname_seg, fname_centerline, folder_output=folder_output, threshold_distance=3.0,
                                       remove_temp_files=remove_temp_files, verbose=verbose)

    # copy header from input to segmentation to make sure qform is the same
    sct.printv("Copy header input --> output(s) to make sure qform is the same.", verbose)
    list_fname = [fname_seg, fname_centerline]
    if use_viewer:
        list_fname.append(fname_labels_viewer)
    for fname in list_fname:
        im = Image(fname)
        im.header = image_input.header
        im.save(dtype='int8')  # they are all binary masks hence fine to save as int8

    return Image(fname_seg)
示例#31
0
def main(args=None):
    """
    Main function
    :param args:
    :return:
    """
    # initializations
    output_type = None
    dim_list = ['x', 'y', 'z', 't']

    # Get parser args
    if args is None:
        args = None if sys.argv[1:] else ['--help']
    parser = get_parser()
    arguments = parser.parse_args(args=args)
    fname_in = arguments.i
    n_in = len(fname_in)
    verbose = arguments.v
    sct.init_sct(log_level=verbose, update=True)  # Update log level

    if arguments.o is not None:
        fname_out = arguments.o
    else:
        fname_out = None

    # Run command
    # Arguments are sorted alphabetically (not according to the usage order)
    if arguments.concat is not None:
        dim = arguments.concat
        assert dim in dim_list
        dim = dim_list.index(dim)
        im_out = [concat_data(fname_in, dim)]  # TODO: adapt to fname_in

    elif arguments.copy_header is not None:
        im_in = Image(fname_in[0])
        im_dest = Image(arguments.copy_header)
        im_dest_new = im_in.copy()
        im_dest_new.data = im_dest.data.copy()
        # im_dest.header = im_in.header
        im_dest_new.absolutepath = im_dest.absolutepath
        im_out = [im_dest_new]
        fname_out = arguments.copy_header

    elif arguments.display_warp:
        im_in = fname_in[0]
        visualize_warp(im_in, fname_grid=None, step=3, rm_tmp=True)
        im_out = None

    elif arguments.getorient:
        im_in = Image(fname_in[0])
        orient = im_in.orientation
        im_out = None

    elif arguments.keep_vol is not None:
        index_vol = (arguments.keep_vol).split(',')
        for iindex_vol, vol in enumerate(index_vol):
            index_vol[iindex_vol] = int(vol)
        im_in = Image(fname_in[0])
        im_out = [remove_vol(im_in, index_vol, todo='keep')]

    elif arguments.mcs:
        im_in = Image(fname_in[0])
        if n_in != 1:
            sct.printv(parser.error('ERROR: -mcs need only one input'))
        if len(im_in.data.shape) != 5:
            sct.printv(
                parser.error(
                    'ERROR: -mcs input need to be a multi-component image'))
        im_out = multicomponent_split(im_in)

    elif arguments.omc:
        im_ref = Image(fname_in[0])
        for fname in fname_in:
            im = Image(fname)
            if im.data.shape != im_ref.data.shape:
                sct.printv(
                    parser.error(
                        'ERROR: -omc inputs need to have all the same shapes'))
            del im
        im_out = [multicomponent_merge(fname_in)]  # TODO: adapt to fname_in

    elif arguments.pad is not None:
        im_in = Image(fname_in[0])
        ndims = len(im_in.data.shape)
        if ndims != 3:
            sct.printv('ERROR: you need to specify a 3D input file.', 1,
                       'error')
            return

        pad_arguments = arguments.pad.split(',')
        if len(pad_arguments) != 3:
            sct.printv('ERROR: you need to specify 3 padding values.', 1,
                       'error')

        padx, pady, padz = pad_arguments
        padx, pady, padz = int(padx), int(pady), int(padz)
        im_out = [
            pad_image(im_in,
                      pad_x_i=padx,
                      pad_x_f=padx,
                      pad_y_i=pady,
                      pad_y_f=pady,
                      pad_z_i=padz,
                      pad_z_f=padz)
        ]

    elif arguments.pad_asym is not None:
        im_in = Image(fname_in[0])
        ndims = len(im_in.data.shape)
        if ndims != 3:
            sct.printv('ERROR: you need to specify a 3D input file.', 1,
                       'error')
            return

        pad_arguments = arguments.pad_asym.split(',')
        if len(pad_arguments) != 6:
            sct.printv('ERROR: you need to specify 6 padding values.', 1,
                       'error')

        padxi, padxf, padyi, padyf, padzi, padzf = pad_arguments
        padxi, padxf, padyi, padyf, padzi, padzf = int(padxi), int(padxf), int(
            padyi), int(padyf), int(padzi), int(padzf)
        im_out = [
            pad_image(im_in,
                      pad_x_i=padxi,
                      pad_x_f=padxf,
                      pad_y_i=padyi,
                      pad_y_f=padyf,
                      pad_z_i=padzi,
                      pad_z_f=padzf)
        ]

    elif arguments.remove_vol is not None:
        index_vol = (arguments.remove_vol).split(',')
        for iindex_vol, vol in enumerate(index_vol):
            index_vol[iindex_vol] = int(vol)
        im_in = Image(fname_in[0])
        im_out = [remove_vol(im_in, index_vol, todo='remove')]

    elif arguments.setorient is not None:
        sct.printv(fname_in[0])
        im_in = Image(fname_in[0])
        im_out = [msct_image.change_orientation(im_in, arguments.setorient)]

    elif arguments.setorient_data is not None:
        im_in = Image(fname_in[0])
        im_out = [
            msct_image.change_orientation(im_in,
                                          arguments.setorient_data,
                                          data_only=True)
        ]

    elif arguments.split is not None:
        dim = arguments.split
        assert dim in dim_list
        im_in = Image(fname_in[0])
        dim = dim_list.index(dim)
        im_out = split_data(im_in, dim)

    elif arguments.type is not None:
        output_type = arguments.type
        im_in = Image(fname_in[0])
        im_out = [im_in]  # TODO: adapt to fname_in

    elif arguments.to_fsl is not None:
        space_files = arguments.to_fsl
        if len(space_files) > 2 or len(space_files) < 1:
            sct.printv(parser.error('ERROR: -to-fsl expects 1 or 2 arguments'))
            return
        spaces = [Image(s) for s in space_files]
        if len(spaces) < 2:
            spaces.append(None)
        im_out = [
            displacement_to_abs_fsl(Image(fname_in[0]), spaces[0], spaces[1])
        ]

    else:
        im_out = None
        sct.printv(
            parser.error(
                'ERROR: you need to specify an operation to do on the input image'
            ))

    # in case fname_out is not defined, use first element of input file name list
    if fname_out is None:
        fname_out = fname_in[0]

    # Write output
    if im_out is not None:
        sct.printv('Generate output files...', verbose)
        # if only one output
        if len(im_out) == 1 and not '-split' in arguments:
            im_out[0].save(fname_out, dtype=output_type, verbose=verbose)
            sct.display_viewer_syntax([fname_out], verbose=verbose)
        if arguments.mcs:
            # use input file name and add _X, _Y _Z. Keep the same extension
            l_fname_out = []
            for i_dim in range(3):
                l_fname_out.append(
                    sct.add_suffix(fname_out or fname_in[0],
                                   '_' + dim_list[i_dim].upper()))
                im_out[i_dim].save(l_fname_out[i_dim], verbose=verbose)
            sct.display_viewer_syntax(fname_out)
        if arguments.split is not None:
            # use input file name and add _"DIM+NUMBER". Keep the same extension
            l_fname_out = []
            for i, im in enumerate(im_out):
                l_fname_out.append(
                    sct.add_suffix(
                        fname_out or fname_in[0],
                        '_' + dim_list[dim].upper() + str(i).zfill(4)))
                im.save(l_fname_out[i])
            sct.display_viewer_syntax(l_fname_out)

    elif arguments.getorient:
        sct.printv(orient)

    elif arguments.display_warp:
        sct.printv('Warping grid generated.', verbose, 'info')
class ProcessLabels(object):
    def __init__(self, fname_label, fname_output=None, fname_ref=None, cross_radius=5, dilate=False,
                 coordinates=None, verbose=1, vertebral_levels=None, value=None, msg=""):
        """
        Collection of processes that deal with label creation/modification.
        :param fname_label:
        :param fname_output:
        :param fname_ref:
        :param cross_radius:
        :param dilate:
        :param coordinates:
        :param verbose:
        :param vertebral_levels:
        :param value:
        :param msg: string. message to display to the user.
        """
        self.image_input = Image(fname_label, verbose=verbose)
        self.image_ref = None
        if fname_ref is not None:
            self.image_ref = Image(fname_ref, verbose=verbose)

        if isinstance(fname_output, list):
            if len(fname_output) == 1:
                self.fname_output = fname_output[0]
            else:
                self.fname_output = fname_output
        else:
            self.fname_output = fname_output
        self.cross_radius = cross_radius
        self.vertebral_levels = vertebral_levels
        self.dilate = dilate
        self.coordinates = coordinates
        self.verbose = verbose
        self.value = value
        self.msg = msg

    def process(self, type_process):
        # for some processes, change orientation of input image to RPI
        change_orientation = False
        if type_process in ['vert-body', 'vert-disc', 'vert-continuous']:
            # get orientation of input image
            input_orientation = self.image_input.orientation
            # change orientation
            self.image_input.change_orientation('RPI')
            change_orientation = True
        if type_process == 'add':
            self.output_image = self.add(self.value)
        if type_process == 'plan':
            self.output_image = self.plan(self.cross_radius, 100, 5)
        if type_process == 'plan_ref':
            self.output_image = self.plan_ref()
        if type_process == 'increment':
            self.output_image = self.increment_z_inverse()
        if type_process == 'disks':
            self.output_image = self.labelize_from_disks()
        if type_process == 'MSE':
            self.MSE()
            self.fname_output = None
        if type_process == 'remove':
            self.output_image = self.remove_label()
        if type_process == 'remove-symm':
            self.output_image = self.remove_label(symmetry=True)
        if type_process == 'create':
            self.output_image = self.create_label()
        if type_process == 'create-add':
            self.output_image = self.create_label(add=True)
        if type_process == 'create-seg':
            self.output_image = self.create_label_along_segmentation()
        if type_process == 'display-voxel':
            self.display_voxel()
            self.fname_output = None
        if type_process == 'diff':
            self.diff()
            self.fname_output = None
        if type_process == 'dist-inter':  # second argument is in pixel distance
            self.distance_interlabels(5)
            self.fname_output = None
        if type_process == 'cubic-to-point':
            self.output_image = self.cubic_to_point()
        if type_process == 'vert-body':
            self.output_image = self.label_vertebrae(self.vertebral_levels)
        if type_process == 'vert-continuous':
            self.output_image = self.continuous_vertebral_levels()
        if type_process == 'create-viewer':
            self.output_image = self.launch_sagittal_viewer(self.value)

        if self.fname_output is not None:
            if change_orientation:
                self.output_image.change_orientation(input_orientation)
            self.output_image.absolutepath = self.fname_output
            if type_process == 'vert-continuous':
                self.output_image.save(dtype='float32')
            elif type_process != 'plan_ref':
                self.output_image.save(dtype='minimize_int')
            else:
                self.output_image.save()

    def add(self, value):
        """
        This function add a specified value to all non-zero voxels.
        """
        image_output = self.image_input.copy()
        # image_output.data *= 0

        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i, coord in enumerate(coordinates_input):
            image_output.data[int(coord.x), int(coord.y), int(coord.z)] = image_output.data[int(coord.x), int(coord.y), int(coord.z)] + float(value)
        return image_output

    def create_label(self, add=False):
        """
        Create an image with labels listed by the user.
        This method works only if the user inserted correct coordinates.

        self.coordinates is a list of coordinates (class in msct_types).
        a Coordinate contains x, y, z and value.
        If only one label is to be added, coordinates must be completed with '[]'
        examples:
        For one label:  object_define=ProcessLabels( fname_label, coordinates=[coordi]) where coordi is a 'Coordinate' object from msct_types
        For two labels: object_define=ProcessLabels( fname_label, coordinates=[coordi1, coordi2]) where coordi1 and coordi2 are 'Coordinate' objects from msct_types
        """
        image_output = self.image_input.copy() if add else msct_image.zeros_like(self.image_input)

        # loop across labels
        for i, coord in enumerate(self.coordinates):
            if len(image_output.data.shape) == 3:
                image_output.data[int(coord.x), int(coord.y), int(coord.z)] = coord.value
            elif len(image_output.data.shape) == 2:
                assert str(coord.z) == '0', "ERROR: 2D coordinates should have a Z value of 0. Z coordinate is :" + str(coord.z)
                image_output.data[int(coord.x), int(coord.y)] = coord.value
            else:
                sct.printv('ERROR: Data should be 2D or 3D. Current shape is: ' + str(image_output.data.shape), 1, 'error')
            # display info
            sct.printv('Label #' + str(i) + ': ' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ' --> ' +
                       str(coord.value), 1)
        return image_output

    def create_label_along_segmentation(self):
        """
        Create an image with labels defined along the spinal cord segmentation (or centerline)
        Example:
        object_define=ProcessLabels(fname_segmentation, coordinates=[coord_1, coord_2, coord_i]), where coord_i='z,value'. If z=-1, then use z=nz/2 (i.e. center of FOV in superior-inferior direction)
        Returns
        -------
        image_output: Image object with labels.
        """

        image_output = msct_image.zeros_like(self.image_input)

        # loop across labels
        for i, coord in enumerate(self.coordinates):
            # split coord string
            list_coord = coord.split(',')
            # convert to int() and assign to variable
            z, value = [int(i) for i in list_coord]
            # if z=-1, replace with nz/2
            if z == -1:
                z = int(np.round(image_output.dim[2] / 2.0))
            # get center of mass of segmentation at given z
            x, y = ndimage.measurements.center_of_mass(np.array(self.image_input.data[:, :, z]))
            # round values to make indices
            x, y = int(np.round(x)), int(np.round(y))
            # display info
            sct.printv('Label #' + str(i) + ': ' + str(x) + ',' + str(y) + ',' + str(z) + ' --> ' + str(value), 1)
            if len(image_output.data.shape) == 3:
                image_output.data[x, y, z] = value
            elif len(image_output.data.shape) == 2:
                assert str(z) == '0', "ERROR: 2D coordinates should have a Z value of 0. Z coordinate is :" + str(z)
                image_output.data[x, y] = value
        return image_output

    def plan(self, width, offset=0, gap=1):
        """
        Create a plane of thickness="width" and changes its value with an offset and a gap between labels.
        """
        image_output = msct_image.zeros_like(self.image_input)

        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for coord in coordinates_input:
            image_output.data[:, :, int(coord.z) - width:int(coord.z) + width] = offset + gap * coord.value

        return image_output

    def plan_ref(self):
        """
        Generate a plane in the reference space for each label present in the input image
        """

        image_output = msct_image.zeros_like(Image(self.image_ref))

        image_input_neg = msct_image.zeros_like(Image(self.image_input))
        image_input_pos = msct_image.zeros_like(Image(self.image_input))

        X, Y, Z = (self.image_input.data < 0).nonzero()
        for i in range(len(X)):
            image_input_neg.data[X[i], Y[i], Z[i]] = -self.image_input.data[X[i], Y[i], Z[i]]  # in order to apply getNonZeroCoordinates
        X_pos, Y_pos, Z_pos = (self.image_input.data > 0).nonzero()
        for i in range(len(X_pos)):
            image_input_pos.data[X_pos[i], Y_pos[i], Z_pos[i]] = self.image_input.data[X_pos[i], Y_pos[i], Z_pos[i]]

        coordinates_input_neg = image_input_neg.getNonZeroCoordinates()
        coordinates_input_pos = image_input_pos.getNonZeroCoordinates()

        image_output.change_type('float32')
        for coord in coordinates_input_neg:
            image_output.data[:, :, int(coord.z)] = -coord.value  # PB: takes the int value of coord.value
        for coord in coordinates_input_pos:
            image_output.data[:, :, int(coord.z)] = coord.value

        return image_output

    def cubic_to_point(self):
        """
        Calculate the center of mass of each group of labels and returns a file of same size with only a
        label by group at the center of mass of this group.
        It is to be used after applying homothetic warping field to a label file as the labels will be dilated.
        Be careful: this algorithm computes the center of mass of voxels with same value, if two groups of voxels with
         the same value are present but separated in space, this algorithm will compute the center of mass of the two
         groups together.
        :return: image_output
        """

        # 0. Initialization of output image
        output_image = msct_image.zeros_like(self.image_input)

        # 1. Extraction of coordinates from all non-null voxels in the image. Coordinates are sorted by value.
        coordinates = self.image_input.getNonZeroCoordinates(sorting='value')

        # 2. Separate all coordinates into groups by value
        groups = dict()
        for coord in coordinates:
            if coord.value in groups:
                groups[coord.value].append(coord)
            else:
                groups[coord.value] = [coord]

        # 3. Compute the center of mass of each group of voxels and write them into the output image
        for value, list_coord in groups.items():
            center_of_mass = sum(list_coord) / float(len(list_coord))
            sct.printv("Value = " + str(center_of_mass.value) + " : (" + str(center_of_mass.x) + ", " + str(center_of_mass.y) + ", " + str(center_of_mass.z) + ") --> ( " + str(np.round(center_of_mass.x)) + ", " + str(np.round(center_of_mass.y)) + ", " + str(np.round(center_of_mass.z)) + ")", verbose=self.verbose)
            output_image.data[int(np.round(center_of_mass.x)), int(np.round(center_of_mass.y)), int(np.round(center_of_mass.z))] = center_of_mass.value

        return output_image

    def increment_z_inverse(self):
        """
        Take all non-zero values, sort them along the inverse z direction, and attributes the values 1,
        2, 3, etc. This function assuming RPI orientation.
        """
        image_output = msct_image.zeros_like(self.image_input)

        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z', reverse_coord=True)

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i, coord in enumerate(coordinates_input):
            image_output.data[int(coord.x), int(coord.y), int(coord.z)] = i + 1

        return image_output

    def labelize_from_disks(self):
        """
        Create an image with regions labelized depending on values from reference.
        Typically, user inputs a segmentation image, and labels with disks position, and this function produces
        a segmentation image with vertebral levels labelized.
        Labels are assumed to be non-zero and incremented from top to bottom, assuming a RPI orientation
        """
        image_output = msct_image.zeros_like(self.image_input)

        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates(sorting='value')

        # for all points in input, find the value that has to be set up, depending on the vertebral level
        for i, coord in enumerate(coordinates_input):
            for j in range(0, len(coordinates_ref) - 1):
                if coordinates_ref[j + 1].z < coord.z <= coordinates_ref[j].z:
                    image_output.data[int(coord.x), int(coord.y), int(coord.z)] = coordinates_ref[j].value

        return image_output

    def label_vertebrae(self, levels_user=None):
        """
        Find the center of mass of vertebral levels specified by the user.
        :return: image_output: Image with labels.
        """
        # get center of mass of each vertebral level
        image_cubic2point = self.cubic_to_point()
        # get list of coordinates for each label
        list_coordinates = image_cubic2point.getNonZeroCoordinates(sorting='value')
        # if user did not specify levels, include all:
        if levels_user[0] == 0:
            levels_user = [int(i.value) for i in list_coordinates]
        # loop across labels and remove those that are not listed by the user
        for i_label in range(len(list_coordinates)):
            # check if this level is NOT in levels_user
            if not levels_user.count(int(list_coordinates[i_label].value)):
                # if not, set value to zero
                image_cubic2point.data[int(list_coordinates[i_label].x), int(list_coordinates[i_label].y), int(list_coordinates[i_label].z)] = 0
        # list all labels
        return image_cubic2point

    def MSE(self, threshold_mse=0):
        """
        Compute the Mean Square Distance Error between two sets of labels (input and ref).
        Moreover, a warning is generated for each label mismatch.
        If the MSE is above the threshold provided (by default = 0mm), a log is reported with the filenames considered here.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        # check if all the labels in both the images match
        if len(coordinates_input) != len(coordinates_ref):
            sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord in coordinates_input:
            if np.round(coord.value) not in [np.round(coord_ref.value) for coord_ref in coordinates_ref]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord_ref in coordinates_ref:
            if np.round(coord_ref.value) not in [np.round(coord.value) for coord in coordinates_input]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')

        result = 0.0
        for coord in coordinates_input:
            for coord_ref in coordinates_ref:
                if np.round(coord_ref.value) == np.round(coord.value):
                    result += (coord_ref.z - coord.z) ** 2
                    break
        result = np.sqrt(result / len(coordinates_input))
        sct.printv('MSE error in Z direction = ' + str(result) + ' mm')

        if result > threshold_mse:
            parent, stem, ext = sct.extract_fname(self.image_input.absolutepath)
            fname_report = os.path.join(parent, 'error_log_{}.txt'.format(stem))
            with open(fname_report, 'w') as f:
                f.write('The labels error (MSE) between {} and {} is: {}\n'.format(
                 os.path.relpath(self.image_input.absolutepath, os.path.dirname(fname_report)),
                 os.path.relpath(self.image_ref.absolutepath, os.path.dirname(fname_report)),
                 result))

        return result

    @staticmethod
    def remove_label_coord(coord_input, coord_ref, symmetry=False):
        """
        coord_input and coord_ref should be sets of CoordinateValue in order to improve speed of intersection
        :param coord_input: set of CoordinateValue
        :param coord_ref: set of CoordinateValue
        :param symmetry: boolean,
        :return: intersection of CoordinateValue: list
        """
        from msct_types import CoordinateValue
        if isinstance(coord_input[0], CoordinateValue) and isinstance(coord_ref[0], CoordinateValue) and symmetry:
            coord_intersection = list(set(coord_input).intersection(set(coord_ref)))
            result_coord_input = [coord for coord in coord_input if coord in coord_intersection]
            result_coord_ref = [coord for coord in coord_ref if coord in coord_intersection]
        else:
            result_coord_ref = coord_ref
            result_coord_input = [coord for coord in coord_input if list(filter(lambda x: x.value == coord.value, coord_ref))]
            if symmetry:
                result_coord_ref = [coord for coord in coord_ref if list(filter(lambda x: x.value == coord.value, result_coord_input))]

        return result_coord_input, result_coord_ref

    def remove_label(self, symmetry=False):
        """
        Compare two label images and remove any labels in input image that are not in reference image.
        The symmetry option enables to remove labels from reference image that are not in input image
        """
        # image_output = Image(self.image_input.dim, orientation=self.image_input.orientation, hdr=self.image_input.hdr, verbose=self.verbose)
        image_output = msct_image.zeros_like(self.image_input)

        result_coord_input, result_coord_ref = self.remove_label_coord(self.image_input.getNonZeroCoordinates(coordValue=True),
                                                                       self.image_ref.getNonZeroCoordinates(coordValue=True), symmetry)

        for coord in result_coord_input:
            image_output.data[int(coord.x), int(coord.y), int(coord.z)] = int(np.round(coord.value))

        if symmetry:
            # image_output_ref = Image(self.image_ref.dim, orientation=self.image_ref.orientation, hdr=self.image_ref.hdr, verbose=self.verbose)
            image_output_ref = Image(self.image_ref, verbose=self.verbose)
            for coord in result_coord_ref:
                image_output_ref.data[int(coord.x), int(coord.y), int(coord.z)] = int(np.round(coord.value))
            image_output_ref.absolutepath = self.fname_output[1]
            image_output_ref.save('minimize_int')

            self.fname_output = self.fname_output[0]

        return image_output

    def display_voxel(self):
        """
        Display all the labels that are contained in the input image.
        The image is suppose to be RPI to display voxels. But works also for other orientations
        """
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='value')
        self.useful_notation = ''
        for coord in coordinates_input:
            sct.printv('Position=(' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ') -- Value= ' + str(coord.value), verbose=self.verbose)
            if self.useful_notation:
                self.useful_notation = self.useful_notation + ':'
            self.useful_notation += str(coord)
        sct.printv('All labels (useful syntax):', verbose=self.verbose)
        sct.printv(self.useful_notation, verbose=self.verbose)
        return coordinates_input

    def get_physical_coordinates(self):
        """
        This function returns the coordinates of the labels in the physical referential system.
        :return: a list of CoordinateValue, in the physical (scanner) space
        """
        coord = self.image_input.getNonZeroCoordinates(sorting='value')
        phys_coord = []
        for c in coord:
            # convert pixelar coordinates to physical coordinates
            c_p = self.image_input.transfo_pix2phys([[c.x, c.y, c.z]])[0]
            phys_coord.append(CoordinateValue([c_p[0], c_p[1], c_p[2], c.value]))
        return phys_coord

    def get_coordinates_in_destination(self, im_dest, type='discrete'):
        """
        This function calculate the position of labels in the pixelar space of a destination image
        :param im_dest: Object Image
        :param type: 'discrete' or 'continuous'
        :return: a list of CoordinateValue, in the pixelar (image) space of the destination image
        """
        phys_coord = self.get_physical_coordinates()
        dest_coord = []
        for c in phys_coord:
            if type is 'discrete':
                c_p = im_dest.transfo_phys2pix([[c.x, c.y, c.y]])[0]
            elif type is 'continuous':
                c_p = im_dest.transfo_phys2pix([[c.x, c.y, c.y]], real=False)[0]
            else:
                raise ValueError("The value of 'type' should either be 'discrete' or 'continuous'.")
            dest_coord.append(CoordinateValue([c_p[0], c_p[1], c_p[2], c.value]))
        return dest_coord

    def diff(self):
        """
        Detect any label mismatch between input image and reference image
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        sct.printv("Label in input image that are not in reference image:")
        for coord in coordinates_input:
            isIn = False
            for coord_ref in coordinates_ref:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                sct.printv(coord.value)

        sct.printv("Label in ref image that are not in input image:")
        for coord_ref in coordinates_ref:
            isIn = False
            for coord in coordinates_input:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                sct.printv(coord_ref.value)

    def distance_interlabels(self, max_dist):
        """
        Calculate the distances between each label in the input image.
        If a distance is larger than max_dist, a warning message is displayed.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i in range(0, len(coordinates_input) - 1):
            dist = np.sqrt((coordinates_input[i].x - coordinates_input[i + 1].x)**2 + (coordinates_input[i].y - coordinates_input[i + 1].y)**2 + (coordinates_input[i].z - coordinates_input[i + 1].z)**2)
            if dist < max_dist:
                sct.printv('Warning: the distance between label ' + str(i) + '[' + str(coordinates_input[i].x) + ',' + str(coordinates_input[i].y) + ',' + str(
                    coordinates_input[i].z) + ']=' + str(coordinates_input[i].value) + ' and label ' + str(i + 1) + '[' + str(
                    coordinates_input[i + 1].x) + ',' + str(coordinates_input[i + 1].y) + ',' + str(coordinates_input[i + 1].z) + ']=' + str(
                    coordinates_input[i + 1].value) + ' is larger than ' + str(max_dist) + '. Distance=' + str(dist))

    def continuous_vertebral_levels(self):
        """
        This function transforms the vertebral levels file from the template into a continuous file.
        Instead of having integer representing the vertebral level on each slice, a continuous value that represents
        the position of the slice in the vertebral level coordinate system.
        The image must be RPI
        :return:
        """
        im_input = Image(self.image_input, self.verbose)
        im_output = msct_image.zeros_like(self.image_input)

        # 1. extract vertebral levels from input image
        #   a. extract centerline
        #   b. for each slice, extract corresponding level
        nx, ny, nz, nt, px, py, pz, pt = im_input.dim
        from spinalcordtoolbox.centerline.core import get_centerline
        _, arr_ctl, _ = get_centerline(self.image_input, algo_fitting='bspline')
        x_centerline_fit, y_centerline_fit, z_centerline = arr_ctl
        value_centerline = np.array(
            [im_input.data[int(x_centerline_fit[it]), int(y_centerline_fit[it]), int(z_centerline[it])]
             for it in range(len(z_centerline))])

        # 2. compute distance for each vertebral level --> Di for i being the vertebral levels
        vertebral_levels = {}
        for slice_image, level in enumerate(value_centerline):
            if level not in vertebral_levels:
                vertebral_levels[level] = slice_image

        length_levels = {}
        for level in vertebral_levels:
            indexes_slice = np.where(value_centerline == level)
            length_levels[level] = np.sum([np.sqrt(((x_centerline_fit[indexes_slice[0][index_slice + 1]] - x_centerline_fit[indexes_slice[0][index_slice]]) * px)**2 +
                                                     ((y_centerline_fit[indexes_slice[0][index_slice + 1]] - y_centerline_fit[indexes_slice[0][index_slice]]) * py)**2 +
                                                     ((z_centerline[indexes_slice[0][index_slice + 1]] - z_centerline[indexes_slice[0][index_slice]]) * pz)**2)
                                           for index_slice in range(len(indexes_slice[0]) - 1)])

        # 2. for each slice:
        #   a. identify corresponding vertebral level --> i
        #   b. calculate distance of slice from upper vertebral level --> d
        #   c. compute relative distance in the vertebral level coordinate system --> d/Di
        continuous_values = {}
        for it, iz in enumerate(z_centerline):
            level = value_centerline[it]
            indexes_slice = np.where(value_centerline == level)
            indexes_slice = indexes_slice[0][indexes_slice[0] >= it]
            distance_from_level = np.sum([np.sqrt(((x_centerline_fit[indexes_slice[index_slice + 1]] - x_centerline_fit[indexes_slice[index_slice]]) * px * px) ** 2 +
                                                    ((y_centerline_fit[indexes_slice[index_slice + 1]] - y_centerline_fit[indexes_slice[index_slice]]) * py * py) ** 2 +
                                                    ((z_centerline[indexes_slice[index_slice + 1]] - z_centerline[indexes_slice[index_slice]]) * pz * pz) ** 2)
                                          for index_slice in range(len(indexes_slice) - 1)])
            continuous_values[iz] = level + 2.0 * distance_from_level / float(length_levels[level])

        # 3. saving data
        # for each slice, get all non-zero pixels and replace with continuous values
        coordinates_input = self.image_input.getNonZeroCoordinates()
        im_output.change_type(np.float32)
        # for all points in input, find the value that has to be set up, depending on the vertebral level
        for i, coord in enumerate(coordinates_input):
            im_output.data[int(coord.x), int(coord.y), int(coord.z)] = continuous_values[coord.z]

        return im_output

    def launch_sagittal_viewer(self, labels):
        from spinalcordtoolbox.gui import base
        from spinalcordtoolbox.gui.sagittal import launch_sagittal_dialog

        params = base.AnatomicalParams()
        params.vertebraes = labels
        params.input_file_name = self.image_input.absolutepath
        params.output_file_name = self.fname_output
        params.subtitle = self.msg
        output = msct_image.zeros_like(self.image_input)
        output.absolutepath = self.fname_output
        launch_sagittal_dialog(self.image_input, output, params)

        return output
示例#33
0
def merge_images(list_fname_src, fname_dest, list_fname_warp, param):
    """
    Merge multiple source images onto destination space. All images are warped to the destination space and then added.
    To deal with overlap during merging (e.g. one voxel in destination image is shared with two input images), the
    resulting voxel is divided by the sum of the partial volume of each image. For example, if src(x,y,z)=1 is mapped to
    dest(i,j,k) with a partial volume of 0.5 (because destination voxel is bigger), then its value after linear interpolation
    will be 0.5. To account for partial volume, the resulting voxel will be: dest(i,j,k) = 0.5*0.5/0.5 = 0.5.
    Now, if two voxels overlap in the destination space, let's say: src(x,y,z)=1 and src2'(x',y',z')=1, then the
    resulting value will be: dest(i,j,k) = (0.5*0.5 + 0.5*0.5) / (0.5+0.5) = 0.5. So this function acts like a weighted
    average operator, only in destination voxels that share multiple source voxels.

    Parameters
    ----------
    list_fname_src
    fname_dest
    list_fname_warp
    param

    Returns
    -------

    """
    # create temporary folder
    path_tmp = tmp_create()

    # get dimensions of destination file
    nii_dest = Image(fname_dest)

    # initialize variables
    data = np.zeros([
        nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2],
        len(list_fname_src)
    ])
    partial_volume = np.zeros([
        nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2],
        len(list_fname_src)
    ])
    data_merge = np.zeros([nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2]])

    # loop across files
    i_file = 0
    for fname_src in list_fname_src:

        # apply transformation src --> dest
        sct_apply_transfo.main(argv=[
            '-i', fname_src, '-d', fname_dest, '-w', list_fname_warp[i_file],
            '-x', param.interp, '-o', 'src_' + str(i_file) +
            '_template.nii.gz', '-v',
            str(param.verbose)
        ])

        # create binary mask from input file by assigning one to all non-null voxels
        img = Image(fname_src)
        out = img.copy()
        out.data = binarize(out.data, param.almost_zero)
        out.save(path=f"src_{i_file}native_bin.nii.gz")

        # apply transformation to binary mask to compute partial volume
        sct_apply_transfo.main(argv=[
            '-i', 'src_' + str(i_file) + 'native_bin.nii.gz', '-d', fname_dest,
            '-w', list_fname_warp[i_file], '-x', param.interp, '-o', 'src_' +
            str(i_file) + '_template_partialVolume.nii.gz'
        ])

        # open data
        data[:, :, :,
             i_file] = Image('src_' + str(i_file) + '_template.nii.gz').data
        partial_volume[:, :, :,
                       i_file] = Image('src_' + str(i_file) +
                                       '_template_partialVolume.nii.gz').data
        i_file += 1

    # merge files using partial volume information (and convert nan resulting from division by zero to zeros)
    data_merge = np.divide(np.sum(data * partial_volume, axis=3),
                           np.sum(partial_volume, axis=3))
    data_merge = np.nan_to_num(data_merge)

    # write result in file
    nii_dest.data = data_merge
    nii_dest.save(param.fname_out)

    # remove temporary folder
    if param.rm_tmp:
        rmtree(path_tmp)
def dummy_centerline(size_arr=(9, 9, 9),
                     subsampling=1,
                     dilate_ctl=0,
                     hasnan=False,
                     zeroslice=[],
                     outlier=[],
                     orientation='RPI',
                     debug=False):
    """
    Create a dummy Image centerline of small size. Return the full and sub-sampled version along z.
    :param size_arr: tuple: (nx, ny, nz)
    :param subsampling: int >=1. Subsampling factor along z. 1: no subsampling. 2: centerline defined every other z.
    :param dilate_ctl: Dilation of centerline. E.g., if dilate_ctl=1, result will be a square of 3x3 per slice.
                         if dilate_ctl=0, result will be a single pixel per slice.
    :param hasnan: Bool: Image has non-numerical values: nan, inf. In this case, do not subsample.
    :param zeroslice: list int: zero all slices listed in this param
    :param outlier: list int: replace the current point with an outlier at the corner of the image for the slices listed
    :param orientation:
    :param debug: Bool: Write temp files
    :return:
    """
    from numpy import poly1d, polyfit
    nx, ny, nz = size_arr
    # define array based on a polynomial function, within X-Z plane, located at y=ny/4, based on the following points:
    x = np.array([round(nx / 4.), round(nx / 2.), round(3 * nx / 4.)])
    z = np.array([0, round(nz / 2.), nz - 1])
    p = poly1d(polyfit(z, x, deg=3))
    data = np.zeros((nx, ny, nz))
    arr_ctl = np.array([
        p(range(nz)).astype(np.int), [round(ny / 4.)] * len(range(nz)),
        range(nz)
    ],
                       dtype=np.uint16)
    # Loop across dilation of centerline. E.g., if dilate_ctl=1, result will be a square of 3x3 per slice.
    for ixiy_ctl in itertools.product(range(-dilate_ctl, dilate_ctl + 1, 1),
                                      range(-dilate_ctl, dilate_ctl + 1, 1)):
        data[(arr_ctl[0] + ixiy_ctl[0]).tolist(),
             (arr_ctl[1] + ixiy_ctl[1]).tolist(), arr_ctl[2].tolist()] = 1
    # Zero specified slices
    if zeroslice is not []:
        data[:, :, zeroslice] = 0
    # Add outlier
    if outlier is not []:
        # First, zero all the slice
        data[:, :, outlier] = 0
        # Then, add point in the corner
        data[0, 0, outlier] = 1
    # Create image with default orientation LPI
    affine = np.eye(4)
    nii = nib.nifti1.Nifti1Image(data, affine)
    img = Image(data, hdr=nii.header, dim=nii.header.get_data_shape())
    # subsample data
    img_sub = img.copy()
    img_sub.data = np.zeros((nx, ny, nz))
    for iz in range(0, nz, subsampling):
        img_sub.data[..., iz] = data[..., iz]
    # Add non-numerical values at the top corner of the image
    if hasnan:
        img.data[0, 0, 0] = np.nan
        img.data[1, 0, 0] = np.inf
    # Update orientation
    img.change_orientation(orientation)
    img_sub.change_orientation(orientation)
    if debug:
        img_sub.save('tmp_dummy_seg_' +
                     datetime.now().strftime("%Y%m%d%H%M%S%f") + '.nii.gz')
    return img, img_sub, arr_ctl
def dummy_centerline(size_arr=(9, 9, 9),
                     pixdim=(1, 1, 1),
                     subsampling=1,
                     dilate_ctl=0,
                     hasnan=False,
                     zeroslice=[],
                     outlier=[],
                     orientation='RPI',
                     debug=False):
    """
    Create a dummy Image centerline of small size. Return the full and sub-sampled version along z. Voxel resolution
    on fully-sampled data is 1x1x1 mm (so, 2x undersampled data along z would have resolution of 1x1x2 mm).
    :param size_arr: tuple: (nx, ny, nz)
    :param pixdim: tuple: (px, py, pz)
    :param subsampling: int >=1. Subsampling factor along z. 1: no subsampling. 2: centerline defined every other z.
    :param dilate_ctl: Dilation of centerline. E.g., if dilate_ctl=1, result will be a square of 3x3 per slice.
                         if dilate_ctl=0, result will be a single pixel per slice.
    :param hasnan: Bool: Image has non-numerical values: nan, inf. In this case, do not subsample.
    :param zeroslice: list int: zero all slices listed in this param
    :param outlier: list int: replace the current point with an outlier at the corner of the image for the slices listed
    :param orientation:
    :param debug: Bool: Write temp files
    :return:
    """
    nx, ny, nz = size_arr
    # create regularized curve, within X-Z plane, located at y=ny/4, passing through the following points:
    x = np.array([round(nx / 4.), round(nx / 2.), round(3 * nx / 4.)])
    z = np.array([0, round(nz / 2.), nz - 1])
    # we use bspline (instead of poly) in order to avoid bad extrapolation at edges
    # see: https://github.com/spinalcordtoolbox/spinalcordtoolbox/pull/2754
    xfit, _ = bspline(z, x, range(nz), 10)
    # p = P.fit(z, x, 3)
    # p = np.poly1d(np.polyfit(z, x, deg=3))
    data = np.zeros((nx, ny, nz))
    arr_ctl = np.array(
        [xfit.astype(np.int), [round(ny / 4.)] * len(range(nz)),
         range(nz)],
        dtype=np.uint16)
    # Loop across dilation of centerline. E.g., if dilate_ctl=1, result will be a square of 3x3 per slice.
    for ixiy_ctl in itertools.product(range(-dilate_ctl, dilate_ctl + 1, 1),
                                      range(-dilate_ctl, dilate_ctl + 1, 1)):
        data[(arr_ctl[0] + ixiy_ctl[0]).tolist(),
             (arr_ctl[1] + ixiy_ctl[1]).tolist(), arr_ctl[2].tolist()] = 1
    # Zero specified slices
    if zeroslice is not []:
        data[:, :, zeroslice] = 0
    # Add outlier
    if outlier is not []:
        # First, zero all the slice
        data[:, :, outlier] = 0
        # Then, add point in the corner
        data[0, 0, outlier] = 1
    # Create image with default orientation LPI
    affine = np.eye(4)
    affine[0:3, 0:3] = affine[0:3, 0:3] * pixdim
    nii = nib.nifti1.Nifti1Image(data, affine)
    img = Image(data, hdr=nii.header, dim=nii.header.get_data_shape())
    # subsample data
    img_sub = img.copy()
    img_sub.data = np.zeros((nx, ny, nz))
    for iz in range(0, nz, subsampling):
        img_sub.data[..., iz] = data[..., iz]
    # Add non-numerical values at the top corner of the image
    if hasnan:
        img.data[0, 0, 0] = np.nan
        img.data[1, 0, 0] = np.inf
    # Update orientation
    img.change_orientation(orientation)
    img_sub.change_orientation(orientation)
    if debug:
        img_sub.save('tmp_dummy_seg_' +
                     datetime.now().strftime("%Y%m%d%H%M%S%f") + '.nii.gz')
    return img, img_sub, arr_ctl
示例#36
0
def main(args=None):

    # initializations
    output_type = None
    param = Param()
    dim_list = ['x', 'y', 'z', 't']

    # check user arguments
    if not args:
        args = sys.argv[1:]

    # Get parser info
    parser = get_parser()
    arguments = parser.parse(args)
    fname_in = arguments["-i"]
    n_in = len(fname_in)
    verbose = int(arguments.get('-v'))
    sct.init_sct(log_level=verbose, update=True)  # Update log level

    if "-o" in arguments:
        fname_out = arguments["-o"]
    else:
        fname_out = None

    # Open file(s)
    # im_in_list = [Image(fn) for fn in fname_in]

    # run command
    if "-concat" in arguments:
        dim = arguments["-concat"]
        assert dim in dim_list
        dim = dim_list.index(dim)
        im_out = [concat_data(fname_in, dim)]  # TODO: adapt to fname_in

    elif "-copy-header" in arguments:
        im_in = Image(fname_in[0])
        im_dest = Image(arguments["-copy-header"])
        im_dest_new = im_in.copy()
        im_dest_new.data = im_dest.data.copy()
        # im_dest.header = im_in.header
        im_dest_new.absolutepath = im_dest.absolutepath
        im_out = [im_dest_new]
        fname_out = arguments["-copy-header"]

    elif '-display-warp' in arguments:
        im_in = fname_in[0]
        visualize_warp(im_in, fname_grid=None, step=3, rm_tmp=True)
        im_out = None

    elif "-getorient" in arguments:
        im_in = Image(fname_in[0])
        orient = im_in.orientation
        im_out = None

    elif '-keep-vol' in arguments:
        index_vol = arguments['-keep-vol']
        im_in = Image(fname_in[0])
        im_out = [remove_vol(im_in, index_vol, todo='keep')]

    elif '-mcs' in arguments:
        im_in = Image(fname_in[0])
        if n_in != 1:
            sct.printv(
                parser.usage.generate(error='ERROR: -mcs need only one input'))
        if len(im_in.data.shape) != 5:
            sct.printv(
                parser.usage.generate(
                    error='ERROR: -mcs input need to be a multi-component image'
                ))
        im_out = multicomponent_split(im_in)

    elif '-omc' in arguments:
        im_ref = Image(fname_in[0])
        for fname in fname_in:
            im = Image(fname)
            if im.data.shape != im_ref.data.shape:
                sct.printv(
                    parser.usage.generate(
                        error=
                        'ERROR: -omc inputs need to have all the same shapes'))
            del im
        im_out = [multicomponent_merge(fname_in)]  # TODO: adapt to fname_in

    elif "-pad" in arguments:
        im_in = Image(fname_in[0])
        ndims = len(im_in.data.shape)
        if ndims != 3:
            sct.printv('ERROR: you need to specify a 3D input file.', 1,
                       'error')
            return

        pad_arguments = arguments["-pad"].split(',')
        if len(pad_arguments) != 3:
            sct.printv('ERROR: you need to specify 3 padding values.', 1,
                       'error')

        padx, pady, padz = pad_arguments
        padx, pady, padz = int(padx), int(pady), int(padz)
        im_out = [
            pad_image(im_in,
                      pad_x_i=padx,
                      pad_x_f=padx,
                      pad_y_i=pady,
                      pad_y_f=pady,
                      pad_z_i=padz,
                      pad_z_f=padz)
        ]

    elif "-pad-asym" in arguments:
        im_in = Image(fname_in[0])
        ndims = len(im_in.data.shape)
        if ndims != 3:
            sct.printv('ERROR: you need to specify a 3D input file.', 1,
                       'error')
            return

        pad_arguments = arguments["-pad-asym"].split(',')
        if len(pad_arguments) != 6:
            sct.printv('ERROR: you need to specify 6 padding values.', 1,
                       'error')

        padxi, padxf, padyi, padyf, padzi, padzf = pad_arguments
        padxi, padxf, padyi, padyf, padzi, padzf = int(padxi), int(padxf), int(
            padyi), int(padyf), int(padzi), int(padzf)
        im_out = [
            pad_image(im_in,
                      pad_x_i=padxi,
                      pad_x_f=padxf,
                      pad_y_i=padyi,
                      pad_y_f=padyf,
                      pad_z_i=padzi,
                      pad_z_f=padzf)
        ]

    elif '-remove-vol' in arguments:
        index_vol = arguments['-remove-vol']
        im_in = Image(fname_in[0])
        im_out = [remove_vol(im_in, index_vol, todo='remove')]

    elif "-setorient" in arguments:
        sct.printv(fname_in[0])
        im_in = Image(fname_in[0])
        im_out = [
            msct_image.change_orientation(
                im_in, arguments["-setorient"]).save(fname_out)
        ]

    elif "-setorient-data" in arguments:
        im_in = Image(fname_in[0])
        im_out = [
            msct_image.change_orientation(im_in,
                                          arguments["-setorient-data"],
                                          inverse=True).save(fname_out)
        ]

    elif "-split" in arguments:
        dim = arguments["-split"]
        assert dim in dim_list
        im_in = Image(fname_in[0])
        dim = dim_list.index(dim)
        im_out = split_data(im_in, dim)

    elif '-type' in arguments:
        output_type = arguments['-type']
        im_in = Image(fname_in[0])
        im_out = [im_in]  # TODO: adapt to fname_in

    else:
        im_out = None
        sct.printv(
            parser.usage.generate(
                error=
                'ERROR: you need to specify an operation to do on the input image'
            ))

    # in case fname_out is not defined, use first element of input file name list
    if fname_out == None:
        fname_out = fname_in[0]

    # Write output
    if im_out is not None:
        sct.printv('Generate output files...', verbose)
        # if only one output
        if len(im_out) == 1 and not '-split' in arguments:
            im_out[0].save(fname_out, dtype=output_type, verbose=verbose)
            sct.display_viewer_syntax([fname_out], verbose=verbose)
        if '-mcs' in arguments:
            # use input file name and add _X, _Y _Z. Keep the same extension
            l_fname_out = []
            for i_dim in range(3):
                l_fname_out.append(
                    sct.add_suffix(fname_out or fname_in[0],
                                   '_' + dim_list[i_dim].upper()))
                im_out[i_dim].save(l_fname_out[i_dim], verbose=verbose)
            sct.display_viewer_syntax(fname_out)
        if '-split' in arguments:
            # use input file name and add _"DIM+NUMBER". Keep the same extension
            l_fname_out = []
            for i, im in enumerate(im_out):
                l_fname_out.append(
                    sct.add_suffix(
                        fname_out or fname_in[0],
                        '_' + dim_list[dim].upper() + str(i).zfill(4)))
                im.save(l_fname_out[i])
            sct.display_viewer_syntax(l_fname_out)

    elif "-getorient" in arguments:
        sct.printv(orient)

    elif '-display-warp' in arguments:
        sct.printv('Warping grid generated.', verbose, 'info')
class ProcessLabels(object):
    def __init__(self, fname_label, fname_output=None, fname_ref=None, cross_radius=5, dilate=False,
                 coordinates=None, verbose=1, vertebral_levels=None, value=None, msg=""):
        """
        Collection of processes that deal with label creation/modification.
        :param fname_label:
        :param fname_output:
        :param fname_ref:
        :param cross_radius:
        :param dilate:  # TODO: remove dilate (does not seem to be used)
        :param coordinates:
        :param verbose:
        :param vertebral_levels:
        :param value:
        :param msg: string. message to display to the user.
        """
        self.image_input = Image(fname_label, verbose=verbose)
        self.image_ref = None
        if fname_ref is not None:
            self.image_ref = Image(fname_ref, verbose=verbose)

        if isinstance(fname_output, list):
            if len(fname_output) == 1:
                self.fname_output = fname_output[0]
            else:
                self.fname_output = fname_output
        else:
            self.fname_output = fname_output
        self.cross_radius = cross_radius
        self.vertebral_levels = vertebral_levels
        self.dilate = dilate
        self.coordinates = coordinates
        self.verbose = verbose
        self.value = value
        self.msg = msg
        self.output_image = None

    def process(self, type_process):
        # for some processes, change orientation of input image to RPI
        change_orientation = False
        if type_process in ['vert-body', 'vert-disc', 'vert-continuous']:
            # get orientation of input image
            input_orientation = self.image_input.orientation
            # change orientation
            self.image_input.change_orientation('RPI')
            change_orientation = True
        if type_process == 'add':
            self.output_image = self.add(self.value)
        if type_process == 'plan':
            self.output_image = self.plan(self.cross_radius, 100, 5)
        if type_process == 'plan_ref':
            self.output_image = self.plan_ref()
        if type_process == 'increment':
            self.output_image = self.increment_z_inverse()
        if type_process == 'disks':
            self.output_image = self.labelize_from_disks()
        if type_process == 'MSE':
            self.MSE()
            self.fname_output = None
        if type_process == 'remove-reference':
            self.output_image = self.remove_label()
        if type_process == 'remove-symm':
            self.output_image = self.remove_label(symmetry=True)
        if type_process == 'create':
            self.output_image = self.create_label()
        if type_process == 'create-add':
            self.output_image = self.create_label(add=True)
        if type_process == 'create-seg':
            self.output_image = self.create_label_along_segmentation()
        if type_process == 'display-voxel':
            self.display_voxel()
            self.fname_output = None
        if type_process == 'diff':
            self.diff()
            self.fname_output = None
        if type_process == 'dist-inter':  # second argument is in pixel distance
            self.distance_interlabels(5)
            self.fname_output = None
        if type_process == 'cubic-to-point':
            self.output_image = self.cubic_to_point()
        if type_process == 'vert-body':
            self.output_image = self.label_vertebrae(self.vertebral_levels)
        if type_process == 'vert-continuous':
            self.output_image = self.continuous_vertebral_levels()
        if type_process == 'create-viewer':
            self.output_image = self.launch_sagittal_viewer(self.value)
        if type_process in ['remove', 'keep']:
            self.output_image = self.remove_or_keep_labels(self.value, action=type_process)

        # TODO: do not save here. Create another function save() for that
        if self.fname_output is not None:
            if change_orientation:
                self.output_image.change_orientation(input_orientation)
            self.output_image.absolutepath = self.fname_output
            if type_process == 'vert-continuous':
                self.output_image.save(dtype='float32')
            elif type_process != 'plan_ref':
                self.output_image.save(dtype='minimize_int')
            else:
                self.output_image.save()
        return self.output_image

    def add(self, value):
        """
        This function add a specified value to all non-zero voxels.
        """
        image_output = self.image_input.copy()
        # image_output.data *= 0

        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i, coord in enumerate(coordinates_input):
            image_output.data[int(coord.x), int(coord.y), int(coord.z)] = image_output.data[int(coord.x), int(coord.y), int(coord.z)] + float(value)
        return image_output

    def create_label(self, add=False):
        """
        Create an image with labels listed by the user.
        This method works only if the user inserted correct coordinates.

        self.coordinates is a list of coordinates (class in msct_types).
        a Coordinate contains x, y, z and value.
        If only one label is to be added, coordinates must be completed with '[]'
        examples:
        For one label:  object_define=ProcessLabels( fname_label, coordinates=[coordi]) where coordi is a 'Coordinate' object from msct_types
        For two labels: object_define=ProcessLabels( fname_label, coordinates=[coordi1, coordi2]) where coordi1 and coordi2 are 'Coordinate' objects from msct_types
        """
        image_output = self.image_input.copy() if add else msct_image.zeros_like(self.image_input)

        # loop across labels
        for i, coord in enumerate(self.coordinates):
            if len(image_output.data.shape) == 3:
                image_output.data[int(coord.x), int(coord.y), int(coord.z)] = coord.value
            elif len(image_output.data.shape) == 2:
                assert str(coord.z) == '0', "ERROR: 2D coordinates should have a Z value of 0. Z coordinate is :" + str(coord.z)
                image_output.data[int(coord.x), int(coord.y)] = coord.value
            else:
                sct.printv('ERROR: Data should be 2D or 3D. Current shape is: ' + str(image_output.data.shape), 1, 'error')
            # display info
            sct.printv('Label #' + str(i) + ': ' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ' --> ' +
                       str(coord.value), 1)
        return image_output

    def create_label_along_segmentation(self):
        """
        Create an image with labels defined along the spinal cord segmentation (or centerline).
        Input image does **not** need to be RPI (re-orientation is done within this function).
        Example:
        object_define=ProcessLabels(fname_segmentation, coordinates=[coord_1, coord_2, coord_i]), where coord_i='z,value'. If z=-1, then use z=nz/2 (i.e. center of FOV in superior-inferior direction)
        Returns
        """
        # reorient input image to RPI
        im_rpi = self.image_input.copy().change_orientation('RPI')
        im_output_rpi = zeros_like(im_rpi)
        # loop across labels
        for ilabel, coord in enumerate(self.coordinates):
            # split coord string
            list_coord = coord.split(',')
            # convert to int() and assign to variable
            z, value = [int(i) for i in list_coord]
            # update z based on native image orientation (z should represent superior-inferior axis)
            coord = Coordinate([z, z, z])  # since we don't know which dimension corresponds to the superior-inferior
            # axis, we put z in all dimensions (we don't care about x and y here)
            _, _, z_rpi = coord.permute(self.image_input, 'RPI')
            # if z=-1, replace with nz/2
            if z == -1:
                z_rpi = int(np.round(im_output_rpi.dim[2] / 2.0))
            # get center of mass of segmentation at given z
            x, y = ndimage.measurements.center_of_mass(np.array(im_rpi.data[:, :, z_rpi]))
            # round values to make indices
            x, y = int(np.round(x)), int(np.round(y))
            # display info
            sct.printv('Label #' + str(ilabel) + ': ' + str(x) + ',' + str(y) + ',' + str(z_rpi) + ' --> ' + str(value), 1)
            if len(im_output_rpi.data.shape) == 3:
                im_output_rpi.data[x, y, z_rpi] = value
            elif len(im_output_rpi.data.shape) == 2:
                assert str(z) == '0', "ERROR: 2D coordinates should have a Z value of 0. Z coordinate is :" + str(z)
                im_output_rpi.data[x, y] = value
        # change orientation back to native
        return im_output_rpi.change_orientation(self.image_input.orientation)

    def plan(self, width, offset=0, gap=1):
        """
        Create a plane of thickness="width" and changes its value with an offset and a gap between labels.
        """
        image_output = msct_image.zeros_like(self.image_input)

        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for coord in coordinates_input:
            image_output.data[:, :, int(coord.z) - width:int(coord.z) + width] = offset + gap * coord.value

        return image_output

    def plan_ref(self):
        """
        Generate a plane in the reference space for each label present in the input image
        """

        image_output = msct_image.zeros_like(Image(self.image_ref))

        image_input_neg = msct_image.zeros_like(Image(self.image_input))
        image_input_pos = msct_image.zeros_like(Image(self.image_input))

        X, Y, Z = (self.image_input.data < 0).nonzero()
        for i in range(len(X)):
            image_input_neg.data[X[i], Y[i], Z[i]] = -self.image_input.data[X[i], Y[i], Z[i]]  # in order to apply getNonZeroCoordinates
        X_pos, Y_pos, Z_pos = (self.image_input.data > 0).nonzero()
        for i in range(len(X_pos)):
            image_input_pos.data[X_pos[i], Y_pos[i], Z_pos[i]] = self.image_input.data[X_pos[i], Y_pos[i], Z_pos[i]]

        coordinates_input_neg = image_input_neg.getNonZeroCoordinates()
        coordinates_input_pos = image_input_pos.getNonZeroCoordinates()

        image_output.change_type('float32')
        for coord in coordinates_input_neg:
            image_output.data[:, :, int(coord.z)] = -coord.value  # PB: takes the int value of coord.value
        for coord in coordinates_input_pos:
            image_output.data[:, :, int(coord.z)] = coord.value

        return image_output

    def cubic_to_point(self):
        """
        Calculate the center of mass of each group of labels and returns a file of same size with only a
        label by group at the center of mass of this group.
        It is to be used after applying homothetic warping field to a label file as the labels will be dilated.
        Be careful: this algorithm computes the center of mass of voxels with same value, if two groups of voxels with
         the same value are present but separated in space, this algorithm will compute the center of mass of the two
         groups together.
        :return: image_output
        """

        # 0. Initialization of output image
        output_image = msct_image.zeros_like(self.image_input)

        # 1. Extraction of coordinates from all non-null voxels in the image. Coordinates are sorted by value.
        coordinates = self.image_input.getNonZeroCoordinates(sorting='value')

        # 2. Separate all coordinates into groups by value
        groups = dict()
        for coord in coordinates:
            if coord.value in groups:
                groups[coord.value].append(coord)
            else:
                groups[coord.value] = [coord]

        # 3. Compute the center of mass of each group of voxels and write them into the output image
        for value, list_coord in groups.items():
            center_of_mass = sum(list_coord) / float(len(list_coord))
            sct.printv("Value = " + str(center_of_mass.value) + " : (" + str(center_of_mass.x) + ", " + str(center_of_mass.y) + ", " + str(center_of_mass.z) + ") --> ( " + str(np.round(center_of_mass.x)) + ", " + str(np.round(center_of_mass.y)) + ", " + str(np.round(center_of_mass.z)) + ")", verbose=self.verbose)
            output_image.data[int(np.round(center_of_mass.x)), int(np.round(center_of_mass.y)), int(np.round(center_of_mass.z))] = center_of_mass.value

        return output_image

    def increment_z_inverse(self):
        """
        Take all non-zero values, sort them along the inverse z direction, and attributes the values 1,
        2, 3, etc. This function assuming RPI orientation.
        """
        image_output = msct_image.zeros_like(self.image_input)

        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z', reverse_coord=True)

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i, coord in enumerate(coordinates_input):
            image_output.data[int(coord.x), int(coord.y), int(coord.z)] = i + 1

        return image_output

    def labelize_from_disks(self):
        """
        Create an image with regions labelized depending on values from reference.
        Typically, user inputs a segmentation image, and labels with disks position, and this function produces
        a segmentation image with vertebral levels labelized.
        Labels are assumed to be non-zero and incremented from top to bottom, assuming a RPI orientation
        """
        image_output = msct_image.zeros_like(self.image_input)

        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates(sorting='value')

        # for all points in input, find the value that has to be set up, depending on the vertebral level
        for i, coord in enumerate(coordinates_input):
            for j in range(0, len(coordinates_ref) - 1):
                if coordinates_ref[j + 1].z < coord.z <= coordinates_ref[j].z:
                    image_output.data[int(coord.x), int(coord.y), int(coord.z)] = coordinates_ref[j].value

        return image_output

    def label_vertebrae(self, levels_user=None):
        """
        Find the center of mass of vertebral levels specified by the user.
        :return: image_output: Image with labels.
        """
        # get center of mass of each vertebral level
        image_cubic2point = self.cubic_to_point()
        # get list of coordinates for each label
        list_coordinates = image_cubic2point.getNonZeroCoordinates(sorting='value')
        # if user did not specify levels, include all:
        if levels_user[0] == 0:
            levels_user = [int(i.value) for i in list_coordinates]
        # loop across labels and remove those that are not listed by the user
        for i_label in range(len(list_coordinates)):
            # check if this level is NOT in levels_user
            if not levels_user.count(int(list_coordinates[i_label].value)):
                # if not, set value to zero
                image_cubic2point.data[int(list_coordinates[i_label].x), int(list_coordinates[i_label].y), int(list_coordinates[i_label].z)] = 0
        # list all labels
        return image_cubic2point

    def MSE(self, threshold_mse=0):
        """
        Compute the Mean Square Distance Error between two sets of labels (input and ref).
        Moreover, a warning is generated for each label mismatch.
        If the MSE is above the threshold provided (by default = 0mm), a log is reported with the filenames considered here.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        # check if all the labels in both the images match
        if len(coordinates_input) != len(coordinates_ref):
            sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord in coordinates_input:
            if np.round(coord.value) not in [np.round(coord_ref.value) for coord_ref in coordinates_ref]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord_ref in coordinates_ref:
            if np.round(coord_ref.value) not in [np.round(coord.value) for coord in coordinates_input]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')

        result = 0.0
        for coord in coordinates_input:
            for coord_ref in coordinates_ref:
                if np.round(coord_ref.value) == np.round(coord.value):
                    result += (coord_ref.z - coord.z) ** 2
                    break
        result = np.sqrt(result / len(coordinates_input))
        sct.printv('MSE error in Z direction = ' + str(result) + ' mm')

        if result > threshold_mse:
            parent, stem, ext = sct.extract_fname(self.image_input.absolutepath)
            fname_report = os.path.join(parent, 'error_log_{}.txt'.format(stem))
            with open(fname_report, 'w') as f:
                f.write('The labels error (MSE) between {} and {} is: {}\n'.format(
                 os.path.relpath(self.image_input.absolutepath, os.path.dirname(fname_report)),
                 os.path.relpath(self.image_ref.absolutepath, os.path.dirname(fname_report)),
                 result))

        return result

    @staticmethod
    def remove_label_coord(coord_input, coord_ref, symmetry=False):
        """
        coord_input and coord_ref should be sets of CoordinateValue in order to improve speed of intersection
        :param coord_input: set of CoordinateValue
        :param coord_ref: set of CoordinateValue
        :param symmetry: boolean,
        :return: intersection of CoordinateValue: list
        """
        from msct_types import CoordinateValue
        if isinstance(coord_input[0], CoordinateValue) and isinstance(coord_ref[0], CoordinateValue) and symmetry:
            coord_intersection = list(set(coord_input).intersection(set(coord_ref)))
            result_coord_input = [coord for coord in coord_input if coord in coord_intersection]
            result_coord_ref = [coord for coord in coord_ref if coord in coord_intersection]
        else:
            result_coord_ref = coord_ref
            result_coord_input = [coord for coord in coord_input if list(filter(lambda x: x.value == coord.value, coord_ref))]
            if symmetry:
                result_coord_ref = [coord for coord in coord_ref if list(filter(lambda x: x.value == coord.value, result_coord_input))]

        return result_coord_input, result_coord_ref

    def remove_label(self, symmetry=False):
        """
        Compare two label images and remove any labels in input image that are not in reference image.
        The symmetry option enables to remove labels from reference image that are not in input image
        """
        # image_output = Image(self.image_input.dim, orientation=self.image_input.orientation, hdr=self.image_input.hdr, verbose=self.verbose)
        image_output = msct_image.zeros_like(self.image_input)

        result_coord_input, result_coord_ref = self.remove_label_coord(self.image_input.getNonZeroCoordinates(coordValue=True),
                                                                       self.image_ref.getNonZeroCoordinates(coordValue=True), symmetry)

        for coord in result_coord_input:
            image_output.data[int(coord.x), int(coord.y), int(coord.z)] = int(np.round(coord.value))

        if symmetry:
            # image_output_ref = Image(self.image_ref.dim, orientation=self.image_ref.orientation, hdr=self.image_ref.hdr, verbose=self.verbose)
            image_output_ref = Image(self.image_ref, verbose=self.verbose)
            for coord in result_coord_ref:
                image_output_ref.data[int(coord.x), int(coord.y), int(coord.z)] = int(np.round(coord.value))
            image_output_ref.absolutepath = self.fname_output[1]
            image_output_ref.save('minimize_int')

            self.fname_output = self.fname_output[0]

        return image_output

    def display_voxel(self):
        """
        Display all the labels that are contained in the input image.
        The image is suppose to be RPI to display voxels. But works also for other orientations
        """
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='value')
        self.useful_notation = ''
        for coord in coordinates_input:
            sct.printv('Position=(' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ') -- Value= ' + str(coord.value), verbose=self.verbose)
            if self.useful_notation:
                self.useful_notation = self.useful_notation + ':'
            self.useful_notation += str(coord)
        sct.printv('All labels (useful syntax):', verbose=self.verbose)
        sct.printv(self.useful_notation, verbose=self.verbose)
        return coordinates_input

    def get_physical_coordinates(self):
        """
        This function returns the coordinates of the labels in the physical referential system.
        :return: a list of CoordinateValue, in the physical (scanner) space
        """
        coord = self.image_input.getNonZeroCoordinates(sorting='value')
        phys_coord = []
        for c in coord:
            # convert pixelar coordinates to physical coordinates
            c_p = self.image_input.transfo_pix2phys([[c.x, c.y, c.z]])[0]
            phys_coord.append(CoordinateValue([c_p[0], c_p[1], c_p[2], c.value]))
        return phys_coord

    def get_coordinates_in_destination(self, im_dest, type='discrete'):
        """
        This function calculate the position of labels in the pixelar space of a destination image
        :param im_dest: Object Image
        :param type: 'discrete' or 'continuous'
        :return: a list of CoordinateValue, in the pixelar (image) space of the destination image
        """
        phys_coord = self.get_physical_coordinates()
        dest_coord = []
        for c in phys_coord:
            if type is 'discrete':
                c_p = im_dest.transfo_phys2pix([[c.x, c.y, c.y]])[0]
            elif type is 'continuous':
                c_p = im_dest.transfo_phys2pix([[c.x, c.y, c.y]], real=False)[0]
            else:
                raise ValueError("The value of 'type' should either be 'discrete' or 'continuous'.")
            dest_coord.append(CoordinateValue([c_p[0], c_p[1], c_p[2], c.value]))
        return dest_coord

    def diff(self):
        """
        Detect any label mismatch between input image and reference image
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        sct.printv("Label in input image that are not in reference image:")
        for coord in coordinates_input:
            isIn = False
            for coord_ref in coordinates_ref:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                sct.printv(coord.value)

        sct.printv("Label in ref image that are not in input image:")
        for coord_ref in coordinates_ref:
            isIn = False
            for coord in coordinates_input:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                sct.printv(coord_ref.value)

    def distance_interlabels(self, max_dist):
        """
        Calculate the distances between each label in the input image.
        If a distance is larger than max_dist, a warning message is displayed.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i in range(0, len(coordinates_input) - 1):
            dist = np.sqrt((coordinates_input[i].x - coordinates_input[i + 1].x)**2 + (coordinates_input[i].y - coordinates_input[i + 1].y)**2 + (coordinates_input[i].z - coordinates_input[i + 1].z)**2)
            if dist < max_dist:
                sct.printv('Warning: the distance between label ' + str(i) + '[' + str(coordinates_input[i].x) + ',' + str(coordinates_input[i].y) + ',' + str(
                    coordinates_input[i].z) + ']=' + str(coordinates_input[i].value) + ' and label ' + str(i + 1) + '[' + str(
                    coordinates_input[i + 1].x) + ',' + str(coordinates_input[i + 1].y) + ',' + str(coordinates_input[i + 1].z) + ']=' + str(
                    coordinates_input[i + 1].value) + ' is larger than ' + str(max_dist) + '. Distance=' + str(dist))

    def continuous_vertebral_levels(self):
        """
        This function transforms the vertebral levels file from the template into a continuous file.
        Instead of having integer representing the vertebral level on each slice, a continuous value that represents
        the position of the slice in the vertebral level coordinate system.
        The image must be RPI
        :return:
        """
        im_input = Image(self.image_input, self.verbose)
        im_output = msct_image.zeros_like(self.image_input)

        # 1. extract vertebral levels from input image
        #   a. extract centerline
        #   b. for each slice, extract corresponding level
        nx, ny, nz, nt, px, py, pz, pt = im_input.dim
        from spinalcordtoolbox.centerline.core import get_centerline
        _, arr_ctl, _ = get_centerline(self.image_input, algo_fitting='bspline')
        x_centerline_fit, y_centerline_fit, z_centerline = arr_ctl
        value_centerline = np.array(
            [im_input.data[int(x_centerline_fit[it]), int(y_centerline_fit[it]), int(z_centerline[it])]
             for it in range(len(z_centerline))])

        # 2. compute distance for each vertebral level --> Di for i being the vertebral levels
        vertebral_levels = {}
        for slice_image, level in enumerate(value_centerline):
            if level not in vertebral_levels:
                vertebral_levels[level] = slice_image

        length_levels = {}
        for level in vertebral_levels:
            indexes_slice = np.where(value_centerline == level)
            length_levels[level] = np.sum([np.sqrt(((x_centerline_fit[indexes_slice[0][index_slice + 1]] - x_centerline_fit[indexes_slice[0][index_slice]]) * px)**2 +
                                                     ((y_centerline_fit[indexes_slice[0][index_slice + 1]] - y_centerline_fit[indexes_slice[0][index_slice]]) * py)**2 +
                                                     ((z_centerline[indexes_slice[0][index_slice + 1]] - z_centerline[indexes_slice[0][index_slice]]) * pz)**2)
                                           for index_slice in range(len(indexes_slice[0]) - 1)])

        # 2. for each slice:
        #   a. identify corresponding vertebral level --> i
        #   b. calculate distance of slice from upper vertebral level --> d
        #   c. compute relative distance in the vertebral level coordinate system --> d/Di
        continuous_values = {}
        for it, iz in enumerate(z_centerline):
            level = value_centerline[it]
            indexes_slice = np.where(value_centerline == level)
            indexes_slice = indexes_slice[0][indexes_slice[0] >= it]
            distance_from_level = np.sum([np.sqrt(((x_centerline_fit[indexes_slice[index_slice + 1]] - x_centerline_fit[indexes_slice[index_slice]]) * px * px) ** 2 +
                                                    ((y_centerline_fit[indexes_slice[index_slice + 1]] - y_centerline_fit[indexes_slice[index_slice]]) * py * py) ** 2 +
                                                    ((z_centerline[indexes_slice[index_slice + 1]] - z_centerline[indexes_slice[index_slice]]) * pz * pz) ** 2)
                                          for index_slice in range(len(indexes_slice) - 1)])
            continuous_values[iz] = level + 2.0 * distance_from_level / float(length_levels[level])

        # 3. saving data
        # for each slice, get all non-zero pixels and replace with continuous values
        coordinates_input = self.image_input.getNonZeroCoordinates()
        im_output.change_type(np.float32)
        # for all points in input, find the value that has to be set up, depending on the vertebral level
        for i, coord in enumerate(coordinates_input):
            im_output.data[int(coord.x), int(coord.y), int(coord.z)] = continuous_values[coord.z]

        return im_output

    def launch_sagittal_viewer(self, labels):
        from spinalcordtoolbox.gui import base
        from spinalcordtoolbox.gui.sagittal import launch_sagittal_dialog

        params = base.AnatomicalParams()
        params.vertebraes = labels
        params.input_file_name = self.image_input.absolutepath
        params.output_file_name = self.fname_output
        params.subtitle = self.msg
        output = msct_image.zeros_like(self.image_input)
        output.absolutepath = self.fname_output
        launch_sagittal_dialog(self.image_input, output, params)

        return output

    def remove_or_keep_labels(self, labels, action):
        """
        Create or remove labels from self.image_input
        :param list(int): Labels to keep or remove
        :param str: 'remove': remove specified labels (i.e. set to zero), 'keep': keep specified labels and remove the others
        """
        if action == 'keep':
            image_output = msct_image.zeros_like(self.image_input)
        elif action == 'remove':
            image_output = self.image_input.copy()
        coordinates_input = self.image_input.getNonZeroCoordinates()

        for labelNumber in labels:
            isInLabels = False
            for coord in coordinates_input:
                if labelNumber == coord.value:
                    new_coord = coord
                    isInLabels = True
            if isInLabels:
                if action == 'keep':
                    image_output.data[int(new_coord.x), int(new_coord.y), int(new_coord.z)] = new_coord.value
                elif action == 'remove':
                    image_output.data[int(new_coord.x), int(new_coord.y), int(new_coord.z)] = 0.0
            else:
                sct.printv("WARNING: Label " + str(float(labelNumber)) + " not found in input image.", type='warning')

        return image_output
示例#38
0
def moco(param):
    """
    Main function that performs motion correction.

    :param param:
    :return:
    """
    # retrieve parameters
    file_data = param.file_data
    file_target = param.file_target
    folder_mat = param.mat_moco  # output folder of mat file
    todo = param.todo
    suffix = param.suffix
    verbose = param.verbose

    # other parameters
    file_mask = 'mask.nii'

    printv('\nInput parameters:', param.verbose)
    printv('  Input file ............ ' + file_data, param.verbose)
    printv('  Reference file ........ ' + file_target, param.verbose)
    printv('  Polynomial degree ..... ' + param.poly, param.verbose)
    printv('  Smoothing kernel ...... ' + param.smooth, param.verbose)
    printv('  Gradient step ......... ' + param.gradStep, param.verbose)
    printv('  Metric ................ ' + param.metric, param.verbose)
    printv('  Sampling .............. ' + param.sampling, param.verbose)
    printv('  Todo .................. ' + todo, param.verbose)
    printv('  Mask  ................. ' + param.fname_mask, param.verbose)
    printv('  Output mat folder ..... ' + folder_mat, param.verbose)

    try:
        os.makedirs(folder_mat)
    except FileExistsError:
        pass

    # Get size of data
    printv('\nData dimensions:', verbose)
    im_data = Image(param.file_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    printv(
        ('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt)),
        verbose)

    # copy file_target to a temporary file
    printv('\nCopy file_target to a temporary file...', verbose)
    file_target = "target.nii.gz"
    convert(param.file_target, file_target, verbose=0)

    # Check if user specified a mask
    if not param.fname_mask == '':
        # Check if this mask is soft (i.e., non-binary, such as a Gaussian mask)
        im_mask = Image(param.fname_mask)
        if not np.array_equal(im_mask.data, im_mask.data.astype(bool)):
            # If it is a soft mask, multiply the target by the soft mask.
            im = Image(file_target)
            im_masked = im.copy()
            im_masked.data = im.data * im_mask.data
            im_masked.save(
                verbose=0)  # silence warning about file overwritting

    # If scan is sagittal, split src and target along Z (slice)
    if param.is_sagittal:
        dim_sag = 2  # TODO: find it
        # z-split data (time series)
        im_z_list = split_data(im_data, dim=dim_sag, squeeze_data=False)
        file_data_splitZ = []
        for im_z in im_z_list:
            im_z.save(verbose=0)
            file_data_splitZ.append(im_z.absolutepath)
        # z-split target
        im_targetz_list = split_data(Image(file_target),
                                     dim=dim_sag,
                                     squeeze_data=False)
        file_target_splitZ = []
        for im_targetz in im_targetz_list:
            im_targetz.save(verbose=0)
            file_target_splitZ.append(im_targetz.absolutepath)
        # z-split mask (if exists)
        if not param.fname_mask == '':
            im_maskz_list = split_data(Image(file_mask),
                                       dim=dim_sag,
                                       squeeze_data=False)
            file_mask_splitZ = []
            for im_maskz in im_maskz_list:
                im_maskz.save(verbose=0)
                file_mask_splitZ.append(im_maskz.absolutepath)
        # initialize file list for output matrices
        file_mat = np.empty((nz, nt), dtype=object)

    # axial orientation
    else:
        file_data_splitZ = [file_data]  # TODO: make it absolute like above
        file_target_splitZ = [file_target]  # TODO: make it absolute like above
        # initialize file list for output matrices
        file_mat = np.empty((1, nt), dtype=object)

        # deal with mask
        if not param.fname_mask == '':
            convert(param.fname_mask, file_mask, squeeze_data=False, verbose=0)
            im_maskz_list = [Image(file_mask)
                             ]  # use a list with single element

    # Loop across file list, where each file is either a 2D volume (if sagittal) or a 3D volume (otherwise)
    # file_mat = tuple([[[] for i in range(nt)] for i in range(nz)])

    file_data_splitZ_moco = []
    printv(
        '\nRegister. Loop across Z (note: there is only one Z if orientation is axial)'
    )
    for file in file_data_splitZ:
        iz = file_data_splitZ.index(file)
        # Split data along T dimension
        # printv('\nSplit data along T dimension.', verbose)
        im_z = Image(file)
        list_im_zt = split_data(im_z, dim=3)
        file_data_splitZ_splitT = []
        for im_zt in list_im_zt:
            im_zt.save(verbose=0)
            file_data_splitZ_splitT.append(im_zt.absolutepath)
        # file_data_splitT = file_data + '_T'

        # Motion correction: initialization
        index = np.arange(nt)
        file_data_splitT_num = []
        file_data_splitZ_splitT_moco = []
        failed_transfo = [0 for i in range(nt)]

        # Motion correction: Loop across T
        for indice_index in sct_progress_bar(range(nt),
                                             unit='iter',
                                             unit_scale=False,
                                             desc="Z=" + str(iz) + "/" +
                                             str(len(file_data_splitZ) - 1),
                                             ascii=False,
                                             ncols=80):

            # create indices and display stuff
            it = index[indice_index]
            file_mat[iz][it] = os.path.join(
                folder_mat,
                "mat.Z") + str(iz).zfill(4) + 'T' + str(it).zfill(4)
            file_data_splitZ_splitT_moco.append(
                add_suffix(file_data_splitZ_splitT[it], '_moco'))
            # deal with masking (except in the 'apply' case, where masking is irrelevant)
            input_mask = None
            if not param.fname_mask == '' and not param.todo == 'apply':
                # Check if mask is binary
                if np.array_equal(im_maskz_list[iz].data,
                                  im_maskz_list[iz].data.astype(bool)):
                    # If it is, pass this mask into register() to be used
                    input_mask = im_maskz_list[iz]
                else:
                    # If not, do not pass this mask into register() because ANTs cannot handle non-binary masks.
                    #  Instead, multiply the input data by the Gaussian mask.
                    im = Image(file_data_splitZ_splitT[it])
                    im_masked = im.copy()
                    im_masked.data = im.data * im_maskz_list[iz].data
                    im_masked.save(
                        verbose=0)  # silence warning about file overwritting

            # run 3D registration
            failed_transfo[it] = register(param,
                                          file_data_splitZ_splitT[it],
                                          file_target_splitZ[iz],
                                          file_mat[iz][it],
                                          file_data_splitZ_splitT_moco[it],
                                          im_mask=input_mask)

            # average registered volume with target image
            # N.B. use weighted averaging: (target * nb_it + moco) / (nb_it + 1)
            if param.iterAvg and indice_index < 10 and failed_transfo[
                    it] == 0 and not param.todo == 'apply':
                im_targetz = Image(file_target_splitZ[iz])
                data_targetz = im_targetz.data
                data_mocoz = Image(file_data_splitZ_splitT_moco[it]).data
                data_targetz = (data_targetz * (indice_index + 1) +
                                data_mocoz) / (indice_index + 2)
                im_targetz.data = data_targetz
                im_targetz.save(verbose=0)

        # Replace failed transformation with the closest good one
        fT = [i for i, j in enumerate(failed_transfo) if j == 1]
        gT = [i for i, j in enumerate(failed_transfo) if j == 0]
        for it in range(len(fT)):
            abs_dist = [np.abs(gT[i] - fT[it]) for i in range(len(gT))]
            if not abs_dist == []:
                index_good = abs_dist.index(min(abs_dist))
                printv(
                    '  transfo #' + str(fT[it]) + ' --> use transfo #' +
                    str(gT[index_good]), verbose)
                # copy transformation
                copy(file_mat[iz][gT[index_good]] + 'Warp.nii.gz',
                     file_mat[iz][fT[it]] + 'Warp.nii.gz')
                # apply transformation
                sct_apply_transfo.main(argv=[
                    '-i', file_data_splitZ_splitT[fT[it]], '-d', file_target,
                    '-w', file_mat[iz][fT[it]] + 'Warp.nii.gz', '-o',
                    file_data_splitZ_splitT_moco[fT[it]], '-x', param.interp
                ])
            else:
                # exit program if no transformation exists.
                printv(
                    '\nERROR in ' + os.path.basename(__file__) +
                    ': No good transformation exist. Exit program.\n', verbose,
                    'error')
                sys.exit(2)

        # Merge data along T
        file_data_splitZ_moco.append(add_suffix(file, suffix))
        if todo != 'estimate':
            im_data_splitZ_splitT_moco = [
                Image(fname) for fname in file_data_splitZ_splitT_moco
            ]
            im_out = concat_data(im_data_splitZ_splitT_moco, 3)
            im_out.absolutepath = file_data_splitZ_moco[iz]
            im_out.save(verbose=0)

    # If sagittal, merge along Z
    if param.is_sagittal:
        # TODO: im_out.dim is incorrect: Z value is one
        im_data_splitZ_moco = [Image(fname) for fname in file_data_splitZ_moco]
        im_out = concat_data(im_data_splitZ_moco, 2)
        dirname, basename, ext = extract_fname(file_data)
        path_out = os.path.join(dirname, basename + suffix + ext)
        im_out.absolutepath = path_out
        im_out.save(verbose=0)

    return file_mat, im_out
示例#39
0
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    # initialize parameters
    param = Param()

    fname_data = arguments.i
    fname_bvecs = arguments.bvec
    average = arguments.a
    remove_temp_files = arguments.r
    path_out = arguments.ofolder

    fname_bvals = arguments.bval
    if arguments.bvalmin:
        param.bval_min = arguments.bvalmin

    # Initialization
    start_time = time.time()

    # printv(arguments)
    printv('\nInput parameters:', verbose)
    printv('  input file ............' + fname_data, verbose)
    printv('  bvecs file ............' + fname_bvecs, verbose)
    printv('  bvals file ............' + fname_bvals, verbose)
    printv('  average ...............' + str(average), verbose)

    # Get full path
    fname_data = os.path.abspath(fname_data)
    fname_bvecs = os.path.abspath(fname_bvecs)
    if fname_bvals:
        fname_bvals = os.path.abspath(fname_bvals)

    # Extract path, file and extension
    path_data, file_data, ext_data = extract_fname(fname_data)

    # create temporary folder
    path_tmp = tmp_create(basename="dmri_separate")

    # copy files into tmp folder and convert to nifti
    printv('\nCopy files into temporary folder...', verbose)
    ext = '.nii'
    dmri_name = 'dmri'
    b0_name = file_data + '_b0'
    b0_mean_name = b0_name + '_mean'
    dwi_name = file_data + '_dwi'
    dwi_mean_name = dwi_name + '_mean'

    if not convert(fname_data, os.path.join(path_tmp, dmri_name + ext)):
        printv('ERROR in convert.', 1, 'error')
    copy(fname_bvecs, os.path.join(path_tmp, "bvecs"), verbose=verbose)

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Get size of data
    im_dmri = Image(dmri_name + ext)
    printv('\nGet dimensions data...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = im_dmri.dim
    printv(
        '.. ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt),
        verbose)

    # Identify b=0 and DWI images
    printv(fname_bvals)
    index_b0, index_dwi, nb_b0, nb_dwi = identify_b0(fname_bvecs, fname_bvals,
                                                     param.bval_min, verbose)

    # Split into T dimension
    printv('\nSplit along T dimension...', verbose)
    im_dmri_split_list = split_data(im_dmri, 3)
    for im_d in im_dmri_split_list:
        im_d.save()

    # Merge b=0 images
    printv('\nMerge b=0...', verbose)
    fname_in_list_b0 = []
    for it in range(nb_b0):
        fname_in_list_b0.append(dmri_name + '_T' + str(index_b0[it]).zfill(4) +
                                ext)
    im_in_list_b0 = [Image(fname) for fname in fname_in_list_b0]
    concat_data(im_in_list_b0, 3).save(b0_name + ext)

    # Average b=0 images
    if average:
        printv('\nAverage b=0...', verbose)
        img = Image(b0_name + ext)
        out = img.copy()
        dim_idx = 3
        if len(np.shape(img.data)) < dim_idx + 1:
            raise ValueError("Expecting image with 4 dimensions!")
        out.data = np.mean(out.data, dim_idx)
        out.save(path=b0_mean_name + ext)

    # Merge DWI
    fname_in_list_dwi = []
    for it in range(nb_dwi):
        fname_in_list_dwi.append(dmri_name + '_T' +
                                 str(index_dwi[it]).zfill(4) + ext)
    im_in_list_dwi = [Image(fname) for fname in fname_in_list_dwi]
    concat_data(im_in_list_dwi, 3).save(dwi_name + ext)

    # Average DWI images
    if average:
        printv('\nAverage DWI...', verbose)
        img = Image(dwi_name + ext)
        out = img.copy()
        dim_idx = 3
        if len(np.shape(img.data)) < dim_idx + 1:
            raise ValueError("Expecting image with 4 dimensions!")
        out.data = np.mean(out.data, dim_idx)
        out.save(path=dwi_mean_name + ext)

    # come back
    os.chdir(curdir)

    # Generate output files
    fname_b0 = os.path.abspath(os.path.join(path_out, b0_name + ext_data))
    fname_dwi = os.path.abspath(os.path.join(path_out, dwi_name + ext_data))
    fname_b0_mean = os.path.abspath(
        os.path.join(path_out, b0_mean_name + ext_data))
    fname_dwi_mean = os.path.abspath(
        os.path.join(path_out, dwi_mean_name + ext_data))
    printv('\nGenerate output files...', verbose)
    generate_output_file(os.path.join(path_tmp, b0_name + ext),
                         fname_b0,
                         verbose=verbose)
    generate_output_file(os.path.join(path_tmp, dwi_name + ext),
                         fname_dwi,
                         verbose=verbose)
    if average:
        generate_output_file(os.path.join(path_tmp, b0_mean_name + ext),
                             fname_b0_mean,
                             verbose=verbose)
        generate_output_file(os.path.join(path_tmp, dwi_mean_name + ext),
                             fname_dwi_mean,
                             verbose=verbose)

    # Remove temporary files
    if remove_temp_files == 1:
        printv('\nRemove temporary files...', verbose)
        rmtree(path_tmp, verbose=verbose)

    # display elapsed time
    elapsed_time = time.time() - start_time
    printv(
        '\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's',
        verbose)

    return fname_b0, fname_b0_mean, fname_dwi, fname_dwi_mean
def interpolate_im_to_ref(im_input, im_input_sc, new_res=0.3, sq_size_size_mm=22.5, interpolation_mode=3):
    nx, ny, nz, nt, px, py, pz, pt = im_input.dim

    im_input_sc = im_input_sc.copy()
    im_input = im_input.copy()

    # keep only spacing and origin in qform to avoid rotation issues
    input_qform = im_input.hdr.get_qform()
    for i in range(4):
        for j in range(4):
            if i != j and j != 3:
                input_qform[i, j] = 0

    im_input.hdr.set_qform(input_qform)
    im_input.hdr.set_sform(input_qform)
    im_input_sc.hdr = im_input.hdr

    sq_size = int(sq_size_size_mm / new_res)
    # create a reference image : square of ones
    im_ref = Image(np.ones((sq_size, sq_size, 1), dtype=np.int), dim=(sq_size, sq_size, 1, 0, new_res, new_res, pz, 0), orientation='RPI')

    # copy input qform matrix to reference image
    im_ref.hdr.set_qform(im_input.hdr.get_qform())
    im_ref.hdr.set_sform(im_input.hdr.get_sform())

    # set correct header to reference image
    im_ref.hdr.set_data_shape((sq_size, sq_size, 1))
    im_ref.hdr.set_zooms((new_res, new_res, pz))

    # save image to set orientation to RPI (not properly done at the creation of the image)
    fname_ref = 'im_ref.nii.gz'
    im_ref.save(fname_ref).change_orientation("RPI")

    # set header origin to zero to get physical coordinates of the center of the square
    im_ref.hdr.as_analyze_map()['qoffset_x'] = 0
    im_ref.hdr.as_analyze_map()['qoffset_y'] = 0
    im_ref.hdr.as_analyze_map()['qoffset_z'] = 0
    im_ref.hdr.set_sform(im_ref.hdr.get_qform())
    im_ref.hdr.set_qform(im_ref.hdr.get_qform())
    [[x_square_center_phys, y_square_center_phys, z_square_center_phys]] = im_ref.transfo_pix2phys(coordi=[[int(sq_size / 2), int(sq_size / 2), 0]])

    list_interpolate_images = []
    # iterate on z dimension of input image
    for iz in range(nz):
        # copy reference image: one reference image per slice
        im_ref_slice_iz = im_ref.copy()

        # get center of mass of SC for slice iz
        x_seg, y_seg = (im_input_sc.data[:, :, iz] > 0).nonzero()
        x_center, y_center = np.mean(x_seg), np.mean(y_seg)
        [[x_center_phys, y_center_phys, z_center_phys]] = im_input_sc.transfo_pix2phys(coordi=[[x_center, y_center, iz]])

        # center reference image on SC for slice iz
        im_ref_slice_iz.hdr.as_analyze_map()['qoffset_x'] = x_center_phys - x_square_center_phys
        im_ref_slice_iz.hdr.as_analyze_map()['qoffset_y'] = y_center_phys - y_square_center_phys
        im_ref_slice_iz.hdr.as_analyze_map()['qoffset_z'] = z_center_phys
        im_ref_slice_iz.hdr.set_sform(im_ref_slice_iz.hdr.get_qform())
        im_ref_slice_iz.hdr.set_qform(im_ref_slice_iz.hdr.get_qform())

        # interpolate input image to reference image
        im_input_interpolate_iz = im_input.interpolate_from_image(im_ref_slice_iz, interpolation_mode=interpolation_mode, border='nearest')
        # reshape data to 2D if needed
        if len(im_input_interpolate_iz.data.shape) == 3:
            im_input_interpolate_iz.data = im_input_interpolate_iz.data.reshape(im_input_interpolate_iz.data.shape[:-1])
        # add slice to list
        list_interpolate_images.append(im_input_interpolate_iz)

    return list_interpolate_images
def pre_processing(fname_target, fname_sc_seg, fname_level=None, fname_manual_gmseg=None, new_res=0.3, square_size_size_mm=22.5, denoising=True, verbose=1, rm_tmp=True, for_model=False):
    printv('\nPre-process data...', verbose, 'normal')

    tmp_dir = sct.tmp_create()

    sct.copy(fname_target, tmp_dir)
    fname_target = ''.join(extract_fname(fname_target)[1:])
    sct.copy(fname_sc_seg, tmp_dir)
    fname_sc_seg = ''.join(extract_fname(fname_sc_seg)[1:])

    curdir = os.getcwd()
    os.chdir(tmp_dir)

    original_info = {'orientation': None, 'im_sc_seg_rpi': None, 'interpolated_images': []}

    im_target = Image(fname_target).copy()
    im_sc_seg = Image(fname_sc_seg).copy()

    # get original orientation
    printv('  Reorient...', verbose, 'normal')
    original_info['orientation'] = im_target.orientation

    # assert images are in the same orientation
    assert im_target.orientation == im_sc_seg.orientation, "ERROR: the image to segment and it's SC segmentation are not in the same orientation"

    im_target_rpi = im_target.copy().change_orientation('RPI', generate_path=True).save()
    im_sc_seg_rpi = im_sc_seg.copy().change_orientation('RPI', generate_path=True).save()
    original_info['im_sc_seg_rpi'] = im_sc_seg_rpi.copy()  # target image in RPI will be used to post-process segmentations

    # denoise using P. Coupe non local means algorithm (see [Manjon et al. JMRI 2010]) implemented in dipy
    if denoising:
        printv('  Denoise...', verbose, 'normal')
        # crop image before denoising to fasten denoising
        nx, ny, nz, nt, px, py, pz, pt = im_target_rpi.dim
        size_x, size_y = (square_size_size_mm + 1) / px, (square_size_size_mm + 1) / py
        size = int(np.ceil(max(size_x, size_y)))
        # create mask
        fname_mask = 'mask_pre_crop.nii.gz'
        sct_create_mask.main(['-i', im_target_rpi.absolutepath, '-p', 'centerline,' + im_sc_seg_rpi.absolutepath, '-f', 'box', '-size', str(size), '-o', fname_mask])
        # crop image
        fname_target_crop = add_suffix(im_target_rpi.absolutepath, '_pre_crop')
        crop_im = ImageCropper(input_file=im_target_rpi.absolutepath, output_file=fname_target_crop, mask=fname_mask)
        im_target_rpi_crop = crop_im.crop()
        # crop segmentation
        fname_sc_seg_crop = add_suffix(im_sc_seg_rpi.absolutepath, '_pre_crop')
        crop_sc_seg = ImageCropper(input_file=im_sc_seg_rpi.absolutepath, output_file=fname_sc_seg_crop, mask=fname_mask)
        im_sc_seg_rpi_crop = crop_sc_seg.crop()
        # denoising
        from sct_maths import denoise_nlmeans
        block_radius = 3
        block_radius = int(im_target_rpi_crop.data.shape[2] / 2) if im_target_rpi_crop.data.shape[2] < (block_radius*2) else block_radius
        patch_radius = block_radius -1
        data_denoised = denoise_nlmeans(im_target_rpi_crop.data, block_radius=block_radius, patch_radius=patch_radius)
        im_target_rpi_crop.data = data_denoised

        im_target_rpi = im_target_rpi_crop
        im_sc_seg_rpi = im_sc_seg_rpi_crop
    else:
        fname_mask = None

    # interpolate image to reference square image (resample and square crop centered on SC)
    printv('  Interpolate data to the model space...', verbose, 'normal')
    list_im_slices = interpolate_im_to_ref(im_target_rpi, im_sc_seg_rpi, new_res=new_res, sq_size_size_mm=square_size_size_mm)
    original_info['interpolated_images'] = list_im_slices # list of images (not Slice() objects)

    printv('  Mask data using the spinal cord segmentation...', verbose, 'normal')
    list_sc_seg_slices = interpolate_im_to_ref(im_sc_seg_rpi, im_sc_seg_rpi, new_res=new_res, sq_size_size_mm=square_size_size_mm, interpolation_mode=1)
    for i in range(len(list_im_slices)):
        # list_im_slices[i].data[list_sc_seg_slices[i].data == 0] = 0
        list_sc_seg_slices[i] = binarize(list_sc_seg_slices[i], thr_min=0.5, thr_max=1)
        list_im_slices[i].data = list_im_slices[i].data * list_sc_seg_slices[i].data

    printv('  Split along rostro-caudal direction...', verbose, 'normal')
    list_slices_target = [Slice(slice_id=i, im=im_slice.data, gm_seg=[], wm_seg=[]) for i, im_slice in enumerate(list_im_slices)]

    # load vertebral levels
    if fname_level is not None:
        printv('  Load vertebral levels...', verbose, 'normal')
        # copy level file to tmp dir
        os.chdir(curdir)
        sct.copy(fname_level, tmp_dir)
        os.chdir(tmp_dir)
        # change fname level to only file name (path = tmp dir now)
        fname_level = ''.join(extract_fname(fname_level)[1:])
        # load levels
        list_slices_target = load_level(list_slices_target, fname_level)

    os.chdir(curdir)

    # load manual gmseg if there is one (model data)
    if fname_manual_gmseg is not None:
        printv('\n\tLoad manual GM segmentation(s) ...', verbose, 'normal')
        list_slices_target = load_manual_gmseg(list_slices_target, fname_manual_gmseg, tmp_dir, im_sc_seg_rpi, new_res, square_size_size_mm, for_model=for_model, fname_mask=fname_mask)

    if rm_tmp:
        # remove tmp folder
        sct.rmtree(tmp_dir)
    return list_slices_target, original_info
示例#42
0
def deep_segmentation_MSlesion(im_image, contrast_type, ctr_algo='svm', ctr_file=None, brain_bool=True,
                               remove_temp_files=1, verbose=1):
    """
    Segment lesions from MRI data.
    :param im_image: Image() object containing the lesions to segment
    :param contrast_type: Constrast of the image. Need to use one supported by the CNN models.
    :param ctr_algo: Algo to find the centerline. See sct_get_centerline
    :param ctr_file: Centerline or segmentation (optional)
    :param brain_bool: If brain if present or not in the image.
    :param remove_temp_files:
    :return:
    """

    # create temporary folder with intermediate results
    tmp_folder = sct.TempFolder(verbose=verbose)
    tmp_folder_path = tmp_folder.get_path()
    if ctr_algo == 'file':  # if the ctr_file is provided
        tmp_folder.copy_from(ctr_file)
        file_ctr = os.path.basename(ctr_file)
    else:
        file_ctr = None
    tmp_folder.chdir()

    # orientation of the image, should be RPI
    logger.info("\nReorient the image to RPI, if necessary...")
    fname_in = im_image.absolutepath
    original_orientation = im_image.orientation
    fname_orient = 'image_in_RPI.nii'
    im_image.change_orientation('RPI').save(fname_orient)

    input_resolution = im_image.dim[4:7]

    # find the spinal cord centerline - execute OptiC binary
    logger.info("\nFinding the spinal cord centerline...")
    contrast_type_ctr = contrast_type.split('_')[0]
    fname_res, centerline_filename = find_centerline(algo=ctr_algo,
                                                    image_fname=fname_orient,
                                                    contrast_type=contrast_type_ctr,
                                                    brain_bool=brain_bool,
                                                    folder_output=tmp_folder_path,
                                                    remove_temp_files=remove_temp_files,
                                                    centerline_fname=file_ctr)
    im_nii, ctr_nii = Image(fname_res), Image(centerline_filename)

    # crop image around the spinal cord centerline
    logger.info("\nCropping the image around the spinal cord...")
    crop_size = 48
    X_CROP_LST, Y_CROP_LST, Z_CROP_LST, im_crop_nii = crop_image_around_centerline(im_in=im_nii,
                                                                                  ctr_in=ctr_nii,
                                                                                  crop_size=crop_size)
    del ctr_nii

    # normalize the intensity of the images
    logger.info("Normalizing the intensity...")
    im_norm_in = apply_intensity_normalization(img=im_crop_nii, contrast=contrast_type)
    del im_crop_nii

    # resample to 0.5mm isotropic
    fname_norm = sct.add_suffix(fname_orient, '_norm')
    im_norm_in.save(fname_norm)
    fname_res3d = sct.add_suffix(fname_norm, '_resampled3d')
    resampling.resample_file(fname_norm, fname_res3d, '0.5x0.5x0.5', 'mm', 'linear',
                             verbose=0)

    # segment data using 3D convolutions
    logger.info("\nSegmenting the MS lesions using deep learning on 3D patches...")
    segmentation_model_fname = os.path.join(sct.__sct_dir__, 'data', 'deepseg_lesion_models',
                                            '{}_lesion.h5'.format(contrast_type))
    fname_seg_crop_res = sct.add_suffix(fname_res3d, '_lesionseg')
    im_res3d = Image(fname_res3d)
    seg_im = segment_3d(model_fname=segmentation_model_fname,
                        contrast_type=contrast_type,
                        im=im_res3d.copy())
    seg_im.save(fname_seg_crop_res)
    del im_res3d, seg_im

    # resample to the initial pz resolution
    fname_seg_res2d = sct.add_suffix(fname_seg_crop_res, '_resampled2d')
    initial_2d_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])
    resampling.resample_file(fname_seg_crop_res, fname_seg_res2d, initial_2d_resolution,
                                                           'mm', 'linear', verbose=0)
    seg_crop = Image(fname_seg_res2d)

    # reconstruct the segmentation from the crop data
    logger.info("\nReassembling the image...")
    seg_uncrop_nii = uncrop_image(ref_in=im_nii, data_crop=seg_crop.copy().data, x_crop_lst=X_CROP_LST,
                                  y_crop_lst=Y_CROP_LST, z_crop_lst=Z_CROP_LST)
    fname_seg_res_RPI = sct.add_suffix(fname_in, '_res_RPI_seg')
    seg_uncrop_nii.save(fname_seg_res_RPI)
    del seg_crop

    # resample to initial resolution
    logger.info("Resampling the segmentation to the original image resolution...")
    initial_resolution = 'x'.join([str(input_resolution[0]), str(input_resolution[1]), str(input_resolution[2])])
    fname_seg_RPI = sct.add_suffix(fname_in, '_RPI_seg')
    resampling.resample_file(fname_seg_res_RPI, fname_seg_RPI, initial_resolution,
                                                           'mm', 'linear', verbose=0)
    seg_initres_nii = Image(fname_seg_RPI)

    if ctr_algo == 'viewer':  # resample and reorient the viewer labels
        fname_res_labels = sct.add_suffix(fname_orient, '_labels-centerline')
        resampling.resample_file(fname_res_labels, fname_res_labels, initial_resolution,
                                                           'mm', 'linear', verbose=0)
        im_image_res_labels_downsamp = Image(fname_res_labels).change_orientation(original_orientation)
    else:
        im_image_res_labels_downsamp = None

    if verbose == 2:
        fname_res_ctr = sct.add_suffix(fname_orient, '_ctr')
        resampling.resample_file(fname_res_ctr, fname_res_ctr, initial_resolution,
                                                           'mm', 'linear', verbose=0)
        im_image_res_ctr_downsamp = Image(fname_res_ctr).change_orientation(original_orientation)
    else:
        im_image_res_ctr_downsamp = None

    # binarize the resampled image to remove interpolation effects
    logger.info("\nBinarizing the segmentation to avoid interpolation effects...")
    thr = 0.1
    seg_initres_nii.data[np.where(seg_initres_nii.data >= thr)] = 1
    seg_initres_nii.data[np.where(seg_initres_nii.data < thr)] = 0

    # reorient to initial orientation
    logger.info("\nReorienting the segmentation to the original image orientation...")
    tmp_folder.chdir_undo()

    # remove temporary files
    if remove_temp_files:
        logger.info("\nRemove temporary files...")
        tmp_folder.cleanup()

    # reorient to initial orientation
    return seg_initres_nii.change_orientation(original_orientation), im_image_res_labels_downsamp, im_image_res_ctr_downsamp
示例#43
0
def deep_segmentation_MSlesion(im_image,
                               contrast_type,
                               ctr_algo='svm',
                               ctr_file=None,
                               brain_bool=True,
                               remove_temp_files=1,
                               verbose=1):
    """
    Segment lesions from MRI data.

    :param im_image: Image() object containing the lesions to segment
    :param contrast_type: Constrast of the image. Need to use one supported by the CNN models.
    :param ctr_algo: Algo to find the centerline. See sct_get_centerline
    :param ctr_file: Centerline or segmentation (optional)
    :param brain_bool: If brain if present or not in the image.
    :param remove_temp_files:
    :return:
    """

    # create temporary folder with intermediate results
    tmp_folder = sct.TempFolder(verbose=verbose)
    tmp_folder_path = tmp_folder.get_path()
    if ctr_algo == 'file':  # if the ctr_file is provided
        tmp_folder.copy_from(ctr_file)
        file_ctr = os.path.basename(ctr_file)
    else:
        file_ctr = None
    tmp_folder.chdir()
    fname_in = im_image.absolutepath

    # re-orient image to RPI
    logger.info("Reorient the image to RPI, if necessary...")
    original_orientation = im_image.orientation
    # fname_orient = 'image_in_RPI.nii'
    im_image.change_orientation('RPI')

    input_resolution = im_image.dim[4:7]

    # Resample image to 0.5mm in plane
    im_image_res = \
        resampling.resample_nib(im_image, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')

    fname_orient = 'image_in_RPI_res.nii'
    im_image_res.save(fname_orient)

    # find the spinal cord centerline - execute OptiC binary
    logger.info("\nFinding the spinal cord centerline...")
    contrast_type_ctr = contrast_type.split('_')[0]
    _, im_ctl, im_labels_viewer = find_centerline(
        algo=ctr_algo,
        image_fname=fname_orient,
        contrast_type=contrast_type_ctr,
        brain_bool=brain_bool,
        folder_output=tmp_folder_path,
        remove_temp_files=remove_temp_files,
        centerline_fname=file_ctr)
    if ctr_algo == 'file':
        im_ctl = \
            resampling.resample_nib(im_ctl, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')

    # crop image around the spinal cord centerline
    logger.info("\nCropping the image around the spinal cord...")
    crop_size = 48
    X_CROP_LST, Y_CROP_LST, Z_CROP_LST, im_crop_nii = crop_image_around_centerline(
        im_in=im_image_res, ctr_in=im_ctl, crop_size=crop_size)
    del im_ctl

    # normalize the intensity of the images
    logger.info("Normalizing the intensity...")
    im_norm_in = apply_intensity_normalization(img=im_crop_nii,
                                               contrast=contrast_type)
    del im_crop_nii

    # resample to 0.5mm isotropic
    fname_norm = sct.add_suffix(fname_orient, '_norm')
    im_norm_in.save(fname_norm)
    fname_res3d = sct.add_suffix(fname_norm, '_resampled3d')
    resampling.resample_file(fname_norm,
                             fname_res3d,
                             '0.5x0.5x0.5',
                             'mm',
                             'linear',
                             verbose=0)

    # segment data using 3D convolutions
    logger.info(
        "\nSegmenting the MS lesions using deep learning on 3D patches...")
    segmentation_model_fname = sct_dir_local_path(
        'data', 'deepseg_lesion_models', '{}_lesion.h5'.format(contrast_type))
    fname_seg_crop_res = sct.add_suffix(fname_res3d, '_lesionseg')
    im_res3d = Image(fname_res3d)
    seg_im = segment_3d(model_fname=segmentation_model_fname,
                        contrast_type=contrast_type,
                        im=im_res3d.copy())
    seg_im.save(fname_seg_crop_res)
    del im_res3d, seg_im

    # resample to the initial pz resolution
    fname_seg_res2d = sct.add_suffix(fname_seg_crop_res, '_resampled2d')
    initial_2d_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])
    resampling.resample_file(fname_seg_crop_res,
                             fname_seg_res2d,
                             initial_2d_resolution,
                             'mm',
                             'linear',
                             verbose=0)
    seg_crop = Image(fname_seg_res2d)

    # reconstruct the segmentation from the crop data
    logger.info("\nReassembling the image...")
    seg_uncrop_nii = uncrop_image(ref_in=im_image_res,
                                  data_crop=seg_crop.copy().data,
                                  x_crop_lst=X_CROP_LST,
                                  y_crop_lst=Y_CROP_LST,
                                  z_crop_lst=Z_CROP_LST)
    fname_seg_res_RPI = sct.add_suffix(fname_in, '_res_RPI_seg')
    seg_uncrop_nii.save(fname_seg_res_RPI)
    del seg_crop

    # resample to initial resolution
    logger.info(
        "Resampling the segmentation to the original image resolution...")
    initial_resolution = 'x'.join([
        str(input_resolution[0]),
        str(input_resolution[1]),
        str(input_resolution[2])
    ])
    fname_seg_RPI = sct.add_suffix(fname_in, '_RPI_seg')
    resampling.resample_file(fname_seg_res_RPI,
                             fname_seg_RPI,
                             initial_resolution,
                             'mm',
                             'linear',
                             verbose=0)
    seg_initres_nii = Image(fname_seg_RPI)

    if ctr_algo == 'viewer':  # resample and reorient the viewer labels
        im_labels_viewer_nib = nib.nifti1.Nifti1Image(
            im_labels_viewer.data, im_labels_viewer.hdr.get_best_affine())
        im_viewer_r_nib = resampling.resample_nib(im_labels_viewer_nib,
                                                  new_size=input_resolution,
                                                  new_size_type='mm',
                                                  interpolation='linear')
        im_viewer = Image(
            im_viewer_r_nib.get_data(),
            hdr=im_viewer_r_nib.header,
            orientation='RPI',
            dim=im_viewer_r_nib.header.get_data_shape()).change_orientation(
                original_orientation)

    else:
        im_viewer = None

    if verbose == 2:
        fname_res_ctr = sct.add_suffix(fname_orient, '_ctr')
        resampling.resample_file(fname_res_ctr,
                                 fname_res_ctr,
                                 initial_resolution,
                                 'mm',
                                 'linear',
                                 verbose=0)
        im_image_res_ctr_downsamp = Image(fname_res_ctr).change_orientation(
            original_orientation)
    else:
        im_image_res_ctr_downsamp = None

    # binarize the resampled image to remove interpolation effects
    logger.info(
        "\nBinarizing the segmentation to avoid interpolation effects...")
    thr = 0.1
    seg_initres_nii.data[np.where(seg_initres_nii.data >= thr)] = 1
    seg_initres_nii.data[np.where(seg_initres_nii.data < thr)] = 0

    # change data type
    seg_initres_nii.change_type(np.uint8)

    # reorient to initial orientation
    logger.info(
        "\nReorienting the segmentation to the original image orientation...")
    tmp_folder.chdir_undo()

    # remove temporary files
    if remove_temp_files:
        logger.info("\nRemove temporary files...")
        tmp_folder.cleanup()

    # reorient to initial orientation
    return seg_initres_nii.change_orientation(
        original_orientation), im_viewer, im_image_res_ctr_downsamp
示例#44
0
def main(args=None):

    # initializations
    output_type = None
    param = Param()
    dim_list = ['x', 'y', 'z', 't']

    # check user arguments
    if not args:
        args = sys.argv[1:]

    # Get parser info
    parser = get_parser()
    arguments = parser.parse(args)
    fname_in = arguments["-i"]
    n_in = len(fname_in)
    verbose = int(arguments.get('-v'))
    sct.init_sct(log_level=verbose, update=True)  # Update log level

    if "-o" in arguments:
        fname_out = arguments["-o"]
    else:
        fname_out = None

    # Open file(s)
    # im_in_list = [Image(fn) for fn in fname_in]

    # run command
    if "-concat" in arguments:
        dim = arguments["-concat"]
        assert dim in dim_list
        dim = dim_list.index(dim)
        im_out = [concat_data(fname_in, dim)]  # TODO: adapt to fname_in

    elif "-copy-header" in arguments:
        im_in = Image(fname_in[0])
        im_dest = Image(arguments["-copy-header"])
        im_dest_new = im_in.copy()
        im_dest_new.data = im_dest.data.copy()
        # im_dest.header = im_in.header
        im_dest_new.absolutepath = im_dest.absolutepath
        im_out = [im_dest_new]
        fname_out = arguments["-copy-header"]

    elif '-display-warp' in arguments:
        im_in = fname_in[0]
        visualize_warp(im_in, fname_grid=None, step=3, rm_tmp=True)
        im_out = None

    elif "-getorient" in arguments:
        im_in = Image(fname_in[0])
        orient = im_in.orientation
        im_out = None

    elif '-keep-vol' in arguments:
        index_vol = arguments['-keep-vol']
        im_in = Image(fname_in[0])
        im_out = [remove_vol(im_in, index_vol, todo='keep')]

    elif '-mcs' in arguments:
        im_in = Image(fname_in[0])
        if n_in != 1:
            sct.printv(parser.usage.generate(error='ERROR: -mcs need only one input'))
        if len(im_in.data.shape) != 5:
            sct.printv(parser.usage.generate(error='ERROR: -mcs input need to be a multi-component image'))
        im_out = multicomponent_split(im_in)

    elif '-omc' in arguments:
        im_ref = Image(fname_in[0])
        for fname in fname_in:
            im = Image(fname)
            if im.data.shape != im_ref.data.shape:
                sct.printv(parser.usage.generate(error='ERROR: -omc inputs need to have all the same shapes'))
            del im
        im_out = [multicomponent_merge(fname_in)]  # TODO: adapt to fname_in

    elif "-pad" in arguments:
        im_in = Image(fname_in[0])
        ndims = len(im_in.data.shape)
        if ndims != 3:
            sct.printv('ERROR: you need to specify a 3D input file.', 1, 'error')
            return

        pad_arguments = arguments["-pad"].split(',')
        if len(pad_arguments) != 3:
            sct.printv('ERROR: you need to specify 3 padding values.', 1, 'error')

        padx, pady, padz = pad_arguments
        padx, pady, padz = int(padx), int(pady), int(padz)
        im_out = [pad_image(im_in, pad_x_i=padx, pad_x_f=padx, pad_y_i=pady,
                            pad_y_f=pady, pad_z_i=padz, pad_z_f=padz)]

    elif "-pad-asym" in arguments:
        im_in = Image(fname_in[0])
        ndims = len(im_in.data.shape)
        if ndims != 3:
            sct.printv('ERROR: you need to specify a 3D input file.', 1, 'error')
            return

        pad_arguments = arguments["-pad-asym"].split(',')
        if len(pad_arguments) != 6:
            sct.printv('ERROR: you need to specify 6 padding values.', 1, 'error')

        padxi, padxf, padyi, padyf, padzi, padzf = pad_arguments
        padxi, padxf, padyi, padyf, padzi, padzf = int(padxi), int(padxf), int(padyi), int(padyf), int(padzi), int(padzf)
        im_out = [pad_image(im_in, pad_x_i=padxi, pad_x_f=padxf, pad_y_i=padyi, pad_y_f=padyf, pad_z_i=padzi, pad_z_f=padzf)]

    elif '-remove-vol' in arguments:
        index_vol = arguments['-remove-vol']
        im_in = Image(fname_in[0])
        im_out = [remove_vol(im_in, index_vol, todo='remove')]

    elif "-setorient" in arguments:
        sct.printv(fname_in[0])
        im_in = Image(fname_in[0])
        im_out = [msct_image.change_orientation(im_in, arguments["-setorient"]).save(fname_out)]

    elif "-setorient-data" in arguments:
        im_in = Image(fname_in[0])
        im_out = [msct_image.change_orientation(im_in, arguments["-setorient-data"], inverse=True).save(fname_out)]

    elif "-split" in arguments:
        dim = arguments["-split"]
        assert dim in dim_list
        im_in = Image(fname_in[0])
        dim = dim_list.index(dim)
        im_out = split_data(im_in, dim)

    elif '-type' in arguments:
        output_type = arguments['-type']
        im_in = Image(fname_in[0])
        im_out = [im_in]  # TODO: adapt to fname_in

    else:
        im_out = None
        sct.printv(parser.usage.generate(error='ERROR: you need to specify an operation to do on the input image'))

    # in case fname_out is not defined, use first element of input file name list
    if fname_out == None:
        fname_out = fname_in[0]

    # Write output
    if im_out is not None:
        sct.printv('Generate output files...', verbose)
        # if only one output
        if len(im_out) == 1 and not '-split' in arguments:
            im_out[0].save(fname_out, dtype=output_type, verbose=verbose)
            sct.display_viewer_syntax([fname_out], verbose=verbose)
        if '-mcs' in arguments:
            # use input file name and add _X, _Y _Z. Keep the same extension
            l_fname_out = []
            for i_dim in range(3):
                l_fname_out.append(sct.add_suffix(fname_out or fname_in[0], '_' + dim_list[i_dim].upper()))
                im_out[i_dim].save(l_fname_out[i_dim], verbose=verbose)
            sct.display_viewer_syntax(fname_out)
        if '-split' in arguments:
            # use input file name and add _"DIM+NUMBER". Keep the same extension
            l_fname_out = []
            for i, im in enumerate(im_out):
                l_fname_out.append(sct.add_suffix(fname_out or fname_in[0], '_' + dim_list[dim].upper() + str(i).zfill(4)))
                im.save(l_fname_out[i])
            sct.display_viewer_syntax(l_fname_out)

    elif "-getorient" in arguments:
        sct.printv(orient)

    elif '-display-warp' in arguments:
        sct.printv('Warping grid generated.', verbose, 'info')