Exemple #1
0
    def run(cls, args):

        if not args.quiet:
            set_loggers(args.logger)

        if args.scratch is not None:
            scratch_dir = args.scratch
        else:
            scratch_dir = op.join(op.expanduser('~'), 'banana-scratch')

        # Ensure scratch dir exists
        os.makedirs(scratch_dir, exist_ok=True)

        work_dir = op.join(scratch_dir, 'work')

        proc_args = {'reprocess': args.reprocess}

        if args.processor[0] == 'single':
            processor = SingleProc(work_dir, **proc_args)
        elif args.processor[0] == 'multi':
            if len(args.processor) > 1:
                num_processes = args.processor[1]
            elif len(args.processor) > 2:
                raise BananaUsageError(
                    "Unrecognised arguments passed to '--processor' option "
                    "({}) expected at most 1 additional argument for 'multi' "
                    "type processor (NUM_PROCS)".format(args.processor))
            else:
                num_processes = cpu_count()
            processor = MultiProc(work_dir,
                                  num_processes=num_processes,
                                  **proc_args)
        else:
            raise BananaUsageError(
                "Unrecognised processor type provided as first argument to "
                "'--processor' option ({})".format(args.processor[0]))

        if args.environment == 'static':
            environment = StaticEnv()
        else:
            environment = ModulesEnv()

        parts = args.test_class.split('.')

        sys.path.insert(0, op.join(args.test_root, parts[:-2]))
        module = import_module(parts[-2])
        sys.path.pop(0)

        test_cls = getattr(module, parts[-1])

        test_cls().generate_reference_data(*args.specs,
                                           processor=processor,
                                           environment=environment)
Exemple #2
0
 def init_dataset(dataset_path,
                  dataset_type,
                  option_str,
                  dataset_args,
                  create_root=False,
                  **kwargs):
     if dataset_type == 'bids':
         if create_root:
             os.makedirs(dataset_path, exist_ok=True)
         dataset = BidsDataset(dataset_path, **kwargs)
     elif dataset_type == 'basic':
         if len(dataset_args) != 1:
             raise BananaUsageError(
                 "Unrecognised arguments passed to '--{}' option "
                 "({}) exactly 1 additional argument is required for "
                 "'basic' type dataset (DEPTH)".format(
                     option_str, dataset_args))
         if create_root:
             os.makedirs(dataset_path, exist_ok=True)
         dataset = Dataset(dataset_path,
                           depth=int(dataset_args[0]),
                           **kwargs)
     elif dataset_type == 'xnat':
         nargs = len(dataset_args)
         if nargs < 1:
             raise BananaUsageError(
                 "Not enough arguments passed to '--{}' option "
                 "({}), at least 1 additional argument is required for "
                 "'xnat' type dataset (SERVER)".format(
                     option_str, dataset_args))
         elif nargs > 3:
             raise BananaUsageError(
                 "Unrecognised arguments passed to '--{}' option "
                 "({}), at most 3 additional arguments are accepted for"
                 " 'xnat' type dataset (SERVER, USER, PASSWORD)".format(
                     option_str, dataset_args))
         dataset = XnatRepo(
             server=dataset_args[0],
             user=(dataset_args[1] if nargs > 2 else None),
             password=(dataset_args[2] if nargs > 3 else None),
             cache_dir=op.join(scratch_dir, 'cache')).dataset(
                 dataset_path, **kwargs)
     else:
         raise BananaUsageError(
             "Unrecognised dataset type provided as first argument "
             "to '--{}' option ({})".format(option_str,
                                            dataset_args[0]))
     return dataset
Exemple #3
0
    def series_coreg_pipeline(self, **name_maps):

        pipeline = self.new_pipeline(
            'series_coreg',
            desc="Applies coregistration transform to DW series",
            citations=[],
            name_maps=name_maps)

        if self.provided('coreg_ref'):
            coreg_ref = 'coreg_ref'
        elif self.provided('coreg_ref_brain'):
            coreg_ref = 'coreg_ref_brain'
        else:
            raise BananaUsageError(
                "Cannot coregister DW series as reference ('coreg_ref' or "
                "'coreg_ref_brain') has not been provided to {}".format(self))

        # Apply co-registration transformation to DW series
        pipeline.add(
            'mask_transform',
            fsl.ApplyXFM(
                output_type='NIFTI_GZ',
                apply_xfm=True),
            inputs={
                'in_matrix_file': ('coreg_fsl_mat', text_matrix_format),
                'in_file': ('series_preproc', nifti_gz_format),
                'reference': (coreg_ref, nifti_gz_format)},
            outputs={
                'series_coreg': ('out_file', nifti_gz_format)},
            requirements=[fsl_req.v('5.0.10')],
            wall_time=10)

        return pipeline
Exemple #4
0
def resolve_class(class_str, prefixes=(DEFAULT_STUDY_CLASS_PATH, )):
    """
    Resolves a class from the '.' delimted module + class name string
    """
    parts = class_str.split('.')
    module_name = '.'.join(parts[:-1])
    class_name = parts[-1]
    cls = None
    for prefix in [None] + list(prefixes):
        if prefix is not None:
            mod_name = prefix + '.' + module_name
        else:
            mod_name = module_name
        if not mod_name:
            continue
        mod_name = mod_name.strip('.')
        try:
            module = import_module(mod_name)
        except ModuleNotFoundError:
            continue
        else:
            try:
                cls = getattr(module, class_name)
            except AttributeError:
                continue
            else:
                break
    if cls is None:
        raise BananaUsageError("Did not find class '{}'".format(class_str))
    return cls
Exemple #5
0
    def segmentation_pipeline(self, img_type=2, **name_maps):

        pipeline = self.new_pipeline(
            name='FAST_segmentation',
            name_maps=name_maps,
            desc="White matter segmentation of the reference image",
            citations=[fsl_cite])

        fast = pipeline.add('fast',
                            fsl.FAST(img_type=img_type,
                                     segments=True,
                                     out_basename='Reference_segmentation',
                                     output_type='NIFTI_GZ'),
                            inputs={'in_files': ('brain', nifti_gz_format)},
                            requirements=[fsl_req.v('5.0.9')])

        # Determine output field of split to use
        if img_type == 1:
            split_output = 'out3'
        elif img_type == 2:
            split_output = 'out2'
        else:
            raise BananaUsageError(
                "'img_type' parameter can either be 1 or 2 (not {})".format(
                    img_type))

        pipeline.add('split',
                     Split(splits=[1, 1, 1], squeeze=True),
                     inputs={'inlist': (fast, 'tissue_class_files')},
                     outputs={'wm_seg': (split_output, nifti_gz_format)})

        return pipeline
Exemple #6
0
def detect_format(path, aux_files):
    ext = split_extension(path)[1]
    aux_names = set(aux_files.keys())
    for frmt in BIDS_FORMATS:
        if frmt.extension == ext and set(frmt.aux_files.keys()) == aux_names:
            return frmt
    raise BananaUsageError(
        "No matching BIDS format matches provided path ({}) and aux files ({})"
        .format(path, aux_files))
Exemple #7
0
    def header_extraction_pipeline(self, **name_maps):

        pipeline = self.new_pipeline(
            name='header_extraction',
            name_maps=name_maps,
            desc=("Pipeline to extract the most important scan "
                  "information from the image header"),
            citations=[])

        input_format = self.input(self.header_image_spec_name).format

        if input_format == dicom_format:

            pipeline.add(
                'hd_info_extraction',
                DicomHeaderInfoExtraction(
                    multivol=False),
                inputs={
                    'dicom_folder': (self.header_image_spec_name, dicom_format)},
                outputs={
                    'tr': ('tr', float),
                    'start_time': ('start_time', str),
                    'total_duration': ('total_duration', str),
                    'real_duration': ('real_duration', str),
                    'ped': ('ped', str),
                    'pe_angle': ('pe_angle', str),
                    'echo_times': ('echo_times', float),
                    'voxel_sizes': ('voxel_sizes', float),
                    'main_field_strength': ('B0', float),
                    'main_field_orient': ('H', float)})

        elif input_format == nifti_gz_x_format:

            pipeline.add(
                'hd_info_extraction',
                NiftixHeaderInfoExtraction(),
                inputs={
                    'in_file': (self.header_image_spec_name, nifti_gz_x_format)},
                outputs={
                    'tr': ('tr', float),
                    'start_time': ('start_time', str),
                    'total_duration': ('total_duration', str),
                    'real_duration': ('real_duration', str),
                    'ped': ('ped', str),
                    'pe_angle': ('pe_angle', str),
                    'echo_times': ('echo_times', float),
                    'voxel_sizes': ('voxel_sizes', float),
                    'main_field_strength': ('B0', float),
                    'main_field_orient': ('H', float)})
        else:
            raise BananaUsageError(
                "Can only extract header info if 'magnitude' fileset "
                "is provided in DICOM or extended NIfTI format (provided {})"
                .format(self.input(self.header_image_spec_name).format))

        return pipeline
Exemple #8
0
 def init_repo(repo_path,
               repo_type,
               option_str,
               *repo_args,
               create_root=False):
     if repo_type == 'bids':
         if create_root:
             os.makedirs(repo_path, exist_ok=True)
         repo = BidsRepo(repo_path)
     elif repo_type == 'basic':
         if len(repo_args) != 1:
             raise BananaUsageError(
                 "Unrecognised arguments passed to '--{}' option "
                 "({}) exactly 1 additional argument is required for "
                 "'basic' type repository (DEPTH)".format(
                     option_str, args.respository))
         if create_root:
             os.makedirs(repo_path, exist_ok=True)
         repo = BasicRepo(repo_path, depth=repo_args[0])
     elif repo_type == 'xnat':
         nargs = len(repo_args)
         if nargs < 1:
             raise BananaUsageError(
                 "Not enough arguments passed to '--{}' option "
                 "({}), at least 1 additional argument is required for "
                 "'xnat' type repository (SERVER)".format(
                     option_str, args.respository))
         elif nargs > 3:
             raise BananaUsageError(
                 "Unrecognised arguments passed to '--{}' option "
                 "({}), at most 3 additional arguments are accepted for"
                 " 'xnat' type repository (SERVER, USER, PASSWORD)".
                 format(option_str, args.respository))
         repo = XnatRepo(project_id=repo_path,
                         server=repo_args[0],
                         user=(repo_args[1] if nargs > 2 else None),
                         password=(repo_args[2] if nargs > 3 else None),
                         cache_dir=op.join(scratch_dir, 'cache'))
     else:
         raise BananaUsageError(
             "Unrecognised repository type provided as first argument "
             "to '--{}' option ({})".format(option_str, repo_args[0]))
     return repo
Exemple #9
0
 def image_centre(cls, array, offset=None):
     """
     Returns the centre point of the non-zero voxels in the image
     """
     nonzeros = np.argwhere(array)
     min_ind = nonzeros.min(axis=0)
     max_ind = nonzeros.max(axis=0)
     centre = ((max_ind - min_ind) // 2) + min_ind
     if offset is not None:
         centre += np.array(offset, dtype=int)
         if np.any(centre < 0) or np.any(centre > array.shape):
             raise BananaUsageError(
                 "Specified offset ({}) is larger than the "
                 "dimension of the image / 2 ({})".format(
                     offset, array.shape // 2))
     return centre
Exemple #10
0
    def sinogram_unlisting_pipeline(self, **kwargs):

        pipeline = self.new_pipeline(
            name='prepare_sinogram',
            desc=('Unlist pet listmode data into several sinograms and '
                  'perform ssrb compression to prepare data for motion '
                  'detection using PCA pipeline.'),
            citations=[],
            **kwargs)

        if not self.provided('list_mode'):
            raise BananaUsageError(
                "'list_mode' was not provided as an input to the study "
                "so cannot perform sinogram unlisting")

        prepare_inputs = pipeline.add('prepare_inputs',
                                      PrepareUnlistingInputs(),
                                      inputs={
                                          'list_mode':
                                          ('list_mode', list_mode_format),
                                          'time_offset': ('time_offset', int),
                                          'num_frames': ('num_frames', int),
                                          'temporal_len':
                                          ('temporal_length', float)
                                      })

        unlisting = pipeline.add(
            'unlisting',
            PETListModeUnlisting(),
            inputs={'list_inputs': (prepare_inputs, 'out')},
            iterfield=['list_inputs'])

        ssrb = pipeline.add(
            'ssrb',
            SSRB(),
            inputs={'unlisted_sinogram': (unlisting, 'pet_sinogram')},
            requirements=[stir_req.v('3.0')])

        pipeline.add(
            'merge_sinograms',
            MergeUnlistingOutputs(),
            inputs={'sinograms': (ssrb, 'ssrb_sinograms')},
            ouputs={'ssrb_sinograms': ('sinogram_folder', directory_format)},
            joinsource='unlisting',
            joinfield=['sinograms'])

        return pipeline
Exemple #11
0
 def __init__(
         self,
         spec_name,
         primary,
         association,
         format=None,  # @ReservedAssignment @IgnorePep8
         fieldmap_order=0,
         **kwargs):
     FilesetSelector.__init__(self,
                              spec_name,
                              format,
                              frequency='per_session',
                              **kwargs)
     self._primary = primary
     if association not in self.VALID_ASSOCIATIONS:
         raise BananaUsageError(
             "Invalid association '{}' passed to BidsAssociatedSelector, "
             "can be one of '{}'".format(
                 association, "', '".join(self.VALID_ASSOCIATIONS)))
     self._association = association
     self._fieldmap_order = fieldmap_order
Exemple #12
0
 def _list_outputs(self):
     outputs = self._outputs().get()
     mag_img = nib.load(self.inputs.magnitude)
     tissue_phase_img = nib.load(self.inputs.tissue_phase)
     mask_img = nib.load(self.inputs.mask)
     mag = mag_img.get_fdata()
     tissue_phase = tissue_phase_img.get_fdata()
     mask = mask_img.get_fdata()
     if mag.shape != tissue_phase.shape:
         raise BananaUsageError(
             "Dimensions of provided magnitude and phase images "
             "differ ({} and {})".format(mag.shape, tissue_phase.shape))
     pos_mask = np.where(tissue_phase > 0) * mask  # Positive phase mask
     rho = np.ones(tissue_phase.shape)
     rho[pos_mask] = np.max(0, np.pi - (tissue_phase[pos_mask] / np.pi))
     swi = mag * (rho**self.inputs.alpha)
     # Set filenames in output spec
     outputs['out_file'] = self._gen_filename('out_file')
     out_file_img = nib.Nifti1Image(swi, mag_img.affine, mag_img.header)
     nib.save(out_file_img, outputs['out_file'])
     return outputs
Exemple #13
0
    def coreg_fsl_mat_pipeline(self, **name_maps):
        if self.branch('coreg_method', 'flirt'):
            pipeline = self._coreg_mat_pipeline(**name_maps)
        elif self.branch('coreg_method', 'ants'):
            # Convert ANTS transform to FSL transform
            pipeline = self.new_pipeline(
                name='convert_ants_to_fsl_coreg_mat',
                name_maps=name_maps)

            if self.provided('coreg_ref'):
                source = 'mag_preproc'
                ref = 'coreg_ref'
            elif self.provided('coreg_ref_brain'):
                source = 'brain'
                ref = 'coreg_ref_brain'
            else:
                raise BananaUsageError(
                    "Either 'coreg_ref' or 'coreg_ref_brain' needs to be "
                    "provided in order to derive brain_coreg or brain_coreg_"
                    "mask")

            pipeline.add(
                'transform_conv',
                ANTs2FSLMatrixConversion(
                    ras2fsl=True),
                inputs={
                    'itk_file': ('coreg_ants_mat', text_matrix_format),
                    'source_file': (source, nifti_gz_format),
                    'reference_file': (ref, nifti_gz_format)},
                outputs={
                    'coreg_fsl_mat': ('fsl_matrix', text_matrix_format)},
                requirements=[c3d_req.v('1.0')])
        else:
            self.unhandled_branch('coreg_method')

        return pipeline
Exemple #14
0
    def preprocess_pipeline(self, **name_maps):
        """
        Performs a series of FSL preprocessing steps, including Eddy and Topup

        Parameters
        ----------
        phase_dir : str{AP|LR|IS}
            The phase encode direction
        """

        # Determine whether we can correct for distortion, i.e. if reference
        # scans are provided
        # Include all references
        references = [fsl_cite, eddy_cite, topup_cite,
                      distort_correct_cite, n4_cite]
        if self.branch('preproc_denoise'):
            references.extend(dwidenoise_cites)

        pipeline = self.new_pipeline(
            name='preprocess',
            name_maps=name_maps,
            desc=(
                "Preprocess dMRI studies using distortion correction"),
            citations=references)

        # Create nodes to gradients to FSL format
        if self.input('series').format == dicom_format:
            extract_grad = pipeline.add(
                "extract_grad",
                ExtractFSLGradients(),
                inputs={
                    'in_file': ('series', dicom_format)},
                outputs={
                    'grad_dirs': ('bvecs_file', fsl_bvecs_format),
                    'bvalues': ('bvals_file', fsl_bvals_format)},
                requirements=[mrtrix_req.v('3.0rc3')])
            grad_fsl_inputs = {'in1': (extract_grad, 'bvecs_file'),
                               'in2': (extract_grad, 'bvals_file')}
        elif self.provided('grad_dirs') and self.provided('bvalues'):
            grad_fsl_inputs = {'in1': ('grad_dirs', fsl_bvecs_format),
                               'in2': ('bvalues', fsl_bvals_format)}
        else:
            raise BananaUsageError(
                "Either input 'magnitude' image needs to be in DICOM format "
                "or gradient directions and b-values need to be explicitly "
                "provided to {}".format(self))

        # Gradient merge node
        grad_fsl = pipeline.add(
            "grad_fsl",
            MergeTuple(2),
            inputs=grad_fsl_inputs)

        gradients = (grad_fsl, 'out')

        # Create node to reorient preproc out_file
        if self.branch('reorient2std'):
            reorient = pipeline.add(
                'fslreorient2std',
                fsl.utils.Reorient2Std(
                    output_type='NIFTI_GZ'),
                inputs={
                    'in_file': ('series', nifti_gz_format)},
                requirements=[fsl_req.v('5.0.9')])
            reoriented = (reorient, 'out_file')
        else:
            reoriented = ('series', nifti_gz_format)

        # Denoise the dwi-scan
        if self.branch('preproc_denoise'):
            # Run denoising
            denoise = pipeline.add(
                'denoise',
                DWIDenoise(),
                inputs={
                    'in_file': reoriented},
                requirements=[mrtrix_req.v('3.0rc3')])

            # Calculate residual noise
            subtract_operands = pipeline.add(
                'subtract_operands',
                Merge(2),
                inputs={
                    'in1': reoriented,
                    'in2': (denoise, 'noise')})

            pipeline.add(
                'subtract',
                MRCalc(
                    operation='subtract'),
                inputs={
                    'operands': (subtract_operands, 'out')},
                outputs={
                    'noise_residual': ('out_file', mrtrix_image_format)},
                requirements=[mrtrix_req.v('3.0rc3')])
            denoised = (denoise, 'out_file')
        else:
            denoised = reoriented

        # Preproc kwargs
        preproc_kwargs = {}
        preproc_inputs = {'in_file': denoised,
                          'grad_fsl': gradients}

        if self.provided('reverse_phase'):

            if self.provided('magnitude', default_okay=False):
                dwi_reference = ('magnitude', mrtrix_image_format)
            else:
                # Extract b=0 volumes
                dwiextract = pipeline.add(
                    'dwiextract',
                    ExtractDWIorB0(
                        bzero=True,
                        out_ext='.nii.gz'),
                    inputs={
                        'in_file': denoised,
                        'fslgrad': gradients},
                    requirements=[mrtrix_req.v('3.0rc3')])

                # Get first b=0 from dwi b=0 volumes
                extract_first_b0 = pipeline.add(
                    "extract_first_vol",
                    MRConvert(
                        coord=(3, 0)),
                    inputs={
                        'in_file': (dwiextract, 'out_file')},
                    requirements=[mrtrix_req.v('3.0rc3')])

                dwi_reference = (extract_first_b0, 'out_file')

            # Concatenate extracted forward rpe with reverse rpe
            combined_images = pipeline.add(
                'combined_images',
                MRCat(),
                inputs={
                    'first_scan': dwi_reference,
                    'second_scan': ('reverse_phase', mrtrix_image_format)},
                requirements=[mrtrix_req.v('3.0rc3')])

            # Create node to assign the right PED to the diffusion
            prep_dwi = pipeline.add(
                'prepare_dwi',
                PrepareDWI(),
                inputs={
                    'pe_dir': ('ped', float),
                    'ped_polarity': ('pe_angle', float)})

            preproc_kwargs['rpe_pair'] = True

            distortion_correction = True
            preproc_inputs['se_epi'] = (combined_images, 'out_file')
        else:
            distortion_correction = False
            preproc_kwargs['rpe_none'] = True

        if self.parameter('preproc_pe_dir') is not None:
            preproc_kwargs['pe_dir'] = self.parameter('preproc_pe_dir')

        preproc = pipeline.add(
            'dwipreproc',
            DWIPreproc(
                no_clean_up=True,
                out_file_ext='.nii.gz',
                # FIXME: Need to determine this programmatically
                # eddy_parameters = '--data_is_shelled '
                temp_dir='dwipreproc_tempdir',
                **preproc_kwargs),
            inputs=preproc_inputs,
            outputs={
                'eddy_par': ('eddy_parameters', eddy_par_format)},
            requirements=[mrtrix_req.v('3.0rc3'), fsl_req.v('5.0.10')],
            wall_time=60)

        if distortion_correction:
            pipeline.connect(prep_dwi, 'pe', preproc, 'pe_dir')

        mask = pipeline.add(
            'dwi2mask',
            BrainMask(
                out_file='brainmask.nii.gz'),
            inputs={
                'in_file': (preproc, 'out_file'),
                'grad_fsl': gradients},
            requirements=[mrtrix_req.v('3.0rc3')])

        # Create bias correct node
        pipeline.add(
            "bias_correct",
            DWIBiasCorrect(
                method='ants'),
            inputs={
                'grad_fsl': gradients,  # internal
                'in_file': (preproc, 'out_file'),
                'mask': (mask, 'out_file')},
            outputs={
                'series_preproc': ('out_file', nifti_gz_format)},
            requirements=[mrtrix_req.v('3.0rc3'), ants_req.v('2.0')])

        return pipeline
Exemple #15
0
def arlo(te, y):
    """
    Used in calculating the R2* signal

    Parameters
    ----------
    te  : list(float)
        array containing te values(in s)
    y : 4-d array
        a multi-echo data set of arbitrary dimension echo should be the
        last dimension

    Outputs
    -------
    r2 : 3-d array
        r2-map(in Hz) empty when only one echo is provided

    If you use this function please cite
    Pei M, Nguyen TD, Thimmappa ND, Salustri C, Dong F, Cooper MA, Li J,
    Prince MR, Wang Y. Algorithm for fast monoexponential fitting based
    on Auto-Regression on Linear Operations(ARLO) of data.
    Magn Reson Med. 2015 Feb
    (2): 843-50. doi: 10.1002/mrm.25137.
    Epub 2014 Mar 24. PubMed PMID: 24664497
    PubMed Central PMCID: PMC4175304.
    """
    num_echos = len(te)
    if num_echos < 2:
        return []

    if y.shape[-1] != num_echos:
        raise BananaUsageError(
            'Last dimension of y has size {}, expected {}'.format(
                y.shape[-1], num_echos))

    yy = np.zeros(y.shape[:3])
    yx = np.zeros(y.shape[:3])
    beta_yx = np.zeros(y.shape[:3])
    beta_xx = np.zeros(y.shape[:3])

    for j in range(num_echos - 2):
        alpha = ((te[j + 2] - te[j]) *
                 (te[j + 2] - te[j]) / 2) / (te[j + 1] - te[j])
        tmp = (2 * te[j + 2] * te[j + 2] - te[j] * te[j + 2] - te[j] * te[j] +
               3 * te[j] * te[j + 1] - 3 * te[j + 1] * te[j + 2]) / 6
        beta = tmp / (te[j + 2] - te[j + 1])
        gamma = tmp / (te[j + 1] - te[j])

        echo0 = y[:, :, :, j]
        echo1 = y[:, :, :, j + 1]
        echo2 = y[:, :, :, j + 2]

        # [te[j+2]-te[j]-alpha+gamma alpha-beta-gamma beta]/((te[2]-te[1])/3)
        y1 = (echo0 * (te[j + 2] - te[j] - alpha + gamma) + echo1 *
              (alpha - beta - gamma) + echo2 * beta)
        x1 = echo0 - echo2

        yy = yy + y1 * y1
        yx = yx + y1 * x1
        beta_yx = beta_yx + beta * y1 * x1
        beta_xx = beta_xx + beta * x1 * x1

    r2 = (yx + beta_xx) / (beta_yx + yy)

    # Set NaN and inf values to 0.0
    r2[~np.isfinite(r2)] = 0.0
    return r2
Exemple #16
0
    def brain_coreg_pipeline(self, **name_maps):
        """
        Coregistered + brain-extracted images can be derived in 2-ways. If an
        explicit brain-extracted reference is provided to
        'coreg_ref_brain' then that is used to coregister a brain extracted
        image against. Alternatively, if only a skull-included reference is
        provided then the registration is performed with skulls-included and
        then brain extraction is performed after
        """
        if self.provided('coreg_ref_brain'):
            # If a reference brain extracted image is provided we coregister
            # the brain extracted image to that
            pipeline = self.coreg_pipeline(
                name='brain_coreg',
                name_maps=dict(
                    input_map={
                        'mag_preproc': 'brain',
                        'coreg_ref': 'coreg_ref_brain'},
                    output_map={
                        'mag_coreg': 'brain_coreg'},
                    name_maps=name_maps))

            # Apply coregistration transform to brain mask
            if self.branch('coreg_method', 'flirt'):
                pipeline.add(
                    'mask_transform',
                    ApplyXFM(
                        output_type='NIFTI_GZ',
                        apply_xfm=True),
                    inputs={
                        'in_matrix_file': (pipeline.node('flirt'),
                                           'out_matrix_file'),
                        'in_file': ('brain_mask', nifti_gz_format),
                        'reference': ('coreg_ref_brain', nifti_gz_format)},
                    outputs={
                        'brain_mask_coreg': ('out_file', nifti_gz_format)},
                    requirements=[fsl_req.v('5.0.10')],
                    wall_time=10)

            elif self.branch('coreg_method', 'ants'):
                # Convert ANTs transform matrix to FSL format if we have used
                # Ants registration so we can apply the transform using
                # ApplyXFM
                pipeline.add(
                    'mask_transform',
                    ants.resampling.ApplyTransforms(
                        interpolation='Linear',
                        input_image_type=3,
                        invert_transform_flags=[True, True, False]),
                    inputs={
                        'input_image': ('brain_mask', nifti_gz_format),
                        'reference_image': ('coreg_ref_brain',
                                            nifti_gz_format),
                        'transforms': (pipeline.node('ants_reg'),
                                       'forward_transforms')},
                    requirements=[ants_req.v('1.9')], mem_gb=16,
                    wall_time=30)
            else:
                self.unhandled_branch('coreg_method')

        elif self.provided('coreg_ref'):
            # If coreg_ref is provided then we co-register the non-brain
            # extracted images and then brain extract the co-registered image
            pipeline = self.brain_extraction_pipeline(
                name='bet_coreg',
                input_map={'mag_preproc': 'mag_coreg'},
                output_map={'brain': 'brain_coreg',
                            'brain_mask': 'brain_mask_coreg'},
                name_maps=name_maps)
        else:
            raise BananaUsageError(
                "Either 'coreg_ref' or 'coreg_ref_brain' needs to be provided "
                "in order to derive brain_coreg or brain_mask_coreg")
        return pipeline
Exemple #17
0
    def generate_test_data(cls,
                           study_class,
                           in_repo,
                           out_repo,
                           in_server=None,
                           out_server=None,
                           work_dir=None,
                           parameters=(),
                           include=None,
                           skip=(),
                           include_bases=(),
                           reprocess=False,
                           repo_depth=0,
                           modules_env=False,
                           clean_work_dir=True,
                           loggers=('nipype.workflow', 'arcana', 'banana')):
        """
        Generates reference data for a pipeline tester unittests given a study
        class and set of parameters

        Parameters
        ----------
        study_class : type(Study)
            The path to the study class to test, e.g. banana.study.MriStudy
        in_repo : str
            The path to repository that houses the input data
        out_repo : str
            If the 'xnat_server' argument is provided then out
            is interpreted as the project ID to use the XNAT
            server (the project must exist already). Otherwise
            it is interpreted as the path to a basic repository
        in_server : str | None
            The server to download the input data from
        out_server : str | None
            The server to upload the reference data to
        work_dir : str
            The work directory
        parameters : dict[str, *]
            Parameter to set when initialising the study
        include : list[str] | None
            Spec names to include in the output repository. If None all names
            except those listed in 'skip' are included
        skip : list[str]
            Spec names to skip in the generation process. Only valid if
            'include' is None
        include_bases : list[type(Study)]
            List of base classes in which all entries in their data
            specification are added to the list to include
        reprocess : bool
            Whether to reprocess the generated datasets
        repo_depth : int
            The depth of the input repository
        modules_env : bool
            Whether to use modules environment or not
        clean_work_dir : bool
            Whether to clean the Nipype work directory or not
        """

        for logger_name in loggers:
            logger = logging.getLogger(logger_name)
            logger.setLevel(logging.INFO)
            handler = logging.StreamHandler()
            formatter = logging.Formatter("%(levelname)s - %(message)s")
            handler.setFormatter(formatter)
            logger.addHandler(handler)

        if work_dir is None:
            work_dir = tempfile.mkdtemp()
        else:
            work_dir = work_dir

        if study_class.__name__.endswith('Study'):
            study_name = study_class.__name__[:-len('Study')]
        else:
            study_name = study_class.__name__

        # Get output repository to write the data to
        if in_server is not None:
            in_repo = XnatRepo(project_id=in_repo,
                               server=in_server,
                               cache_dir=op.join(work_dir, 'xnat-cache'))
        else:
            in_repo = BasicRepo(in_repo, depth=repo_depth)

        temp_repo_root = op.join(work_dir, 'temp-repo')
        if os.path.exists(temp_repo_root) and reprocess:
            shutil.rmtree(temp_repo_root)
        os.makedirs(temp_repo_root, exist_ok=True)

        temp_repo = BasicRepo(temp_repo_root, depth=repo_depth)

        inputs = None
        for session in in_repo.tree().sessions:
            session_inputs = []
            for item in chain(session.filesets, session.fields):
                if isinstance(item, Fileset):
                    inpt = InputFilesets(item.basename,
                                         item.basename,
                                         item.format,
                                         repository=in_repo)
                else:
                    inpt = InputFields(item.name,
                                       item.name,
                                       item.dtype,
                                       repository=in_repo)
                try:
                    spec = study_class.data_spec(inpt)
                except ArcanaNameError:
                    print(
                        "Skipping {} as it doesn't match a spec in {}".format(
                            item, study_class))
                else:
                    session_inputs.append(inpt)
            session_inputs = sorted(session_inputs)
            if inputs is not None and session_inputs != inputs:
                raise BananaUsageError(
                    "Inconsistent inputs ({} and {}) found in sessions of {}".
                    format(inputs, session_inputs, in_repo))
            else:
                inputs = session_inputs

        if modules_env:
            env = ModulesEnv()
        else:
            env = StaticEnv()

        study = study_class(
            study_name,
            repository=temp_repo,
            processor=SingleProc(
                work_dir,
                reprocess=reprocess,
                clean_work_dir_between_runs=clean_work_dir,
                prov_ignore=(
                    SingleProc.DEFAULT_PROV_IGNORE +
                    ['.*/pkg_version', 'workflow/nodes/.*/requirements/.*'])),
            environment=env,
            inputs=inputs,
            parameters=parameters,
            subject_ids=in_repo.tree().subject_ids,
            visit_ids=in_repo.tree().visit_ids,
            fill_tree=True)

        if include is None:
            # Get set of methods that could override pipeline getters in
            # base classes that are not included
            potentially_overridden = set()
            for cls in chain(include_bases, [study_class]):
                potentially_overridden.update(cls.__dict__.keys())

            include = set()
            for base in study_class.__mro__:
                if not hasattr(base, 'add_data_specs'):
                    continue
                for spec in base.add_data_specs:
                    if isinstance(spec,
                                  BaseInputSpecMixin) or spec.name in skip:
                        continue
                    if (base is study_class or base in include_bases
                            or spec.pipeline_getter in potentially_overridden):
                        include.add(spec.name)

        # Generate all derived data
        for spec_name in sorted(include):
            study.data(spec_name)

        # Get output repository to write the data to
        if out_server is not None:
            out_repo = XnatRepo(project_id=out_repo,
                                server=out_server,
                                cache_dir=op.join(work_dir, 'xnat-cache'))
        else:
            out_repo = BasicRepo(out_repo, depth=repo_depth)

        # Upload data to repository
        for spec in study.data_specs():
            try:
                data = study.data(spec.name, generate=False)
            except ArcanaMissingDataException:
                continue
            for item in data:
                if not item.exists:
                    logger.info("Skipping upload of non-existant {}".format(
                        item.name))
                    continue
                if skip is not None and item.name in skip:
                    logger.info("Forced skip of {}".format(item.name))
                    continue
                if item.is_fileset:
                    item_cpy = Fileset(name=item.name,
                                       format=item.format,
                                       frequency=item.frequency,
                                       path=item.path,
                                       aux_files=copy(item.aux_files),
                                       subject_id=item.subject_id,
                                       visit_id=item.visit_id,
                                       repository=out_repo,
                                       exists=True)
                else:
                    item_cpy = Field(name=item.name,
                                     value=item.value,
                                     dtype=item.dtype,
                                     frequency=item.frequency,
                                     array=item.array,
                                     subject_id=item.subject_id,
                                     visit_id=item.visit_id,
                                     repository=out_repo,
                                     exists=True)
                logger.info("Uploading {}".format(item_cpy))
                item_cpy.put()
                logger.info("Uploaded {}".format(item_cpy))
        logger.info(
            "Finished generating and uploading test data for {}".format(
                study_class))
Exemple #18
0
    def _list_outputs(self):
        outputs = self._outputs().get()
        mag_fname = outputs['magnitude'] = self._gen_filename('magnitude')
        phase_fname = outputs['phase'] = self._gen_filename('phase')
        q_fname = outputs['q'] = self._gen_filename('q')
        mag_paths = defaultdict(dict)
        phase_paths = defaultdict(dict)
        # Compile regular expression for extracting channel, echo and
        # complex axis indices from input file names
        fname_re = re.compile(self.inputs.in_fname_re)
        for dpath, dct in ((self.inputs.magnitudes_dir, mag_paths),
                           (self.inputs.phases_dir, phase_paths)):
            for fname in os.listdir(dpath):
                match = fname_re.match(fname)
                if match is None:
                    logger.warning("Skipping '{}' file in '{}' as it doesn't "
                                   "match expected filename pattern for raw "
                                   "channel files ('{}')".format(
                                       fname, dpath, self.inputs.in_fname_re))
                    continue
            dct[match.group('channel')][match.group('echo')] = op.join(
                dpath, fname)
        if len(mag_paths) != len(phase_paths):
            raise BananaUsageError(
                "Mismatching number of channels between magnitude and phase "
                "channels")
        hip = None
        for chann_i in mag_paths:
            if len(mag_paths[chann_i]) != 2:
                raise BananaUsageError(
                    "Expected exactly two echos for channel magnitude {}, "
                    "found {}".format(chann_i, len(mag_paths[chann_i])))
            if len(phase_paths[chann_i]) != 2:
                raise BananaUsageError(
                    "Expected exactly two echos for channel magnitude {}, "
                    "found {}".format(chann_i, len(phase_paths[chann_i])))
            mag1 = nib.load(mag_paths[chann_i][0])
            phase1 = nib.load(phase_paths[chann_i][0])
            mag2 = nib.load(mag_paths[chann_i][1])
            phase2 = nib.load(phase_paths[chann_i][1])

            # Get array data
            mag1_array = mag1.get_fdata()
            phase1_array = phase1.get_fdata()
            mag2_array = mag2.get_fdata()
            phase2_array = phase2.get_fdata()

            if hip is None:
                hip = np.zeros(mag1_array.shape)
                sum_mag = np.zeros(mag1_array.shape)
            hip += mag1_array * mag2_array * np.exp(
                -1j * (phase1_array - phase2_array))
            sum_mag += mag1_array * mag2_array
        # Get magnitude and phase
        phase = np.angle(hip)
        mag = np.abs(hip)
        q = mag / sum_mag
        # Create NIfTI images
        phase_img = nib.Nifti1Image(phase, phase1.affine, phase1.header)
        mag_img = nib.Nifti1Image(mag, mag1.affine, mag1.header)
        q_img = nib.Nifti1Image(q, mag1.affine, mag1.header)
        # Save NIfTIs
        nib.save(phase_img, phase_fname)
        nib.save(mag_img, mag_fname)
        nib.save(q_img, q_fname)
        return outputs
Exemple #19
0
 def _list_outputs(self):
     outputs = self._outputs().get()
     hip = None
     for fname in os.listdir(self.inputs.channels_dir):
         img = nib.load(op.join(self.inputs.channels_dir, fname))
         img_data = img.get_fdata()
         if hip is None:
             hip = np.zeros(img_data.shape[:3], dtype=complex)
             sum_mag = np.zeros(img_data.shape[:3])
             r2star = np.zeros(img_data.shape[:3])
             mag = np.zeros(img_data.shape[:3])
         num_echos = img_data.shape[3]
         if len(self.inputs.echo_times) != num_echos:
             raise BananaUsageError(
                 "Number of echos differs from provided dataset ({}) and "
                 "echo times ({})".format(num_echos,
                                          self.inputs.echo_times))
         if num_echos < 2:
             raise BananaUsageError(
                 "At least two echos required for channel magnitude {}, "
                 "found {}".format(fname, num_echos))
         cmplx_coil = np.squeeze(img_data[:, :, :, :, 0] +
                                 1j * img_data[:, :, :, :, 1])
         phase_coil = np.angle(cmplx_coil)
         mag_coil = np.abs(cmplx_coil)
         for i, j in zip(range(0, num_echos - 1), range(1, num_echos)):
             # Get successive echos
             mag_a = mag_coil[:, :, :, i]
             mag_b = mag_coil[:, :, :, j]
             phase_a = phase_coil[:, :, :, i]
             phase_b = phase_coil[:, :, :, j]
             # Combine HIP and sum and total magnitude
             hip += mag_a * mag_b * np.exp(-1j * (phase_a - phase_b))
             sum_mag += mag_a * mag_b
         # Calculate R2*
         sum_echo_mags = np.sum(mag_coil, axis=3)
         r2star += sum_echo_mags * arlo(self.inputs.echo_times, mag_coil)
         mag += sum_echo_mags
     if hip is None:
         raise BananaUsageError(
             "No channels loaded from channels directory {}".format(
                 self.inputs.channels_dir))
     # Get magnitude and phase
     phase = np.angle(hip)
     mag = np.abs(hip)
     q = mag / sum_mag
     # Set filenames in output spec
     outputs['phase'] = self._gen_filename('phase')
     outputs['magnitude'] = self._gen_filename('magnitude')
     outputs['q'] = self._gen_filename('q')
     outputs['r2star'] = self._gen_filename('r2star')
     # Create NIfTI images
     phase_img = nib.Nifti1Image(phase, img.affine, img.header)
     mag_img = nib.Nifti1Image(mag, img.affine, img.header)
     q_img = nib.Nifti1Image(q, img.affine, img.header)
     r2star_img = nib.Nifti1Image(r2star, img.affine, img.header)
     # Save NIfTIs
     nib.save(phase_img, outputs['phase'])
     nib.save(mag_img, outputs['magnitude'])
     nib.save(q_img, outputs['q'])
     nib.save(r2star_img, outputs['r2star'])
     return outputs
Exemple #20
0
    def run(cls, args):

        set_loggers(args.logger)

        study_class = resolve_class(args.study_class)

        if args.scratch is not None:
            scratch_dir = args.scratch
        else:
            scratch_dir = op.join(op.expanduser('~'), 'banana-scratch')

        # Ensure scratch dir exists
        os.makedirs(scratch_dir, exist_ok=True)

        work_dir = op.join(scratch_dir, 'work')

        if args.repository is None:
            if args.input:
                repository_type = 'basic'
            else:
                repository_type = 'bids'
        else:
            repository_type = args.repository[0]

        # Load subject_ids from file if single value is provided with
        # a '/' in the string
        if (args.subject_ids is not None and len(args.subject_ids)
                and '/' in args.subject_ids[0]):
            with open(args.subject_ids[0]) as f:
                subject_ids = f.read().split()
        else:
            subject_ids = args.subject_ids

        # Load visit_ids from file if single value is provided with
        # a '/' in the string
        if (args.visit_ids is not None and len(args.visit_ids)
                and '/' in args.visit_ids[0]):
            with open(args.visit_ids[0]) as f:
                visit_ids = f.read().split()
        else:
            visit_ids = args.visit_ids

        def init_repo(repo_path,
                      repo_type,
                      option_str,
                      *repo_args,
                      create_root=False):
            if repo_type == 'bids':
                if create_root:
                    os.makedirs(repo_path, exist_ok=True)
                repo = BidsRepo(repo_path)
            elif repo_type == 'basic':
                if len(repo_args) != 1:
                    raise BananaUsageError(
                        "Unrecognised arguments passed to '--{}' option "
                        "({}) exactly 1 additional argument is required for "
                        "'basic' type repository (DEPTH)".format(
                            option_str, args.respository))
                if create_root:
                    os.makedirs(repo_path, exist_ok=True)
                repo = BasicRepo(repo_path, depth=repo_args[0])
            elif repo_type == 'xnat':
                nargs = len(repo_args)
                if nargs < 1:
                    raise BananaUsageError(
                        "Not enough arguments passed to '--{}' option "
                        "({}), at least 1 additional argument is required for "
                        "'xnat' type repository (SERVER)".format(
                            option_str, args.respository))
                elif nargs > 3:
                    raise BananaUsageError(
                        "Unrecognised arguments passed to '--{}' option "
                        "({}), at most 3 additional arguments are accepted for"
                        " 'xnat' type repository (SERVER, USER, PASSWORD)".
                        format(option_str, args.respository))
                repo = XnatRepo(project_id=repo_path,
                                server=repo_args[0],
                                user=(repo_args[1] if nargs > 2 else None),
                                password=(repo_args[2] if nargs > 3 else None),
                                cache_dir=op.join(scratch_dir, 'cache'))
            else:
                raise BananaUsageError(
                    "Unrecognised repository type provided as first argument "
                    "to '--{}' option ({})".format(option_str, repo_args[0]))
            return repo

        repository = init_repo(args.repository_path, repository_type,
                               'repository', *args.repository)

        if args.output_repository is not None:
            input_repository = repository
            tree = repository.cached_tree()
            if subject_ids is None:
                subject_ids = list(tree.subject_ids)
            if visit_ids is None:
                visit_ids = list(tree.visit_ids)
            fill_tree = True
            nargs = len(args.output_repository)
            if nargs == 1:
                repo_type = 'basic'
                out_path = args.output_repository[0]
                out_repo_args = [input_repository.depth]
            else:
                repo_type = args.output_repository[0]
                out_path = args.output_repository[1]
                out_repo_args = args.output_repository[2:]
            repository = init_repo(out_path,
                                   repo_type,
                                   'output_repository',
                                   *out_repo_args,
                                   create_root=True)
        else:
            input_repository = None
            fill_tree = False

        if args.email is not None:
            email = args.email
        else:
            try:
                email = os.environ['EMAIL']
            except KeyError:
                email = None

        proc_args = {'reprocess': args.reprocess}

        if args.processor[0] == 'single':
            processor = SingleProc(work_dir, **proc_args)
        elif args.processor[0] == 'multi':
            if len(args.processor) > 1:
                num_processes = args.processor[1]
            elif len(args.processor) > 2:
                raise BananaUsageError(
                    "Unrecognised arguments passed to '--processor' option "
                    "({}) expected at most 1 additional argument for 'multi' "
                    "type processor (NUM_PROCS)".format(args.processor))
            else:
                num_processes = cpu_count()
            processor = MultiProc(work_dir,
                                  num_processes=num_processes,
                                  **proc_args)
        elif args.processor[0] == 'slurm':
            if email is None:
                raise BananaUsageError(
                    "Email needs to be provided either via '--email' argument "
                    "or set in 'EMAIL' environment variable for SLURM "
                    "processor")
            nargs = len(args.processor)
            if nargs > 3:
                raise BananaUsageError(
                    "Unrecognised arguments passed to '--processor' option "
                    "with 'slurm' type ({}), expected at most 2 additional "
                    "arguments [ACCOUNT, PARTITION]".format(args.processor))
            processor = SlurmProc(
                work_dir,
                account=(args.processor[1] if nargs >= 2 else None),
                partition=(args.processor[2] if nargs >= 3 else None),
                email=email,
                mail_on=('FAIL', ),
                **proc_args)
        else:
            raise BananaUsageError(
                "Unrecognised processor type provided as first argument to "
                "'--processor' option ({})".format(args.processor[0]))

        if args.environment == 'static':
            environment = StaticEnv()
        else:
            environment = ModulesEnv()

        parameters = {}
        for name, value in args.parameter:
            parameters[name] = parse_value(
                value, dtype=study_class.param_spec(name).dtype)

        if input_repository is not None and input_repository.type == 'bids':
            inputs = study_class.get_bids_inputs(args.bids_task,
                                                 repository=input_repository)
        else:
            inputs = {}
        for name, pattern in args.input:
            spec = study_class.data_spec(name)
            if spec.is_fileset:
                inpt_cls = InputFilesets
            else:
                inpt_cls = InputFields
            inputs[name] = inpt_cls(name,
                                    pattern=pattern,
                                    is_regex=True,
                                    repository=input_repository)

        study = study_class(name=args.study_name,
                            repository=repository,
                            processor=processor,
                            environment=environment,
                            inputs=inputs,
                            parameters=parameters,
                            subject_ids=subject_ids,
                            visit_ids=visit_ids,
                            enforce_inputs=args.enforce_inputs,
                            fill_tree=fill_tree,
                            bids_task=args.bids_task)

        for spec_name in args.cache:
            spec = study.bound_spec(spec_name)
            if not isinstance(spec, InputFilesets):
                raise BananaUsageError(
                    "Cannot cache non-input fileset '{}'".format(spec_name))
            spec.cache()

        # Generate data
        study.data(args.derivatives)

        logger.info("Generated derivatives for '{}'".format(args.derivatives))
Exemple #21
0
    def display_slice_panel(self,
                            filesets,
                            img_size=5,
                            row_kwargs=None,
                            offset=None,
                            **kwargs):
        """
        Displays an image in a Nx3 panel axial, coronal and sagittal
        slices for the filesets correspdong to each of the data names
        provided.

        Parameters
        ----------
        data_names : List[str]
            List of image names to plot as rows of a panel
        size : Tuple(2)[int]
            Size of the figure to plot
        row_kargs : List[Dict[str, *]]
            A list of row-specific kwargs to passed on to
            _display_mid_slices
        offset : Tuple(3)[int]
            An array of integers with which to offset the slices displayed
        """
        n_rows = len(filesets)
        if row_kwargs is None:
            row_kwargs = repeat({})
        elif not n_rows == len(row_kwargs):
            raise BananaUsageError("Length of row_kwargs ({}) needs to "
                                   "match length of filesets ({})".format(
                                       len(row_kwargs), n_rows))
        # Set up figure
        gs = GridSpec(n_rows, 3)
        gs.update(wspace=0.0, hspace=0.0)
        fig = plt.figure(figsize=(3 * img_size, n_rows * img_size))
        # Loop through derivatives and generate image
        for i, (fileset, rkwargs) in enumerate(zip(filesets, row_kwargs)):
            array = fileset.get_array()
            header = fileset.get_header()
            if fileset.format in (nifti_format, nifti_gz_format):
                vox = header['pixdim'][1:4]
            elif fileset.format == dicom_format:
                vox = [float(v) for v in header.PixelSpacing]
                vox.append(float(header.SliceThickness))
            else:
                raise BananaUsageError(
                    "'{}' format images are not supported for display slice ".
                    format(fileset.format))
            rkwargs = copy(rkwargs)
            rkwargs.update(kwargs)
            try:
                self._display_mid_slices(array,
                                         vox,
                                         fig,
                                         gs,
                                         i,
                                         offset=offset,
                                         **rkwargs)
            except BananaUsageError as e:
                raise BananaUsageError(
                    str(e) + " displaying {}".format(fileset.path))
        # Remove space around figure
        plt.tight_layout(0.0)