コード例 #1
0
ファイル: testing.py プロジェクト: apoz00003/banana
    def generate_test_data(cls,
                           study_class,
                           in_repo,
                           out_repo,
                           in_server=None,
                           out_server=None,
                           work_dir=None,
                           parameters=(),
                           include=None,
                           skip=(),
                           include_bases=(),
                           reprocess=False,
                           repo_depth=0,
                           modules_env=False,
                           clean_work_dir=True,
                           loggers=('nipype.workflow', 'arcana', 'banana')):
        """
        Generates reference data for a pipeline tester unittests given a study
        class and set of parameters

        Parameters
        ----------
        study_class : type(Study)
            The path to the study class to test, e.g. banana.study.MriStudy
        in_repo : str
            The path to repository that houses the input data
        out_repo : str
            If the 'xnat_server' argument is provided then out
            is interpreted as the project ID to use the XNAT
            server (the project must exist already). Otherwise
            it is interpreted as the path to a basic repository
        in_server : str | None
            The server to download the input data from
        out_server : str | None
            The server to upload the reference data to
        work_dir : str
            The work directory
        parameters : dict[str, *]
            Parameter to set when initialising the study
        include : list[str] | None
            Spec names to include in the output repository. If None all names
            except those listed in 'skip' are included
        skip : list[str]
            Spec names to skip in the generation process. Only valid if
            'include' is None
        include_bases : list[type(Study)]
            List of base classes in which all entries in their data
            specification are added to the list to include
        reprocess : bool
            Whether to reprocess the generated datasets
        repo_depth : int
            The depth of the input repository
        modules_env : bool
            Whether to use modules environment or not
        clean_work_dir : bool
            Whether to clean the Nipype work directory or not
        """

        for logger_name in loggers:
            logger = logging.getLogger(logger_name)
            logger.setLevel(logging.INFO)
            handler = logging.StreamHandler()
            formatter = logging.Formatter("%(levelname)s - %(message)s")
            handler.setFormatter(formatter)
            logger.addHandler(handler)

        if work_dir is None:
            work_dir = tempfile.mkdtemp()
        else:
            work_dir = work_dir

        if study_class.__name__.endswith('Study'):
            study_name = study_class.__name__[:-len('Study')]
        else:
            study_name = study_class.__name__

        # Get output repository to write the data to
        if in_server is not None:
            in_repo = XnatRepo(project_id=in_repo,
                               server=in_server,
                               cache_dir=op.join(work_dir, 'xnat-cache'))
        else:
            in_repo = BasicRepo(in_repo, depth=repo_depth)

        temp_repo_root = op.join(work_dir, 'temp-repo')
        if os.path.exists(temp_repo_root) and reprocess:
            shutil.rmtree(temp_repo_root)
        os.makedirs(temp_repo_root, exist_ok=True)

        temp_repo = BasicRepo(temp_repo_root, depth=repo_depth)

        inputs = None
        for session in in_repo.tree().sessions:
            session_inputs = []
            for item in chain(session.filesets, session.fields):
                if isinstance(item, Fileset):
                    inpt = InputFilesets(item.basename,
                                         item.basename,
                                         item.format,
                                         repository=in_repo)
                else:
                    inpt = InputFields(item.name,
                                       item.name,
                                       item.dtype,
                                       repository=in_repo)
                try:
                    spec = study_class.data_spec(inpt)
                except ArcanaNameError:
                    print(
                        "Skipping {} as it doesn't match a spec in {}".format(
                            item, study_class))
                else:
                    session_inputs.append(inpt)
            session_inputs = sorted(session_inputs)
            if inputs is not None and session_inputs != inputs:
                raise BananaUsageError(
                    "Inconsistent inputs ({} and {}) found in sessions of {}".
                    format(inputs, session_inputs, in_repo))
            else:
                inputs = session_inputs

        if modules_env:
            env = ModulesEnv()
        else:
            env = StaticEnv()

        study = study_class(
            study_name,
            repository=temp_repo,
            processor=SingleProc(
                work_dir,
                reprocess=reprocess,
                clean_work_dir_between_runs=clean_work_dir,
                prov_ignore=(
                    SingleProc.DEFAULT_PROV_IGNORE +
                    ['.*/pkg_version', 'workflow/nodes/.*/requirements/.*'])),
            environment=env,
            inputs=inputs,
            parameters=parameters,
            subject_ids=in_repo.tree().subject_ids,
            visit_ids=in_repo.tree().visit_ids,
            fill_tree=True)

        if include is None:
            # Get set of methods that could override pipeline getters in
            # base classes that are not included
            potentially_overridden = set()
            for cls in chain(include_bases, [study_class]):
                potentially_overridden.update(cls.__dict__.keys())

            include = set()
            for base in study_class.__mro__:
                if not hasattr(base, 'add_data_specs'):
                    continue
                for spec in base.add_data_specs:
                    if isinstance(spec,
                                  BaseInputSpecMixin) or spec.name in skip:
                        continue
                    if (base is study_class or base in include_bases
                            or spec.pipeline_getter in potentially_overridden):
                        include.add(spec.name)

        # Generate all derived data
        for spec_name in sorted(include):
            study.data(spec_name)

        # Get output repository to write the data to
        if out_server is not None:
            out_repo = XnatRepo(project_id=out_repo,
                                server=out_server,
                                cache_dir=op.join(work_dir, 'xnat-cache'))
        else:
            out_repo = BasicRepo(out_repo, depth=repo_depth)

        # Upload data to repository
        for spec in study.data_specs():
            try:
                data = study.data(spec.name, generate=False)
            except ArcanaMissingDataException:
                continue
            for item in data:
                if not item.exists:
                    logger.info("Skipping upload of non-existant {}".format(
                        item.name))
                    continue
                if skip is not None and item.name in skip:
                    logger.info("Forced skip of {}".format(item.name))
                    continue
                if item.is_fileset:
                    item_cpy = Fileset(name=item.name,
                                       format=item.format,
                                       frequency=item.frequency,
                                       path=item.path,
                                       aux_files=copy(item.aux_files),
                                       subject_id=item.subject_id,
                                       visit_id=item.visit_id,
                                       repository=out_repo,
                                       exists=True)
                else:
                    item_cpy = Field(name=item.name,
                                     value=item.value,
                                     dtype=item.dtype,
                                     frequency=item.frequency,
                                     array=item.array,
                                     subject_id=item.subject_id,
                                     visit_id=item.visit_id,
                                     repository=out_repo,
                                     exists=True)
                logger.info("Uploading {}".format(item_cpy))
                item_cpy.put()
                logger.info("Uploaded {}".format(item_cpy))
        logger.info(
            "Finished generating and uploading test data for {}".format(
                study_class))
コード例 #2
0
 def slice(self):
     return FilesetSlice(
         self.name,
         [Fileset.from_path(self.path, frequency=self.frequency)],
         format=self._format,
         frequency=self.frequency)
コード例 #3
0
 def collection(self):
     return FilesetCollection(
         self.name,
         [Fileset.from_path(self.path, frequency=self.frequency)],
         format=self._format,
         frequency=self.frequency)