Exemplo n.º 1
0
def create_workflow():
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', dest='data',
                        help="Path to bids dataset")
    args = parser.parse_args()
    if not os.path.exists(args.data):
        raise IOError('Input data not found')
    if not os.path.exists(OUTDIR):
        os.makedirs(OUTDIR)

    # grab data from bids structure
    layout = BIDSLayout(args.data)
    subj = layout.get_subjects()[0]
    func = [f.filename for f in layout.get(subject=subj, type='bold',
                                           extensions=['nii.gz'])][0]

    outfile = os.path.join(OUTDIR, 'test_{}_{}_motcor'.format(subj, ENV['os']))

    # run interface
    mcflirt = MCFLIRT()
    mcflirt.inputs.in_file = func
    # FIX: this has to be unique for each environment
    mcflirt.inputs.out_file = outfile + '.nii.gz'
    res = mcflirt.run()

    # write out json to keep track of information
    ENV.update({'inputs': res.inputs})
    ENV.update({'nipype_version': nipype.__version__})
    #ENV.update({'outputs': res.outputs})
    # write out to json
    env_to_json(ENV, outname=outfile + '.json')
def test_mcflirt_run():
    file_inp = os.path.join(Data_dir, "data_input/sub-02_task-fingerfootlips_bold.nii.gz")
    file_out_ref = os.path.join(Data_dir, "data_ref/sub-02_task-fingerfootlips_bold_MCF.nii.gz")


    mcflt = MCFLIRT()

    mcflt.inputs.in_file = file_inp
    mcflt.inputs.out_file = "output_mcf.nii.gz"

    mcflt.run()

    data_out_ref = nb.load(file_out_ref).get_data()
    data_out = nb.load(mcflt.inputs.out_file).get_data()

    assert np.allclose(data_out_ref, data_out) # think about atol and rtol
def test_mcflirt_run_copy_image(image_fmri_nii, image_copy_fmri_nii, tmpdir):
    file_inp, image_inp, data_inp = image_fmri_nii
    filename_copy, data_copy = image_copy_fmri_nii

    mcflt = MCFLIRT()
    #pdb.set_trace()
    mcflt.inputs.in_file = filename_copy
    mcflt.inputs.out_file = str(tmpdir.join("output_mcf_copy_im.nii.gz"))
    mcflt.basedir = "test"

    mcflt.run()

    img_out = nb.load(mcflt.inputs.out_file)
    data_out = img_out.get_data()
    #pdb.set_trace()
    # since all images are the same mcflirt shouldn't do anything
    assert (data_copy == data_out).all()
def test_mcflirt_run(image_fmri_nii, cost_function, tmpdir):
    file_inp, _, data_inp = image_fmri_nii

    mcflt = MCFLIRT()

    mcflt.inputs.in_file = file_inp
    mcflt.inputs.out_file = str(tmpdir.join("output_mcf.nii.gz"))
    mcflt.basedir = "test"
    setattr(mcflt.inputs, "cost", cost_function)

    mcflt.run()

    data_out = nb.load(mcflt.inputs.out_file).get_data()

    # the middle image shouldn't change
    assert (data_inp[:, :, :, data_inp.shape[3] //
                     2] == data_out[:, :, :, data_inp.shape[3] // 2]).all()

    # i'm assuming that the sum shouldn't change "too much"
    for i in range(data_inp.shape[3]):
        assert np.allclose(data_inp[:, :, :, i].sum(),
                           data_out[:, :, :, i].sum(),
                           rtol=5e-3)
def test_mcflirt_translate_image(image_fmri_nii, tmpdir):
    file_inp, image_inp, data_inp = image_fmri_nii

    mcflt = MCFLIRT()

    filename_trans, data_trans = image_translate_nii(data_inp, image_inp)

    mcflt.inputs.in_file = filename_trans
    mcflt.inputs.out_file = str(tmpdir.join("output_mcf_translate.nii.gz"))
    mcflt.basedir = "test"
    mcflt.inputs.smooth = 0.

    mcflt.run()

    img_out = nb.load(mcflt.inputs.out_file)
    data_out = img_out.get_data()

    # should think about some other error metric
    # this one gives a big error
    # mcflt.inputs.smooth = 0. doesn't really change
    for i in [0, 2]:
        assert np.allclose(data_out[:, :, :, i],
                           data_out[:, :, :, 1],
                           rtol=1e-1)
Exemplo n.º 6
0
def mcflirt(infile: Path) -> Path:
    # motion-correction FIRST supported by e.g.
    # https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6736626/
    # given we are dealing with FSL
    outfile = Path(str(infile).replace("bold.nii.gz", MCFLIRT_SUFFIX))
    if outfile.exists():
        return outfile
    cmd = MCFLIRT(
        in_file=str(infile),
        out_file=str(outfile),
        output_type="NIFTI_GZ",
        save_mats=False,
        save_rms=False,
        stages=3,
        stats_imgs=False,
        terminal_output="stream",
        mean_vol=False,  # speed up things
    )

    results = cmd.run()
    return Path(results.outputs.out_file)
def preproc(data_dir, sink_dir, subject, task, session, run, masks,
            motion_thresh, moco):
    from nipype.interfaces.fsl import MCFLIRT, FLIRT, FNIRT, ExtractROI, ApplyWarp, MotionOutliers, InvWarp, FAST
    #from nipype.interfaces.afni import AlignEpiAnatPy
    from nipype.interfaces.utility import Function
    from nilearn.plotting import plot_anat
    from nilearn import input_data

    #WRITE A DARA GRABBER
    def get_niftis(subject_id, data_dir, task, run, session):
        from os.path import join, exists
        t1 = join(data_dir, subject_id, 'session-{0}'.format(session),
                  'anatomical', 'anatomical-0', 'anatomical.nii.gz')
        #t1_brain_mask = join(data_dir, subject_id, 'session-1', 'anatomical', 'anatomical-0', 'fsl', 'anatomical-bet.nii.gz')
        epi = join(data_dir, subject_id, 'session-{0}'.format(session), task,
                   '{0}-{1}'.format(task, run), '{0}.nii.gz'.format(task))
        assert exists(t1), "t1 does not exist at {0}".format(t1)
        assert exists(epi), "epi does not exist at {0}".format(epi)
        standard = '/home/applications/fsl/5.0.8/data/standard/MNI152_T1_2mm.nii.gz'
        return t1, epi, standard

    data = Function(
        function=get_niftis,
        input_names=["subject_id", "data_dir", "task", "run", "session"],
        output_names=["t1", "epi", "standard"])
    data.inputs.data_dir = data_dir
    data.inputs.subject_id = subject
    data.inputs.run = run
    data.inputs.session = session
    data.inputs.task = task
    grabber = data.run()

    if session == 0:
        sesh = 'pre'
    if session == 1:
        sesh = 'post'

    #reg_dir = '/home/data/nbc/physics-learning/data/first-level/{0}/session-1/retr/retr-{1}/retr-5mm.feat/reg'.format(subject, run)
    #set output paths for quality assurance pngs
    qa1 = join(
        sink_dir, 'qa',
        '{0}-session-{1}_{2}-{3}_t1_flirt.png'.format(subject, session, task,
                                                      run))
    qa2 = join(
        sink_dir, 'qa',
        '{0}-session-{1}_{2}-{3}_mni_flirt.png'.format(subject, session, task,
                                                       run))
    qa3 = join(
        sink_dir, 'qa',
        '{0}-session-{1}_{2}-{3}_mni_fnirt.png'.format(subject, session, task,
                                                       run))
    confound_file = join(
        sink_dir, sesh, subject,
        '{0}-session-{1}_{2}-{3}_confounds.txt'.format(subject, session, task,
                                                       run))

    #run motion correction if indicated
    if moco == True:
        mcflirt = MCFLIRT(ref_vol=144, save_plots=True, output_type='NIFTI_GZ')
        mcflirt.inputs.in_file = grabber.outputs.epi
        #mcflirt.inputs.in_file = join(data_dir, subject, 'session-1', 'retr', 'retr-{0}'.format(run), 'retr.nii.gz')
        mcflirt.inputs.out_file = join(
            sink_dir, sesh, subject,
            '{0}-session-{1}_{2}-{3}_mcf.nii.gz'.format(
                subject, session, task, run))
        flirty = mcflirt.run()
        motion = np.genfromtxt(flirty.outputs.par_file)
    else:
        print "no moco needed"
        motion = 0

    #calculate motion outliers
    try:
        mout = MotionOutliers(metric='fd', threshold=motion_thresh)
        mout.inputs.in_file = grabber.outputs.epi
        mout.inputs.out_file = join(
            sink_dir, sesh, subject,
            '{0}-session-{1}_{2}-{3}_fd-gt-{3}mm'.format(
                subject, session, task, run, motion_thresh))
        mout.inputs.out_metric_plot = join(
            sink_dir, sesh, subject,
            '{0}-session-{1}_{2}-{3}_metrics.png'.format(
                subject, session, task, run))
        mout.inputs.out_metric_values = join(
            sink_dir, sesh, subject,
            '{0}-session-{1}_{2}-{3}_fd.txt'.format(subject, session, task,
                                                    run))
        moutliers = mout.run()
        outliers = np.genfromtxt(moutliers.outputs.out_file)
        e = 'no errors in motion outliers, yay'
    except Exception as e:
        print(e)
        outliers = np.genfromtxt(mout.inputs.out_metric_values)
        #set everything above the threshold to 1 and everything below to 0
        outliers[outliers > motion_thresh] = 1
        outliers[outliers < motion_thresh] = 0

    #concatenate motion parameters and motion outliers to form confounds file

    #outliers = outliers.reshape((outliers.shape[0],1))
    conf = outliers
    np.savetxt(confound_file, conf, delimiter=',')

    #extract an example volume for normalization
    ex_fun = ExtractROI(t_min=144, t_size=1)
    ex_fun.inputs.in_file = flirty.outputs.out_file
    ex_fun.inputs.roi_file = join(
        sink_dir, sesh, subject,
        '{0}-session-{1}_{2}-{3}-example_func.nii.gz'.format(
            subject, session, task, run))
    fun = ex_fun.run()

    warp = ApplyWarp(interp="nn", abswarp=True)

    if not exists(
            '/home/data/nbc/physics-learning/data/first-level/{0}/session-{1}/{2}/{2}-{3}/{2}-5mm.feat/reg/example_func2standard_warp.nii.gz'
            .format(subject, session, task, run)):
        #two-step normalization using flirt and fnirt, outputting qa pix
        flit = FLIRT(cost_func="corratio", dof=12)
        reg_func = flit.run(
            reference=fun.outputs.roi_file,
            in_file=grabber.outputs.t1,
            searchr_x=[-180, 180],
            searchr_y=[-180, 180],
            out_file=join(
                sink_dir, sesh, subject,
                '{0}-session-{1}_{2}-{3}_t1-flirt.nii.gz'.format(
                    subject, session, task, run)),
            out_matrix_file=join(
                sink_dir, sesh, subject,
                '{0}-session-{1}_{2}-{3}_t1-flirt.mat'.format(
                    subject, session, task, run)))
        reg_mni = flit.run(
            reference=grabber.outputs.t1,
            in_file=grabber.outputs.standard,
            searchr_y=[-180, 180],
            searchr_z=[-180, 180],
            out_file=join(
                sink_dir, sesh, subject,
                '{0}-session-{1}_{2}-{3}_mni-flirt-t1.nii.gz'.format(
                    subject, session, task, run)),
            out_matrix_file=join(
                sink_dir, sesh, subject,
                '{0}-session-{1}_{2}-{3}_mni-flirt-t1.mat'.format(
                    subject, session, task, run)))

        #plot_stat_map(aligner.outputs.out_file, bg_img=fun.outputs.roi_file, colorbar=True, draw_cross=False, threshold=1000, output_file=qa1a, dim=-2)
        display = plot_anat(fun.outputs.roi_file, dim=-1)
        display.add_edges(reg_func.outputs.out_file)
        display.savefig(qa1, dpi=300)
        display.close()

        display = plot_anat(grabber.outputs.t1, dim=-1)
        display.add_edges(reg_mni.outputs.out_file)
        display.savefig(qa2, dpi=300)
        display.close()

        perf = FNIRT(output_type='NIFTI_GZ')
        perf.inputs.warped_file = join(
            sink_dir, sesh, subject,
            '{0}-session-{1}_{2}-{3}_mni-fnirt-t1.nii.gz'.format(
                subject, session, task, run))
        perf.inputs.affine_file = reg_mni.outputs.out_matrix_file
        perf.inputs.in_file = grabber.outputs.standard
        perf.inputs.subsampling_scheme = [8, 4, 2, 2]
        perf.inputs.fieldcoeff_file = join(
            sink_dir, sesh, subject,
            '{0}-session-{1}_{2}-{3}_mni-fnirt-t1-warpcoeff.nii.gz'.format(
                subject, session, task, run))
        perf.inputs.field_file = join(
            sink_dir, sesh, subject,
            '{0}-session-{1}_{2}-{3}_mni-fnirt-t1-warp.nii.gz'.format(
                subject, session, task, run))
        perf.inputs.ref_file = grabber.outputs.t1
        reg2 = perf.run()
        warp.inputs.field_file = reg2.outputs.field_file
        #plot fnirted MNI overlaid on example func
        display = plot_anat(grabber.outputs.t1, dim=-1)
        display.add_edges(reg2.outputs.warped_file)
        display.savefig(qa3, dpi=300)
        display.close()
    else:
        warpspeed = InvWarp(output_type='NIFTI_GZ')
        warpspeed.inputs.warp = '/home/data/nbc/physics-learning/data/first-level/{0}/session-{1}/{2}/{2}-{3}/{2}-5mm.feat/reg/example_func2standard_warp.nii.gz'.format(
            subject, session, task, run)
        warpspeed.inputs.reference = fun.outputs.roi_file
        warpspeed.inputs.inverse_warp = join(
            sink_dir, sesh, subject,
            '{0}-session-{1}_{2}-{3}_mni-fnirt-t1-warp.nii.gz'.format(
                subject, session, task, run))
        mni2epiwarp = warpspeed.run()
        warp.inputs.field_file = mni2epiwarp.outputs.inverse_warp

    for key in masks.keys():
        #warp takes us from mni to epi
        warp.inputs.in_file = masks[key]
        warp.inputs.ref_file = fun.outputs.roi_file
        warp.inputs.out_file = join(
            sink_dir, sesh, subject,
            '{0}-session-{1}_{2}-{3}_{4}.nii.gz'.format(
                subject, session, task, run, key))
        net_warp = warp.run()

        qa_file = join(
            sink_dir, 'qa', '{0}-session-{1}_{2}-{3}_qa_{4}.png'.format(
                subject, session, task, run, key))

        display = plotting.plot_roi(net_warp.outputs.out_file,
                                    bg_img=fun.outputs.roi_file,
                                    colorbar=True,
                                    vmin=0,
                                    vmax=18,
                                    draw_cross=False)
        display.savefig(qa_file, dpi=300)
        display.close()

    return flirty.outputs.out_file, confound_file, e
Exemplo n.º 8
0
"""Test FSL's MCFLIRT for motion correction"""
from nipype.interfaces.fsl import MCFLIRT

mcflt = MCFLIRT()
mcflt.inputs.in_file = 'test-data/haxby2001/subj2/bold.nii.gz'
mcflt.inputs.cost = 'mutualinfo'
mcflt.inputs.out_file = 'output/fsl-mcflirt/functional_mcorr.nii.gz'
mcflt.inputs.save_mats = True
mcflt.inputs.save_plots = True
mcflt.cmdline

# How long to run?
%timeit mcflt.run()
Exemplo n.º 9
0
    def _run_interface(self, runtime):
        in_files = self.inputs.in_files
        if not isinstance(in_files, list):
            in_files = [self.inputs.in_files]

        if self.inputs.to_ras:
            in_files = [reorient(inf, newpath=runtime.cwd) for inf in in_files]

        run_hmc = self.inputs.hmc and len(in_files) > 1

        nii_list = []
        # Remove one-sized extra dimensions
        for i, f in enumerate(in_files):
            filenii = nb.load(f)
            filenii = nb.squeeze_image(filenii)
            if len(filenii.shape) == 5:
                raise RuntimeError("Input image (%s) is 5D." % f)
            if filenii.dataobj.ndim == 4:
                nii_list += nb.four_to_three(filenii)
            else:
                nii_list.append(filenii)

        if len(nii_list) > 1:
            filenii = nb.concat_images(nii_list)
        else:
            filenii = nii_list[0]

        merged_fname = fname_presuffix(self.inputs.in_files[0],
                                       suffix="_merged",
                                       newpath=runtime.cwd)
        filenii.to_filename(merged_fname)
        self._results["out_file"] = merged_fname
        self._results["out_avg"] = merged_fname

        if filenii.dataobj.ndim < 4:
            # TODO: generate identity out_mats and zero-filled out_movpar
            return runtime

        if run_hmc:
            from nipype.interfaces.fsl import MCFLIRT

            mcflirt = MCFLIRT(
                cost="normcorr",
                save_mats=True,
                save_plots=True,
                ref_vol=0,
                in_file=merged_fname,
            )
            mcres = mcflirt.run()
            filenii = nb.load(mcres.outputs.out_file)
            self._results["out_file"] = mcres.outputs.out_file
            self._results["out_mats"] = mcres.outputs.mat_file
            self._results["out_movpar"] = mcres.outputs.par_file

        hmcdata = filenii.get_fdata(dtype="float32")
        if self.inputs.grand_mean_scaling:
            if not isdefined(self.inputs.in_mask):
                mean = np.median(hmcdata, axis=-1)
                thres = np.percentile(mean, 25)
                mask = mean > thres
            else:
                mask = nb.load(
                    self.inputs.in_mask).get_fdata(dtype="float32") > 0.5

            nimgs = hmcdata.shape[-1]
            means = np.median(hmcdata[mask[..., np.newaxis]].reshape(
                (-1, nimgs)).T,
                              axis=-1)
            max_mean = means.max()
            for i in range(nimgs):
                hmcdata[..., i] *= max_mean / means[i]

        hmcdata = hmcdata.mean(axis=3)
        if self.inputs.zero_based_avg:
            hmcdata -= hmcdata.min()

        self._results["out_avg"] = fname_presuffix(self.inputs.in_files[0],
                                                   suffix="_avg",
                                                   newpath=runtime.cwd)
        nb.Nifti1Image(hmcdata, filenii.affine,
                       filenii.header).to_filename(self._results["out_avg"])

        return runtime
Exemplo n.º 10
0
st.inputs.interleaved = True
st.inputs.time_repetition = 3
st.inputs.out_file = '../testdata/func_st.nii.gz'

st.run()

# Perform motion correction
mc = MCFLIRT()
mc.inputs.in_file = '../testdata/func_st.nii.gz'
mc.inputs.cost = 'mutualinfo'
mc.inputs.interpolation = 'sinc'
mc.inputs.save_mats = True
mc.inputs.save_plots = True
mc.inputs.mean_vol = True
mc.inputs.out_file = '../testdata/func_mc_st.gz'
mc.run()

#Plot Motion Parameters - saved as .png in same directory

# Rotation
plotter_rot = PlotMotionParams()
plotter_rot.inputs.in_file = '../testdata/func_mc_st.gz.par'
plotter_rot.inputs.in_source = 'fsl'
plotter_rot.inputs.plot_type = 'rotations'
plotter_rot.run()

#Translation
plotter_trans = PlotMotionParams()
plotter_trans.inputs.in_file = '../testdata/func_mc_st.gz.par'
plotter_trans.inputs.in_source = 'fsl'
plotter_trans.inputs.plot_type = 'translations'