def apply_xfm_node(config: dict, **kwargs): ''' Parses config file to return desired apply_xfm node. Parameters ---------- config : dict PALS config file kwargs Keyword arguments to send to registration method. Returns ------- MapNode ''' if (not config['Analysis']['Registration']): # No registration; no xfm to apply. n = MapNode(Function(function=infile_to_outfile, input_names=['in_file', 'in_matrix_file'], output_names='out_file'), name='transformation_skip', iterfield=['in_file', 'in_matrix_file']) else: n = MapNode(fsl.FLIRT(apply_xfm=True, reference=config['Registration']['reference']), name='transformation_flirt', iterfield=['in_file', 'in_matrix_file']) return n
def registration_node(config: dict, **kwargs): ''' Parses config file to return desired registration method. Parameters ---------- config : dict PALS config file kwargs Keyword arguments to send to registration method. Returns ------- MapNode ''' # Get registration method reg_method = config['Analysis']['RegistrationMethod'] if (not config['Analysis']['Registration']): # No registration; in -> out n = MapNode(Function(function=reg_no_reg, input_names=['in_file'], output_names=['out_file', 'out_matrix_file']), name='registration_identity', iterfield='in_file') elif (reg_method.lower() == 'flirt'): # Use FLIRT n = MapNode(fsl.FLIRT(), name='registration_flirt', iterfield='in_file') for k, v in kwargs.items(): setattr(n.inputs, k, v) else: raise (NotImplementedError( f'Registration method {reg_method} not implemented.')) return n
def extraction_node(config: dict, **kwargs): ''' Parses config file to return desired brain extraction method. Parameters ---------- config : dict PALS config file kwargs Keyword arguments to send to brain extraction method. Returns ------- MapNode ''' # Get extraction type extract_type = config['Analysis']['BrainExtractionMethod'] if (not config['Analysis']['BrainExtraction']): # No brain extraction; in-> out n = MapNode(Function(function=infile_to_outfile, input_names='in_file', output_names='out_file'), name='extract_skip', iterfield='in_file') return n elif (extract_type.lower() == 'bet'): n = MapNode(fsl.BET(**kwargs), name='extraction_bet', iterfield='in_file') return n else: raise (NotImplementedError( f'Extraction method {extract_type} not implemented.'))
def wf_transform_anat(in_file_list, in_matrix_file_list, reference): func2std_xform = MapNode( FLIRT(output_type='NIFTI', apply_xfm=True), name="func2std_xform", iterfield=['in_file', 'in_matrix_file', 'reference']) inputspec = Node(IdentityInterface( fields=['in_file_list', 'in_matrix_file_list', 'reference']), name="inputspec") inputspec.inputs.in_file_list = in_file_list inputspec.inputs.in_matrix_file_list = in_matrix_file_list inputspec.inputs.reference = reference wf_transform_anat = Workflow(name="wf_transform_anat") wf_transform_anat.connect(inputspec, 'in_file_list', func2std_xform, 'in_file') wf_transform_anat.connect(inputspec, 'in_matrix_file_list', func2std_xform, 'in_matrix_file') wf_transform_anat.connect(inputspec, 'reference', func2std_xform, 'reference') return wf_transform_anat
# composite transform dictWarps = {'transforms': [], 'outpath': []} for strWarpDir in lsWarpDirs: strAffinePath = glob.glob(os.path.join(strWarpDir, 'out_matrix', '*.mat'))[0] # Remove rigid body components (translation and rotation) which don't # contribute meaningful variation strAffinePath = remove_rigidbody(strAffinePath) # We use the inverse warp field, which contains the nonlinear transformation from MNI->subject strNonlinearPath = glob.glob( os.path.join(strWarpDir, 'inverse_warp_field', '*.nii.gz'))[0] dictWarps['transforms'].append([strNonlinearPath, strAffinePath]) dictWarps['outpath'].append( os.path.join(strWarpDir, 'composite_to_mni.nii.gz')) # Use ANTs ApplyTransforms to compose the transforms antstool = MapNode(ants.ApplyTransforms(input_image=TEMPLATE, reference_image=TEMPLATE, interpolation='BSpline', invert_transform_flags=[False, True], print_out_composite_warp_file=True), name='applytransforms', iterfield=['output_image', 'transforms']) antstool.inputs.output_image = dictWarps['outpath'] antstool.inputs.transforms = dictWarps['transforms'] # Create and run nipype workflow wf = Workflow('composite_transforms') wf.add_nodes([antstool]) wf.run(plugin='MultiProc', plugin_args={'n_procs': PIPELINE_JOBS})
def create_workflow(self): """Create the Niype workflow of the super-resolution pipeline. It is composed of a succession of Nodes and their corresponding parameters, where the output of node i goes to the input of node i+1. """ sub_ses = self.subject if self.session is not None: sub_ses = ''.join([sub_ses, '_', self.session]) if self.session is None: wf_base_dir = os.path.join( self.output_dir, '-'.join(["nipype", __nipype_version__]), self.subject, "rec-{}".format(self.sr_id)) final_res_dir = os.path.join(self.output_dir, '-'.join(["pymialsrtk", __version__]), self.subject) else: wf_base_dir = os.path.join( self.output_dir, '-'.join(["nipype", __nipype_version__]), self.subject, self.session, "rec-{}".format(self.sr_id)) final_res_dir = os.path.join(self.output_dir, '-'.join(["pymialsrtk", __version__]), self.subject, self.session) if not os.path.exists(wf_base_dir): os.makedirs(wf_base_dir) print("Process directory: {}".format(wf_base_dir)) # Initialization (Not sure we can control the name of nipype log) if os.path.isfile(os.path.join(wf_base_dir, "pypeline.log")): os.unlink(os.path.join(wf_base_dir, "pypeline.log")) self.wf = Workflow(name=self.pipeline_name, base_dir=wf_base_dir) config.update_config({ 'logging': { 'log_directory': os.path.join(wf_base_dir), 'log_to_file': True }, 'execution': { 'remove_unnecessary_outputs': False, 'stop_on_first_crash': True, 'stop_on_first_rerun': False, 'crashfile_format': "txt", 'use_relative_paths': True, 'write_provenance': False } }) # Update nypipe logging with config nipype_logging.update_logging(config) # config.enable_provenance() if self.use_manual_masks: dg = Node(interface=DataGrabber(outfields=['T2ws', 'masks']), name='data_grabber') dg.inputs.base_directory = self.bids_dir dg.inputs.template = '*' dg.inputs.raise_on_empty = False dg.inputs.sort_filelist = True if self.session is not None: t2ws_template = os.path.join( self.subject, self.session, 'anat', '_'.join([sub_ses, '*run-*', '*T2w.nii.gz'])) if self.m_masks_desc is not None: masks_template = os.path.join( 'derivatives', self.m_masks_derivatives_dir, self.subject, self.session, 'anat', '_'.join([ sub_ses, '*_run-*', '_desc-' + self.m_masks_desc, '*mask.nii.gz' ])) else: masks_template = os.path.join( 'derivatives', self.m_masks_derivatives_dir, self.subject, self.session, 'anat', '_'.join([sub_ses, '*run-*', '*mask.nii.gz'])) else: t2ws_template = os.path.join(self.subject, 'anat', sub_ses + '*_run-*_T2w.nii.gz') if self.m_masks_desc is not None: masks_template = os.path.join( 'derivatives', self.m_masks_derivatives_dir, self.subject, self.session, 'anat', '_'.join([ sub_ses, '*_run-*', '_desc-' + self.m_masks_desc, '*mask.nii.gz' ])) else: masks_template = os.path.join( 'derivatives', self.m_masks_derivatives_dir, self.subject, 'anat', sub_ses + '*_run-*_*mask.nii.gz') dg.inputs.field_template = dict(T2ws=t2ws_template, masks=masks_template) brainMask = MapNode( interface=IdentityInterface(fields=['out_file']), name='brain_masks_bypass', iterfield=['out_file']) if self.m_stacks is not None: custom_masks_filter = Node( interface=preprocess.FilteringByRunid(), name='custom_masks_filter') custom_masks_filter.inputs.stacks_id = self.m_stacks else: dg = Node(interface=DataGrabber(outfields=['T2ws']), name='data_grabber') dg.inputs.base_directory = self.bids_dir dg.inputs.template = '*' dg.inputs.raise_on_empty = False dg.inputs.sort_filelist = True dg.inputs.field_template = dict( T2ws=os.path.join(self.subject, 'anat', sub_ses + '*_run-*_T2w.nii.gz')) if self.session is not None: dg.inputs.field_template = dict(T2ws=os.path.join( self.subject, self.session, 'anat', '_'.join( [sub_ses, '*run-*', '*T2w.nii.gz']))) if self.m_stacks is not None: t2ws_filter_prior_masks = Node( interface=preprocess.FilteringByRunid(), name='t2ws_filter_prior_masks') t2ws_filter_prior_masks.inputs.stacks_id = self.m_stacks brainMask = MapNode(interface=preprocess.BrainExtraction(), name='brainExtraction', iterfield=['in_file']) brainMask.inputs.bids_dir = self.bids_dir brainMask.inputs.in_ckpt_loc = pkg_resources.resource_filename( "pymialsrtk", os.path.join("data", "Network_checkpoints", "Network_checkpoints_localization", "Unet.ckpt-88000.index")).split('.index')[0] brainMask.inputs.threshold_loc = 0.49 brainMask.inputs.in_ckpt_seg = pkg_resources.resource_filename( "pymialsrtk", os.path.join("data", "Network_checkpoints", "Network_checkpoints_segmentation", "Unet.ckpt-20000.index")).split('.index')[0] brainMask.inputs.threshold_seg = 0.5 t2ws_filtered = Node(interface=preprocess.FilteringByRunid(), name='t2ws_filtered') masks_filtered = Node(interface=preprocess.FilteringByRunid(), name='masks_filtered') if not self.m_skip_stacks_ordering: stacksOrdering = Node(interface=preprocess.StacksOrdering(), name='stackOrdering') else: stacksOrdering = Node( interface=IdentityInterface(fields=['stacks_order']), name='stackOrdering') stacksOrdering.inputs.stacks_order = self.m_stacks if not self.m_skip_nlm_denoising: nlmDenoise = MapNode(interface=preprocess.BtkNLMDenoising(), name='nlmDenoise', iterfield=['in_file', 'in_mask']) nlmDenoise.inputs.bids_dir = self.bids_dir # Sans le mask le premier correct slice intensity... srtkCorrectSliceIntensity01_nlm = MapNode( interface=preprocess.MialsrtkCorrectSliceIntensity(), name='srtkCorrectSliceIntensity01_nlm', iterfield=['in_file', 'in_mask']) srtkCorrectSliceIntensity01_nlm.inputs.bids_dir = self.bids_dir srtkCorrectSliceIntensity01_nlm.inputs.out_postfix = '_uni' srtkCorrectSliceIntensity01 = MapNode( interface=preprocess.MialsrtkCorrectSliceIntensity(), name='srtkCorrectSliceIntensity01', iterfield=['in_file', 'in_mask']) srtkCorrectSliceIntensity01.inputs.bids_dir = self.bids_dir srtkCorrectSliceIntensity01.inputs.out_postfix = '_uni' srtkSliceBySliceN4BiasFieldCorrection = MapNode( interface=preprocess.MialsrtkSliceBySliceN4BiasFieldCorrection(), name='srtkSliceBySliceN4BiasFieldCorrection', iterfield=['in_file', 'in_mask']) srtkSliceBySliceN4BiasFieldCorrection.inputs.bids_dir = self.bids_dir srtkSliceBySliceCorrectBiasField = MapNode( interface=preprocess.MialsrtkSliceBySliceCorrectBiasField(), name='srtkSliceBySliceCorrectBiasField', iterfield=['in_file', 'in_mask', 'in_field']) srtkSliceBySliceCorrectBiasField.inputs.bids_dir = self.bids_dir # 4-modules sequence to be defined as a stage. if not self.m_skip_nlm_denoising: srtkCorrectSliceIntensity02_nlm = MapNode( interface=preprocess.MialsrtkCorrectSliceIntensity(), name='srtkCorrectSliceIntensity02_nlm', iterfield=['in_file', 'in_mask']) srtkCorrectSliceIntensity02_nlm.inputs.bids_dir = self.bids_dir srtkIntensityStandardization01_nlm = Node( interface=preprocess.MialsrtkIntensityStandardization(), name='srtkIntensityStandardization01_nlm') srtkIntensityStandardization01_nlm.inputs.bids_dir = self.bids_dir srtkHistogramNormalization_nlm = Node( interface=preprocess.MialsrtkHistogramNormalization(), name='srtkHistogramNormalization_nlm') srtkHistogramNormalization_nlm.inputs.bids_dir = self.bids_dir srtkIntensityStandardization02_nlm = Node( interface=preprocess.MialsrtkIntensityStandardization(), name='srtkIntensityStandardization02_nlm') srtkIntensityStandardization02_nlm.inputs.bids_dir = self.bids_dir # 4-modules sequence to be defined as a stage. srtkCorrectSliceIntensity02 = MapNode( interface=preprocess.MialsrtkCorrectSliceIntensity(), name='srtkCorrectSliceIntensity02', iterfield=['in_file', 'in_mask']) srtkCorrectSliceIntensity02.inputs.bids_dir = self.bids_dir srtkIntensityStandardization01 = Node( interface=preprocess.MialsrtkIntensityStandardization(), name='srtkIntensityStandardization01') srtkIntensityStandardization01.inputs.bids_dir = self.bids_dir srtkHistogramNormalization = Node( interface=preprocess.MialsrtkHistogramNormalization(), name='srtkHistogramNormalization') srtkHistogramNormalization.inputs.bids_dir = self.bids_dir srtkIntensityStandardization02 = Node( interface=preprocess.MialsrtkIntensityStandardization(), name='srtkIntensityStandardization02') srtkIntensityStandardization02.inputs.bids_dir = self.bids_dir srtkMaskImage01 = MapNode(interface=preprocess.MialsrtkMaskImage(), name='srtkMaskImage01', iterfield=['in_file', 'in_mask']) srtkMaskImage01.inputs.bids_dir = self.bids_dir srtkImageReconstruction = Node( interface=reconstruction.MialsrtkImageReconstruction(), name='srtkImageReconstruction') srtkImageReconstruction.inputs.bids_dir = self.bids_dir srtkImageReconstruction.inputs.sub_ses = sub_ses srtkImageReconstruction.inputs.no_reg = self.m_skip_svr srtkTVSuperResolution = Node( interface=reconstruction.MialsrtkTVSuperResolution(), name='srtkTVSuperResolution') srtkTVSuperResolution.inputs.bids_dir = self.bids_dir srtkTVSuperResolution.inputs.sub_ses = sub_ses srtkTVSuperResolution.inputs.in_loop = self.primal_dual_loops srtkTVSuperResolution.inputs.in_deltat = self.deltatTV srtkTVSuperResolution.inputs.in_lambda = self.lambdaTV srtkTVSuperResolution.inputs.use_manual_masks = self.use_manual_masks srtkN4BiasFieldCorrection = Node( interface=postprocess.MialsrtkN4BiasFieldCorrection(), name='srtkN4BiasFieldCorrection') srtkN4BiasFieldCorrection.inputs.bids_dir = self.bids_dir if self.m_do_refine_hr_mask: srtkHRMask = Node( interface=postprocess.MialsrtkRefineHRMaskByIntersection(), name='srtkHRMask') srtkHRMask.inputs.bids_dir = self.bids_dir else: srtkHRMask = Node(interface=postprocess.BinarizeImage(), name='srtkHRMask') srtkMaskImage02 = Node(interface=preprocess.MialsrtkMaskImage(), name='srtkMaskImage02') srtkMaskImage02.inputs.bids_dir = self.bids_dir # Build workflow : connections of the nodes # Nodes ready : Linking now if self.use_manual_masks: if self.m_stacks is not None: self.wf.connect(dg, "masks", custom_masks_filter, "input_files") self.wf.connect(custom_masks_filter, "output_files", brainMask, "out_file") else: self.wf.connect(dg, "masks", brainMask, "out_file") else: if self.m_stacks is not None: self.wf.connect(dg, "T2ws", t2ws_filter_prior_masks, "input_files") self.wf.connect(t2ws_filter_prior_masks, "output_files", brainMask, "in_file") else: self.wf.connect(dg, "T2ws", brainMask, "in_file") if not self.m_skip_stacks_ordering: self.wf.connect(brainMask, "out_file", stacksOrdering, "input_masks") self.wf.connect(stacksOrdering, "stacks_order", t2ws_filtered, "stacks_id") self.wf.connect(dg, "T2ws", t2ws_filtered, "input_files") self.wf.connect(stacksOrdering, "stacks_order", masks_filtered, "stacks_id") self.wf.connect(brainMask, "out_file", masks_filtered, "input_files") if not self.m_skip_nlm_denoising: self.wf.connect(t2ws_filtered, ("output_files", utils.sort_ascending), nlmDenoise, "in_file") self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending), nlmDenoise, "in_mask") ## Comment to match docker process self.wf.connect(nlmDenoise, ("out_file", utils.sort_ascending), srtkCorrectSliceIntensity01_nlm, "in_file") self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending), srtkCorrectSliceIntensity01_nlm, "in_mask") self.wf.connect(t2ws_filtered, ("output_files", utils.sort_ascending), srtkCorrectSliceIntensity01, "in_file") self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending), srtkCorrectSliceIntensity01, "in_mask") if not self.m_skip_nlm_denoising: self.wf.connect(srtkCorrectSliceIntensity01_nlm, ("out_file", utils.sort_ascending), srtkSliceBySliceN4BiasFieldCorrection, "in_file") else: self.wf.connect(srtkCorrectSliceIntensity01, ("out_file", utils.sort_ascending), srtkSliceBySliceN4BiasFieldCorrection, "in_file") self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending), srtkSliceBySliceN4BiasFieldCorrection, "in_mask") self.wf.connect(srtkCorrectSliceIntensity01, ("out_file", utils.sort_ascending), srtkSliceBySliceCorrectBiasField, "in_file") self.wf.connect(srtkSliceBySliceN4BiasFieldCorrection, ("out_fld_file", utils.sort_ascending), srtkSliceBySliceCorrectBiasField, "in_field") self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending), srtkSliceBySliceCorrectBiasField, "in_mask") if not self.m_skip_nlm_denoising: self.wf.connect(srtkSliceBySliceN4BiasFieldCorrection, ("out_im_file", utils.sort_ascending), srtkCorrectSliceIntensity02_nlm, "in_file") self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending), srtkCorrectSliceIntensity02_nlm, "in_mask") self.wf.connect(srtkCorrectSliceIntensity02_nlm, ("out_file", utils.sort_ascending), srtkIntensityStandardization01_nlm, "input_images") self.wf.connect(srtkIntensityStandardization01_nlm, ("output_images", utils.sort_ascending), srtkHistogramNormalization_nlm, "input_images") self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending), srtkHistogramNormalization_nlm, "input_masks") self.wf.connect(srtkHistogramNormalization_nlm, ("output_images", utils.sort_ascending), srtkIntensityStandardization02_nlm, "input_images") self.wf.connect(srtkSliceBySliceCorrectBiasField, ("out_im_file", utils.sort_ascending), srtkCorrectSliceIntensity02, "in_file") self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending), srtkCorrectSliceIntensity02, "in_mask") self.wf.connect(srtkCorrectSliceIntensity02, ("out_file", utils.sort_ascending), srtkIntensityStandardization01, "input_images") self.wf.connect(srtkIntensityStandardization01, ("output_images", utils.sort_ascending), srtkHistogramNormalization, "input_images") self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending), srtkHistogramNormalization, "input_masks") self.wf.connect(srtkHistogramNormalization, ("output_images", utils.sort_ascending), srtkIntensityStandardization02, "input_images") if not self.m_skip_nlm_denoising: self.wf.connect(srtkIntensityStandardization02_nlm, ("output_images", utils.sort_ascending), srtkMaskImage01, "in_file") self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending), srtkMaskImage01, "in_mask") else: self.wf.connect(srtkIntensityStandardization02, ("output_images", utils.sort_ascending), srtkMaskImage01, "in_file") self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending), srtkMaskImage01, "in_mask") self.wf.connect(srtkMaskImage01, "out_im_file", srtkImageReconstruction, "input_images") self.wf.connect(masks_filtered, "output_files", srtkImageReconstruction, "input_masks") self.wf.connect(stacksOrdering, "stacks_order", srtkImageReconstruction, "stacks_order") self.wf.connect(srtkIntensityStandardization02, "output_images", srtkTVSuperResolution, "input_images") self.wf.connect(srtkImageReconstruction, ("output_transforms", utils.sort_ascending), srtkTVSuperResolution, "input_transforms") self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending), srtkTVSuperResolution, "input_masks") self.wf.connect(stacksOrdering, "stacks_order", srtkTVSuperResolution, "stacks_order") self.wf.connect(srtkImageReconstruction, "output_sdi", srtkTVSuperResolution, "input_sdi") if self.m_do_refine_hr_mask: self.wf.connect(srtkIntensityStandardization02, ("output_images", utils.sort_ascending), srtkHRMask, "input_images") self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending), srtkHRMask, "input_masks") self.wf.connect(srtkImageReconstruction, ("output_transforms", utils.sort_ascending), srtkHRMask, "input_transforms") self.wf.connect(srtkTVSuperResolution, "output_sr", srtkHRMask, "input_sr") else: self.wf.connect(srtkTVSuperResolution, "output_sr", srtkHRMask, "input_image") self.wf.connect(srtkTVSuperResolution, "output_sr", srtkMaskImage02, "in_file") self.wf.connect(srtkHRMask, "output_srmask", srtkMaskImage02, "in_mask") self.wf.connect(srtkTVSuperResolution, "output_sr", srtkN4BiasFieldCorrection, "input_image") self.wf.connect(srtkHRMask, "output_srmask", srtkN4BiasFieldCorrection, "input_mask") # Datasinker finalFilenamesGeneration = Node( interface=postprocess.FilenamesGeneration(), name='filenames_gen') finalFilenamesGeneration.inputs.sub_ses = sub_ses finalFilenamesGeneration.inputs.sr_id = self.sr_id finalFilenamesGeneration.inputs.use_manual_masks = self.use_manual_masks self.wf.connect(stacksOrdering, "stacks_order", finalFilenamesGeneration, "stacks_order") datasink = Node(interface=DataSink(), name='data_sinker') datasink.inputs.base_directory = final_res_dir if not self.m_skip_stacks_ordering: self.wf.connect(stacksOrdering, "report_image", datasink, 'figures.@stackOrderingQC') self.wf.connect(stacksOrdering, "motion_tsv", datasink, 'anat.@motionTSV') self.wf.connect(masks_filtered, ("output_files", utils.sort_ascending), datasink, 'anat.@LRmasks') self.wf.connect(srtkIntensityStandardization02, ("output_images", utils.sort_ascending), datasink, 'anat.@LRsPreproc') self.wf.connect(srtkImageReconstruction, ("output_transforms", utils.sort_ascending), datasink, 'xfm.@transforms') self.wf.connect(finalFilenamesGeneration, "substitutions", datasink, "substitutions") self.wf.connect(srtkMaskImage01, ("out_im_file", utils.sort_ascending), datasink, 'anat.@LRsDenoised') self.wf.connect(srtkImageReconstruction, "output_sdi", datasink, 'anat.@SDI') self.wf.connect(srtkN4BiasFieldCorrection, "output_image", datasink, 'anat.@SR') self.wf.connect(srtkTVSuperResolution, "output_json_path", datasink, 'anat.@SRjson') self.wf.connect(srtkTVSuperResolution, "output_sr_png", datasink, 'figures.@SRpng') self.wf.connect(srtkHRMask, "output_srmask", datasink, 'anat.@SRmask')
def pals(config: dict): # Get config file defining workflow # configs = json.load(open(config_file, 'r')) print('Starting: initializing workflow.') # Build pipelie wf = Workflow(name='PALS') # bidsLayout = bids.BIDSLayout(config['BIDSRoot']) # Get data loader = BIDSDataGrabber(index_derivatives=False) loader.inputs.base_dir = config['BIDSRoot'] loader.inputs.subject = config['Subject'] if (config['Session'] is not None): loader.inputs.session = config['Session'] loader.inputs.output_query = { 't1w': dict(**config['T1Entities'], invalid_filters='allow') } loader.inputs.extra_derivatives = [config['BIDSRoot']] loader = Node(loader, name='BIDSgrabber') entities = { 'subject': config['Subject'], 'session': config['Session'], 'suffix': 'T1w', 'extension': '.nii.gz' } # Reorient to radiological if (config['Analysis']['Reorient']): radio = MapNode( Reorient(orientation=config['Analysis']['Orientation']), name="reorientation", iterfield='in_file') if ('Reorient' in config['Outputs'].keys()): reorient_sink = MapNode(Function(function=copyfile, input_names=['src', 'dst']), name='reorient_copy', iterfield='src') path_pattern = 'sub-{subject}/ses-{session}/anat/sub-{subject}_ses-{session}_desc-' + config[ 'Analysis']['Orientation'] + '_{suffix}{extension}' reorient_filename = join(config['Outputs']['Reorient'], path_pattern.format(**entities)) pathlib.Path(os.path.dirname(reorient_filename)).mkdir( parents=True, exist_ok=True) reorient_sink.inputs.dst = reorient_filename wf.connect([(radio, reorient_sink, [('out_file', 'src')])]) else: radio = MapNode(Function(function=infile_to_outfile, input_names='in_file', output_names='out_file'), name='identity', iterfield='in_file') # Brain extraction bet = node_fetch.extraction_node(config, **config['BrainExtraction']) if ('BrainExtraction' in config['Outputs'].keys()): path_pattern = 'sub-{subject}/ses-{session}/anat/sub-{subject}_ses-{session}_space-' + \ config['Outputs']['StartRegistrationSpace'] + '_desc-brain_mask{extension}' brain_mask_sink = MapNode(Function(function=copyfile, input_names=['src', 'dst']), name='brain_mask_sink', iterfield='src') brain_mask_out = join(config['Outputs']['BrainExtraction'], path_pattern.format(**entities)) pathlib.Path(os.path.dirname(brain_mask_out)).mkdir(parents=True, exist_ok=True) brain_mask_sink.inputs.dst = brain_mask_out ## Lesion load calculation # Registration reg = node_fetch.registration_node(config, **config['Registration']) if ('RegistrationTransform' in config['Outputs'].keys()): path_pattern = 'sub-{subject}/ses-{session}/anat/sub-{subject}_ses-{session}_space-' + \ config['Outputs']['StartRegistrationSpace'] + '_desc-transform.mat' registration_transform_filename = join( config['Outputs']['RegistrationTransform'], path_pattern.format(**entities)) registration_transform_sink = MapNode(Function( function=copyfile, input_names=['src', 'dst']), name='registration_transf_sink', iterfield='src') pathlib.Path(os.path.dirname(registration_transform_filename)).mkdir( parents=True, exist_ok=True) registration_transform_sink.inputs.dst = registration_transform_filename wf.connect([(reg, registration_transform_sink, [('out_matrix_file', 'src')])]) # Get mask mask_path_fetcher = Node(BIDSDataGrabber( base_dir=config['LesionRoot'], subject=config['Subject'], index_derivatives=False, output_query={ 'mask': dict(**config['LesionEntities'], invalid_filters='allow') }, extra_derivatives=[config['LesionRoot']]), name='mask_grabber') if (config['Session'] is not None): mask_path_fetcher.inputs.session = config['Session'] # Apply reg file to lesion mask apply_xfm = node_fetch.apply_xfm_node(config) # Lesion load calculation if (config['Analysis']['LesionLoadCalculation']): lesion_load = MapNode(Function(function=overlap, input_names=['ref_mask', 'roi_list'], output_names='out_list'), name='overlap_calc', iterfield=['ref_mask']) roi_list = [] if (os.path.exists(config['ROIDir'])): buf = os.listdir(config['ROIDir']) roi_list = [ os.path.abspath(os.path.join(config['ROIDir'], b)) for b in buf ] else: warnings.warn(f"ROIDir ({config['ROIDir']}) doesn't exist.") buf = config['ROIList'] roi_list += [os.path.abspath(b) for b in buf] lesion_load.inputs.roi_list = roi_list # CSV output csv_output = MapNode(Function( function=csv_writer, input_names=['filename', 'data_dict', 'subject', 'session']), name='csv_output', iterfield=['data_dict']) csv_output.inputs.subject = config['Subject'] csv_output.inputs.session = config['Session'] path_pattern = 'sub-{subject}/ses-{session}/anat/sub-{subject}_ses-{session}_desc-LesionLoad.csv' csv_out_filename = join(config['Outputs']['RegistrationTransform'], path_pattern.format(**entities)) csv_output.inputs.filename = csv_out_filename wf.connect([(apply_xfm, lesion_load, [('out_file', 'ref_mask')]), (lesion_load, csv_output, [('out_list', 'data_dict')])]) ## Lesion correction if (config['Analysis']['LesionCorrection']): ## White matter removal node. Does the white matter correction; has multiple inputs that need to be supplied. wm_removal = MapNode(Function( function=white_matter_correction, input_names=[ 'image', 'wm_mask', 'lesion_mask', 'max_difference_fraction' ], output_names=['out_data', 'corrected_volume']), name='wm_removal', iterfield=['image', 'wm_mask', 'lesion_mask']) wm_removal.inputs.max_difference_fraction = config['LesionCorrection'][ 'WhiteMatterSpread'] ## File loaders # Loads the subject image, passes it to wm_removal node subject_image_loader = MapNode(Function(function=image_load, input_names=['in_filename'], output_names='out_image'), name='file_load0', iterfield='in_filename') wf.connect([ (radio, subject_image_loader, [('out_file', 'in_filename')]), (subject_image_loader, wm_removal, [('out_image', 'image')]) ]) # Loads the mask image, passes it to wm_removal node mask_image_loader = MapNode(Function(function=image_load, input_names=['in_filename'], output_names='out_image'), name='file_load2', iterfield='in_filename') wf.connect([ (mask_path_fetcher, mask_image_loader, [('mask', 'in_filename')]), (mask_image_loader, wm_removal, [('out_image', 'lesion_mask')]) ]) # Save lesion mask with white matter voxels removed output_image = MapNode(Function( function=image_write, input_names=['image', 'reference', 'file_name']), name='image_writer0', iterfield=['image', 'reference']) path_pattern = 'sub-{subject}/ses-{session}/anat/sub-{subject}_ses-{session}_space-' + \ config['Outputs']['StartRegistrationSpace'] + '_desc-CorrectedLesion_mask{extension}' lesion_corrected_filename = join(config['Outputs']['LesionCorrected'], path_pattern.format(**entities)) output_image.inputs.file_name = lesion_corrected_filename wf.connect([(wm_removal, output_image, [('out_data', 'image')]), (mask_path_fetcher, output_image, [('mask', 'reference')]) ]) ## CSV output csv_output_corr = MapNode(Function(function=csv_writer, input_names=[ 'filename', 'subject', 'session', 'data', 'data_name' ]), name='csv_output_corr', iterfield=['data']) csv_output_corr.inputs.subject = config['Subject'] csv_output_corr.inputs.session = config['Session'] csv_output_corr.inputs.data_name = 'CorrectedVolume' path_pattern = 'sub-{subject}/ses-{session}/anat/sub-{subject}_ses-{session}_desc-LesionLoad.csv' csv_out_filename = join(config['Outputs']['RegistrationTransform'], path_pattern.format(**entities)) csv_output_corr.inputs.filename = csv_out_filename wf.connect([(wm_removal, csv_output_corr, [('corrected_volume', 'data') ])]) ## White matter segmentation; either do segmentation or load the file if (config['Analysis']['WhiteMatterSegmentation']): # Config is set to do white matter segmentation # T1 intensity normalization t1_norm = MapNode(Function( function=rescale_image, input_names=['image', 'range_min', 'range_max', 'save_image'], output_names='out_file'), name='normalization', iterfield=['image']) t1_norm.inputs.range_min = config['LesionCorrection'][ 'ImageNormMin'] t1_norm.inputs.range_max = config['LesionCorrection'][ 'ImageNormMax'] t1_norm.inputs.save_image = True wf.connect([(bet, t1_norm, [('out_file', 'image')])]) # White matter segmentation wm_seg = MapNode(FAST(), name="wm_seg", iterfield='in_files') wm_seg.inputs.out_basename = "segmentation" wm_seg.inputs.img_type = 1 wm_seg.inputs.number_classes = 3 wm_seg.inputs.hyper = 0.1 wm_seg.inputs.iters_afterbias = 4 wm_seg.inputs.bias_lowpass = 20 wm_seg.inputs.segments = True wm_seg.inputs.no_pve = True ex_last = MapNode(Function(function=extract_last, input_names=['in_list'], output_names='out_entry'), name='ex_last', iterfield='in_list') file_load1 = MapNode(Function(function=image_load, input_names=['in_filename'], output_names='out_image'), name='file_load1', iterfield='in_filename') # White matter output; only necessary if white matter is segmented wm_map = MapNode(Function( function=image_write, input_names=['image', 'reference', 'file_name']), name='image_writer1', iterfield=['image', 'reference']) path_pattern = 'sub-{subject}/ses-{session}/anat/sub-{subject}_ses-{session}_space-' + \ config['Outputs']['StartRegistrationSpace'] + '_desc-WhiteMatter_mask{extension}' wm_map_filename = join(config['Outputs']['LesionCorrected'], path_pattern.format(**entities)) wm_map.inputs.file_name = wm_map_filename wf.connect([(file_load1, wm_map, [('out_image', 'image')]), (mask_path_fetcher, wm_map, [('mask', 'reference')])]) # Connect nodes in workflow wf.connect([ (wm_seg, ex_last, [('tissue_class_files', 'in_list')]), (t1_norm, wm_seg, [('out_file', 'in_files')]), # (ex_last, wm_map, [('out_entry', 'image')]), (ex_last, file_load1, [('out_entry', 'in_filename')]), (file_load1, wm_removal, [('out_image', 'wm_mask')]) ]) elif (config['Analysis']['LesionCorrection']): # No white matter segmentation should be done, but lesion correction is expected. # White matter segmentation must be supplied wm_seg_path = config['WhiteMatterSegmentationFile'] if (len(wm_seg_path) == 0 or not os.path.exists(wm_seg_path)): # Check if file exists at output path_pattern = 'sub-{subject}/ses-{session}/anat/sub-{subject}_ses-{session}_space-' + \ config['Outputs']['StartRegistrationSpace'] + '_desc-WhiteMatter_mask{extension}' wm_map_filename = join(config['Outputs']['LesionCorrected'], path_pattern.format(**entities)) if (os.path.exists(wm_map_filename)): wm_seg_path = wm_map_filename else: raise ValueError( 'Config file is inconsistent; if WhiteMatterSegmentation is false but LesionCorrection' ' is true, then WhiteMatterSegmentationFile must be defined and must exist.' ) file_load1 = MapNode(Function(function=image_load, input_names=['in_filename'], output_names='out_image'), name='file_load1', iterfield='in_filename') file_load1.inputs.in_filename = wm_seg_path # Connect nodes in workflow wf.connect([(file_load1, wm_removal, [('out_image', 'wm_mask')])]) # Connecting workflow. wf.connect([ # Starter (loader, radio, [('t1w', 'in_file')]), (radio, bet, [('out_file', 'in_file')]), (bet, reg, [('out_file', 'in_file')]), (reg, apply_xfm, [('out_matrix_file', 'in_matrix_file')]), (mask_path_fetcher, apply_xfm, [('mask', 'in_file')]), ]) try: graph_out = config['Outputs'][ 'LesionCorrected'] + '/sub-{subject}/ses-{session}/anat/'.format( **entities) wf.write_graph(graph2use='orig', dotfilename=join(graph_out, 'graph.dot'), format='png') os.remove(graph_out + 'graph.dot') os.remove(graph_out + 'graph_detailed.dot') except OSError: warnings.warn( "graphviz not installed; can't produce graph. See http://www.graphviz.org/download/ for " "installation instructions.") wf.run() return wf