示例#1
0
文件: testing.py 项目: tclose/banana
 def generate_reference_data(self, *spec_names, processor=None,
                             work_dir=None, environment=None, **kwargs):
     """
     Generates reference data and provenance against which the unittests
     are run against
     """
     if work_dir is None:
         work_dir = tempfile.mkdtemp()
     if processor is None:
         processor = SingleProc(work_dir=work_dir, **kwargs)
     if environment is None:
         environment = StaticEnv()
     analysis = self.analysis_class(  # pylint: disable=no-member
         name=self.name,  # pylint: disable=no-member
         inputs=self.inputs_dict,
         parameters=self.parameters,  # pylint: disable=no-member
         dataset=self.dataset,
         environment=environment,
         processor=processor)
     if not spec_names:
         try:
             skip_specs = self.skip_specs
         except AttributeError:
             skip_specs = ()
         spec_names = [s.name for s in analysis.data_specs()
                       if s.derived and s.name not in skip_specs]
     analysis.derive(spec_names)
def create_analysis():
    my_analysis = MyExtendedBasicBrainAnalysis(
        'my_extended_analysis',  # The name needs to be the same as the previous version
        dataset=Dataset('output/sample-datasets/depth1', depth=1),
        processor=SingleProc('work', reprocess=True),
        inputs=[
            FilesetFilter('magnitude', '.*T1w$', is_regex=True)])
    return my_analysis
示例#3
0
from arcana import LocalFileSystemRepo, SingleProc, FilesetFilter
from banana.analysis.mri import DwiAnalysis
from banana.file_format import dicom_format
import os.path as op

test_dir = op.join(op.dirname(__file__), '..', 'test', 'data',
                   'diffusion-test')

analysis = DwiAnalysis('diffusion',
                       LocalFileSystemRepo(op.join(test_dir, 'analysis')),
                       SingleProc(op.join(test_dir, 'work')),
                       inputs=[
                           FilesetFilter('magnitude',
                                         dicom_format,
                                         '16.*',
                                         is_regex=True),
                           FilesetFilter('reverse_phase',
                                         dicom_format,
                                         '15.*',
                                         is_regex=True)
                       ])

print('FA: {}'.format(
    analysis.derive('fa', derive=True).path(subject_id='subject',
                                            visit_id='visit')))
print('ADC: {}'.format(
    analysis.derive('adc', derive=True).path(subject_id='subject',
                                             visit_id='visit')))
# print('tracking: {}'.format(analysis.derive('wb_tracking').path))
示例#4
0
    def generate_test_data(cls,
                           study_class,
                           in_repo,
                           out_repo,
                           in_server=None,
                           out_server=None,
                           work_dir=None,
                           parameters=(),
                           include=None,
                           skip=(),
                           include_bases=(),
                           reprocess=False,
                           repo_depth=0,
                           modules_env=False,
                           clean_work_dir=True,
                           loggers=('nipype.workflow', 'arcana', 'banana')):
        """
        Generates reference data for a pipeline tester unittests given a study
        class and set of parameters

        Parameters
        ----------
        study_class : type(Study)
            The path to the study class to test, e.g. banana.study.MriStudy
        in_repo : str
            The path to repository that houses the input data
        out_repo : str
            If the 'xnat_server' argument is provided then out
            is interpreted as the project ID to use the XNAT
            server (the project must exist already). Otherwise
            it is interpreted as the path to a basic repository
        in_server : str | None
            The server to download the input data from
        out_server : str | None
            The server to upload the reference data to
        work_dir : str
            The work directory
        parameters : dict[str, *]
            Parameter to set when initialising the study
        include : list[str] | None
            Spec names to include in the output repository. If None all names
            except those listed in 'skip' are included
        skip : list[str]
            Spec names to skip in the generation process. Only valid if
            'include' is None
        include_bases : list[type(Study)]
            List of base classes in which all entries in their data
            specification are added to the list to include
        reprocess : bool
            Whether to reprocess the generated datasets
        repo_depth : int
            The depth of the input repository
        modules_env : bool
            Whether to use modules environment or not
        clean_work_dir : bool
            Whether to clean the Nipype work directory or not
        """

        for logger_name in loggers:
            logger = logging.getLogger(logger_name)
            logger.setLevel(logging.INFO)
            handler = logging.StreamHandler()
            formatter = logging.Formatter("%(levelname)s - %(message)s")
            handler.setFormatter(formatter)
            logger.addHandler(handler)

        if work_dir is None:
            work_dir = tempfile.mkdtemp()
        else:
            work_dir = work_dir

        if study_class.__name__.endswith('Study'):
            study_name = study_class.__name__[:-len('Study')]
        else:
            study_name = study_class.__name__

        # Get output repository to write the data to
        if in_server is not None:
            in_repo = XnatRepo(project_id=in_repo,
                               server=in_server,
                               cache_dir=op.join(work_dir, 'xnat-cache'))
        else:
            in_repo = BasicRepo(in_repo, depth=repo_depth)

        temp_repo_root = op.join(work_dir, 'temp-repo')
        if os.path.exists(temp_repo_root) and reprocess:
            shutil.rmtree(temp_repo_root)
        os.makedirs(temp_repo_root, exist_ok=True)

        temp_repo = BasicRepo(temp_repo_root, depth=repo_depth)

        inputs = None
        for session in in_repo.tree().sessions:
            session_inputs = []
            for item in chain(session.filesets, session.fields):
                if isinstance(item, Fileset):
                    inpt = InputFilesets(item.basename,
                                         item.basename,
                                         item.format,
                                         repository=in_repo)
                else:
                    inpt = InputFields(item.name,
                                       item.name,
                                       item.dtype,
                                       repository=in_repo)
                try:
                    spec = study_class.data_spec(inpt)
                except ArcanaNameError:
                    print(
                        "Skipping {} as it doesn't match a spec in {}".format(
                            item, study_class))
                else:
                    session_inputs.append(inpt)
            session_inputs = sorted(session_inputs)
            if inputs is not None and session_inputs != inputs:
                raise BananaUsageError(
                    "Inconsistent inputs ({} and {}) found in sessions of {}".
                    format(inputs, session_inputs, in_repo))
            else:
                inputs = session_inputs

        if modules_env:
            env = ModulesEnv()
        else:
            env = StaticEnv()

        study = study_class(
            study_name,
            repository=temp_repo,
            processor=SingleProc(
                work_dir,
                reprocess=reprocess,
                clean_work_dir_between_runs=clean_work_dir,
                prov_ignore=(
                    SingleProc.DEFAULT_PROV_IGNORE +
                    ['.*/pkg_version', 'workflow/nodes/.*/requirements/.*'])),
            environment=env,
            inputs=inputs,
            parameters=parameters,
            subject_ids=in_repo.tree().subject_ids,
            visit_ids=in_repo.tree().visit_ids,
            fill_tree=True)

        if include is None:
            # Get set of methods that could override pipeline getters in
            # base classes that are not included
            potentially_overridden = set()
            for cls in chain(include_bases, [study_class]):
                potentially_overridden.update(cls.__dict__.keys())

            include = set()
            for base in study_class.__mro__:
                if not hasattr(base, 'add_data_specs'):
                    continue
                for spec in base.add_data_specs:
                    if isinstance(spec,
                                  BaseInputSpecMixin) or spec.name in skip:
                        continue
                    if (base is study_class or base in include_bases
                            or spec.pipeline_getter in potentially_overridden):
                        include.add(spec.name)

        # Generate all derived data
        for spec_name in sorted(include):
            study.data(spec_name)

        # Get output repository to write the data to
        if out_server is not None:
            out_repo = XnatRepo(project_id=out_repo,
                                server=out_server,
                                cache_dir=op.join(work_dir, 'xnat-cache'))
        else:
            out_repo = BasicRepo(out_repo, depth=repo_depth)

        # Upload data to repository
        for spec in study.data_specs():
            try:
                data = study.data(spec.name, generate=False)
            except ArcanaMissingDataException:
                continue
            for item in data:
                if not item.exists:
                    logger.info("Skipping upload of non-existant {}".format(
                        item.name))
                    continue
                if skip is not None and item.name in skip:
                    logger.info("Forced skip of {}".format(item.name))
                    continue
                if item.is_fileset:
                    item_cpy = Fileset(name=item.name,
                                       format=item.format,
                                       frequency=item.frequency,
                                       path=item.path,
                                       aux_files=copy(item.aux_files),
                                       subject_id=item.subject_id,
                                       visit_id=item.visit_id,
                                       repository=out_repo,
                                       exists=True)
                else:
                    item_cpy = Field(name=item.name,
                                     value=item.value,
                                     dtype=item.dtype,
                                     frequency=item.frequency,
                                     array=item.array,
                                     subject_id=item.subject_id,
                                     visit_id=item.visit_id,
                                     repository=out_repo,
                                     exists=True)
                logger.info("Uploading {}".format(item_cpy))
                item_cpy.put()
                logger.info("Uploaded {}".format(item_cpy))
        logger.info(
            "Finished generating and uploading test data for {}".format(
                study_class))
示例#5
0
    def run_pipeline_test(self,
                          pipeline_getter,
                          add_inputs=(),
                          test_criteria=None,
                          pipeline_args=None):
        """
        Runs a pipeline and tests its outputs against the reference data

        Parameters
        ----------
        pipeline_getter : str
            The name of the pipeline to test
        add_inputs : list[str]
            Inputs that are required in the output study for the pipeline to
            run, in addition to the direct inputs of the pipeline, i.e. ones
            that are tested for with the 'provided' method in the pipeline
            construction
        test_criteria : dct[str, *] | None
            A dictionary containing the criteria by which to determine equality
            for the derived filesets. The keys are spec-names and the values
            are specific to the format of the fileset. If a spec-name is not
            in the dictionary or None is provided then the default
            criteria are used for each fileset test.
        """
        if test_criteria is None:
            test_criteria = {}
        if pipeline_args is None:
            pipeline_args = {}
        # A study with all inputs provided to determine which inputs are needed
        # by the pipeline
        ref_pipeline = self.ref_study.pipeline(pipeline_getter,
                                               pipeline_args=pipeline_args)
        inputs = []
        for spec_name in chain(ref_pipeline.input_names, add_inputs):
            try:
                inputs.append(self.inputs[spec_name])
            except KeyError:
                pass  # Inputs with a default value
        # Set up output study
        output_study = self.study_class(  # pylint: disable=not-callable
            pipeline_getter,
            repository=self.output_repo,
            processor=SingleProc(self.work_dir, reprocess='force'),
            environment=self.environment,
            inputs=inputs,
            parameters=self.parameters,
            subject_ids=self.ref_study.subject_ids,
            visit_ids=self.ref_study.visit_ids,
            enforce_inputs=False,
            fill_tree=True)
        for spec_name in ref_pipeline.output_names:
            for ref, test in zip(self.ref_study.data(spec_name),
                                 output_study.data(spec_name)):
                if ref.is_fileset:
                    try:
                        self.assertTrue(
                            test.contents_equal(
                                ref, **test_criteria.get(spec_name, {})),
                            "'{}' fileset generated by {} in {} doesn't match "
                            "reference".format(spec_name, pipeline_getter,
                                               self.study_class))
                    except Exception:
                        if hasattr(test, 'headers_diff'):
                            header_diff = test.headers_diff(ref)
                            if header_diff:
                                print("Headers don't match on {}".format(
                                    header_diff))
                                print("Test header:\n{}".format(
                                    pformat(test.get_header())))
                                print("Reference header:\n{}".format(
                                    pformat(ref.get_header())))
                            else:
                                print("Image RMS diff: {}".format(
                                    test.rms_diff(ref)))
                                test.contents_equal(ref)
                        raise
                else:
                    self.assertEqual(
                        test.value, ref.value,
                        "value for {} ({}) generated by {} in {} doesn't "
                        "match reference ({})".format(spec_name, test.value,
                                                      pipeline_getter,
                                                      self.study_class,
                                                      ref.value))
示例#6
0
class T2StarT1Study(MultiStudy, metaclass=MultiStudyMetaClass):

    add_substudy_specs = [
        SubStudySpec(
            't1', T1Study),
        SubStudySpec(
            't2star', T2starStudy,
            name_map={'t1_brain': 'coreg_ref_brain',
                      't1_coreg_to_tmpl_ants_mat': 'coreg_to_tmpl_ants_mat',
                      't1_coreg_to_tmpl_ants_warp': 'coreg_to_tmpl_ants_warp'})]


study = T2StarT1Study(
    'qsm_corrected_times',
    repository=single_echo_dir,
    processor=SingleProc(op.join(test_data, 'work')),
    inputs=[
        InputFilesets('t2star_channels', 'swi_coils_icerecon', zip_format),
        InputFilesets('t2star_header_image', 'SWI_Images', dicom_format),
        InputFilesets('t2star_swi', 'SWI_Images', dicom_format),
        InputFilesets('t1_magnitude', dicom_format,
                        't1_mprage_sag_p2_iso_1mm')],
    parameters=[
        Parameter('t2star_reorient_to_std', False),
        Parameter('t1_reorient_to_std', False)])

# print(study.data('t2star_mag_channels', clean_work_dir=True).path(
#     subject_id='SUBJECT', visit_id='VISIT'))

print(study.data('t2star_vein_mask', clean_work_dir=True).path(
    subject_id='SUBJECT', visit_id='VISIT'))
示例#7
0
#!/usr/bin/env python3
import os.path as op
from arcana import (FilesetFilter, Dataset, SingleProc, StaticEnv)
from banana.analysis.mri.dwi import DwiAnalysis
from banana.file_format import dicom_format

analysis = DwiAnalysis(name='example_diffusion',
                       dataset=Dataset(op.join(op.expanduser('~'), 'Downloads',
                                               'test-dir'),
                                       depth=0),
                       processor=SingleProc(work_dir=op.expanduser('~/work')),
                       environment=StaticEnv(),
                       inputs=[
                           FilesetFilter('magnitude',
                                         'R_L.*',
                                         dicom_format,
                                         is_regex=True),
                           FilesetFilter('reverse_phase',
                                         'L_R.*',
                                         dicom_format,
                                         is_regex=True)
                       ],
                       parameters={'num_global_tracks': int(1e6)})

# Generate whole brain tracks and return path to cached dataset
wb_tcks = analysis.data('global_tracks', derive=True)
for sess_tcks in wb_tcks:
    print("Performed whole-brain tractography for {}:{} session, the results "
          "are stored at '{}'".format(sess_tcks.subject_id, sess_tcks.visit_id,
                                      sess_tcks.path))
示例#8
0
    MotionDetection, inputs = create_motion_detection_class('MotionDetection',
                                                            ref,
                                                            ref_type,
                                                            t1s=t1s,
                                                            t2s=t2s,
                                                            dwis=dwis,
                                                            epis=epis)

    sub_id = 'work_sub_dir'
    session_id = 'work_session_dir'
    repository = BasicRepo(op.join(input_dir, 'work_dir'))
    work_dir = op.join(input_dir, 'motion_detection_cache')
    WORK_PATH = work_dir
    try:
        os.makedirs(WORK_PATH)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

    study = MotionDetection(name='MotionDetection',
                            processor=SingleProc(WORK_PATH),
                            environment=(ModulesEnv() if args.environment
                                         == 'modules' else StaticEnv()),
                            repository=repository,
                            inputs=inputs,
                            subject_ids=[sub_id],
                            visit_ids=[session_id])
    study.data('motion_detection_output')

print('Done!')