Exemple #1
0
def apply_xfm_node(config: dict, **kwargs):
    '''
    Parses config file to return desired apply_xfm node.
    Parameters
    ----------
    config : dict
        PALS config file
    kwargs
        Keyword arguments to send to registration method.

    Returns
    -------
    MapNode
    '''

    if (not config['Analysis']['Registration']):
        # No registration; no xfm to apply.
        n = MapNode(Function(function=infile_to_outfile,
                             input_names=['in_file', 'in_matrix_file'],
                             output_names='out_file'),
                    name='transformation_skip',
                    iterfield=['in_file', 'in_matrix_file'])
    else:
        n = MapNode(fsl.FLIRT(apply_xfm=True,
                              reference=config['Registration']['reference']),
                    name='transformation_flirt',
                    iterfield=['in_file', 'in_matrix_file'])
    return n
Exemple #2
0
def registration_node(config: dict, **kwargs):
    '''
    Parses config file to return desired registration method.
    Parameters
    ----------
    config : dict
        PALS config file
    kwargs
        Keyword arguments to send to registration method.

    Returns
    -------
    MapNode
    '''
    # Get registration method
    reg_method = config['Analysis']['RegistrationMethod']
    if (not config['Analysis']['Registration']):
        # No registration; in -> out
        n = MapNode(Function(function=reg_no_reg,
                             input_names=['in_file'],
                             output_names=['out_file', 'out_matrix_file']),
                    name='registration_identity',
                    iterfield='in_file')
    elif (reg_method.lower() == 'flirt'):
        # Use FLIRT
        n = MapNode(fsl.FLIRT(),
                    name='registration_flirt',
                    iterfield='in_file')
        for k, v in kwargs.items():
            setattr(n.inputs, k, v)
    else:
        raise (NotImplementedError(
            f'Registration method {reg_method} not implemented.'))
    return n
Exemple #3
0
def extraction_node(config: dict, **kwargs):
    '''
    Parses config file to return desired brain extraction method.
    Parameters
    ----------
    config : dict
        PALS config file
    kwargs
        Keyword arguments to send to brain extraction method.

    Returns
    -------
    MapNode
    '''
    # Get extraction type
    extract_type = config['Analysis']['BrainExtractionMethod']
    if (not config['Analysis']['BrainExtraction']):
        # No brain extraction; in-> out
        n = MapNode(Function(function=infile_to_outfile,
                             input_names='in_file',
                             output_names='out_file'),
                    name='extract_skip',
                    iterfield='in_file')
        return n
    elif (extract_type.lower() == 'bet'):
        n = MapNode(fsl.BET(**kwargs),
                    name='extraction_bet',
                    iterfield='in_file')
        return n
    else:
        raise (NotImplementedError(
            f'Extraction method {extract_type} not implemented.'))
Exemple #4
0
def wf_transform_anat(in_file_list, in_matrix_file_list, reference):
    func2std_xform = MapNode(
        FLIRT(output_type='NIFTI', apply_xfm=True),
        name="func2std_xform",
        iterfield=['in_file', 'in_matrix_file', 'reference'])

    inputspec = Node(IdentityInterface(
        fields=['in_file_list', 'in_matrix_file_list', 'reference']),
                     name="inputspec")

    inputspec.inputs.in_file_list = in_file_list
    inputspec.inputs.in_matrix_file_list = in_matrix_file_list
    inputspec.inputs.reference = reference

    wf_transform_anat = Workflow(name="wf_transform_anat")
    wf_transform_anat.connect(inputspec, 'in_file_list', func2std_xform,
                              'in_file')
    wf_transform_anat.connect(inputspec, 'in_matrix_file_list', func2std_xform,
                              'in_matrix_file')
    wf_transform_anat.connect(inputspec, 'reference', func2std_xform,
                              'reference')

    return wf_transform_anat
Exemple #5
0
# composite transform
dictWarps = {'transforms': [], 'outpath': []}
for strWarpDir in lsWarpDirs:
    strAffinePath = glob.glob(os.path.join(strWarpDir, 'out_matrix',
                                           '*.mat'))[0]
    # Remove rigid body components (translation and rotation) which don't
    # contribute meaningful variation
    strAffinePath = remove_rigidbody(strAffinePath)
    # We use the inverse warp field, which contains the nonlinear transformation from MNI->subject
    strNonlinearPath = glob.glob(
        os.path.join(strWarpDir, 'inverse_warp_field', '*.nii.gz'))[0]
    dictWarps['transforms'].append([strNonlinearPath, strAffinePath])
    dictWarps['outpath'].append(
        os.path.join(strWarpDir, 'composite_to_mni.nii.gz'))

# Use ANTs ApplyTransforms to compose the transforms
antstool = MapNode(ants.ApplyTransforms(input_image=TEMPLATE,
                                        reference_image=TEMPLATE,
                                        interpolation='BSpline',
                                        invert_transform_flags=[False, True],
                                        print_out_composite_warp_file=True),
                   name='applytransforms',
                   iterfield=['output_image', 'transforms'])
antstool.inputs.output_image = dictWarps['outpath']
antstool.inputs.transforms = dictWarps['transforms']

# Create and run nipype workflow
wf = Workflow('composite_transforms')
wf.add_nodes([antstool])
wf.run(plugin='MultiProc', plugin_args={'n_procs': PIPELINE_JOBS})
Exemple #6
0
    def create_workflow(self):
        """Create the Niype workflow of the super-resolution pipeline.

        It is composed of a succession of Nodes and their corresponding parameters,
        where the output of node i goes to the input of node i+1.

        """
        sub_ses = self.subject
        if self.session is not None:
            sub_ses = ''.join([sub_ses, '_', self.session])

        if self.session is None:
            wf_base_dir = os.path.join(
                self.output_dir, '-'.join(["nipype", __nipype_version__]),
                self.subject, "rec-{}".format(self.sr_id))
            final_res_dir = os.path.join(self.output_dir,
                                         '-'.join(["pymialsrtk",
                                                   __version__]), self.subject)
        else:
            wf_base_dir = os.path.join(
                self.output_dir, '-'.join(["nipype", __nipype_version__]),
                self.subject, self.session, "rec-{}".format(self.sr_id))
            final_res_dir = os.path.join(self.output_dir,
                                         '-'.join(["pymialsrtk", __version__]),
                                         self.subject, self.session)

        if not os.path.exists(wf_base_dir):
            os.makedirs(wf_base_dir)
        print("Process directory: {}".format(wf_base_dir))

        # Initialization (Not sure we can control the name of nipype log)
        if os.path.isfile(os.path.join(wf_base_dir, "pypeline.log")):
            os.unlink(os.path.join(wf_base_dir, "pypeline.log"))

        self.wf = Workflow(name=self.pipeline_name, base_dir=wf_base_dir)

        config.update_config({
            'logging': {
                'log_directory': os.path.join(wf_base_dir),
                'log_to_file': True
            },
            'execution': {
                'remove_unnecessary_outputs': False,
                'stop_on_first_crash': True,
                'stop_on_first_rerun': False,
                'crashfile_format': "txt",
                'use_relative_paths': True,
                'write_provenance': False
            }
        })

        # Update nypipe logging with config
        nipype_logging.update_logging(config)
        # config.enable_provenance()

        if self.use_manual_masks:
            dg = Node(interface=DataGrabber(outfields=['T2ws', 'masks']),
                      name='data_grabber')
            dg.inputs.base_directory = self.bids_dir
            dg.inputs.template = '*'
            dg.inputs.raise_on_empty = False
            dg.inputs.sort_filelist = True

            if self.session is not None:
                t2ws_template = os.path.join(
                    self.subject, self.session, 'anat',
                    '_'.join([sub_ses, '*run-*', '*T2w.nii.gz']))
                if self.m_masks_desc is not None:
                    masks_template = os.path.join(
                        'derivatives', self.m_masks_derivatives_dir,
                        self.subject, self.session, 'anat', '_'.join([
                            sub_ses, '*_run-*', '_desc-' + self.m_masks_desc,
                            '*mask.nii.gz'
                        ]))
                else:
                    masks_template = os.path.join(
                        'derivatives', self.m_masks_derivatives_dir,
                        self.subject, self.session, 'anat',
                        '_'.join([sub_ses, '*run-*', '*mask.nii.gz']))
            else:
                t2ws_template = os.path.join(self.subject, 'anat',
                                             sub_ses + '*_run-*_T2w.nii.gz')

                if self.m_masks_desc is not None:
                    masks_template = os.path.join(
                        'derivatives', self.m_masks_derivatives_dir,
                        self.subject, self.session, 'anat', '_'.join([
                            sub_ses, '*_run-*', '_desc-' + self.m_masks_desc,
                            '*mask.nii.gz'
                        ]))
                else:
                    masks_template = os.path.join(
                        'derivatives', self.m_masks_derivatives_dir,
                        self.subject, 'anat', sub_ses + '*_run-*_*mask.nii.gz')

            dg.inputs.field_template = dict(T2ws=t2ws_template,
                                            masks=masks_template)

            brainMask = MapNode(
                interface=IdentityInterface(fields=['out_file']),
                name='brain_masks_bypass',
                iterfield=['out_file'])

            if self.m_stacks is not None:
                custom_masks_filter = Node(
                    interface=preprocess.FilteringByRunid(),
                    name='custom_masks_filter')
                custom_masks_filter.inputs.stacks_id = self.m_stacks

        else:
            dg = Node(interface=DataGrabber(outfields=['T2ws']),
                      name='data_grabber')

            dg.inputs.base_directory = self.bids_dir
            dg.inputs.template = '*'
            dg.inputs.raise_on_empty = False
            dg.inputs.sort_filelist = True

            dg.inputs.field_template = dict(
                T2ws=os.path.join(self.subject, 'anat', sub_ses +
                                  '*_run-*_T2w.nii.gz'))
            if self.session is not None:
                dg.inputs.field_template = dict(T2ws=os.path.join(
                    self.subject, self.session, 'anat', '_'.join(
                        [sub_ses, '*run-*', '*T2w.nii.gz'])))

            if self.m_stacks is not None:
                t2ws_filter_prior_masks = Node(
                    interface=preprocess.FilteringByRunid(),
                    name='t2ws_filter_prior_masks')
                t2ws_filter_prior_masks.inputs.stacks_id = self.m_stacks

            brainMask = MapNode(interface=preprocess.BrainExtraction(),
                                name='brainExtraction',
                                iterfield=['in_file'])

            brainMask.inputs.bids_dir = self.bids_dir
            brainMask.inputs.in_ckpt_loc = pkg_resources.resource_filename(
                "pymialsrtk",
                os.path.join("data", "Network_checkpoints",
                             "Network_checkpoints_localization",
                             "Unet.ckpt-88000.index")).split('.index')[0]
            brainMask.inputs.threshold_loc = 0.49
            brainMask.inputs.in_ckpt_seg = pkg_resources.resource_filename(
                "pymialsrtk",
                os.path.join("data", "Network_checkpoints",
                             "Network_checkpoints_segmentation",
                             "Unet.ckpt-20000.index")).split('.index')[0]
            brainMask.inputs.threshold_seg = 0.5

        t2ws_filtered = Node(interface=preprocess.FilteringByRunid(),
                             name='t2ws_filtered')
        masks_filtered = Node(interface=preprocess.FilteringByRunid(),
                              name='masks_filtered')

        if not self.m_skip_stacks_ordering:
            stacksOrdering = Node(interface=preprocess.StacksOrdering(),
                                  name='stackOrdering')
        else:
            stacksOrdering = Node(
                interface=IdentityInterface(fields=['stacks_order']),
                name='stackOrdering')
            stacksOrdering.inputs.stacks_order = self.m_stacks

        if not self.m_skip_nlm_denoising:
            nlmDenoise = MapNode(interface=preprocess.BtkNLMDenoising(),
                                 name='nlmDenoise',
                                 iterfield=['in_file', 'in_mask'])
            nlmDenoise.inputs.bids_dir = self.bids_dir

            # Sans le mask le premier correct slice intensity...
            srtkCorrectSliceIntensity01_nlm = MapNode(
                interface=preprocess.MialsrtkCorrectSliceIntensity(),
                name='srtkCorrectSliceIntensity01_nlm',
                iterfield=['in_file', 'in_mask'])
            srtkCorrectSliceIntensity01_nlm.inputs.bids_dir = self.bids_dir
            srtkCorrectSliceIntensity01_nlm.inputs.out_postfix = '_uni'

        srtkCorrectSliceIntensity01 = MapNode(
            interface=preprocess.MialsrtkCorrectSliceIntensity(),
            name='srtkCorrectSliceIntensity01',
            iterfield=['in_file', 'in_mask'])
        srtkCorrectSliceIntensity01.inputs.bids_dir = self.bids_dir
        srtkCorrectSliceIntensity01.inputs.out_postfix = '_uni'

        srtkSliceBySliceN4BiasFieldCorrection = MapNode(
            interface=preprocess.MialsrtkSliceBySliceN4BiasFieldCorrection(),
            name='srtkSliceBySliceN4BiasFieldCorrection',
            iterfield=['in_file', 'in_mask'])
        srtkSliceBySliceN4BiasFieldCorrection.inputs.bids_dir = self.bids_dir

        srtkSliceBySliceCorrectBiasField = MapNode(
            interface=preprocess.MialsrtkSliceBySliceCorrectBiasField(),
            name='srtkSliceBySliceCorrectBiasField',
            iterfield=['in_file', 'in_mask', 'in_field'])
        srtkSliceBySliceCorrectBiasField.inputs.bids_dir = self.bids_dir

        # 4-modules sequence to be defined as a stage.
        if not self.m_skip_nlm_denoising:
            srtkCorrectSliceIntensity02_nlm = MapNode(
                interface=preprocess.MialsrtkCorrectSliceIntensity(),
                name='srtkCorrectSliceIntensity02_nlm',
                iterfield=['in_file', 'in_mask'])
            srtkCorrectSliceIntensity02_nlm.inputs.bids_dir = self.bids_dir

            srtkIntensityStandardization01_nlm = Node(
                interface=preprocess.MialsrtkIntensityStandardization(),
                name='srtkIntensityStandardization01_nlm')
            srtkIntensityStandardization01_nlm.inputs.bids_dir = self.bids_dir

            srtkHistogramNormalization_nlm = Node(
                interface=preprocess.MialsrtkHistogramNormalization(),
                name='srtkHistogramNormalization_nlm')
            srtkHistogramNormalization_nlm.inputs.bids_dir = self.bids_dir

            srtkIntensityStandardization02_nlm = Node(
                interface=preprocess.MialsrtkIntensityStandardization(),
                name='srtkIntensityStandardization02_nlm')
            srtkIntensityStandardization02_nlm.inputs.bids_dir = self.bids_dir

        # 4-modules sequence to be defined as a stage.
        srtkCorrectSliceIntensity02 = MapNode(
            interface=preprocess.MialsrtkCorrectSliceIntensity(),
            name='srtkCorrectSliceIntensity02',
            iterfield=['in_file', 'in_mask'])
        srtkCorrectSliceIntensity02.inputs.bids_dir = self.bids_dir

        srtkIntensityStandardization01 = Node(
            interface=preprocess.MialsrtkIntensityStandardization(),
            name='srtkIntensityStandardization01')
        srtkIntensityStandardization01.inputs.bids_dir = self.bids_dir

        srtkHistogramNormalization = Node(
            interface=preprocess.MialsrtkHistogramNormalization(),
            name='srtkHistogramNormalization')
        srtkHistogramNormalization.inputs.bids_dir = self.bids_dir

        srtkIntensityStandardization02 = Node(
            interface=preprocess.MialsrtkIntensityStandardization(),
            name='srtkIntensityStandardization02')
        srtkIntensityStandardization02.inputs.bids_dir = self.bids_dir

        srtkMaskImage01 = MapNode(interface=preprocess.MialsrtkMaskImage(),
                                  name='srtkMaskImage01',
                                  iterfield=['in_file', 'in_mask'])
        srtkMaskImage01.inputs.bids_dir = self.bids_dir

        srtkImageReconstruction = Node(
            interface=reconstruction.MialsrtkImageReconstruction(),
            name='srtkImageReconstruction')
        srtkImageReconstruction.inputs.bids_dir = self.bids_dir
        srtkImageReconstruction.inputs.sub_ses = sub_ses
        srtkImageReconstruction.inputs.no_reg = self.m_skip_svr

        srtkTVSuperResolution = Node(
            interface=reconstruction.MialsrtkTVSuperResolution(),
            name='srtkTVSuperResolution')
        srtkTVSuperResolution.inputs.bids_dir = self.bids_dir
        srtkTVSuperResolution.inputs.sub_ses = sub_ses
        srtkTVSuperResolution.inputs.in_loop = self.primal_dual_loops
        srtkTVSuperResolution.inputs.in_deltat = self.deltatTV
        srtkTVSuperResolution.inputs.in_lambda = self.lambdaTV
        srtkTVSuperResolution.inputs.use_manual_masks = self.use_manual_masks

        srtkN4BiasFieldCorrection = Node(
            interface=postprocess.MialsrtkN4BiasFieldCorrection(),
            name='srtkN4BiasFieldCorrection')
        srtkN4BiasFieldCorrection.inputs.bids_dir = self.bids_dir

        if self.m_do_refine_hr_mask:
            srtkHRMask = Node(
                interface=postprocess.MialsrtkRefineHRMaskByIntersection(),
                name='srtkHRMask')
            srtkHRMask.inputs.bids_dir = self.bids_dir
        else:
            srtkHRMask = Node(interface=postprocess.BinarizeImage(),
                              name='srtkHRMask')

        srtkMaskImage02 = Node(interface=preprocess.MialsrtkMaskImage(),
                               name='srtkMaskImage02')
        srtkMaskImage02.inputs.bids_dir = self.bids_dir

        # Build workflow : connections of the nodes
        # Nodes ready : Linking now
        if self.use_manual_masks:
            if self.m_stacks is not None:
                self.wf.connect(dg, "masks", custom_masks_filter,
                                "input_files")
                self.wf.connect(custom_masks_filter, "output_files", brainMask,
                                "out_file")
            else:
                self.wf.connect(dg, "masks", brainMask, "out_file")
        else:
            if self.m_stacks is not None:
                self.wf.connect(dg, "T2ws", t2ws_filter_prior_masks,
                                "input_files")
                self.wf.connect(t2ws_filter_prior_masks, "output_files",
                                brainMask, "in_file")
            else:
                self.wf.connect(dg, "T2ws", brainMask, "in_file")

        if not self.m_skip_stacks_ordering:
            self.wf.connect(brainMask, "out_file", stacksOrdering,
                            "input_masks")

        self.wf.connect(stacksOrdering, "stacks_order", t2ws_filtered,
                        "stacks_id")
        self.wf.connect(dg, "T2ws", t2ws_filtered, "input_files")

        self.wf.connect(stacksOrdering, "stacks_order", masks_filtered,
                        "stacks_id")
        self.wf.connect(brainMask, "out_file", masks_filtered, "input_files")

        if not self.m_skip_nlm_denoising:
            self.wf.connect(t2ws_filtered,
                            ("output_files", utils.sort_ascending), nlmDenoise,
                            "in_file")
            self.wf.connect(masks_filtered,
                            ("output_files", utils.sort_ascending), nlmDenoise,
                            "in_mask")  ## Comment to match docker process

            self.wf.connect(nlmDenoise, ("out_file", utils.sort_ascending),
                            srtkCorrectSliceIntensity01_nlm, "in_file")
            self.wf.connect(masks_filtered,
                            ("output_files", utils.sort_ascending),
                            srtkCorrectSliceIntensity01_nlm, "in_mask")

        self.wf.connect(t2ws_filtered, ("output_files", utils.sort_ascending),
                        srtkCorrectSliceIntensity01, "in_file")
        self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending),
                        srtkCorrectSliceIntensity01, "in_mask")

        if not self.m_skip_nlm_denoising:
            self.wf.connect(srtkCorrectSliceIntensity01_nlm,
                            ("out_file", utils.sort_ascending),
                            srtkSliceBySliceN4BiasFieldCorrection, "in_file")
        else:
            self.wf.connect(srtkCorrectSliceIntensity01,
                            ("out_file", utils.sort_ascending),
                            srtkSliceBySliceN4BiasFieldCorrection, "in_file")
        self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending),
                        srtkSliceBySliceN4BiasFieldCorrection, "in_mask")

        self.wf.connect(srtkCorrectSliceIntensity01,
                        ("out_file", utils.sort_ascending),
                        srtkSliceBySliceCorrectBiasField, "in_file")
        self.wf.connect(srtkSliceBySliceN4BiasFieldCorrection,
                        ("out_fld_file", utils.sort_ascending),
                        srtkSliceBySliceCorrectBiasField, "in_field")
        self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending),
                        srtkSliceBySliceCorrectBiasField, "in_mask")

        if not self.m_skip_nlm_denoising:
            self.wf.connect(srtkSliceBySliceN4BiasFieldCorrection,
                            ("out_im_file", utils.sort_ascending),
                            srtkCorrectSliceIntensity02_nlm, "in_file")
            self.wf.connect(masks_filtered,
                            ("output_files", utils.sort_ascending),
                            srtkCorrectSliceIntensity02_nlm, "in_mask")
            self.wf.connect(srtkCorrectSliceIntensity02_nlm,
                            ("out_file", utils.sort_ascending),
                            srtkIntensityStandardization01_nlm, "input_images")
            self.wf.connect(srtkIntensityStandardization01_nlm,
                            ("output_images", utils.sort_ascending),
                            srtkHistogramNormalization_nlm, "input_images")
            self.wf.connect(masks_filtered,
                            ("output_files", utils.sort_ascending),
                            srtkHistogramNormalization_nlm, "input_masks")
            self.wf.connect(srtkHistogramNormalization_nlm,
                            ("output_images", utils.sort_ascending),
                            srtkIntensityStandardization02_nlm, "input_images")

        self.wf.connect(srtkSliceBySliceCorrectBiasField,
                        ("out_im_file", utils.sort_ascending),
                        srtkCorrectSliceIntensity02, "in_file")
        self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending),
                        srtkCorrectSliceIntensity02, "in_mask")
        self.wf.connect(srtkCorrectSliceIntensity02,
                        ("out_file", utils.sort_ascending),
                        srtkIntensityStandardization01, "input_images")

        self.wf.connect(srtkIntensityStandardization01,
                        ("output_images", utils.sort_ascending),
                        srtkHistogramNormalization, "input_images")
        self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending),
                        srtkHistogramNormalization, "input_masks")
        self.wf.connect(srtkHistogramNormalization,
                        ("output_images", utils.sort_ascending),
                        srtkIntensityStandardization02, "input_images")

        if not self.m_skip_nlm_denoising:
            self.wf.connect(srtkIntensityStandardization02_nlm,
                            ("output_images", utils.sort_ascending),
                            srtkMaskImage01, "in_file")
            self.wf.connect(masks_filtered,
                            ("output_files", utils.sort_ascending),
                            srtkMaskImage01, "in_mask")
        else:
            self.wf.connect(srtkIntensityStandardization02,
                            ("output_images", utils.sort_ascending),
                            srtkMaskImage01, "in_file")
            self.wf.connect(masks_filtered,
                            ("output_files", utils.sort_ascending),
                            srtkMaskImage01, "in_mask")

        self.wf.connect(srtkMaskImage01, "out_im_file",
                        srtkImageReconstruction, "input_images")
        self.wf.connect(masks_filtered, "output_files",
                        srtkImageReconstruction, "input_masks")
        self.wf.connect(stacksOrdering, "stacks_order",
                        srtkImageReconstruction, "stacks_order")

        self.wf.connect(srtkIntensityStandardization02, "output_images",
                        srtkTVSuperResolution, "input_images")
        self.wf.connect(srtkImageReconstruction,
                        ("output_transforms", utils.sort_ascending),
                        srtkTVSuperResolution, "input_transforms")
        self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending),
                        srtkTVSuperResolution, "input_masks")
        self.wf.connect(stacksOrdering, "stacks_order", srtkTVSuperResolution,
                        "stacks_order")

        self.wf.connect(srtkImageReconstruction, "output_sdi",
                        srtkTVSuperResolution, "input_sdi")

        if self.m_do_refine_hr_mask:
            self.wf.connect(srtkIntensityStandardization02,
                            ("output_images", utils.sort_ascending),
                            srtkHRMask, "input_images")
            self.wf.connect(masks_filtered,
                            ("output_files", utils.sort_ascending), srtkHRMask,
                            "input_masks")
            self.wf.connect(srtkImageReconstruction,
                            ("output_transforms", utils.sort_ascending),
                            srtkHRMask, "input_transforms")
            self.wf.connect(srtkTVSuperResolution, "output_sr", srtkHRMask,
                            "input_sr")
        else:
            self.wf.connect(srtkTVSuperResolution, "output_sr", srtkHRMask,
                            "input_image")

        self.wf.connect(srtkTVSuperResolution, "output_sr", srtkMaskImage02,
                        "in_file")
        self.wf.connect(srtkHRMask, "output_srmask", srtkMaskImage02,
                        "in_mask")

        self.wf.connect(srtkTVSuperResolution, "output_sr",
                        srtkN4BiasFieldCorrection, "input_image")
        self.wf.connect(srtkHRMask, "output_srmask", srtkN4BiasFieldCorrection,
                        "input_mask")

        # Datasinker
        finalFilenamesGeneration = Node(
            interface=postprocess.FilenamesGeneration(), name='filenames_gen')
        finalFilenamesGeneration.inputs.sub_ses = sub_ses
        finalFilenamesGeneration.inputs.sr_id = self.sr_id
        finalFilenamesGeneration.inputs.use_manual_masks = self.use_manual_masks

        self.wf.connect(stacksOrdering, "stacks_order",
                        finalFilenamesGeneration, "stacks_order")

        datasink = Node(interface=DataSink(), name='data_sinker')
        datasink.inputs.base_directory = final_res_dir

        if not self.m_skip_stacks_ordering:
            self.wf.connect(stacksOrdering, "report_image", datasink,
                            'figures.@stackOrderingQC')
            self.wf.connect(stacksOrdering, "motion_tsv", datasink,
                            'anat.@motionTSV')
        self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending),
                        datasink, 'anat.@LRmasks')
        self.wf.connect(srtkIntensityStandardization02,
                        ("output_images", utils.sort_ascending), datasink,
                        'anat.@LRsPreproc')
        self.wf.connect(srtkImageReconstruction,
                        ("output_transforms", utils.sort_ascending), datasink,
                        'xfm.@transforms')
        self.wf.connect(finalFilenamesGeneration, "substitutions", datasink,
                        "substitutions")
        self.wf.connect(srtkMaskImage01, ("out_im_file", utils.sort_ascending),
                        datasink, 'anat.@LRsDenoised')
        self.wf.connect(srtkImageReconstruction, "output_sdi", datasink,
                        'anat.@SDI')
        self.wf.connect(srtkN4BiasFieldCorrection, "output_image", datasink,
                        'anat.@SR')
        self.wf.connect(srtkTVSuperResolution, "output_json_path", datasink,
                        'anat.@SRjson')
        self.wf.connect(srtkTVSuperResolution, "output_sr_png", datasink,
                        'figures.@SRpng')
        self.wf.connect(srtkHRMask, "output_srmask", datasink, 'anat.@SRmask')
Exemple #7
0
def pals(config: dict):
    # Get config file defining workflow
    # configs = json.load(open(config_file, 'r'))
    print('Starting: initializing workflow.')
    # Build pipelie
    wf = Workflow(name='PALS')

    # bidsLayout = bids.BIDSLayout(config['BIDSRoot'])
    # Get data
    loader = BIDSDataGrabber(index_derivatives=False)
    loader.inputs.base_dir = config['BIDSRoot']
    loader.inputs.subject = config['Subject']
    if (config['Session'] is not None):
        loader.inputs.session = config['Session']
    loader.inputs.output_query = {
        't1w': dict(**config['T1Entities'], invalid_filters='allow')
    }
    loader.inputs.extra_derivatives = [config['BIDSRoot']]
    loader = Node(loader, name='BIDSgrabber')

    entities = {
        'subject': config['Subject'],
        'session': config['Session'],
        'suffix': 'T1w',
        'extension': '.nii.gz'
    }

    # Reorient to radiological
    if (config['Analysis']['Reorient']):
        radio = MapNode(
            Reorient(orientation=config['Analysis']['Orientation']),
            name="reorientation",
            iterfield='in_file')
        if ('Reorient' in config['Outputs'].keys()):
            reorient_sink = MapNode(Function(function=copyfile,
                                             input_names=['src', 'dst']),
                                    name='reorient_copy',
                                    iterfield='src')
            path_pattern = 'sub-{subject}/ses-{session}/anat/sub-{subject}_ses-{session}_desc-' + config[
                'Analysis']['Orientation'] + '_{suffix}{extension}'
            reorient_filename = join(config['Outputs']['Reorient'],
                                     path_pattern.format(**entities))
            pathlib.Path(os.path.dirname(reorient_filename)).mkdir(
                parents=True, exist_ok=True)
            reorient_sink.inputs.dst = reorient_filename
            wf.connect([(radio, reorient_sink, [('out_file', 'src')])])

    else:
        radio = MapNode(Function(function=infile_to_outfile,
                                 input_names='in_file',
                                 output_names='out_file'),
                        name='identity',
                        iterfield='in_file')

    # Brain extraction
    bet = node_fetch.extraction_node(config, **config['BrainExtraction'])
    if ('BrainExtraction' in config['Outputs'].keys()):
        path_pattern = 'sub-{subject}/ses-{session}/anat/sub-{subject}_ses-{session}_space-' + \
                       config['Outputs']['StartRegistrationSpace'] + '_desc-brain_mask{extension}'
        brain_mask_sink = MapNode(Function(function=copyfile,
                                           input_names=['src', 'dst']),
                                  name='brain_mask_sink',
                                  iterfield='src')
        brain_mask_out = join(config['Outputs']['BrainExtraction'],
                              path_pattern.format(**entities))
        pathlib.Path(os.path.dirname(brain_mask_out)).mkdir(parents=True,
                                                            exist_ok=True)
        brain_mask_sink.inputs.dst = brain_mask_out

    ## Lesion load calculation
    # Registration
    reg = node_fetch.registration_node(config, **config['Registration'])
    if ('RegistrationTransform' in config['Outputs'].keys()):

        path_pattern = 'sub-{subject}/ses-{session}/anat/sub-{subject}_ses-{session}_space-' + \
                       config['Outputs']['StartRegistrationSpace'] + '_desc-transform.mat'

        registration_transform_filename = join(
            config['Outputs']['RegistrationTransform'],
            path_pattern.format(**entities))
        registration_transform_sink = MapNode(Function(
            function=copyfile, input_names=['src', 'dst']),
                                              name='registration_transf_sink',
                                              iterfield='src')
        pathlib.Path(os.path.dirname(registration_transform_filename)).mkdir(
            parents=True, exist_ok=True)
        registration_transform_sink.inputs.dst = registration_transform_filename
        wf.connect([(reg, registration_transform_sink, [('out_matrix_file',
                                                         'src')])])

    # Get mask
    mask_path_fetcher = Node(BIDSDataGrabber(
        base_dir=config['LesionRoot'],
        subject=config['Subject'],
        index_derivatives=False,
        output_query={
            'mask': dict(**config['LesionEntities'], invalid_filters='allow')
        },
        extra_derivatives=[config['LesionRoot']]),
                             name='mask_grabber')
    if (config['Session'] is not None):
        mask_path_fetcher.inputs.session = config['Session']

    # Apply reg file to lesion mask
    apply_xfm = node_fetch.apply_xfm_node(config)

    # Lesion load calculation
    if (config['Analysis']['LesionLoadCalculation']):
        lesion_load = MapNode(Function(function=overlap,
                                       input_names=['ref_mask', 'roi_list'],
                                       output_names='out_list'),
                              name='overlap_calc',
                              iterfield=['ref_mask'])
        roi_list = []
        if (os.path.exists(config['ROIDir'])):
            buf = os.listdir(config['ROIDir'])
            roi_list = [
                os.path.abspath(os.path.join(config['ROIDir'], b)) for b in buf
            ]
        else:
            warnings.warn(f"ROIDir ({config['ROIDir']}) doesn't exist.")
        buf = config['ROIList']
        roi_list += [os.path.abspath(b) for b in buf]
        lesion_load.inputs.roi_list = roi_list

        # CSV output
        csv_output = MapNode(Function(
            function=csv_writer,
            input_names=['filename', 'data_dict', 'subject', 'session']),
                             name='csv_output',
                             iterfield=['data_dict'])
        csv_output.inputs.subject = config['Subject']
        csv_output.inputs.session = config['Session']
        path_pattern = 'sub-{subject}/ses-{session}/anat/sub-{subject}_ses-{session}_desc-LesionLoad.csv'
        csv_out_filename = join(config['Outputs']['RegistrationTransform'],
                                path_pattern.format(**entities))
        csv_output.inputs.filename = csv_out_filename

        wf.connect([(apply_xfm, lesion_load, [('out_file', 'ref_mask')]),
                    (lesion_load, csv_output, [('out_list', 'data_dict')])])

    ## Lesion correction
    if (config['Analysis']['LesionCorrection']):
        ## White matter removal node. Does the white matter correction; has multiple inputs that need to be supplied.
        wm_removal = MapNode(Function(
            function=white_matter_correction,
            input_names=[
                'image', 'wm_mask', 'lesion_mask', 'max_difference_fraction'
            ],
            output_names=['out_data', 'corrected_volume']),
                             name='wm_removal',
                             iterfield=['image', 'wm_mask', 'lesion_mask'])
        wm_removal.inputs.max_difference_fraction = config['LesionCorrection'][
            'WhiteMatterSpread']

        ## File loaders
        # Loads the subject image, passes it to wm_removal node
        subject_image_loader = MapNode(Function(function=image_load,
                                                input_names=['in_filename'],
                                                output_names='out_image'),
                                       name='file_load0',
                                       iterfield='in_filename')
        wf.connect([
            (radio, subject_image_loader, [('out_file', 'in_filename')]),
            (subject_image_loader, wm_removal, [('out_image', 'image')])
        ])

        # Loads the mask image, passes it to wm_removal node
        mask_image_loader = MapNode(Function(function=image_load,
                                             input_names=['in_filename'],
                                             output_names='out_image'),
                                    name='file_load2',
                                    iterfield='in_filename')
        wf.connect([
            (mask_path_fetcher, mask_image_loader, [('mask', 'in_filename')]),
            (mask_image_loader, wm_removal, [('out_image', 'lesion_mask')])
        ])

        # Save lesion mask with white matter voxels removed
        output_image = MapNode(Function(
            function=image_write,
            input_names=['image', 'reference', 'file_name']),
                               name='image_writer0',
                               iterfield=['image', 'reference'])
        path_pattern = 'sub-{subject}/ses-{session}/anat/sub-{subject}_ses-{session}_space-' + \
                       config['Outputs']['StartRegistrationSpace'] + '_desc-CorrectedLesion_mask{extension}'
        lesion_corrected_filename = join(config['Outputs']['LesionCorrected'],
                                         path_pattern.format(**entities))
        output_image.inputs.file_name = lesion_corrected_filename
        wf.connect([(wm_removal, output_image, [('out_data', 'image')]),
                    (mask_path_fetcher, output_image, [('mask', 'reference')])
                    ])

        ## CSV output
        csv_output_corr = MapNode(Function(function=csv_writer,
                                           input_names=[
                                               'filename', 'subject',
                                               'session', 'data', 'data_name'
                                           ]),
                                  name='csv_output_corr',
                                  iterfield=['data'])
        csv_output_corr.inputs.subject = config['Subject']
        csv_output_corr.inputs.session = config['Session']
        csv_output_corr.inputs.data_name = 'CorrectedVolume'

        path_pattern = 'sub-{subject}/ses-{session}/anat/sub-{subject}_ses-{session}_desc-LesionLoad.csv'
        csv_out_filename = join(config['Outputs']['RegistrationTransform'],
                                path_pattern.format(**entities))
        csv_output_corr.inputs.filename = csv_out_filename

        wf.connect([(wm_removal, csv_output_corr, [('corrected_volume', 'data')
                                                   ])])

        ## White matter segmentation; either do segmentation or load the file
        if (config['Analysis']['WhiteMatterSegmentation']):
            # Config is set to do white matter segmentation
            # T1 intensity normalization
            t1_norm = MapNode(Function(
                function=rescale_image,
                input_names=['image', 'range_min', 'range_max', 'save_image'],
                output_names='out_file'),
                              name='normalization',
                              iterfield=['image'])
            t1_norm.inputs.range_min = config['LesionCorrection'][
                'ImageNormMin']
            t1_norm.inputs.range_max = config['LesionCorrection'][
                'ImageNormMax']
            t1_norm.inputs.save_image = True
            wf.connect([(bet, t1_norm, [('out_file', 'image')])])

            # White matter segmentation
            wm_seg = MapNode(FAST(), name="wm_seg", iterfield='in_files')
            wm_seg.inputs.out_basename = "segmentation"
            wm_seg.inputs.img_type = 1
            wm_seg.inputs.number_classes = 3
            wm_seg.inputs.hyper = 0.1
            wm_seg.inputs.iters_afterbias = 4
            wm_seg.inputs.bias_lowpass = 20
            wm_seg.inputs.segments = True
            wm_seg.inputs.no_pve = True
            ex_last = MapNode(Function(function=extract_last,
                                       input_names=['in_list'],
                                       output_names='out_entry'),
                              name='ex_last',
                              iterfield='in_list')

            file_load1 = MapNode(Function(function=image_load,
                                          input_names=['in_filename'],
                                          output_names='out_image'),
                                 name='file_load1',
                                 iterfield='in_filename')
            # White matter output; only necessary if white matter is segmented
            wm_map = MapNode(Function(
                function=image_write,
                input_names=['image', 'reference', 'file_name']),
                             name='image_writer1',
                             iterfield=['image', 'reference'])
            path_pattern = 'sub-{subject}/ses-{session}/anat/sub-{subject}_ses-{session}_space-' + \
                           config['Outputs']['StartRegistrationSpace'] + '_desc-WhiteMatter_mask{extension}'
            wm_map_filename = join(config['Outputs']['LesionCorrected'],
                                   path_pattern.format(**entities))
            wm_map.inputs.file_name = wm_map_filename
            wf.connect([(file_load1, wm_map, [('out_image', 'image')]),
                        (mask_path_fetcher, wm_map, [('mask', 'reference')])])
            # Connect nodes in workflow
            wf.connect([
                (wm_seg, ex_last, [('tissue_class_files', 'in_list')]),
                (t1_norm, wm_seg, [('out_file', 'in_files')]),
                # (ex_last, wm_map, [('out_entry', 'image')]),
                (ex_last, file_load1, [('out_entry', 'in_filename')]),
                (file_load1, wm_removal, [('out_image', 'wm_mask')])
            ])

        elif (config['Analysis']['LesionCorrection']):
            # No white matter segmentation should be done, but lesion correction is expected.
            # White matter segmentation must be supplied
            wm_seg_path = config['WhiteMatterSegmentationFile']
            if (len(wm_seg_path) == 0 or not os.path.exists(wm_seg_path)):
                # Check if file exists at output
                path_pattern = 'sub-{subject}/ses-{session}/anat/sub-{subject}_ses-{session}_space-' + \
                               config['Outputs']['StartRegistrationSpace'] + '_desc-WhiteMatter_mask{extension}'
                wm_map_filename = join(config['Outputs']['LesionCorrected'],
                                       path_pattern.format(**entities))
                if (os.path.exists(wm_map_filename)):
                    wm_seg_path = wm_map_filename
            else:
                raise ValueError(
                    'Config file is inconsistent; if WhiteMatterSegmentation is false but LesionCorrection'
                    ' is true, then WhiteMatterSegmentationFile must be defined and must exist.'
                )
            file_load1 = MapNode(Function(function=image_load,
                                          input_names=['in_filename'],
                                          output_names='out_image'),
                                 name='file_load1',
                                 iterfield='in_filename')
            file_load1.inputs.in_filename = wm_seg_path

            # Connect nodes in workflow
            wf.connect([(file_load1, wm_removal, [('out_image', 'wm_mask')])])

    # Connecting workflow.
    wf.connect([
        # Starter
        (loader, radio, [('t1w', 'in_file')]),
        (radio, bet, [('out_file', 'in_file')]),
        (bet, reg, [('out_file', 'in_file')]),
        (reg, apply_xfm, [('out_matrix_file', 'in_matrix_file')]),
        (mask_path_fetcher, apply_xfm, [('mask', 'in_file')]),
    ])

    try:
        graph_out = config['Outputs'][
            'LesionCorrected'] + '/sub-{subject}/ses-{session}/anat/'.format(
                **entities)
        wf.write_graph(graph2use='orig',
                       dotfilename=join(graph_out, 'graph.dot'),
                       format='png')
        os.remove(graph_out + 'graph.dot')
        os.remove(graph_out + 'graph_detailed.dot')
    except OSError:
        warnings.warn(
            "graphviz not installed; can't produce graph. See http://www.graphviz.org/download/ for "
            "installation instructions.")
    wf.run()
    return wf