Beispiel #1
0
    def parse_args(self, argv=None, values=None):
        if argv is None:
            argv = sys.argv[1:]
        options, args = OptionParser.parse_args(self, argv, values)
        if options.optfile:
            # When an option file is specifeid, extract the options, build
            # a new argv vector and re-parse it. This is the only way to ensure
            # that options in the file work identically to CLI options.
            new_argv = self._add_from_file(argv, options.optfile)
            options, args = OptionParser.parse_args(self, new_argv, values)

        # Deal with case where asldata is given as separate files
        if args and options.asldata is None:
            merged_data = None
            for idx, fname in enumerate(args):
                img = Image(fname)
                shape = list(img.shape)
                if img.ndim == 3:
                    shape += [
                        1,
                    ]
                if merged_data is None:
                    merged_data = np.zeros(shape[:3] + [shape[3] * len(args)])
                merged_data[...,
                            idx * shape[3]:(idx + 1) * shape[3]] = img.data
            merged_img = Image(merged_data, header=img.header)
            temp_asldata = tempfile.NamedTemporaryFile(prefix="oxasl",
                                                       delete=True)
            options.asldata = temp_asldata.name
            merged_img.save(options.asldata)

        return options, args
def tag_control_differencing(subject_dir, target='structural'):
    # load subject's json
    json_dict = load_json(subject_dir)

    # load motion- and distortion- corrected data, Y_moco
    if target == 'structural':
        distcorr_dir = Path(json_dict['structasl']) / 'TIs/DistCorr'
    else:
        distcorr_dir = Path(json_dict['TIs_dir']) / 'SecondPass/DistCorr'
    Y_moco_name = distcorr_dir / 'tis_distcorr.nii.gz'
    Y_moco = Image(str(Y_moco_name))

    # load registered scaling factors, S_st
    sfs_name = distcorr_dir / 'combined_scaling_factors.nii.gz'
    S_st = Image(str(sfs_name))

    # calculate X_perf = X_tc * S_st
    X_tc = np.ones((1, 1, 1, 86)) * 0.5
    X_tc[0, 0, 0, 0::2] = -0.5
    X_perf = X_tc * S_st.data

    # split X_perf and Y_moco into even and odd indices
    X_odd = X_perf[:, :, :, 1::2]
    X_even = X_perf[:, :, :, 0::2]
    Y_odd = Y_moco.data[:, :, :, 1::2]
    Y_even = Y_moco.data[:, :, :, 0::2]

    # calculate B_perf and B_baseline
    B_perf = (Y_odd - Y_even) / (X_odd - X_even)
    B_baseline = (X_odd * Y_even - X_even * Y_odd) / (X_odd - X_even)

    # save both images
    beta_dir_name = distcorr_dir.parent / 'Betas'
    create_dirs([
        beta_dir_name,
    ])
    B_perf_name = beta_dir_name / 'beta_perf.nii.gz'
    B_perf_img = Image(B_perf, header=Y_moco.header)
    B_perf_img.save(B_perf_name)
    B_baseline_name = beta_dir_name / 'beta_baseline.nii.gz'
    B_baseline_img = Image(B_baseline, header=Y_moco.header)
    B_baseline_img.save(B_baseline_name)

    # add B_perf_name to the json as will be needed in oxford_asl
    important_names = {'beta_perf': str(B_perf_name)}
    update_json(important_names, json_dict)
Beispiel #3
0
def estimate_mt(subject_dirs, rois=['wm', ], tr=8, method='separate'):
    """
    Estimates the slice-dependent MT effect on the given subject's 
    calibration images. Performs the estimation using a linear 
    model and calculates scaling factors which can be used to 
    correct the effect.
    """
    for tissue in rois:
        # initialise array to store image-level means
        mean_array = np.zeros((60, 2*len(subject_dirs)))
        count_array = np.zeros((60, 2*len(subject_dirs), 2)) # wm and gm
        # iterate over subjects
        for n1, subject_dir in enumerate(subject_dirs):
            print(subject_dir)
            # load subject's json
            json_dict = load_json(subject_dir)
            # calculate mean per slice of masked in both calib images
            masked_names = (
                json_dict[f'calib0_{tissue}_masked'],
                json_dict[f'calib1_{tissue}_masked']
            )
            for n2, masked_name in enumerate(masked_names):
                if tissue == 'combined':
                    gm_masked, wm_masked = masked_name
                    gm_masked_data = slicetime_correction(
                        image=Image(gm_masked).data, 
                        tissue='gm',
                        tr=tr
                    )
                    wm_masked_data = slicetime_correction(
                        image=Image(wm_masked).data, 
                        tissue='wm',
                        tr=tr
                    )
                    masked_data = gm_masked_data + wm_masked_data
                    gm_bin = np.where(gm_masked_data>0, 1, 0)
                    gm_count = np.sum(gm_bin, axis=(0, 1))[..., np.newaxis]
                    wm_bin = np.where(wm_masked_data>0, 1, 0)
                    wm_count = np.sum(wm_bin, axis=(0, 1))[..., np.newaxis]
                    count_array[:, 2*n1 + n2, :] = np.hstack((wm_count, gm_count))
                else:
                    # load masked calibration data
                    masked_data = slicetime_correction(
                        image=Image(masked_name).data,
                        tissue=tissue,
                        tr=tr
                    )
                # find zero indices
                masked_data[masked_data==0] = np.nan
                # calculate slicewise summary stats
                slicewise_mean = np.nanmean(masked_data, axis=(0, 1))
                mean_array[:, 2*n1 + n2] = slicewise_mean

        # calculate non-zero slicewise mean of mean_array
        slice_means = np.nanmean(mean_array, axis=1)

        # calculate slicewise mean of tissue type counts
        count_means = np.nanmean(count_array, axis=1)

        # fit linear models to central 4 bands
        # estimate scaling factors using these models
        scaling_factors, X_pred, y_pred = fit_linear_model(slice_means, method=method)
        # plot slicewise mean signal
        slice_numbers = np.arange(0, 60, 1)
        x_coords = np.arange(0, 60, 10)
        plt.figure(figsize=(8, 4.5))
        plt.scatter(slice_numbers, slice_means)
        plt.scatter(np.arange(0, 60, 0.001), y_pred.flatten(), color='k', s=0.1)
        plt.ylim([0, PLOT_LIMS[tissue]])
        plt.xlim([0, 60])
        plt.title(f'Mean signal per slice in {tissue} across 47 subjects.')
        plt.xlabel('Slice number')
        plt.ylabel('Mean signal')
        for x_coord in x_coords:
            plt.axvline(x_coord, linestyle='-', linewidth=0.1, color='k')
        # save plot
        plt_name = Path().cwd() / f'{tissue}_mean_per_slice.png'
        plt.savefig(plt_name)

        # plot slicewise mean tissue count for WM and GM
        fig, ax = plt.subplots(figsize=(8, 4.5))
        ax.scatter(slice_numbers, count_means[:, 0], c='c', label='WM pve > 70%') # wm mean
        ax.scatter(slice_numbers, count_means[:, 1], c='g', label='GM pve > 70%') # gm mean
        ax.legend()
        for x_coord in x_coords:
            ax.axvline(x_coord, linestyle='-', linewidth=0.1, color='k')
        plt.title('Mean number of voxels per slice with' +
                ' PVE $\geqslant$ 70% across 47 subjects.')
        plt.xlabel('Slice number')
        plt.ylabel('Mean number of voxels with PVE $\geqslant$ 70% in a given tissue')
        plt_name = Path().cwd() / 'mean_voxel_count.png'
        plt.savefig(plt_name)

        # # the scaling factors have been estimated on images which have been 
        # # slice-timing corrected - the scaling factors should hence be 
        # # adjusted to account for this, as they will be applied to images 
        # # which haven't had this correction
        # scaling_factors = undo_st_correction(scaling_factors, tissue, tr)

        for subject_dir in subject_dirs:
            json_dict = load_json(subject_dir)
            # load calibration image
            calib_img = Image(json_dict['calib0_img'])
            # create and save scaling factors image
            scaling_img = Image(scaling_factors, header=calib_img.header)
            scaling_dir = Path(json_dict['calib_dir']) / 'MTEstimation'
            create_dirs([scaling_dir, ])
            scaling_name = scaling_dir / f'MTcorr_SFs_{tissue}.nii.gz'
            scaling_img.save(scaling_name)
Beispiel #4
0
def hcp_asl_moco(subject_dir,
                 mt_factors,
                 superlevel=1,
                 cores=mp.cpu_count(),
                 order=3):
    """
    This function performs the full motion-correction pipeline for 
    the HCP ASL data. The steps of the pipeline include:
    - Bias-field correction
    - MT correction
    - Saturation recovery
    - Initial slice-timing correction
    - Motion estimation
    - Second slice-timing correction
    - Registration

    Inputs
        - `subject_dir` = pathlib.Path object specifying the 
            subject's base directory
        - `mt_factors` = pathlib.Path object specifying the 
            location of empirically estimated MT correction 
            scaling factors
    """
    # asl sequence parameters
    ntis = 5
    iaf = "tc"
    ibf = "tis"
    tis = [1.7, 2.2, 2.7, 3.2, 3.7]
    rpts = [6, 6, 6, 10, 15]
    slicedt = 0.059
    sliceband = 10
    n_slices = 60
    # load json containing important file info
    json_dict = load_json(subject_dir)
    # create directories for results
    tis_dir_name = Path(json_dict['TIs_dir'])
    first_pass_dir = tis_dir_name / 'FirstPass'
    second_pass_dir = tis_dir_name / 'SecondPass'
    create_dirs([tis_dir_name, first_pass_dir, second_pass_dir])
    # original ASL series and bias field names
    asl_name = Path(json_dict['ASL_seq'])
    bias_name = json_dict['calib0_bias']
    old_m02asl = first_pass_dir / 'MoCo/m02asln.mat'
    # iterate over first and second passes
    for n, iteration in enumerate((first_pass_dir, second_pass_dir)):
        bcorr_dir = iteration / 'BiasCorr'
        mtcorr_dir = iteration / 'MTCorr'
        satrecov_dir = iteration / 'SatRecov'
        stcorr_dir = iteration / 'STCorr'
        moco_dir = iteration / 'MoCo'
        asln2m0_name = moco_dir / 'asln2m0.mat'
        m02asln_name = moco_dir / 'm02asln.mat'
        asln2asl0_name = moco_dir / 'asln2asl0.mat'
        asl02asln_name = moco_dir / 'asl02asln.mat'
        create_dirs([
            bcorr_dir, mtcorr_dir, satrecov_dir, stcorr_dir, moco_dir,
            asln2m0_name, m02asln_name, asln2asl0_name, asl02asln_name
        ])
        # bias correct the original ASL series
        bcorr_img = bcorr_dir / 'tis_biascorr.nii.gz'
        if n == 1:
            # register bias field to ASL series
            reg_bias_name = bcorr_dir / 'bias_reg.nii.gz'
            old_m02asl = rt.MotionCorrection.from_mcflirt(
                str(old_m02asl), bias_name, bias_name)
            nib.save(
                old_m02asl.apply_to_image(bias_name,
                                          bias_name,
                                          superlevel=superlevel,
                                          cores=cores,
                                          order=order), str(reg_bias_name))
            bias_name = reg_bias_name
        fslmaths(str(asl_name)).div(str(bias_name)).run(str(bcorr_img))
        # apply MT scaling factors to the bias-corrected ASL series
        mtcorr_name = mtcorr_dir / 'tis_mtcorr.nii.gz'
        # load mt factors
        mt_sfs = np.loadtxt(mt_factors)
        biascorr_img = Image(str(bcorr_img))
        assert (len(mt_sfs) == biascorr_img.shape[2])
        mtcorr_img = Image(biascorr_img.data * mt_sfs.reshape(1, 1, -1, 1),
                           header=biascorr_img.header)
        mtcorr_img.save(str(mtcorr_name))
        # estimate satrecov model on bias and MT corrected ASL series
        t1_name = _saturation_recovery(mtcorr_name, satrecov_dir, ntis, iaf,
                                       ibf, tis, rpts)
        t1_filt_name = _fslmaths_med_filter_wrapper(t1_name)
        # perform slice-time correction using estimated tissue params
        stcorr_img, stfactors_img = _slicetiming_correction(
            mtcorr_name, t1_filt_name, tis, rpts, slicedt, sliceband, n_slices)
        stcorr_name = stcorr_dir / 'tis_stcorr.nii.gz'
        stcorr_img.save(stcorr_name)
        stfactors_name = stcorr_dir / 'st_scaling_factors.nii.gz'
        stfactors_img.save(stfactors_name)
        # register ASL series to calibration image
        reg_name = moco_dir / 'initial_registration_TIs.nii.gz'
        mcflirt(stcorr_img,
                reffile=json_dict['calib0_mc'],
                mats=True,
                out=str(reg_name))
        # rename mcflirt matrices directory
        orig_mcflirt = moco_dir / 'initial_registration_TIs.nii.gz.mat'
        if asln2m0_name.exists():
            shutil.rmtree(asln2m0_name)
        orig_mcflirt.rename(asln2m0_name)
        # get motion estimates from ASLn to ASL0 (and their inverses)
        asl2m0_list = sorted(asln2m0_name.glob('**/MAT*'))
        m02asl0 = np.linalg.inv(np.loadtxt(asl2m0_list[0]))
        for n, xform in enumerate(asl2m0_list):
            if n == 0:
                fwd_xform = np.eye(4)
            else:
                fwd_xform = m02asl0 @ np.loadtxt(xform)
            inv_xform = np.linalg.inv(fwd_xform)
            np.savetxt(m02asln_name / xform.stem,
                       np.linalg.inv(np.loadtxt(xform)))
            np.savetxt(asln2asl0_name / xform.stem, fwd_xform)
            np.savetxt(asl02asln_name / xform.stem, inv_xform)
    # register pre-ST-correction ASLn to ASL0
    temp_reg_mtcorr = moco_dir / 'temp_reg_tis_mtcorr.nii.gz'
    asln2m0_moco = rt.MotionCorrection.from_mcflirt(str(asln2m0_name),
                                                    str(mtcorr_name),
                                                    json_dict['calib0_mc'])
    asln2asl0 = rt.chain(asln2m0_moco, asln2m0_moco.transforms[0].inverse())
    reg_mtcorr = Image(
        asln2asl0.apply_to_image(str(mtcorr_name),
                                 json_dict['calib0_mc'],
                                 superlevel=superlevel,
                                 cores=cores,
                                 order=order))
    reg_mtcorr.save(str(temp_reg_mtcorr))

    # estimate satrecov model on motion-corrected data
    satrecov_dir = iteration / 'SatRecov2'
    stcorr_dir = iteration / 'STCorr2'
    create_dirs([satrecov_dir, stcorr_dir])
    t1_name = _saturation_recovery(temp_reg_mtcorr, satrecov_dir, ntis, iaf,
                                   ibf, tis, rpts)
    t1_filt_name = _fslmaths_med_filter_wrapper(t1_name)
    # apply asl0 to asln registrations to new t1 map
    reg_t1_filt_name = t1_filt_name.parent / f'{t1_filt_name.stem.split(".")[0]}_reg.nii.gz'
    reg_t1_filt = Image(asln2asl0.inverse().apply_to_image(
        str(t1_filt_name),
        json_dict['calib0_mc'],
        superlevel=superlevel,
        cores=cores,
        order=order))
    reg_t1_filt.save(str(reg_t1_filt_name))
    # perform slice-time correction using estimated tissue params
    stcorr_img, stfactors_img = _slicetiming_correction(
        mtcorr_name, reg_t1_filt_name, tis, rpts, slicedt, sliceband, n_slices)
    # save images
    stcorr_name = stcorr_dir / 'tis_stcorr.nii.gz'
    stfactors_name = stcorr_dir / 'st_scaling_factors.nii.gz'
    stcorr_img.save(str(stcorr_name))
    stfactors_img.save(str(stfactors_name))
    # combined MT and ST scaling factors
    combined_factors_name = stcorr_dir / 'combined_scaling_factors.nii.gz'
    combined_factors_img = Image(stfactors_img.data *
                                 mt_sfs.reshape(1, 1, -1, 1),
                                 header=stfactors_img.header)
    combined_factors_img.save(str(combined_factors_name))
    # save locations of important files in the json
    important_names = {
        'ASL_stcorr': str(stcorr_name),
        'scaling_factors': str(combined_factors_name)
    }
    update_json(important_names, json_dict)
Beispiel #5
0
def fabber(options, output=LOAD, ref_nii=None, progress_log=None, **kwargs):
    """
    Wrapper for Fabber tool

    This is not a 'conventional' FSL command line tool wrapper. Rather it is
    using the Fabber Python API which interfaces to the C++ code using either
    the pure-C api and Python ctypes or its own CLI wrapper.

    The main reason for not using a conventional wrapper is that Fabber
    can take an arbitrary number of inputs including image inputs so the
    @fileOrImage type decorators don't really work well. Also you can't
    tell what image inputs you might have until you query the individual
    model for its options. All of this complexity is therefore hidden in
    the generic Fabber python API.

    Nevertheless we aim to replicate the interface of FSL tools as much as
    possible, e.g. accepting fsl.data.image.Image instances for input
    data, and respecting the LOAD special parameter to indicate which
    data items should be returned as Image instances.

    :param options: Fabber run options
    :param output: Name of output directory to put results in. The special value
                   LOAD is supported and will cause output to be returned as
                   a dictionary instead.
    :param ref_nii: Optional reference Nibabel image to use when writing output
                    files. Not required if main data is FSL or Nibabel image.
    :param progress_log: File-like stream to logging progress percentage to
    :return: Dictionary of output data items name:image. The image matches the
             type of the main input data unless this was a file in which case
             an fsl.data.image.Image is returned.
    """
    extra_search_dirs = kwargs.pop("fabber_dirs", ())
    fab = Fabber(*extra_search_dirs)

    options = dict(options)
    main_data = options.get("data", None)
    if main_data is None:
        raise ValueError("Main data not specified")

    if output != LOAD and not os.path.exists(output):
        os.makedirs(output)

    # Get a reference Nibabel image to use when generating output
    if not ref_nii:
        if isinstance(main_data, Image):
            ref_nii = main_data.nibImage
        else:
            ref_nii = Image(main_data).nibImage

    if ref_nii:
        header = ref_nii.header
        affine = ref_nii.header.get_best_affine()
    else:
        header = None
        affine = np.identity(4)

    # Replace fsl.Image objects with the underlying nibabel object. The Fabber
    # Python API can already handle Numpy arrays, nibabel images and filenames
    for key in list(options.keys()):
        value = options[key]
        if isinstance(value, Image):
            options[key] = value.nibImage

    # Streams to capture stdout and stderr and maybe send them elsewhere too
    stdout = Tee()
    stderr = Tee()

    # Deal with standard keyword arguments
    ret_exitcode = kwargs.pop("exitcode", False)
    ret_stdout = kwargs.pop("stdout", True)
    ret_stderr = kwargs.pop("stderr", False)
    log = kwargs.pop("log", {})

    stdout.add(log.get("stdout", None))
    stderr.add(log.get("stderr", None))
    if log.get("tee", False):
        stdout.add(sys.stdout)
        stderr.add(sys.stderr)

    if kwargs.pop("submit", False):
        raise ValueError("submit not supported for Fabber")

    exception = None
    cmd_output = []
    ret = _Results(cmd_output)
    try:
        ret["paramnames"] = fab.get_model_params(options)
        if log.get("cmd", None):
            log["cmd"].write("Using fabber:\n  core lib=%s\n  core_exe=%s\n  model libs=%s\n  model exes=%s\n" % (fab.core_lib, fab.core_exe, fab.model_libs, fab.model_exes))
            log["cmd"].write("fabber ")
            for key, value in options.items():
                if not isinstance(value, six.string_types) and not isinstance(value, (int, float)):
                    value = str(type(value))
                log["cmd"].write("--%s=%s " % (key.replace("_", "-"), value))
            log["cmd"].write("\n")
        progress_cb = None
        if progress_log:
            progress_cb = percent_progress(progress_log)
        run = fab.run(options, progress_cb)
        ret["logfile"] = run.log

        # Write output data or save it as required
        for data_name, data in run.data.items():
            nii = nib.Nifti1Image(data, header=header, affine=affine)
            nii.update_header()
            nii.header.set_data_dtype(data.dtype)
            img = Image(nii)
            if output == LOAD:
                # Return in-memory data items as the same type as image as the main data
                ret[data_name] = _matching_image(main_data, img)
            else:
                fname = os.path.join(output, data_name)
                img.save(fname)

    except FabberException as exc:
        # Error while actually running Fabber - may raise later
        # or replace with exit code
        exception = exc
        stderr.write(str(exc) + "\n")

    if ret_stdout:
        cmd_output.append(str(stdout))
    if ret_stderr:
        cmd_output.append(str(stderr))
    if ret_exitcode:
        cmd_output.append(int(exception is not None))

    if exception is not None and not ret_exitcode:
        raise exception

    return ret
Beispiel #6
0
def correct_M0(subject_dir, mt_factors):
    """
    Correct the M0 images for a particular subject whose data 
    is stored in `subject_dir`. The corrections to be 
    performed include:
        - Bias-field correction
        - Magnetisation Transfer correction
    
    Inputs
        - `subject_dir` = pathlib.Path object specifying the 
            subject's base directory
        - `mt_factors` = pathlib.Path object specifying the 
            location of empirically estimated MT correction 
            scaling factors
    """
    # load json containing info on where files are stored
    json_dict = load_json(subject_dir)

    # do for both m0 images for the subject, calib0 and calib1
    calib_names = [json_dict['calib0_img'], json_dict['calib1_img']]
    for calib_name in calib_names:
        # get calib_dir and other info
        calib_path = Path(calib_name)
        calib_dir = calib_path.parent
        calib_name_stem = calib_path.stem.split('.')[0]

        # run BET on m0 image
        betted_m0 = bet(calib_name, LOAD)

        # create directories to store results
        fast_dir = calib_dir / 'FAST'
        biascorr_dir = calib_dir / 'BiasCorr'
        mtcorr_dir = calib_dir / 'MTCorr'
        create_dirs([fast_dir, biascorr_dir, mtcorr_dir])

        # estimate bias field on brain-extracted m0 image
        # run FAST, storing results in directory
        fast_base = fast_dir / calib_name_stem
        fast(
            betted_m0['output'],  # output of bet
            out=str(fast_base),
            type=3,  # image type, 3=PD image
            b=True,  # output estimated bias field
            nopve=True  # don't need pv estimates
        )
        bias_name = fast_dir / f'{calib_name_stem}_bias.nii.gz'

        # apply bias field to original m0 image (i.e. not BETted)
        biascorr_name = biascorr_dir / f'{calib_name_stem}_restore.nii.gz'
        fslmaths(calib_name).div(str(bias_name)).run(str(biascorr_name))

        # load mt factors
        mt_sfs = np.loadtxt(mt_factors)
        # apply mt_factors to bias-corrected m0 image
        mtcorr_name = mtcorr_dir / f'{calib_name_stem}_mtcorr.nii.gz'
        biascorr_img = Image(str(biascorr_name))
        assert (len(mt_sfs) == biascorr_img.shape[2])
        mtcorr_img = Image(biascorr_img.data * mt_sfs,
                           header=biascorr_img.header)
        mtcorr_img.save(str(mtcorr_name))

        # add locations of above files to the json
        important_names = {
            f'{calib_name_stem}_bias': str(bias_name),
            f'{calib_name_stem}_bc': str(biascorr_name),
            f'{calib_name_stem}_mc': str(mtcorr_name)
        }
        update_json(important_names, json_dict)
Beispiel #7
0
def se_based_bias_estimation():
    """
    This script seeks to replicate the SE-based bias estimation 
    in the HCP's ComputeSpinEchoBiasField.sh.

    Some modifications need to be made for the ASL pipeline so, 
    for now, we shall use the script below. Necessary 
    modifications include the use of an already pre-computed 
    SpinEchoMean.nii.gz (from topup), the re-sampling of 
    ribbon.mgz and wmparc.mgz into our ASL-gridded T1 space, 
    and the use of our proton density weighted images rather 
    than the GRE.nii.gz in the original script.
    """
    # argument handling
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-i",
        "--input",
        help="Image from which we wish to estimate the bias field.",
        required=True)
    parser.add_argument(
        '--asl',
        help="ASL series to which we wish to apply the bias field. Optional.")
    parser.add_argument("-f",
                        "--fmapmag",
                        help="Fieldmap magnitude image from topup.",
                        required=True)
    parser.add_argument("-m", "--mask", help="Brain mask.", required=True)
    parser.add_argument('--wmparc',
                        help="wmparc.nii.gz from FreeSurfer",
                        required=not "--tissue_mask" in sys.argv,
                        default=None)
    parser.add_argument('--ribbon',
                        help="ribbon.nii.gz from FreeSurfer",
                        required=not "--tissue_mask" in sys.argv,
                        default=None)
    parser.add_argument("--corticallut",
                        help="Filename for FreeSurfer's Cortical Lable Table",
                        required=not "--tissue_mask" in sys.argv,
                        default=None)
    parser.add_argument(
        "--subcorticallut",
        help="Filename for FreeSurfer's Subcortical Lable Table",
        required=not "--tissue_mask" in sys.argv,
        default=None)
    parser.add_argument(
        "--struct2calib",
        help="flirt registration from structural space to the calibration " +
        "image from which we wish to estimate the bias field.",
        default=None)
    parser.add_argument(
        "--structural",
        help="Path to an image in T1w structural space for use when applying "
        + "struct2calib.mat. Only required if the --struct2calib option has " +
        "been provided.",
        required="--struct2calib" in sys.argv,
        default=None)
    parser.add_argument('-o',
                        '--outdir',
                        help="Output directory for results.",
                        required=True)
    parser.add_argument(
        '--debug',
        help="If this argument is specified, all intermediate files " +
        "will be saved for inspection.",
        action='store_true')
    parser.add_argument(
        '--tissue_mask',
        help="Filename for tissue mask we've derived ourselves to use " +
        "instead of a gray matter mask derived from FreeSurfer " + "outputs.",
        default=None)

    args = parser.parse_args()

    m0_name = args.input
    asl_name = args.asl
    sem_name = args.fmapmag
    mask_name = args.mask
    wmparc_name = args.wmparc
    ribbon_name = args.ribbon
    corticallut = args.corticallut
    subcorticallut = args.subcorticallut
    outdir = Path(args.outdir)
    debug = args.debug
    tissue_mask = args.tissue_mask

    # create output directory
    outdir.mkdir(exist_ok=True)

    # load images
    m0_img, sem_img, mask_img = [
        Image(name) for name in (m0_name, sem_name, mask_name)
    ]

    # find ratio between SpinEchoMean and M0
    SEdivM0 = np.where(m0_img.data != 0, (sem_img.data / m0_img.data), 0)
    if debug:
        SEdivM0_name = str(outdir / 'SEdivM0.nii.gz')
        SEdivM0_img = Image(SEdivM0, header=m0_img.header)
        SEdivM0_img.save(SEdivM0_name)

    # apply mask to ratio
    SEdivM0_brain = SEdivM0 * mask_img.data
    if debug:
        SEdivM0_brain_name = str(outdir / 'SEdivM0_brain.nii.gz')
        SEdivM0_brain_img = Image(SEdivM0_brain, header=m0_img.header)
        SEdivM0_brain_img.save(SEdivM0_brain_name)

    # get summary stats for thresholding
    nanned_temp = np.where(mask_img.data == 0, np.nan, SEdivM0_brain)
    median, std = [np.nanmedian(nanned_temp), np.nanstd(nanned_temp)]
    if debug:
        print(np.array(median), np.array(std))
        savenames = [
            str(outdir / f'ratio_{stat}.txt') for stat in ('median', 'std')
        ]
        [
            np.savetxt(name, [val])
            for name, val in zip(savenames, (median, std))
        ]

    # apply thresholding
    lower, upper = [median - (std / 3), median + (std / 3)]
    print(lower, upper)
    SEdivM0_brain_thr = np.where(
        np.logical_and(SEdivM0_brain >= lower, SEdivM0_brain <= upper),
        SEdivM0_brain, 0)
    if debug:
        SEdivM0_brain_thr_name = str(outdir / 'SEdivM0_brain_thr.nii.gz')
        SEdivM0_brain_thr_img = Image(SEdivM0_brain_thr, header=m0_img.header)
        SEdivM0_brain_thr_img.save(SEdivM0_brain_thr_name)

    ### HCP pipeline does median dilation here but isn't used - skip for now ###

    # set sigma for smoothing used in HCPPipeline
    fwhm = 5
    sigma = fwhm / np.sqrt(8 * np.log(2))
    # binarise and smooth the thresholded image
    SEdivM0_brain_thr_roi = np.where(SEdivM0_brain_thr > 0, 1,
                                     0).astype(np.float)
    SEdivM0_brain_thr_s5 = scipy.ndimage.gaussian_filter(SEdivM0_brain_thr,
                                                         sigma=sigma)
    SEdivM0_brain_thr_roi_s5 = scipy.ndimage.gaussian_filter(
        SEdivM0_brain_thr_roi, sigma=sigma)
    SEdivM0_brain_bias = np.where(
        np.logical_and(SEdivM0_brain_thr_roi_s5 != 0, mask_img.data == 1),
        (SEdivM0_brain_thr_s5 / SEdivM0_brain_thr_roi_s5), 0)
    if debug:
        savenames = [
            str(outdir / f'SEdivM0_brain_{name}.nii.gz')
            for name in ('thr_roi', 'thr_s5', 'thr_roi_s5', 'bias')
        ]
        images = [
            Image(array, header=m0_img.header)
            for array in (SEdivM0_brain_thr_roi, SEdivM0_brain_thr_s5,
                          SEdivM0_brain_thr_roi_s5, SEdivM0_brain_bias)
        ]
        [image.save(savename) for image, savename in zip(images, savenames)]

    ### HCP pipeline does median dilation here but isn't used - skip for now ###

    # correct the SEFM image
    SpinEchoMean_brain_BC = np.where(
        np.logical_and(mask_img.data == 1, SEdivM0_brain_bias != 0),
        sem_img.data / SEdivM0_brain_bias, 0)
    if debug:
        SpinEchoMean_brain_BC_name = str(outdir /
                                         'SpinEchoMean_brain_BC.nii.gz')
        SpinEchoMean_brain_BC_img = Image(SpinEchoMean_brain_BC,
                                          header=m0_img.header)
        SpinEchoMean_brain_BC_img.save(SpinEchoMean_brain_BC_name)

    # get ratio between bias-corrected FM and M0 image
    SEBCdivM0_brain = np.where(
        np.logical_and(mask_img.data == 1, SpinEchoMean_brain_BC != 0),
        m0_img.data / SpinEchoMean_brain_BC, 0)
    if debug:
        SEBCdivM0_brain_name = str(outdir / 'SEBCdivM0_brain.nii.gz')
        SEBCdivM0_brain_img = Image(SEBCdivM0_brain, header=m0_img.header)
        SEBCdivM0_brain_img.save(SEBCdivM0_brain_name)

    # find dropouts
    Dropouts = np.where(
        np.logical_and(SEBCdivM0_brain > 0, SEBCdivM0_brain < 0.6), 1, 0)
    Dropouts_inv = np.where(Dropouts == 1, 0, 1)
    if debug:
        savenames = [
            str(outdir / f'{name}.nii.gz')
            for name in ('Dropouts', 'Dropouts_inv')
        ]
        images = [
            Image(array, header=m0_img.header)
            for array in (Dropouts, Dropouts_inv)
        ]
        [image.save(savename) for image, savename in zip(images, savenames)]

    if tissue_mask:
        tissue_mask = rt.Registration.identity().apply_to_image(
            tissue_mask, m0_name, order=0).get_fdata()
        if debug:
            savename = str(outdir / 'TissueMask.nii.gz')
            image = Image(tissue_mask, header=m0_img.header)
            image.save(savename)
    else:
        # downsample wmparc and ribbon to ASL-gridded T1 resolution
        if args.struct2calib:
            registration = rt.Registration.from_flirt(args.struct2calib,
                                                      args.structural, m0_name)
        else:
            registration = rt.Registration.identity()
        wmparc_aslt1, ribbon_aslt1 = [
            registration.apply_to_image(name, m0_name, order=0)
            for name in (wmparc_name, ribbon_name)
        ]
        # parse LUTs
        c_labels, sc_labels = [
            parse_LUT(lut) for lut in (corticallut, subcorticallut)
        ]
        cgm, scgm = [
            np.zeros(ribbon_aslt1.shape),
            np.zeros(wmparc_aslt1.shape)
        ]
        cgm, scgm = [
            np.zeros(ribbon_aslt1.shape),
            np.zeros(wmparc_aslt1.shape)
        ]
        for label in c_labels:
            cgm = np.where(ribbon_aslt1.get_fdata() == label, 1, cgm)
        for label in sc_labels:
            scgm = np.where(wmparc_aslt1.get_fdata() == label, 1, scgm)
        if debug:
            savenames = [
                str(outdir / f'{pre}GreyMatter.nii.gz')
                for pre in ('Cortical', 'Subcortical')
            ]
            images = [
                Image(array, header=m0_img.header) for array in (cgm, scgm)
            ]
            [
                image.save(savename)
                for image, savename in zip(images, savenames)
            ]

        # combine masks
        tissue_mask = np.where(np.logical_or(cgm == 1, scgm == 1), 1, 0)
        if debug:
            savename = str(outdir / 'AllGreyMatter.nii.gz')
            image = Image(tissue_mask, header=m0_img.header)
            image.save(savename)

    # mask M0 image with both the tissue mask and Dropouts_inv mask
    M0_grey = np.where(np.logical_and(tissue_mask == 1, Dropouts_inv == 1),
                       m0_img.data, 0).astype(np.float)
    M0_greyroi = np.where(M0_grey != 0, 1, 0).astype(np.float)
    M0_grey_s5, M0_greyroi_s5 = [
        scipy.ndimage.gaussian_filter(arr, sigma=sigma)
        for arr in (M0_grey, M0_greyroi)
    ]
    if debug:
        savenames = [
            str(outdir / f'M0_grey{part}.nii.gz')
            for part in ('', 'roi', '_s5', 'roi_s5')
        ]
        images = [
            Image(array, header=m0_img.header)
            for array in (M0_grey, M0_greyroi, M0_grey_s5, M0_greyroi_s5)
        ]
        [image.save(savename) for image, savename in zip(images, savenames)]

    # M0_bias_raw needs to undergo fslmaths' -dilall
    M0_bias_raw = np.where(
        np.logical_and(tissue_mask != 0, M0_greyroi_s5 != 0),
        M0_grey_s5 / M0_greyroi_s5, 0)
    M0_bias_raw_name = str(outdir / 'M0_bias_raw.nii.gz')
    M0_bias_raw_img = Image(M0_bias_raw, header=m0_img.header)
    M0_bias_raw_img.save(M0_bias_raw_name)
    dilall_cmd = [
        'fslmaths', M0_bias_raw_name, '-dilall', '-mas', mask_name,
        M0_bias_raw_name
    ]
    subprocess.run(dilall_cmd, check=True)

    # reload dilalled-and-masked M0_bias_raw
    M0_bias_raw_img = Image(M0_bias_raw_name)
    M0_bias_raw = M0_bias_raw_img.data

    # refine bias field
    M0_bias_roi = np.where(M0_bias_raw > 0, 1, 0).astype(np.float)
    M0_bias_raw_s5, M0_bias_roi_s5 = [
        scipy.ndimage.gaussian_filter(array, sigma=sigma)
        for array in (M0_bias_raw, M0_bias_roi)
    ]
    M0_bias = np.where(np.logical_and(mask_img.data != 0, M0_bias_roi_s5 != 0),
                       M0_bias_raw_s5 / M0_bias_roi_s5, 0)
    if debug:
        savenames = [
            str(outdir / f'M0_bias{part}.nii.gz')
            for part in ('roi', 'raw_s5', 'roi_s5', '')
        ]
        images = [
            Image(array, header=m0_img.header)
            for array in (M0_bias_roi, M0_bias_raw_s5, M0_bias_roi_s5, M0_bias)
        ]
        [image.save(savename) for image, savename in zip(images, savenames)]

    # get summary stats
    nanned_temp = np.where(mask_img.data == 0, np.nan, M0_bias)
    mean = np.nanmean(nanned_temp)

    # get sebased bias - should get ref also but leaving for now
    sebased_bias = M0_bias / mean
    sebased_bias_name = str(outdir / 'sebased_bias.nii.gz')
    sebased_bias_img = Image(sebased_bias, header=m0_img.header)
    sebased_bias_img.save(sebased_bias_name)

    # apply 2 rounds of dilation to sebased_bias
    sebased_bias_dil_name = str(outdir / 'sebased_bias_dil.nii.gz')
    bias_dil_cmd = [
        'fslmaths', sebased_bias_name, '-dilM', '-dilM', sebased_bias_dil_name
    ]
    subprocess.run(bias_dil_cmd, check=True)

    # apply bias field to calibration and ASL images
    sebased_bias = Image(sebased_bias_dil_name)
    calib_bc = np.where(sebased_bias.data != 0,
                        m0_img.data / sebased_bias.data, m0_img.data)
    Image(calib_bc,
          header=m0_img.header).save(str(outdir / "calib0_secorr.nii.gz"))
    if asl_name:
        asl_img = Image(asl_name)
        asl_bc = np.where(sebased_bias.data[..., np.newaxis] != 0,
                          asl_img.data / sebased_bias.data[..., np.newaxis],
                          asl_img.data)
        Image(asl_bc,
              header=asl_img.header).save(str(outdir / "tis_secorr.nii.gz"))
Beispiel #8
0
def estimate_mt(subject_dirs,
                rois=[
                    'wm',
                ],
                tr=8,
                method='separate',
                outdir=None,
                ignore_dropouts=False):
    """
    Estimates the slice-dependent MT effect on the given subject's 
    calibration images. Performs the estimation using a linear 
    model and calculates scaling factors which can be used to 
    correct the effect.
    """
    outdir = Path(outdir).resolve(strict=True) if outdir else Path.cwd()
    errors = []
    suf = "_ignoredropouts" if ignore_dropouts else ""
    error_free_subs = []
    for tissue in rois:
        # initialise array to store image-level means
        mean_array = np.zeros((60, 2 * len(subject_dirs)))
        count_array = np.zeros((60, 2 * len(subject_dirs), 2))  # wm and gm
        # iterate over subjects
        for n1, subject_dir in enumerate(subject_dirs):
            try:
                print(subject_dir)
                mask_dirs = [
                    subject_dir / "ASL/Calib" / c /
                    f"SEbased_MT_t1mask{suf}/DistCorr/masks"
                    for c in ("Calib0", "Calib1")
                ]
                tissues = ("gm", "wm") if tissue == "combined" else (tissue, )
                masked_names = [[
                    mask_dir / tissue / f"calib{n}_{t}_masked.nii.gz"
                    for t in tissues
                ] for n, mask_dir in enumerate(mask_dirs)]
                for n2, masked_name in enumerate(masked_names):
                    if tissue == 'combined':
                        gm_masked, wm_masked = masked_name
                        gm_masked_data = slicetime_correction(image=Image(
                            str(gm_masked)).data,
                                                              tissue='gm',
                                                              tr=tr)
                        wm_masked_data = slicetime_correction(image=Image(
                            str(wm_masked)).data,
                                                              tissue='wm',
                                                              tr=tr)
                        masked_data = gm_masked_data + wm_masked_data
                        gm_bin = np.where(gm_masked_data > 0, 1, 0)
                        gm_count = np.sum(gm_bin, axis=(0, 1))[..., np.newaxis]
                        wm_bin = np.where(wm_masked_data > 0, 1, 0)
                        wm_count = np.sum(wm_bin, axis=(0, 1))[..., np.newaxis]
                        count_array[:, 2 * n1 + n2, :] = np.hstack(
                            (wm_count, gm_count))
                    else:
                        # load masked calibration data
                        masked_data = slicetime_correction(image=Image(
                            str(*masked_name)).data,
                                                           tissue=tissue,
                                                           tr=tr)
                    # find zero indices
                    masked_data[masked_data == 0] = np.nan
                    # calculate slicewise summary stats
                    slicewise_mean = np.nanmean(masked_data, axis=(0, 1))
                    mean_array[:, 2 * n1 + n2] = slicewise_mean
                error_free_subs.append(subject_dir)
            except:
                errors.append(tissue + " " + str(subject_dir))
        # calculate non-zero slicewise mean of mean_array
        slice_means = np.nanmean(mean_array, axis=1)
        slice_std = np.nanstd(mean_array, axis=1)

        # calculate slicewise mean of tissue type counts
        count_means = np.nanmean(count_array, axis=1)

        # fit linear models to central 4 bands
        # estimate scaling factors using these models
        scaling_factors, X_pred, y_pred = fit_linear_model(slice_means,
                                                           method=method)
        # plot slicewise mean signal
        slice_numbers = np.arange(0, 60, 1)
        x_coords = np.arange(0, 60, 10)
        plt.figure(figsize=(8, 4.5))
        plt.scatter(slice_numbers, slice_means)
        plt.errorbar(slice_numbers,
                     slice_means,
                     slice_std,
                     linestyle='None',
                     capsize=3)
        plt.ylim([0, PLOT_LIMS[tissue]])
        plt.xlim([0, 60])
        if tissue == 'combined':
            plt.title(
                f'Mean signal per slice in GM and WM across {int(len(error_free_subs)/len(rois))} subjects.'
            )
        else:
            plt.title(
                f'Mean signal per slice in {tissue} ({method}) across {int(len(error_free_subs)/len(rois))} subjects.'
            )
        plt.xlabel('Slice number')
        plt.ylabel('Mean signal')
        for x_coord in x_coords:
            plt.axvline(x_coord, linestyle='-', linewidth=0.1, color='k')
        # save plot
        plt_name = outdir / f'{method}_{tissue}_mean_per_slice_t1.png'
        plt.savefig(plt_name)
        # add linear models on top
        plt.scatter(np.arange(10, 50, 0.001),
                    y_pred.flatten()[10000:50000],
                    color='k',
                    s=0.1)
        plt_name = outdir / f'{method}_{tissue}_mean_per_slice_with_lin_sebased.png'
        plt.savefig(plt_name)

        # plot rescaled slice-means
        fig, ax = plt.subplots(figsize=(8, 4.5))
        rescaled_means = slice_means * scaling_factors
        plt.scatter(slice_numbers, rescaled_means)
        plt.ylim([0, PLOT_LIMS[tissue]])
        plt.xlim([0, 60])
        if tissue == 'combined':
            plt.title(
                f'Rescaled mean signal per slice in GM and WM across {int(len(error_free_subs)/len(rois))} subjects.'
            )
        else:
            plt.title(
                f'Rescaled mean signal per slice in {tissue} across {int(len(error_free_subs)/len(rois))} subjects.'
            )
        plt.xlabel('Slice number')
        plt.ylabel('Rescaled mean signal')
        for x_coord in x_coords:
            plt.axvline(x_coord, linestyle='-', linewidth=0.1, color='k')
        # save plot
        plt_name = outdir / f'{method}_{tissue}_mean_per_slice_rescaled_sebased.png'
        plt.savefig(plt_name)

        # plot slicewise mean tissue count for WM and GM
        fig, ax = plt.subplots(figsize=(8, 4.5))
        ax.scatter(slice_numbers,
                   count_means[:, 0],
                   c='c',
                   label='WM pve > 70%')  # wm mean
        ax.scatter(slice_numbers,
                   count_means[:, 1],
                   c='g',
                   label='GM pve > 70%')  # gm mean
        ax.legend()
        for x_coord in x_coords:
            ax.axvline(x_coord, linestyle='-', linewidth=0.1, color='k')
        plt.title(
            'Mean number of voxels per slice with' +
            f' PVE $\geqslant$ 70% across {int(len(error_free_subs)/len(rois))} subjects.'
        )
        plt.xlabel('Slice number')
        plt.ylabel(
            'Mean number of voxels with PVE $\geqslant$ 70% in a given tissue')
        plt_name = outdir / f'mean_voxel_count_sebased.png'
        plt.savefig(plt_name)

        # # the scaling factors have been estimated on images which have been
        # # slice-timing corrected - the scaling factors should hence be
        # # adjusted to account for this, as they will be applied to images
        # # which haven't had this correction
        # scaling_factors = undo_st_correction(scaling_factors, tissue, tr)

        # save scaling factors as a .txt file
        sfs_savename = outdir / f'{method}_{tissue}_scaling_factors_sebased.txt'
        np.savetxt(sfs_savename, scaling_factors, fmt='%.5f')
        # create array from scaling_factors
        scaling_factors = np.tile(scaling_factors, (86, 86, 1))
        for subject_dir in subject_dirs:
            # load bias (and possibly distortion) corrected calibration image
            method_dir = subject_dir / f"ASL/Calib/Calib0/SEbased_MT_t1mask{suf}/DistCorr"
            calib_name = method_dir / "calib0_restore.nii.gz"
            calib_img = Image(str(calib_name))
            # create and save scaling factors image
            scaling_img = Image(scaling_factors, header=calib_img.header)
            mtcorr_dir = method_dir / "MTCorr"
            mtcorr_dir.mkdir(exist_ok=True)
            scaling_name = mtcorr_dir / f'MTcorr_SFs_{method}_{tissue}_sebased.nii.gz'
            scaling_img.save(scaling_name)

            # apply scaling factors to image to perform MT correction
            mtcorr_name = mtcorr_dir / f'calib0_mtcorr_{method}_{tissue}_sebased.nii.gz'
            mtcorr_img = Image(calib_img.data * scaling_factors,
                               header=calib_img.header)
            mtcorr_img.save(str(mtcorr_name))
    return errors