예제 #1
0
파일: NLIN.py 프로젝트: psteadman/pydpiper
def NLIN_pipeline(options):

    # if options.application.files is None:
    #     raise ValueError("Please, some files! (or try '--help')")  # TODO make a util procedure for this

    output_dir    = options.application.output_directory
    pipeline_name = options.application.pipeline_name

    # TODO this is tedious and annoyingly similar to the registration chain and MBM and LSQ6 ...
    processed_dir = os.path.join(output_dir, pipeline_name + "_processed")
    nlin_dir      = os.path.join(output_dir, pipeline_name + "_nlin")

    resolution = (options.registration.resolution  # TODO does using the finest resolution here make sense?
                  or min([get_resolution_from_file(f) for f in options.application.files]))

    imgs = get_imgs(options.application)

    # imgs = [MincAtom(f, pipeline_sub_dir=processed_dir) for f in options.application.files]

    # determine NLIN settings by overriding defaults with
    # any settings present in protocol file, if it exists
    # could add a hook to print a message announcing completion, output files,
    # add more stages here to make a CSV

    initial_target_mask = MincAtom(options.nlin.target_mask) if options.nlin.target_mask else None
    initial_target = MincAtom(options.nlin.target, mask=initial_target_mask)

    full_hierarchy = get_nonlinear_configuration_from_options(nlin_protocol=options.nlin.nlin_protocol,
                                                              reg_method=options.nlin.reg_method,
                                                              file_resolution=resolution)

    s = Stages()

    nlin_result = s.defer(nlin_build_model(imgs, initial_target=initial_target, conf=full_hierarchy, nlin_dir=nlin_dir))

    # TODO return these?
    inverted_xfms = [s.defer(invert_xfmhandler(xfm)) for xfm in nlin_result.output]

    if options.stats.calc_stats:
        # TODO: put the stats part behind a flag ...

        determinants = [s.defer(determinants_at_fwhms(
                                  xfm=inv_xfm,
                                  inv_xfm=xfm,
                                  blur_fwhms=options.stats.stats_kernels))
                        for xfm, inv_xfm in zip(nlin_result.output, inverted_xfms)]

        return Result(stages=s,
                      output=Namespace(nlin_xfms=nlin_result,
                                       avg_img=nlin_result.avg_img,
                                       determinants=determinants))
    else:
        # there's no consistency in what gets returned, yikes ...
        return Result(stages=s, output=Namespace(nlin_xfms=nlin_result, avg_img=nlin_result.avg_img))
예제 #2
0
def NLIN_pipeline(options):

    if options.application.files is None:
        raise ValueError("Please, some files! (or try '--help')")  # TODO make a util procedure for this

    output_dir    = options.application.output_directory
    pipeline_name = options.application.pipeline_name

    # TODO this is tedious and annoyingly similar to the registration chain and MBM and LSQ6 ...
    processed_dir = os.path.join(output_dir, pipeline_name + "_processed")
    nlin_dir      = os.path.join(output_dir, pipeline_name + "_nlin")

    resolution = (options.registration.resolution  # TODO does using the finest resolution here make sense?
                  or min([get_resolution_from_file(f) for f in options.application.files]))

    imgs = [MincAtom(f, pipeline_sub_dir=processed_dir) for f in options.application.files]

    # determine NLIN settings by overriding defaults with
    # any settings present in protocol file, if it exists
    # could add a hook to print a message announcing completion, output files,
    # add more stages here to make a CSV

    initial_target_mask = MincAtom(options.nlin.target_mask) if options.nlin.target_mask else None
    initial_target = MincAtom(options.nlin.target, mask=initial_target_mask)

    full_hierarchy = get_nonlinear_configuration_from_options(nlin_protocol=options.nlin.nlin_protocol,
                                                              flag_nlin_protocol=next(iter(options.nlin.flags_.nlin_protocol)),
                                                              reg_method=options.nlin.reg_method,
                                                              file_resolution=resolution)

    s = Stages()

    nlin_result = s.defer(nlin_build_model(imgs, initial_target=initial_target, conf=full_hierarchy, nlin_dir=nlin_dir))

    # TODO return these?
    inverted_xfms = [s.defer(invert_xfmhandler(xfm)) for xfm in nlin_result.output]

    if options.stats.calc_stats:
        # TODO: put the stats part behind a flag ...

        determinants = [s.defer(determinants_at_fwhms(
                                  xfm=inv_xfm,
                                  inv_xfm=xfm,
                                  blur_fwhms=options.stats.stats_kernels))
                        for xfm, inv_xfm in zip(nlin_result.output, inverted_xfms)]

        return Result(stages=s,
                      output=Namespace(nlin_xfms=nlin_result,
                                       avg_img=nlin_result.avg_img,
                                       determinants=determinants))
    else:
        # there's no consistency in what gets returned, yikes ...
        return Result(stages=s, output=Namespace(nlin_xfms=nlin_result, avg_img=nlin_result.avg_img))
예제 #3
0
def tamarack(imgs: pd.DataFrame, options):
    # columns of the input df: `img` : MincAtom, `timept` : number, ...
    # columns of the pride of models : 'timept' : number, 'model' : MincAtom
    s = Stages()

    # TODO some assertions that the pride_of_models, if provided, is correct, and that this is intended target type

    def group_options(options, timepoint):
        options = copy.deepcopy(options)

        if options.mbm.lsq6.target_type == TargetType.pride_of_models:
            options = copy.deepcopy(options)
            targets = get_closest_model_from_pride_of_models(
                pride_of_models_dict=get_pride_of_models_mapping(
                    pride_csv=options.mbm.lsq6.target_file,
                    output_dir=options.application.output_directory,
                    pipeline_name=options.application.pipeline_name),
                time_point=timepoint)

            options.mbm.lsq6 = options.mbm.lsq6.replace(
                target_type=TargetType.initial_model,
                target_file=targets.registration_standard.path)

        #    resolution = (options.registration.resolution
        #                  or get_resolution_from_file(targets.registration_standard.path))
        #    options.registration = options.registration.replace(resolution=resolution)

        # FIXME use of registration_standard here is quite wrong ...
        # part of the trouble is that mbm calls registration_targets itself,
        # so we can't send this RegistrationTargets to `mbm` directly ...
        # one option: add yet another optional arg to `mbm` ...
        else:
            targets = s.defer(
                registration_targets(lsq6_conf=options.mbm.lsq6,
                                     app_conf=options.application,
                                     reg_conf=options.registration,
                                     first_input_file=imgs.filename.iloc[0]))

        resolution = (options.registration.resolution
                      or get_resolution_from_file(
                          targets.registration_standard.path))

        # This must happen after calling registration_targets otherwise it will resample to options.registration.resolution
        options.registration = options.registration.replace(
            resolution=resolution)

        return options

    # build all first-level models:
    first_level_results = (
        imgs  # TODO 'group' => 'timept' ?
        .groupby('group', as_index=False
                 )  # the usual annoying pattern to do an aggregate with access
        .aggregate({'file': lambda files: list(files)}
                   )  # to the groupby object's keys ... TODO: fix
        .rename(columns={
            'file': "files"
        }).assign(options=lambda df: df.apply(
            axis=1, func=lambda row: group_options(options, row.group))
                  ).assign(build_model=lambda df: df.apply(
                      axis=1,
                      func=lambda row: s.defer(
                          mbm(imgs=row.files,
                              options=row.options,
                              prefix="%s" % row.group,
                              output_dir=os.path.join(
                                  options.application.output_directory, options
                                  .application.pipeline_name + "_first_level",
                                  "%s_processed" % row.group))))
                           ).sort_values(by='group'))

    if all(
            first_level_results.options.map(
                lambda opts: opts.registration.resolution) ==
            first_level_results.options.iloc[0].registration.resolution):
        options.registration = options.registration.replace(
            resolution=first_level_results.options.iloc[0].registration.
            resolution)
    else:
        raise ValueError(
            "some first-level models are run at different resolutions, possibly not what you want ..."
        )

    # construction of the overall inter-average transforms will be done iteratively (for efficiency/aesthetics),
    # which doesn't really fit the DataFrame mold ...

    full_hierarchy = get_nonlinear_configuration_from_options(
        nlin_protocol=options.mbm.nlin.nlin_protocol,
        reg_method=options.mbm.nlin.reg_method,
        file_resolution=options.registration.resolution)

    # FIXME no good can come of this
    nlin_protocol = full_hierarchy.confs[-1] if isinstance(
        full_hierarchy, MultilevelANTSConf) else full_hierarchy
    # first register consecutive averages together:
    average_registrations = (
        first_level_results[:-1].assign(
            next_model=list(first_level_results[1:].build_model))
        # TODO: we should be able to do lsq6 registration here as well!
        .assign(xfm=lambda df: df.apply(
            axis=1,
            func=lambda row: s.defer(
                lsq12_nlin(source=row.build_model.avg_img,
                           target=row.next_model.avg_img,
                           lsq12_conf=get_linear_configuration_from_options(
                               options.mbm.lsq12,
                               transform_type=LinearTransType.lsq12,
                               file_resolution=options.registration.resolution
                           ),
                           nlin_conf=nlin_protocol)))))

    # now compose the above transforms to produce transforms from each average to the common average:
    common_time_pt = options.tamarack.common_time_pt
    common_model = first_level_results[
        first_level_results.group ==
        common_time_pt].iloc[0].build_model.avg_img
    #common = average_registrations[average_registrations.group == common_time_pt].iloc[0]
    before = average_registrations[
        average_registrations.group <
        common_time_pt]  # asymmetry in before/after since
    after = average_registrations[
        average_registrations.group >=
        common_time_pt]  # we used `next_`, not `previous_`

    # compose 1st and 2nd level transforms and resample into the common average space:
    def suffixes(xs):
        if len(xs) == 0:
            return [[]]
        else:
            ys = suffixes(xs[1:])
            return [[xs[0]] + ys[0]] + ys

    def prefixes(xs):
        if len(xs) == 0:
            return [[]]
        else:
            ys = prefixes(xs[1:])
            return ys + [ys[-1] + [xs[0]]]

    xfms_to_common = (first_level_results.assign(
        uncomposed_xfms=suffixes(list(before.xfm))[:-1] + [None] +
        prefixes(list(after.xfm))[1:]).assign(
            xfm_to_common=lambda df: df.apply(
                axis=1,
                func=lambda row: ((lambda x: s.defer(invert_xfmhandler(
                    x)) if row.group >= common_time_pt else x)(s.defer(
                        concat_xfmhandlers(
                            row.uncomposed_xfms,
                            name=("%s_to_common" if row.group < common_time_pt
                                  else "%s_from_common") % row.group))))
                if row.uncomposed_xfms is not None else None)).drop(
                    'uncomposed_xfms', axis=1))  # TODO None => identity??

    # TODO indexing here is not good ...
    first_level_determinants = pd.concat(list(
        first_level_results.build_model.apply(
            lambda x: x.determinants.assign(first_level_avg=x.avg_img))),
                                         ignore_index=True)

    resampled_determinants = (pd.merge(
        left=first_level_determinants,
        right=xfms_to_common.assign(source=lambda df: df.xfm_to_common.apply(
            lambda x: x.source if x is not None else None)),
        left_on="first_level_avg",
        right_on='source').assign(
            resampled_log_full_det=lambda df: df.apply(
                axis=1,
                func=lambda row: s.defer(
                    mincresample_new(img=row.log_full_det,
                                     xfm=row.xfm_to_common.xfm,
                                     like=common_model))
                if row.xfm_to_common is not None else row.img),
            resampled_log_nlin_det=lambda df: df.apply(
                axis=1,
                func=lambda row: s.defer(
                    mincresample_new(img=row.log_nlin_det,
                                     xfm=row.xfm_to_common.xfm,
                                     like=common_model))
                if row.xfm_to_common is not None else row.img)))

    inverted_overall_xfms = pd.Series({
        xfm: (s.defer(concat_xfmhandlers([xfm, row.xfm_to_common]))
              if row.xfm_to_common is not None else xfm)
        for _ix, row in xfms_to_common.iterrows()
        for xfm in row.build_model.xfms.lsq12_nlin_xfm
    })

    overall_xfms = inverted_overall_xfms.apply(
        lambda x: s.defer(invert_xfmhandler(x)))

    overall_determinants = determinants_at_fwhms(
        xfms=overall_xfms,
        blur_fwhms=options.mbm.stats.stats_kernels,
        inv_xfms=inverted_overall_xfms)

    # TODO turn off bootstrap as with two-level code?

    # TODO combine into one data frame
    return Result(stages=s,
                  output=Namespace(
                      first_level_results=first_level_results,
                      overall_determinants=overall_determinants,
                      resampled_determinants=resampled_determinants.drop(
                          ['options'], axis=1)))
def tamarack(imgs : pd.DataFrame, options):
    # columns of the input df: `img` : MincAtom, `timept` : number, ...
    # columns of the pride of models : 'timept' : number, 'model' : MincAtom
    s = Stages()

    # TODO some assertions that the pride_of_models, if provided, is correct, and that this is intended target type

    def group_options(options, timepoint):
        options = copy.deepcopy(options)

        if options.mbm.lsq6.target_type == TargetType.pride_of_models:
            options = copy.deepcopy(options)
            targets = get_closest_model_from_pride_of_models(pride_of_models_dict=get_pride_of_models_mapping(
                                                                 pride_csv=options.mbm.lsq6.target_file,
                                                                 output_dir=options.application.output_directory,
                                                                 pipeline_name=options.application.pipeline_name),
                                                             time_point=timepoint)

            options.mbm.lsq6 = options.mbm.lsq6.replace(target_type=TargetType.initial_model,
                                                        target_file=targets.registration_standard.path)

        #    resolution = (options.registration.resolution
        #                  or get_resolution_from_file(targets.registration_standard.path))
        #    options.registration = options.registration.replace(resolution=resolution)

                                                        # FIXME use of registration_standard here is quite wrong ...
                                                        # part of the trouble is that mbm calls registration_targets itself,
                                                        # so we can't send this RegistrationTargets to `mbm` directly ...
                                                        # one option: add yet another optional arg to `mbm` ...
        else:
            targets = s.defer(registration_targets(lsq6_conf=options.mbm.lsq6,
                                           app_conf=options.application, reg_conf=options.registration,
                                           first_input_file=imgs.filename.iloc[0]))

        resolution = (options.registration.resolution or
                        get_resolution_from_file(targets.registration_standard.path))

        # This must happen after calling registration_targets otherwise it will resample to options.registration.resolution
        options.registration = options.registration.replace(resolution=resolution)

        return options

    # build all first-level models:
    first_level_results = (
        imgs  # TODO 'group' => 'timept' ?
        .groupby('group', as_index=False)       # the usual annoying pattern to do an aggregate with access
        .aggregate({ 'file' : lambda files: list(files) })  # to the groupby object's keys ... TODO: fix
        .rename(columns={ 'file' : "files" })
        .assign(options=lambda df: df.apply(axis=1, func=lambda row: group_options(options, row.group)))
        .assign(build_model=lambda df:
                              df.apply(axis=1,
                                       func=lambda row: s.defer(
                                           mbm(imgs=row.files,
                                               options=row.options,
                                               prefix="%s" % row.group,
                                               output_dir=os.path.join(
                                               options.application.output_directory,
                                               options.application.pipeline_name + "_first_level",
                                               "%s_processed" % row.group)))))
        .sort_values(by='group')

        )

    if all(first_level_results.options.map(lambda opts: opts.registration.resolution)
             == first_level_results.options.iloc[0].registration.resolution):
        options.registration = options.registration.replace(
            resolution=first_level_results.options.iloc[0].registration.resolution)
    else:
        raise ValueError("some first-level models are run at different resolutions, possibly not what you want ...")

    # construction of the overall inter-average transforms will be done iteratively (for efficiency/aesthetics),
    # which doesn't really fit the DataFrame mold ...


    full_hierarchy = get_nonlinear_configuration_from_options(
      nlin_protocol=options.mbm.nlin.nlin_protocol,
      reg_method=options.mbm.nlin.reg_method,
      file_resolution=options.registration.resolution)

    # FIXME no good can come of this
    nlin_protocol = full_hierarchy.confs[-1] if isinstance(full_hierarchy, MultilevelANTSConf) else full_hierarchy
    # first register consecutive averages together:
    average_registrations = (
        first_level_results[:-1]
            .assign(next_model=list(first_level_results[1:].build_model))
            # TODO: we should be able to do lsq6 registration here as well!
            .assign(xfm=lambda df: df.apply(axis=1, func=lambda row: s.defer(
                                                      lsq12_nlin(source=row.build_model.avg_img,
                                                                 target=row.next_model.avg_img,
                                                                 lsq12_conf=get_linear_configuration_from_options(
                                                                     options.mbm.lsq12,
                                                                     transform_type=LinearTransType.lsq12,
                                                                     file_resolution=options.registration.resolution),
                                                                 nlin_conf=nlin_protocol)))))

    # now compose the above transforms to produce transforms from each average to the common average:
    common_time_pt = options.tamarack.common_time_pt
    common_model   = first_level_results[first_level_results.group == common_time_pt].iloc[0].build_model.avg_img
    #common = average_registrations[average_registrations.group == common_time_pt].iloc[0]
    before = average_registrations[average_registrations.group <  common_time_pt]  # asymmetry in before/after since
    after  = average_registrations[average_registrations.group >= common_time_pt]  # we used `next_`, not `previous_`

    # compose 1st and 2nd level transforms and resample into the common average space:
    def suffixes(xs):
        if len(xs) == 0:
            return [[]]
        else:
            ys = suffixes(xs[1:])
            return [[xs[0]] + ys[0]] + ys


    def prefixes(xs):
        if len(xs) == 0:
            return [[]]
        else:
            ys = prefixes(xs[1:])
            return ys + [ys[-1] + [xs[0]]]

    xfms_to_common = (
        first_level_results
        .assign(uncomposed_xfms=suffixes(list(before.xfm))[:-1] + [None] + prefixes(list(after.xfm))[1:])
        .assign(xfm_to_common=lambda df: df.apply(axis=1, func=lambda row:
                                ((lambda x: s.defer(invert_xfmhandler(x)) if row.group >= common_time_pt else x)
                                   (s.defer(concat_xfmhandlers(row.uncomposed_xfms,
                                                               name=("%s_to_common"
                                                                     if row.group < common_time_pt
                                                                     else "%s_from_common") % row.group))))
                                  if row.uncomposed_xfms is not None else None))
        .drop('uncomposed_xfms', axis=1))  # TODO None => identity??

    # TODO indexing here is not good ...
    first_level_determinants = pd.concat(list(first_level_results.build_model.apply(
                                                lambda x: x.determinants.assign(first_level_avg=x.avg_img))),
                                         ignore_index=True)

    resampled_determinants = (
        pd.merge(left=first_level_determinants,
                 right=xfms_to_common.assign(source=lambda df: df.xfm_to_common.apply(
                                                              lambda x:
                                                                x.source if x is not None else None)),
                 left_on="first_level_avg", right_on='source')
        .assign(resampled_log_full_det=lambda df: df.apply(axis=1, func=lambda row:
                                         s.defer(mincresample_new(img=row.log_full_det,
                                                                  xfm=row.xfm_to_common.xfm,
                                                                  like=common_model))
                                                 if row.xfm_to_common is not None else row.img),
                resampled_log_nlin_det=lambda df: df.apply(axis=1, func=lambda row:
                                         s.defer(mincresample_new(img=row.log_nlin_det,
                                                                  xfm=row.xfm_to_common.xfm,
                                                                  like=common_model))
                                                 if row.xfm_to_common is not None else row.img))
    )

    inverted_overall_xfms = pd.Series({ xfm : (s.defer(concat_xfmhandlers([xfm, row.xfm_to_common]))
                                                 if row.xfm_to_common is not None else xfm)
                                        for _ix, row in xfms_to_common.iterrows()
                                        for xfm in row.build_model.xfms.lsq12_nlin_xfm })

    overall_xfms = inverted_overall_xfms.apply(lambda x: s.defer(invert_xfmhandler(x)))

    overall_determinants = determinants_at_fwhms(xfms=overall_xfms,
                                                 blur_fwhms=options.mbm.stats.stats_kernels,
                                                 inv_xfms=inverted_overall_xfms)


    # TODO turn off bootstrap as with two-level code?

    # TODO combine into one data frame
    return Result(stages=s, output=Namespace(first_level_results=first_level_results,
                                             overall_determinants=overall_determinants,
                                             resampled_determinants=resampled_determinants.drop(
                                                 ['options'],
                                                 axis=1)))
예제 #5
0
def mbm(imgs : List[MincAtom], options : MBMConf, prefix : str, output_dir : str = ""):

    # TODO could also allow pluggable pipeline parts e.g. LSQ6 could be substituted out for the modified LSQ6
    # for the kidney tips, etc...

    # TODO this is tedious and annoyingly similar to the registration chain ...
    lsq6_dir  = os.path.join(output_dir, prefix + "_lsq6")
    lsq12_dir = os.path.join(output_dir, prefix + "_lsq12")
    nlin_dir  = os.path.join(output_dir, prefix + "_nlin")

    s = Stages()

    if len(imgs) == 0:
        raise ValueError("Please, some files!")

    # FIXME: why do we have to call registration_targets *outside* of lsq6_nuc_inorm? is it just because of the extra
    # options required?  Also, shouldn't options.registration be a required input (as it contains `input_space`) ...?
    targets = registration_targets(lsq6_conf=options.mbm.lsq6,
                                   app_conf=options.application,
                                   first_input_file=imgs[0].path)

    # TODO this is quite tedious and duplicates stuff in the registration chain ...
    resolution = (options.registration.resolution or
                  get_resolution_from_file(targets.registration_standard.path))
    options.registration = options.registration.replace(resolution=resolution)

    # FIXME it probably makes most sense if the lsq6 module itself (even within lsq6_nuc_inorm) handles the run_lsq6
    # setting (via use of the identity transform) since then this doesn't have to be implemented for every pipeline
    if options.mbm.lsq6.run_lsq6:
        lsq6_result = s.defer(lsq6_nuc_inorm(imgs=imgs,
                                             resolution=resolution,
                                             registration_targets=targets,
                                             lsq6_dir=lsq6_dir,
                                             lsq6_options=options.mbm.lsq6))
    else:
        # TODO don't actually do this resampling if not required (i.e., if the imgs already have the same grids)
        identity_xfm = s.defer(param2xfm(out_xfm=FileAtom(name="identity.xfm")))
        lsq6_result  = [XfmHandler(source=img, target=img, xfm=identity_xfm,
                                   resampled=s.defer(mincresample_new(img=img,
                                                                      like=targets.registration_standard,
                                                                      xfm=identity_xfm)))
                        for img in imgs]
    # what about running nuc/inorm without a linear registration step??

    full_hierarchy = get_nonlinear_configuration_from_options(nlin_protocol=options.mbm.nlin.nlin_protocol,
                                                              reg_method=options.mbm.nlin.reg_method,
                                                              file_resolution=resolution)

    lsq12_nlin_result = s.defer(lsq12_nlin_build_model(imgs=[xfm.resampled for xfm in lsq6_result],
                                                       resolution=resolution,
                                                       lsq12_dir=lsq12_dir,
                                                       nlin_dir=nlin_dir,
                                                       nlin_prefix=prefix,
                                                       lsq12_conf=options.mbm.lsq12,
                                                       nlin_conf=full_hierarchy))

    inverted_xfms = [s.defer(invert_xfmhandler(xfm)) for xfm in lsq12_nlin_result.output]

    determinants = s.defer(determinants_at_fwhms(
                             xfms=inverted_xfms,
                             inv_xfms=lsq12_nlin_result.output,
                             blur_fwhms=options.mbm.stats.stats_kernels))

    overall_xfms = [s.defer(concat_xfmhandlers([rigid_xfm, lsq12_nlin_xfm]))
                    for rigid_xfm, lsq12_nlin_xfm in zip(lsq6_result, lsq12_nlin_result.output)]

    output_xfms = (pd.DataFrame({ "rigid_xfm"      : lsq6_result,  # maybe don't return this if LSQ6 not run??
                                  "lsq12_nlin_xfm" : lsq12_nlin_result.output,
                                  "overall_xfm"    : overall_xfms }))
    # we could `merge` the determinants with this table, but preserving information would cause lots of duplication
    # of the transforms (or storing determinants in more columns, but iterating over dynamically known columns
    # seems a bit odd ...)

                            # TODO transpose these fields?})
                            #avg_img=lsq12_nlin_result.avg_img,  # inconsistent w/ WithAvgImgs[...]-style outputs
                           # "determinants"    : determinants })

    #output.avg_img = lsq12_nlin_result.avg_img
    #output.determinants = determinants   # TODO temporary - remove once incorporated properly into `output` proper
    # TODO add more of lsq12_nlin_result?

    # FIXME: this needs to go outside of the `mbm` function to avoid being run from within other pipelines (or
    # those other pipelines need to turn off this option)
    # TODO return some MAGeT stuff from MBM function ??
    # if options.mbm.mbm.run_maget:
    #     import copy
    #     maget_options = copy.deepcopy(options)  #Namespace(maget=options)
    #     #maget_options
    #     #maget_options.maget = maget_options.mbm
    #     #maget_options.execution = options.execution
    #     #maget_options.application = options.application
    #     maget_options.maget = options.mbm.maget
    #     del maget_options.mbm
    #
    #     s.defer(maget([xfm.resampled for xfm in lsq6_result],
    #                   options=maget_options,
    #                   prefix="%s_MAGeT" % prefix,
    #                   output_dir=os.path.join(output_dir, prefix + "_processed")))

    # should also move outside `mbm` function ...
    #if options.mbm.thickness.run_thickness:
    #    if not options.mbm.segmentation.run_maget:
    #        warnings.warn("MAGeT files (atlases, protocols) are needed to run thickness calculation.")
    #    # run MAGeT to segment the nlin average:
    #    import copy
    #    maget_options = copy.deepcopy(options)  #Namespace(maget=options)
    #    maget_options.maget = options.mbm.maget
    #    del maget_options.mbm
    #    segmented_avg = s.defer(maget(imgs=[lsq12_nlin_result.avg_img],
    #                                  options=maget_options,
    #                                  output_dir=os.path.join(options.application.output_directory,
    #                                                          prefix + "_processed"),
    #                                  prefix="%s_thickness_MAGeT" % prefix)).ix[0].img
    #    thickness = s.defer(cortical_thickness(xfms=pd.Series(inverted_xfms), atlas=segmented_avg,
    #                                           label_mapping=FileAtom(options.mbm.thickness.label_mapping),
    #                                           atlas_fwhm=0.56, thickness_fwhm=0.56))  # TODO magic fwhms
    #    # TODO write CSV -- should `cortical_thickness` do this/return a table?


    # FIXME: this needs to go outside of the `mbm` function to avoid being run from within other pipelines (or
    # those other pipelines need to turn off this option)
    if options.mbm.common_space.do_common_space_registration:
        warnings.warn("This feature is experimental ...")
        if not options.mbm.common_space.common_space_model:
            raise ValueError("No common space template provided!")
        # TODO allow lsq6 registration as well ...
        common_space_model = MincAtom(options.mbm.common_space.common_space_model,
                                      pipeline_sub_dir=os.path.join(options.application.output_directory,
                                                         options.application.pipeline_name + "_processed"))
        # TODO allow different lsq12/nlin config params than the ones used in MBM ...
        # WEIRD ... see comment in lsq12_nlin code ...
        nlin_conf  = full_hierarchy.confs[-1] if isinstance(full_hierarchy, MultilevelMincANTSConf) else full_hierarchy
        # also weird that we need to call get_linear_configuration_from_options here ... ?
        lsq12_conf = get_linear_configuration_from_options(conf=options.mbm.lsq12,
                                                           transform_type=LinearTransType.lsq12,
                                                           file_resolution=resolution)
        xfm_to_common = s.defer(lsq12_nlin(source=lsq12_nlin_result.avg_img, target=common_space_model,
                                           lsq12_conf=lsq12_conf, nlin_conf=nlin_conf,
                                           resample_source=True))

        model_common = s.defer(mincresample_new(img=lsq12_nlin_result.avg_img,
                                                xfm=xfm_to_common.xfm, like=common_space_model,
                                                postfix="_common"))

        overall_xfms_common = [s.defer(concat_xfmhandlers([rigid_xfm, nlin_xfm, xfm_to_common]))
                               for rigid_xfm, nlin_xfm in zip(lsq6_result, lsq12_nlin_result.output)]

        xfms_common = [s.defer(concat_xfmhandlers([nlin_xfm, xfm_to_common]))
                       for nlin_xfm in lsq12_nlin_result.output]

        output_xfms = output_xfms.assign(xfm_common=xfms_common, overall_xfm_common=overall_xfms_common)

        log_nlin_det_common, log_full_det_common = [dets.map(lambda d:
                                                      s.defer(mincresample_new(
                                                        img=d,
                                                        xfm=xfm_to_common.xfm,
                                                        like=common_space_model,
                                                        postfix="_common",
                                                        extra_flags=("-keep_real_range",),
                                                        interpolation=Interpolation.nearest_neighbour)))
                                                    for dets in (determinants.log_nlin_det, determinants.log_full_det)]

        determinants = determinants.assign(log_nlin_det_common=log_nlin_det_common,
                                           log_full_det_common=log_full_det_common)

    output = Namespace(avg_img=lsq12_nlin_result.avg_img, xfms=output_xfms, determinants=determinants)

    if options.mbm.common_space.do_common_space_registration:
        output.model_common = model_common

    return Result(stages=s, output=output)
예제 #6
0
def maget_mask(imgs : List[MincAtom], atlases, options):

    s = Stages()

    resample  = np.vectorize(mincresample_new, excluded={"extra_flags"})
    defer     = np.vectorize(s.defer)

    lsq12_conf = get_linear_configuration_from_options(options.maget.lsq12,
                                                       LinearTransType.lsq12,
                                                       options.registration.resolution)

    masking_nlin_hierarchy = get_nonlinear_configuration_from_options(options.maget.maget.masking_nlin_protocol,
                                                                      options.maget.maget.mask_method,
                                                                      options.registration.resolution)

    masking_alignments = pd.DataFrame({ 'img'   : img,
                                        'atlas' : atlas,
                                        'xfm'   : s.defer(lsq12_nlin(source=img, target=atlas,
                                                                     lsq12_conf=lsq12_conf,
                                                                     nlin_conf=masking_nlin_hierarchy,
                                                                     resample_source=False))}
                                      for img in imgs for atlas in atlases)
    # propagate a mask to each image using the above `alignments` as follows:
    # - for each image, voxel_vote on the masks propagated to that image to get a suitable mask
    # - run mincmath -clobber -mult <img> <voted_mask> to apply the mask to the files
    masked_img = (
        masking_alignments
        .assign(resampled_mask=lambda df: defer(resample(img=df.atlas.apply(lambda x: x.mask),
                                                         xfm=df.xfm.apply(lambda x: x.xfm),
                                                         like=df.img,
                                                         invert=True,
                                                         interpolation=Interpolation.nearest_neighbour,
                                                         postfix="-input-mask",
                                                         subdir="tmp",
                                                         # TODO annoying hack; fix mincresample(_mask) ...:
                                                         #new_name_wo_ext=df.apply(lambda row:
                                                         #    "%s_to_%s-input-mask" % (row.atlas.filename_wo_ext,
                                                         #                             row.img.filename_wo_ext),
                                                         #    axis=1),
                                                         extra_flags=("-keep_real_range",))))
        .groupby('img', sort=False, as_index=False)
        # sort=False: just for speed (might also need to implement more comparison methods on `MincAtom`s)
        .aggregate({'resampled_mask' : lambda masks: list(masks)})
        .rename(columns={"resampled_mask" : "resampled_masks"})
        .assign(voted_mask=lambda df: df.apply(axis=1,
                                               func=lambda row:
                                                 s.defer(voxel_vote(label_files=row.resampled_masks,
                                                                    name="%s_voted_mask" % row.img.filename_wo_ext,
                                                                    output_dir=os.path.join(row.img.output_sub_dir,
                                                                                            "tmp")))))
        .assign(masked_img=lambda df:
          df.apply(axis=1,
                 func=lambda row:
                   s.defer(mincmath(op="mult",
                                    # img must precede mask here
                                    # for output image range to be correct:
                                    vols=[row.img, row.voted_mask],
                                    new_name="%s_masked" % row.img.filename_wo_ext,
                                    subdir="resampled")))))  #['img']

    # resample the atlas images back to the input images:
    # (note: this doesn't modify `masking_alignments`, but only stages additional outputs)
    masking_alignments.assign(resampled_img=lambda df:
    defer(resample(img=df.atlas,
                   xfm=df.xfm.apply(lambda x: x.xfm),
                   subdir="tmp",
                   # TODO delete this stupid hack:
                   #new_name_wo_ext=df.apply(lambda row:
                   #  "%s_to_%s-resampled" % (row.atlas.filename_wo_ext,
                   #                          row.img.filename_wo_ext),
                   #                          axis=1),
                   like=df.img, invert=True)))

    # replace the table of alignments with a new one with masked images
    masking_alignments = (pd.merge(left=masking_alignments.assign(unmasked_img=lambda df: df.img),
                                   right=masked_img,
                                   on=["img"], how="right", sort=False)
                          .assign(img=lambda df: df.masked_img))

    return Result(stages=s, output=masking_alignments)
예제 #7
0
def maget(imgs : List[MincAtom], options, prefix, output_dir):     # FIXME prefix, output_dir aren't used !!

    s = Stages()

    maget_options = options.maget.maget

    pipeline_sub_dir = os.path.join(options.application.output_directory,
                                    options.application.pipeline_name + "_atlases")

    if maget_options.atlas_lib is None:
        raise ValueError("Need some atlases ...")

    #atlas_dir = os.path.join(output_dir, "input_atlases") ???

    # TODO should alternately accept a CSV file ...
    atlas_library = read_atlas_dir(atlas_lib=maget_options.atlas_lib, pipeline_sub_dir=pipeline_sub_dir)

    if len(atlas_library) == 0:
        raise ValueError("No atlases found in specified directory '%s' ..." % options.maget.maget.atlas_lib)

    num_atlases_needed = min(maget_options.max_templates, len(atlas_library))
    # TODO arbitrary; could choose atlases better ...
    atlases = atlas_library[:num_atlases_needed]
    # TODO issue a warning if not all atlases used or if more atlases requested than available?
    # TODO also, doesn't slicing with a higher number (i.e., if max_templates > n) go to the end of the list anyway?

    lsq12_conf = get_linear_configuration_from_options(options.maget.lsq12,
                                                       LinearTransType.lsq12,
                                                       options.registration.resolution)

    masking_nlin_hierarchy = get_nonlinear_configuration_from_options(options.maget.maget.masking_nlin_protocol,
                                                                      options.maget.maget.mask_method,
                                                                      options.registration.resolution)

    nlin_hierarchy = get_nonlinear_configuration_from_options(options.maget.nlin.nlin_protocol,
                                                              options.maget.nlin.reg_method,
                                                              options.registration.resolution)

    resample  = np.vectorize(mincresample_new, excluded={"extra_flags"})
    defer     = np.vectorize(s.defer)

    # plan the basic registrations between all image-atlas pairs; store the result paths in a table
    masking_alignments = pd.DataFrame({ 'img'   : img,
                                        'atlas' : atlas,
                                        'xfm'   : s.defer(lsq12_nlin(source=img, target=atlas,
                                                                     lsq12_conf=lsq12_conf,
                                                                     nlin_conf=masking_nlin_hierarchy,
                                                                     resample_source=False))}
                                      for img in imgs for atlas in atlases)

    if maget_options.mask or maget_options.mask_only:

        masking_alignments = s.defer(maget_mask(imgs, atlases, options))

        masked_atlases = atlases.apply(lambda atlas:
                           s.defer(mincmath(op='mult', vols=[atlas, atlas.mask], subdir="resampled",
                                            new_name="%s_masked" % atlas.filename_wo_ext)))

        # now propagate only the masked form of the images and atlases:
        imgs    = masking_alignments.img
        atlases = masked_atlases  # TODO is this needed?

    if maget_options.mask_only:
        # register each input to each atlas, creating a mask
        return Result(stages=s, output=masking_alignments)   # TODO rename `alignments` to `registrations`??
    else:
        del masking_alignments
        # this `del` is just to verify that we don't accidentally use this later, since my intent is that these
        # coarser alignments shouldn't be re-used, just the masked images they create; can be removed later
        # if a sensible use is found

        if maget_options.pairwise:

            def choose_new_templates(ts, n):
                # currently silly, but we might implement a smarter method ...
                # FIXME what if there aren't enough other imgs around?!  This silently goes weird ...
                return ts[:n+1]  # n+1 instead of n: choose one more since we won't use image as its own template ...

            new_templates = choose_new_templates(ts=imgs, n=maget_options.max_templates)
            # note these images are the masked ones if masking was done ...

            # TODO write a function to do these alignments and the image->atlas one above
            # align the new templates chosen from the images to the initial atlases:
            new_template_to_atlas_alignments = (
                pd.DataFrame({ 'img'   : template,
                               'atlas' : atlas,
                               'xfm'   : s.defer(lsq12_nlin(source=template, target=atlas,
                                                            lsq12_conf=lsq12_conf,
                                                            nlin_conf=nlin_hierarchy,
                                                            resample_source=False))}
                             for template in new_templates for atlas in atlases))
                             # ... and these atlases are multiplied by their masks (but is this necessary?)

            # label the new templates from resampling the atlas labels onto them:
            # TODO now vote on the labels to be used for the new templates ...
            # TODO extract into procedure?
            new_templates_labelled = (
                new_template_to_atlas_alignments
                .assign(resampled_labels=lambda df: defer(
                                               resample(img=df.atlas.apply(lambda x: x.labels),
                                                                      xfm=df.xfm.apply(lambda x: x.xfm),
                                                                      interpolation=Interpolation.nearest_neighbour,
                                                                      extra_flags=("-keep_real_range",),
                                                                      like=df.img, invert=True)))
                .groupby('img', sort=False, as_index=False)
                .aggregate({'resampled_labels' : lambda labels: list(labels)})
                .assign(voted_labels=lambda df: df.apply(axis=1,
                                                         func=lambda row:
                                                           s.defer(voxel_vote(label_files=row.resampled_labels,
                                                                              name="%s_template_labels" %
                                                                                   row.img.filename_wo_ext,
                                                                              output_dir=os.path.join(
                                                                                  row.img.pipeline_sub_dir,
                                                                                  row.img.output_sub_dir,
                                                                                  "labels"))))))

            # TODO write a procedure for this assign-groupby-aggregate-rename...
            # FIXME should be in above algebraic manipulation but MincAtoms don't support flexible immutable updating
            for row in pd.merge(left=new_template_to_atlas_alignments, right=new_templates_labelled,
                                on=["img"], how="right", sort=False).itertuples():
                row.img.labels = s.defer(mincresample_new(img=row.voted_labels, xfm=row.xfm.xfm, like=row.img,
                                                          invert=True, interpolation=Interpolation.nearest_neighbour,
                                                          #postfix="-input-labels",
                                                          # this makes names really long ...:
                                                          # TODO this doesn't work for running MAGeT on the nlin avg:
                                                          #new_name_wo_ext="%s_on_%s" %
                                                          #                (row.voted_labels.filename_wo_ext,
                                                          #                 row.img.filename_wo_ext),
                                                          #postfix="_labels_via_%s" % row.xfm.xfm.filename_wo_ext,
                                                          new_name_wo_ext="%s_via_%s" % (row.voted_labels.filename_wo_ext,
                                                                                         row.xfm.xfm.filename_wo_ext),
                                                          extra_flags=("-keep_real_range",)))

            # now that the new templates have been labelled, combine with the atlases:
            # FIXME use the masked atlases created earlier ??
            all_templates = pd.concat([new_templates_labelled.img, atlases], ignore_index=True)

            # now take union of the resampled labels from the new templates with labels from the original atlases:
            #all_alignments = pd.concat([image_to_template_alignments,
            #                            alignments.rename(columns={ "atlas" : "template" })],
            #                           ignore_index=True, join="inner")

        else:
            all_templates = atlases

        # now register each input to each selected template
        # N.B.: Even though we've already registered each image to each initial atlas, this happens again here,
        #       but using `nlin_hierarchy` instead of `masking_nlin_hierarchy` as options.
        #       This is not 'work-efficient' in the sense that this computation happens twice (although
        #       hopefully at greater precision the second time!), but the idea is to run a coarse initial
        #       registration to get a mask and then do a better registration with that mask (though I'm not
        #       sure exactly when this is faster than doing a single registration).
        #       This _can_ allow the overall computation to finish more rapidly
        #       (depending on the relative speed of the two alignment methods/parameters,
        #       number of atlases and other templates used, number of cores available, etc.).
        image_to_template_alignments = (
            pd.DataFrame({ "img"      : img,
                           "template" : template_img,
                           "xfm"      : xfm }
                         for img in imgs      # TODO use the masked imgs here?
                         for template_img in
                             all_templates
                             # FIXME delete this one alignment
                             #labelled_templates[labelled_templates.img != img]
                             # since equality is equality of filepaths (a bit dangerous)
                             # TODO is there a more direct/faster way just to delete the template?
                         for xfm in [s.defer(lsq12_nlin(source=img, target=template_img,
                                                        lsq12_conf=lsq12_conf,
                                                        nlin_conf=nlin_hierarchy))]
                         )
        )

        # now do a voxel_vote on all resampled template labels, just as earlier with the masks
        voted = (image_to_template_alignments
                 .assign(resampled_labels=lambda df:
                                            defer(resample(img=df.template.apply(lambda x: x.labels),
                                                           # FIXME bug: at this point templates from template_alignments
                                                           # don't have associated labels (i.e., `None`s) -- fatal
                                                           xfm=df.xfm.apply(lambda x: x.xfm),
                                                           interpolation=Interpolation.nearest_neighbour,
                                                           extra_flags=("-keep_real_range",),
                                                           like=df.img, invert=True)))
                 .groupby('img', sort=False)
                 # TODO the pattern groupby-aggregate(lambda x: list(x))-reset_index-assign is basically a hack
                 # to do a groupby-assign with access to the group name;
                 # see http://stackoverflow.com/a/30224447/849272 for a better solution
                 # (note this pattern occurs several times in MAGeT and two-level code)
                 .aggregate({'resampled_labels' : lambda labels: list(labels)})
                 .reset_index()
                 .assign(voted_labels=lambda df: defer(np.vectorize(voxel_vote)(label_files=df.resampled_labels,
                                                                                output_dir=df.img.apply(
                                                                                    lambda x: os.path.join(
                                                                                        x.pipeline_sub_dir,
                                                                                        x.output_sub_dir))))))

        # TODO doing mincresample -invert separately for the img->atlas xfm for mask, labels is silly
        # (when Pydpiper's `mincresample` does both automatically)?

        # blargh, another destructive update ...
        for row in voted.itertuples():
            row.img.labels = row.voted_labels

        # returning voted_labels as a column is slightly redundant, but possibly useful ...
        return Result(stages=s, output=voted)  # voted.drop("voted_labels", axis=1))
예제 #8
0
def chain(options):
    """Create a registration chain pipeline from the given options."""

    # TODO:
    # one overall question for this entire piece of code is how
    # we are going to make sure that we can concatenate/add all
    # the transformations together. Many of the sub-registrations
    # that are performed (inter-subject registration, lsq6 using
    # multiple initial models) are applied to subsets of the entire 
    # data, making it harder to keep the mapping simple/straightforward


    chain_opts = options.chain  # type : ChainConf

    s = Stages()
    
    with open(options.chain.csv_file, 'r') as f:
        subject_info = parse_csv(rows=f, common_time_pt=options.chain.common_time_point)

    output_dir    = options.application.output_directory
    pipeline_name = options.application.pipeline_name

    pipeline_processed_dir = os.path.join(output_dir, pipeline_name + "_processed")
    pipeline_lsq12_common_dir = os.path.join(output_dir, pipeline_name + "_lsq12_" + options.chain.common_time_point_name)
    pipeline_nlin_common_dir = os.path.join(output_dir, pipeline_name + "_nlin_" + options.chain.common_time_point_name)
    pipeline_montage_dir = os.path.join(output_dir, pipeline_name + "_montage")
    
    
    pipeline_subject_info = map_over_time_pt_dict_in_Subject(
                                     lambda subj_str:  MincAtom(name=subj_str, pipeline_sub_dir=pipeline_processed_dir),
                                     subject_info)  # type: Dict[str, Subject[MincAtom]]
    
    # verify that in input files are proper MINC files, and that there 
    # are no duplicates in the filenames
    all_Minc_atoms = []  # type: List[MincAtom]
    for s_id, subj in pipeline_subject_info.items():
        for subj_time_pt, subj_filename in subj.time_pt_dict.items():
            all_Minc_atoms.append(subj_filename)
    # check_MINC_input_files takes strings, so pass along those instead of the actual MincAtoms
    check_MINC_input_files([minc_atom.path for minc_atom in all_Minc_atoms])

    if options.registration.input_space == InputSpace.lsq6 or \
        options.registration.input_space == InputSpace.lsq12:
        # the input files are not going through the lsq6 alignment. This is the place
        # where they will all be resampled using a single like file, and get the same
        # image dimensions/lengths/resolution. So in order for the subsequent stages to
        # finish (mincaverage stages for instance), all files need to have the same
        # image parameters:
        check_MINC_files_have_equal_dimensions_and_resolution([minc_atom.path for minc_atom in all_Minc_atoms],
                                                              additional_msg="Given that the input images are "
                                                                             "already in " + str(options.registration.input_space) +
                                                                             " space, all input files need to have "
                                                                             "the same dimensions/starts/step sizes.")

    if options.registration.input_space not in InputSpace.__members__.values():
        raise ValueError('unrecognized input space: %s; choices: %s' %
                         (options.registration.input_space, ','.join(InputSpace.__members__)))
    
    if options.registration.input_space == InputSpace.native:
        if options.lsq6.target_type == TargetType.bootstrap:
            raise ValueError("\nA bootstrap model is ill-defined for the registration chain. "
                             "(Which file is the 'first' input file?). Please use the --lsq6-target "
                             "flag to specify a target for the lsq6 stage, or use an initial model.")
        if options.lsq6.target_type == TargetType.pride_of_models:
            pride_of_models_dict = get_pride_of_models_mapping(pride_csv=options.lsq6.target_file,
                                                               output_dir=options.application.output_directory,
                                                               pipeline_name=options.application.pipeline_name)
            subj_id_to_subj_with_lsq6_xfm_dict = map_with_index_over_time_pt_dict_in_Subject(
                                    lambda subj_atom, time_point:
                                        s.defer(lsq6_nuc_inorm([subj_atom],
                                                               registration_targets=get_closest_model_from_pride_of_models(
                                                                                        pride_of_models_dict, time_point),
                                                               resolution=options.registration.resolution,
                                                               lsq6_options=options.lsq6,
                                                               lsq6_dir=None,  # never used since no average
                                                               # (could call this "average_dir" with None -> no avg ?)
                                                               subject_matter=options.registration.subject_matter,
                                                               create_qc_images=False,
                                                               create_average=False))[0],
                                        pipeline_subject_info)  # type: Dict[str, Subject[XfmHandler]]
        else:
            # if we are not dealing with a pride of models, we can retrieve a fixed
            # registration target for all input files:
            targets = registration_targets(lsq6_conf=options.lsq6,
                                           app_conf=options.application)
            
            # we want to store the xfm handlers in the same shape as pipeline_subject_info,
            # as such we will call lsq6_nuc_inorm for each file individually and simply extract
            # the first (and only) element from the resulting list via s.defer(...)[0].
            subj_id_to_subj_with_lsq6_xfm_dict = map_over_time_pt_dict_in_Subject(
                                         lambda subj_atom:
                                           s.defer(lsq6_nuc_inorm([subj_atom],
                                                                  registration_targets=targets,
                                                                  resolution=options.registration.resolution,
                                                                  lsq6_options=options.lsq6,
                                                                  lsq6_dir=None, # no average will be create, is just one file...
                                                                  create_qc_images=False,
                                                                  create_average=False,
                                                                  subject_matter=options.registration.subject_matter)
                                                   )[0],
                                         pipeline_subject_info)  # type: Dict[str, Subject[XfmHandler]]

        # create verification images to show the 6 parameter alignment
        montageLSQ6 = pipeline_montage_dir + "/quality_control_montage_lsq6.png"
        # TODO, base scaling factor on resolution of initial model or target
        filesToCreateImagesFrom = []
        for subj_id, subj in subj_id_to_subj_with_lsq6_xfm_dict.items():
            for time_pt, subj_time_pt_xfm in subj.time_pt_dict.items():
                filesToCreateImagesFrom.append(subj_time_pt_xfm.resampled)

        # TODO it's strange that create_quality_control_images gets the montage directory twice
        # TODO (in montages=output=montageLSQ6 and in montage_dir), suggesting a weirdness in create_q_c_images
        lsq6VerificationImages = s.defer(create_quality_control_images(filesToCreateImagesFrom,
                                                                       montage_output=montageLSQ6,
                                                                       montage_dir=pipeline_montage_dir,
                                                                       message=" the input images after the lsq6 alignment"))

    # NB currently LSQ6 expects an array of files, but we have a map.
    # possibilities:
    # - note that pairwise is enough (except for efficiency -- redundant blurring, etc.)
    #   and just use the map fn above with an LSQ6 fn taking only a single source
    # - rewrite LSQ6 to use such a (nested) map
    # - write conversion which creates a tagged array from the map, performs LSQ6,
    #   and converts back
    # - write 'over' which takes a registration, a data structure, and 'get/set' fns ...?
    

    # Intersubject registration: LSQ12/NLIN registration of common-timepoint images
    # The assumption here is that all these files are roughly aligned. Here is a toy
    # schematic of what happens. In this example, the common timepoint is set timepoint 2: 
    #
    #                            ------------
    # subject A    A_time_1   -> | A_time_2 | ->   A_time_3
    # subject B    B_time_1   -> | B_time_2 | ->   B_time_3
    # subject C    C_time_1   -> | C_time_2 | ->   C_time_3
    #                            ------------
    #                                 |
    #                            group_wise registration on time point 2
    #

    # dictionary that holds the transformations from the intersubject images
    # to the final common space average
    intersubj_img_to_xfm_to_common_avg_dict = {}  # type: Dict[MincAtom, XfmHandler]
    if options.registration.input_space in (InputSpace.lsq6, InputSpace.lsq12):
        # no registrations have been performed yet, so we can point to the input files
        s_id_to_intersubj_img_dict = { s_id : subj.intersubject_registration_image
                          for s_id, subj in pipeline_subject_info.items() }
    else:
        # lsq6 aligned images
        # When we ran the lsq6 alignment, we stored the XfmHandlers in the Subject dictionary. So when we call
        # xfmhandler.intersubject_registration_image, this returns an XfmHandler. From which
        # we want to extract the resampled file (in order to continue the registration with)
        s_id_to_intersubj_img_dict = { s_id : subj_with_xfmhandler.intersubject_registration_image.resampled
                          for s_id, subj_with_xfmhandler in subj_id_to_subj_with_lsq6_xfm_dict.items() }
    
    if options.application.verbose:
        print("\nImages that are used for the inter-subject registration:")
        print("ID\timage")
        for subject in s_id_to_intersubj_img_dict:
            print(subject + '\t' + s_id_to_intersubj_img_dict[subject].path)

    # determine what configuration to use for the non linear registration
    nonlinear_configuration = get_nonlinear_configuration_from_options(options.nlin.nlin_protocol,
                                                                       options.nlin.reg_method,
                                                                       options.registration.resolution)

    if options.registration.input_space in [InputSpace.lsq6, InputSpace.native]:
        intersubj_xfms = s.defer(lsq12_nlin_build_model(imgs=list(s_id_to_intersubj_img_dict.values()),
                                                lsq12_conf=options.lsq12,
                                                nlin_conf=nonlinear_configuration,
                                                resolution=options.registration.resolution,
                                                lsq12_dir=pipeline_lsq12_common_dir,
                                                nlin_dir=pipeline_nlin_common_dir,
                                                nlin_prefix="common"))
                                                #, like={atlas_from_init_model_at_this_tp}
    elif options.registration.input_space == InputSpace.lsq12:
        #TODO: write reader that creates a mincANTS configuration out of an input protocol
        # if we're starting with files that are already aligned with an affine transformation
        # (overall scaling is also dealt with), then the target for the non linear registration
        # should be the averge of the current input files.
        first_nlin_target = s.defer(mincaverage(imgs=list(s_id_to_intersubj_img_dict.values()),
                                                name_wo_ext="avg_of_input_files",
                                                output_dir=pipeline_nlin_common_dir))
        intersubj_xfms = s.defer(mincANTS_NLIN_build_model(imgs=list(s_id_to_intersubj_img_dict.values()),
                                                   initial_target=first_nlin_target,
                                                   nlin_dir=pipeline_nlin_common_dir,
                                                   conf=nonlinear_configuration))


    intersubj_img_to_xfm_to_common_avg_dict = { xfm.source : xfm for xfm in intersubj_xfms.output }

    # create one more convenience data structure: a mapping from subject_ID to the xfm_handler
    # that contains the transformation from the subject at the common time point to the
    # common time point average.
    subj_ID_to_xfm_handler_to_common_avg = {}
    for s_id, subj_at_common_tp in s_id_to_intersubj_img_dict.items():
        subj_ID_to_xfm_handler_to_common_avg[s_id] = intersubj_img_to_xfm_to_common_avg_dict[subj_at_common_tp]

    # create verification images to show the inter-subject  alignment
    montage_inter_subject = pipeline_montage_dir + "/quality_control_montage_inter_subject_registration.png"
    avg_and_inter_subject_images = []
    avg_and_inter_subject_images.append(intersubj_xfms.avg_img)
    for xfmh in intersubj_xfms.output:
        avg_and_inter_subject_images.append(xfmh.resampled)

    inter_subject_verification_images = s.defer(create_quality_control_images(
                                                  imgs=avg_and_inter_subject_images,
                                                  montage_output=montage_inter_subject,
                                                  montage_dir=pipeline_montage_dir,
                                                  message=" the result of the inter-subject alignment"))

    if options.application.verbose:
        print("\nTransformations for intersubject images to final nlin common space:")
        print("MincAtom\ttransformation")
        for subj_atom, xfm_handler in intersubj_img_to_xfm_to_common_avg_dict.items():
            print(subj_atom.path + '\t' + xfm_handler.xfm.path)


    ## within-subject registration
    # In the toy scenario below: 
    # subject A    A_time_1   ->   A_time_2   ->   A_time_3
    # subject B    B_time_1   ->   B_time_2   ->   B_time_3
    # subject C    C_time_1   ->   C_time_2   ->   C_time_3
    # 
    # The following registrations are run:
    # 1) A_time_1   ->   A_time_2
    # 2) A_time_2   ->   A_time_3
    #
    # 3) B_time_1   ->   B_time_2
    # 4) B_time_2   ->   B_time_3
    #
    # 5) C_time_1   ->   C_time_2
    # 6) C_time_2   ->   C_time_3    

    subj_id_to_Subjec_for_within_dict = pipeline_subject_info
    if options.registration.input_space == InputSpace.native:
        # we started with input images that were not aligned whatsoever
        # in this case we should use the images that were rigidly
        # aligned files to continue the within-subject registration with
        # # type: Dict[str, Subject[XfmHandler]]
        subj_id_to_Subjec_for_within_dict = map_over_time_pt_dict_in_Subject(lambda x: x.resampled,
                                                                             subj_id_to_subj_with_lsq6_xfm_dict)

    if options.application.verbose:
        print("\n\nWithin subject registrations:")
        for s_id, subj in subj_id_to_Subjec_for_within_dict.items():
            print("ID: ", s_id)
            for time_pt, subj_img in subj.time_pt_dict.items():
                print(time_pt, " ", subj_img.path)
            print("\n")

    # dictionary that maps subject IDs to a list containing:
    # ( [(time_pt_n, time_pt_n+1, XfmHandler_from_n_to_n+1), ..., (,,,)],
    #   index_of_common_time_pt)
    chain_xfms = { s_id : s.defer(intrasubject_registrations(
                                    subj=subj,
                                    linear_conf=default_lsq12_multilevel_minctracc,
                                    nlin_conf=mincANTS_default_conf.replace(
                                        file_resolution=options.registration.resolution,
                                        iterations="100x100x100x50")))
                   for s_id, subj in subj_id_to_Subjec_for_within_dict.items() }

    # create a montage image for each pair of time points
    for s_id, output_from_intra in chain_xfms.items():
        for time_pt_n, time_pt_n_plus_1, transform in output_from_intra[0]:
            montage_chain = pipeline_montage_dir + "/quality_control_chain_ID_" + s_id + \
                            "_timepoint_" + str(time_pt_n) + "_to_" + str(time_pt_n_plus_1) + ".png"
            chain_images = [transform.resampled, transform.target]
            chain_verification_images = s.defer(create_quality_control_images(chain_images,
                                                                              montage_output=montage_chain,
                                                                              montage_dir=pipeline_montage_dir,
                                                                              message="the alignment between ID " + s_id + " time point " +
                                                                                      str(time_pt_n) + " and " + str(time_pt_n_plus_1)))

    if options.application.verbose:
        print("\n\nTransformations gotten from the intrasubject registrations:")
        for s_id, output_from_intra in chain_xfms.items():
            print("ID: ", s_id)
            for time_pt_n, time_pt_n_plus_1, transform in output_from_intra[0]:
                print("Time point: ", time_pt_n, " to ", time_pt_n_plus_1, " trans: ", transform.xfm.path)
            print("\n")

    ## stats
    #
    # The statistic files we want to create are the following:
    # 1) subject <----- subject_common_time_point                              (resampled to common average)
    # 2) subject <----- subject_common_time_point <- common_time_point_average (incorporates inter subject differences)
    # 3) subject_time_point_n <----- subject_time_point_n+1                    (resampled to common average)

    # create transformation from each subject to the final common time point average,
    # and from each subject to the subject's common time point
    (non_rigid_xfms_to_common_avg, non_rigid_xfms_to_common_subj) = s.defer(get_chain_transforms_for_stats(subj_id_to_Subjec_for_within_dict,
                                                                            intersubj_img_to_xfm_to_common_avg_dict,
                                                                            chain_xfms))

    # Ad 1) provide transformations from the subject's common time point to each subject
    #       These are temporary, because they still need to be resampled into the
    #       average common time point space
    determinants_from_subject_common_to_subject = map_over_time_pt_dict_in_Subject(
        lambda xfm: s.defer(determinants_at_fwhms(xfms=[s.defer(invert_xfmhandler(xfm))],
                                                  inv_xfms=[xfm],  # determinants_at_fwhms now vectorized-unhelpful here
                                                  blur_fwhms=options.stats.stats_kernels)),
        non_rigid_xfms_to_common_subj)
    # the content of determinants_from_subject_common_to_subject is:
    #
    # {subject_ID : Subject(inter_subject_time_pt, time_pt_dict)
    #
    # where time_pt_dict contains:
    #
    # {time_point : Tuple(List[Tuple(float, Tuple(MincAtom, MincAtom))],
    #                     List[Tuple(float, Tuple(MincAtom, MincAtom))])
    #
    # And to be a bit more verbose:
    #
    # {time_point : Tuple(relative_stat_files,
    #                     absolute_stat_files)
    #
    # where either the relative_stat_files or the absolute_stat_files look like:
    #
    # [blur_kernel_1, (determinant_file_1, log_of_determinant_file_1),
    #  ...,
    #  blur_kernel_n, (determinant_file_n, log_of_determinant_file_n)]
    #
    # Now the only thing we really want to do, is to resample the actual log
    # determinants that were generated into the space of the common average.
    # To make that a little easier, I'll create a mapping that will contain:
    #
    # {subject_ID: Subject(intersubject_timepoint, {time_pt_1: [stat_file_1, ..., stat_file_n],
    #                                               ...,
    #                                               time_pt_n: [stat_file_1, ..., stat_file_n]}
    # }
    for s_id, subject_with_determinants in determinants_from_subject_common_to_subject.items():
        transform_from_common_subj_to_common_avg = subj_ID_to_xfm_handler_to_common_avg[s_id].xfm
        for time_pt, determinant_info in subject_with_determinants.time_pt_dict.items():
            # here, each determinant_info is a DataFrame where each row contains
            # 'abs_det', 'nlin_det', 'log_nlin_det', 'log_abs_det', 'fwhm' fields
            # of the log-determinants, blurred at various fwhms (corresponding to different rows)
            for _ix, row in determinant_info.iterrows():
                for log_det_file_to_resample in (row.log_full_det, row.log_nlin_det):
                    # TODO the MincAtoms corresponding to the resampled files are never returned
                    new_name_wo_ext = log_det_file_to_resample.filename_wo_ext + "_resampled_to_common"
                    s.defer(mincresample(img=log_det_file_to_resample,
                                         xfm=transform_from_common_subj_to_common_avg,
                                         like=log_det_file_to_resample,
                                         new_name_wo_ext=new_name_wo_ext,
                                         subdir="stats-volumes"))

    # Ad 2) provide transformations from the common avg to each subject. That's the
    #       inverse of what was provided by get_chain_transforms_for_stats()
    determinants_from_common_avg_to_subject = map_over_time_pt_dict_in_Subject(
        lambda xfm: s.defer(determinants_at_fwhms(xfms=[s.defer(invert_xfmhandler(xfm))],
                                                  inv_xfms=[xfm],  # determinants_at_fwhms now vectorized-unhelpful here
                                                  blur_fwhms=options.stats.stats_kernels)),
        non_rigid_xfms_to_common_avg)

    # TODO don't just return an (unnamed-)tuple here
    return Result(stages=s, output=Namespace(non_rigid_xfms_to_common=non_rigid_xfms_to_common_avg,
                                             determinants_from_common_avg_to_subject=determinants_from_common_avg_to_subject,
                                             determinants_from_subject_common_to_subject=determinants_from_subject_common_to_subject))
예제 #9
0
파일: MAGeT.py 프로젝트: psteadman/pydpiper
def maget(imgs : List[MincAtom], options, prefix, output_dir):     # FIXME prefix, output_dir aren't used !!

    s = Stages()

    maget_options = options.maget.maget

    resolution = options.registration.resolution  # TODO or get_resolution_from_file(...) -- only if file always exists!

    pipeline_sub_dir = os.path.join(options.application.output_directory,
                                    options.application.pipeline_name + "_atlases")

    if maget_options.atlas_lib is None:
        raise ValueError("Need some atlases ...")

    # TODO should alternately accept a CSV file ...
    atlases = atlases_from_dir(atlas_lib=maget_options.atlas_lib,
                               max_templates=maget_options.max_templates,
                               pipeline_sub_dir=pipeline_sub_dir)

    lsq12_conf = get_linear_configuration_from_options(options.maget.lsq12,
                                                       transform_type=LinearTransType.lsq12,
                                                       file_resolution=resolution)

    nlin_hierarchy = get_nonlinear_configuration_from_options(options.maget.nlin.nlin_protocol,
                                                              reg_method=options.maget.nlin.reg_method,
                                                              file_resolution=resolution)

    if maget_options.mask or maget_options.mask_only:

        # this used to return alignments but doesn't currently do so
        masked_img = s.defer(maget_mask(imgs=imgs,
                                        maget_options=options.maget, atlases=atlases,
                                        pipeline_sub_dir=pipeline_sub_dir + "_masking", # FIXME repeats all alignments!!!
                                        resolution=resolution))

        # now propagate only the masked form of the images and atlases:
        imgs    = masked_img
        #atlases = masked_atlases  # TODO is this needed?

    if maget_options.mask_only:
        # register each input to each atlas, creating a mask
        return Result(stages=s, output=masked_img)   # TODO rename `alignments` to `registrations`??
    else:
        if maget_options.mask:
            del masked_img
        # this `del` is just to verify that we don't accidentally use this later, since these potentially
        # coarser alignments shouldn't be re-used (but if the protocols for masking and alignment are the same,
        # hash-consing will take care of things), just the masked images they create; can be removed later
        # if a sensible use is found

        # images with labels from atlases
        # N.B.: Even though we've already registered each image to each initial atlas, this happens again here,
        #       but using `nlin_hierarchy` instead of `masking_nlin_hierarchy` as options.
        #       This is not 'work-efficient' in the sense that this computation happens twice (although
        #       hopefully at greater precision the second time!), but the idea is to run a coarse initial
        #       registration to get a mask and then do a better registration with that mask (though I'm not
        #       sure exactly when this is faster than doing a single registration).
        #       This _can_ allow the overall computation to finish more rapidly
        #       (depending on the relative speed of the two alignment methods/parameters,
        #       number of atlases and other templates used, number of cores available, etc.).
        atlas_labelled_imgs = (
            pd.DataFrame({ 'img'        : img,
                           'label_file' : s.defer(  # can't use `label` in a pd.DataFrame index!
                              mincresample_new(img=atlas.labels,
                                               xfm=s.defer(lsq12_nlin(source=img,
                                                                      target=atlas,
                                                                      lsq12_conf=lsq12_conf,
                                                                      nlin_conf=nlin_hierarchy,
                                                                      resample_source=False)).xfm,
                                               like=img,
                                               invert=True,
                                               interpolation=Interpolation.nearest_neighbour,
                                               extra_flags=('-keep_real_range',)))}
                         for img in imgs for atlas in atlases)
        )

        if maget_options.pairwise:

            def choose_new_templates(ts, n):
                # currently silly, but we might implement a smarter method ...
                # FIXME what if there aren't enough other imgs around?!  This silently goes weird ...
                return pd.Series(ts[:n+1])  # n+1 instead of n: choose one more since we won't use image as its own template ...

            # FIXME: the --max-templates flag is ambiguously named ... should be --max-new-templates
            # (and just use all atlases)
            templates = pd.DataFrame({ 'template' : choose_new_templates(ts=imgs,
                                                                         n=maget_options.max_templates - len(atlases))})
            # note these images are the masked ones if masking was done ...

            # the templates together with their (multiple) labels from the atlases (this merge just acts as a filter)
            labelled_templates = pd.merge(left=atlas_labelled_imgs, right=templates,
                                          left_on="img", right_on="template").drop('img', axis=1)

            # images with new labels from the templates
            imgs_and_templates = pd.merge(#left=atlas_labelled_imgs,
                                          left=pd.DataFrame({ "img" : imgs }).assign(fake=1),
                                          right=labelled_templates.assign(fake=1),
                                          on='fake')
                                          #left_on='img', right_on='template')  # TODO do select here instead of below?

            template_labelled_imgs = (
                imgs_and_templates
                .rename(columns={ 'label_file' : 'template_label_file' })
                # don't register template to itself, since otherwise atlases would vote on that template twice
                .select(lambda ix: imgs_and_templates.img[ix].path
                                     != imgs_and_templates.template[ix].path)  # TODO hardcoded name
                .assign(label_file=lambda df: df.apply(axis=1, func=lambda row:
                           s.defer(mincresample_new(img=row.template_label_file,
                                                    xfm=s.defer(lsq12_nlin(source=row.img,
                                                                           target=row.template,
                                                                           lsq12_conf=lsq12_conf,
                                                                           nlin_conf=nlin_hierarchy,
                                                                           resample_source=False)).xfm,
                                                    like=row.img,
                                                    invert=True,
                                                    interpolation=Interpolation.nearest_neighbour,
                                                    extra_flags=('-keep_real_range',)))))
            )

            imgs_with_all_labels = pd.concat([atlas_labelled_imgs[['img', 'label_file']],
                                              template_labelled_imgs[['img', 'label_file']]],
                                             ignore_index=True)
        else:
            imgs_with_all_labels = atlas_labelled_imgs

        segmented_imgs = (
                imgs_with_all_labels
                .groupby('img')
                .aggregate({'label_file' : lambda resampled_label_files: list(resampled_label_files)})
                .rename(columns={ 'label_file' : 'label_files' })
                .reset_index()
                .assign(voted_labels=lambda df: df.apply(axis=1, func=lambda row:
                          s.defer(voxel_vote(label_files=row.label_files,
                                             output_dir=os.path.join(row.img.pipeline_sub_dir, row.img.output_sub_dir)))))
                .apply(axis=1, func=lambda row: row.img._replace(labels=row.voted_labels))
        )

        return Result(stages=s, output=segmented_imgs)
예제 #10
0
파일: MAGeT.py 프로젝트: psteadman/pydpiper
def maget_mask(imgs : List[MincAtom], maget_options, resolution : float, pipeline_sub_dir : str, atlases=None):

    s = Stages()

    resample  = np.vectorize(mincresample_new, excluded={"extra_flags"})
    defer     = np.vectorize(s.defer)

    original_imgs = imgs
    imgs = copy.deepcopy(imgs)
    original_imgs = pd.Series(original_imgs, index=[img.path for img in original_imgs])
    for img in imgs:
        img.output_sub_dir = os.path.join(img.output_sub_dir, "masking")

    # TODO dereference maget_options -> maget_options.maget outside maget_mask call?
    if atlases is None:
        if maget_options.maget.atlas_lib is None:
            raise ValueError("need some atlases for MAGeT-based masking ...")
        atlases = atlases_from_dir(atlas_lib=maget_options.maget.atlas_lib,
                                   max_templates=maget_options.maget.max_templates,
                                   pipeline_sub_dir=pipeline_sub_dir)

    lsq12_conf = get_linear_configuration_from_options(maget_options.lsq12,
                                                       LinearTransType.lsq12,
                                                       resolution)

    masking_nlin_hierarchy = get_nonlinear_configuration_from_options(maget_options.maget.masking_nlin_protocol,
                                                                      maget_options.maget.mask_method,
                                                                      resolution)

    # TODO lift outside then delete
    #masking_imgs = copy.deepcopy(imgs)
    #for img in masking_imgs:
    #    img.pipeline_sub_dir = os.path.join(img.pipeline_sub_dir, "masking")

    masking_alignments = pd.DataFrame({ 'img'   : img,
                                        'atlas' : atlas,
                                        'xfm'   : s.defer(lsq12_nlin(source=img, target=atlas,
                                                                     lsq12_conf=lsq12_conf,
                                                                     nlin_conf=masking_nlin_hierarchy,
                                                                     resample_source=False))}
                                      for img in imgs for atlas in atlases)

    # propagate a mask to each image using the above `alignments` as follows:
    # - for each image, voxel_vote on the masks propagated to that image to get a suitable mask
    # - run mincmath -clobber -mult <img> <voted_mask> to apply the mask to the files
    masked_img = (
        masking_alignments
        .assign(resampled_mask=lambda df: defer(resample(img=df.atlas.apply(lambda x: x.mask),
                                                         xfm=df.xfm.apply(lambda x: x.xfm),
                                                         like=df.img,
                                                         invert=True,
                                                         interpolation=Interpolation.nearest_neighbour,
                                                         postfix="-input-mask",
                                                         subdir="tmp",
                                                         # TODO annoying hack; fix mincresample(_mask) ...:
                                                         #new_name_wo_ext=df.apply(lambda row:
                                                         #    "%s_to_%s-input-mask" % (row.atlas.filename_wo_ext,
                                                         #                             row.img.filename_wo_ext),
                                                         #    axis=1),
                                                         extra_flags=("-keep_real_range",))))
        .groupby('img', as_index=False)
        .aggregate({'resampled_mask' : lambda masks: list(masks)})
        .rename(columns={"resampled_mask" : "resampled_masks"})
        .assign(voted_mask=lambda df: df.apply(axis=1,
                                               func=lambda row:
                                                 s.defer(mincmath(op="max", vols=sorted(row.resampled_masks),
                                                                  new_name="%s_max_mask" % row.img.filename_wo_ext,
                                                                  subdir="tmp"))))
        .apply(axis=1, func=lambda row: row.img._replace(mask=row.voted_mask)))

    # resample the atlas images back to the input images:
    # (note: this doesn't modify `masking_alignments`, but only stages additional outputs)
    masking_alignments.assign(resampled_img=lambda df:
      defer(resample(img=df.atlas,
                     xfm=df.xfm.apply(lambda x: x.xfm),
                     subdir="tmp",
                     # TODO delete this stupid hack:
                     #new_name_wo_ext=df.apply(lambda row:
                     #  "%s_to_%s-resampled" % (row.atlas.filename_wo_ext,
                     #                          row.img.filename_wo_ext),
                     #                          axis=1),
                     like=df.img, invert=True)))

    for img in masked_img:
        img.output_sub_dir = original_imgs.ix[img.path].output_sub_dir

    return Result(stages=s, output=masked_img)