Example #1
0
class FullIterativeLSQ12Nlin:
    """Does a full iterative LSQ12 and NLIN. Basically iterative model building starting from LSQ6
       and without stats at the end. Designed to be called as part of a larger application. 
       Specifying an initModel is optional, all other arguments are mandatory."""
    def __init__(self, inputs, dirs, options, avgPrefix=None, initModel=None):
        self.inputs = inputs
        self.dirs = dirs
        self.options = options
        self.avgPrefix = avgPrefix
        self.initModel = initModel
        self.nlinFH = None
        
        self.p = Pipeline()
        
        self.buildPipeline()
        
    def buildPipeline(self):
        lsq12LikeFH = None 
        if self.initModel:
            lsq12LikeFH = self.initModel[0]
        elif self.options.lsq12_likeFile: 
            lsq12LikeFH = self.options.lsq12_likeFile 
        lsq12module = lsq12.FullLSQ12(self.inputs,
                                      self.dirs.lsq12Dir,
                                      likeFile=lsq12LikeFH,
                                      maxPairs=self.options.lsq12_max_pairs,
                                      lsq12_protocol=self.options.lsq12_protocol,
                                      subject_matter=self.options.lsq12_subject_matter)
        lsq12module.iterate()
        self.p.addPipeline(lsq12module.p)
        self.lsq12Params = lsq12module.lsq12Params
        if lsq12module.lsq12AvgFH.getMask()== None:
            if self.initModel:
                lsq12module.lsq12AvgFH.setMask(self.initModel[0].getMask())
        if not self.avgPrefix:
            self.avgPrefix = self.options.pipeline_name
        nlinModule = nlin.initializeAndRunNLIN(self.dirs.lsq12Dir,
                                               self.inputs,
                                               self.dirs.nlinDir,
                                               avgPrefix=self.avgPrefix, 
                                               createAvg=False,
                                               targetAvg=lsq12module.lsq12AvgFH,
                                               nlin_protocol=self.options.nlin_protocol,
                                               reg_method=self.options.reg_method)
        self.p.addPipeline(nlinModule.p)
        self.nlinFH = nlinModule.nlinAverages[-1]
        self.nlinParams = nlinModule.nlinParams
        self.initialTarget = nlinModule.initialTarget
        # Now we need the full transform to go back to LSQ6 space
        for i in self.inputs:
            linXfm = lsq12module.lsq12AvgXfms[i]
            nlinXfm = i.getLastXfm(self.nlinFH)
            outXfm = st.createOutputFileName(i, nlinXfm, "transforms", "_with_additional.xfm")
            xc = ma.xfmConcat([linXfm, nlinXfm], outXfm, fh.logFromFile(i.logDir, outXfm))
            self.p.addStage(xc)
            i.addAndSetXfmToUse(self.nlinFH, outXfm)
class HierarchicalMinctracc:
    """Default HierarchicalMinctracc currently does:
        1. 2 lsq12 stages with a blur of 0.25
        2. 5 nlin stages with a blur of 0.25
        3. 1 nlin stage with no blur"""
    def __init__(self, 
                 inputPipeFH, 
                 templatePipeFH,
                 steps=[1,0.5,0.5,0.2,0.2,0.1],
                 blurs=[0.25,0.25,0.25,0.25,0.25, -1], 
                 gradients=[False, False, True, False, True, False],
                 iterations=[60,60,60,10,10,4],
                 simplexes=[3,3,3,1.5,1.5,1],
                 w_translations=0.2,
                 linearparams = {'type' : "lsq12", 'simplex' : 1, 'step' : 1}, 
                 defaultDir="tmp"):
        
        self.p = Pipeline()
        
        for b in blurs:
            #MF TODO: -1 case is also handled in blur. Need here for addStage.
            #Fix this redundancy and/or better design?
            if b != -1:
                tblur = ma.blur(templatePipeFH, b, gradient=True)
                iblur = ma.blur(inputPipeFH, b, gradient=True)               
                self.p.addStage(tblur)
                self.p.addStage(iblur)
            
        # Do standard LSQ12 alignment prior to non-linear stages 
        lsq12reg = lsq12.LSQ12(inputPipeFH, 
                               templatePipeFH, 
                               defaultDir=defaultDir)
        self.p.addPipeline(lsq12reg.p)
        
        # create the nonlinear registrations
        for i in range(len(steps)):
            """For the final stage, make sure the output directory is transforms."""
            if i == (len(steps) - 1):
                defaultDir = "transforms"
            nlinStage = ma.minctracc(inputPipeFH, 
                                     templatePipeFH,
                                     defaultDir=defaultDir,
                                     blur=blurs[i],
                                     gradient=gradients[i],
                                     iterations=iterations[i],
                                     step=steps[i],
                                     similarity=0.8,
                                     w_translations=w_translations,
                                     simplex=simplexes[i])
            self.p.addStage(nlinStage)
Example #3
0
def MAGeTRegister(inputFH,
                  templateFH,
                  regMethod,
                  name="initial",
                  createMask=False,
                  lsq12_protocol=None,
                  nlin_protocol=None):

    p = Pipeline()
    if createMask:
        defaultDir = "tmp"
    else:
        defaultDir = "transforms"
    if regMethod == "minctracc":
        sp = HierarchicalMinctracc(inputFH,
                                   templateFH,
                                   lsq12_protocol=lsq12_protocol,
                                   nlin_protocol=nlin_protocol,
                                   defaultDir=defaultDir)
        p.addPipeline(sp.p)
    elif regMethod == "mincANTS":
        register = LSQ12ANTSNlin(inputFH,
                                 templateFH,
                                 lsq12_protocol=lsq12_protocol,
                                 nlin_protocol=nlin_protocol,
                                 defaultDir=defaultDir)
        p.addPipeline(register.p)

    rp = LabelAndFileResampling(inputFH,
                                templateFH,
                                name=name,
                                createMask=createMask)
    p.addPipeline(rp.p)

    return (p)
Example #4
0
def MAGeTRegister(inputFH, 
                  templateFH, 
                  regMethod,
                  name="initial", 
                  createMask=False,
                  lsq12_protocol=None,
                  nlin_protocol=None):
    
    p = Pipeline()
    if createMask:
        defaultDir="tmp"
    else:
        defaultDir="transforms"
    if regMethod == "minctracc":
        sp = HierarchicalMinctracc(inputFH, 
                                   templateFH,
                                   lsq12_protocol=lsq12_protocol,
                                   nlin_protocol=nlin_protocol,
                                   defaultDir=defaultDir)
        p.addPipeline(sp.p)
    elif regMethod == "mincANTS":
        register = LSQ12ANTSNlin(inputFH, 
                                 templateFH, 
                                 lsq12_protocol=lsq12_protocol,
                                 nlin_protocol=nlin_protocol,
                                 defaultDir=defaultDir)
        p.addPipeline(register.p)
        
    rp = LabelAndFileResampling(inputFH, templateFH, name=name, createMask=createMask)
    p.addPipeline(rp.p)
    
    return(p)
Example #5
0
def MAGeTMask(atlases,
              inputs,
              numAtlases,
              regMethod,
              lsq12_protocol=None,
              nlin_protocol=None):
    """ Masking algorithm is as follows:
        1. Run HierarchicalMinctracc or mincANTS with mask=True, 
           using masks instead of labels. 
        2. Do voxel voting to find the best mask. (Or, if single atlas,
            use that transform)
        3. mincMath to multiply original input by mask to get _masked.mnc file
            (This is done for both atlases and inputs, though for atlases, voxel
             voting is not required.)
        4. Replace lastBasevol with masked version, since once we have created
            mask, we no longer care about unmasked version. 
        5. Clear out labels arrays, which were used to keep track of masks,
            as we want to re-set them for actual labels.
                
        Note: All data will be placed in a newly created masking directory
        to keep it separate from data generated during actual MAGeT. 
        """
    p = Pipeline()
    for atlasFH in atlases:
        maskDirectoryStructure(atlasFH, masking=True)
    for inputFH in inputs:
        maskDirectoryStructure(inputFH, masking=True)
        for atlasFH in atlases:
            sp = MAGeTRegister(inputFH,
                               atlasFH,
                               regMethod,
                               name="initial",
                               createMask=True,
                               lsq12_protocol=lsq12_protocol,
                               nlin_protocol=nlin_protocol)
            p.addPipeline(sp)
    """ Prior to final masking, set log and tmp directories as they were."""
    for atlasFH in atlases:
        """Retrieve labels for use in new group. Assume only one"""
        labels = atlasFH.returnLabels(True)
        maskDirectoryStructure(atlasFH, masking=False)
        mp = maskFiles(atlasFH, True)
        p.addPipeline(mp)
        atlasFH.newGroup()
        atlasFH.addLabels(labels[0], inputLabel=True)
    for inputFH in inputs:
        maskDirectoryStructure(inputFH, masking=False)
        mp = maskFiles(inputFH, False, numAtlases)
        p.addPipeline(mp)
        # this will remove the "inputLabels"; labels that
        # come directly from the atlas library
        inputFH.clearLabels(True)
        # this will remove the "labels"; second generation
        # labels. I.e. labels from labels from the atlas library
        inputFH.clearLabels(False)
        inputFH.newGroup()
    return (p)
Example #6
0
def MAGeTMask(atlases, inputs, numAtlases, regMethod, lsq12_protocol=None, nlin_protocol=None):
    """ Masking algorithm is as follows:
        1. Run HierarchicalMinctracc or mincANTS with mask=True, 
           using masks instead of labels. 
        2. Do voxel voting to find the best mask. (Or, if single atlas,
            use that transform)
        3. mincMath to multiply original input by mask to get _masked.mnc file
            (This is done for both atlases and inputs, though for atlases, voxel
             voting is not required.)
        4. Replace lastBasevol with masked version, since once we have created
            mask, we no longer care about unmasked version. 
        5. Clear out labels arrays, which were used to keep track of masks,
            as we want to re-set them for actual labels.
                
        Note: All data will be placed in a newly created masking directory
        to keep it separate from data generated during actual MAGeT. 
        """
    p = Pipeline()
    for atlasFH in atlases:
        maskDirectoryStructure(atlasFH, masking=True)
    for inputFH in inputs:
        maskDirectoryStructure(inputFH, masking=True)
        for atlasFH in atlases:
            sp = MAGeTRegister(inputFH, 
                               atlasFH, 
                               regMethod, 
                               name="initial", 
                               createMask=True,
                               lsq12_protocol=lsq12_protocol,
                               nlin_protocol=nlin_protocol)
            p.addPipeline(sp)          
    """ Prior to final masking, set log and tmp directories as they were."""
    for atlasFH in atlases:
        """Retrieve labels for use in new group. Assume only one"""
        labels = atlasFH.returnLabels(True)
        maskDirectoryStructure(atlasFH, masking=False)
        mp = maskFiles(atlasFH, True)
        p.addPipeline(mp)
        atlasFH.newGroup()
        atlasFH.addLabels(labels[0], inputLabel=True)
    for inputFH in inputs:
        maskDirectoryStructure(inputFH, masking=False)
        mp = maskFiles(inputFH, False, numAtlases)
        p.addPipeline(mp)
        # this will remove the "inputLabels"; labels that
        # come directly from the atlas library
        inputFH.clearLabels(True)
        # this will remove the "labels"; second generation
        # labels. I.e. labels from labels from the atlas library
        inputFH.clearLabels(False) 
        inputFH.newGroup()  
    return(p)    
Example #7
0
class LSQ12ANTSNlin:
    """Class that runs a basic LSQ12 registration, followed by a single mincANTS call.
       Currently used in MAGeT, registration_chain and pairwise_nlin."""
    def __init__(self,
                 inputFH,
                 targetFH,
                 lsq12_protocol=None,
                 nlin_protocol=None,
                 subject_matter=None,
                 defaultDir="tmp"):
        
        self.p = Pipeline()
        self.inputFH = inputFH
        self.targetFH = targetFH
        self.lsq12_protocol = lsq12_protocol
        self.nlin_protocol = nlin_protocol
        self.subject_matter = subject_matter
        self.defaultDir = defaultDir
        
        if ((self.lsq12_protocol == None and self.subject_matter==None) or self.nlin_protocol == None):
            # always base the resolution to be used on the target for the registrations
            self.fileRes = rf.returnFinestResolution(self.targetFH)
        else:
            self.fileRes = None
        
        self.buildPipeline()    
    
    def buildPipeline(self):
        # Run lsq12 registration prior to non-linear
        self.lsq12Params = mp.setLSQ12MinctraccParams(self.fileRes, 
                                                      subject_matter=self.subject_matter,
                                                      reg_protocol=self.lsq12_protocol)
        lsq12reg = lsq12.LSQ12(self.inputFH, 
                               self.targetFH, 
                               blurs=self.lsq12Params.blurs,
                               step=self.lsq12Params.stepSize,
                               gradient=self.lsq12Params.useGradient,
                               simplex=self.lsq12Params.simplex,
                               w_translations=self.lsq12Params.w_translations,
                               defaultDir=self.defaultDir)
        self.p.addPipeline(lsq12reg.p)
        
        #Resample using final LSQ12 transform and reset last base volume. 
        res = ma.mincresample(self.inputFH, self.targetFH, likeFile=self.targetFH, argArray=["-sinc"])   
        self.p.addStage(res)
        self.inputFH.setLastBasevol(res.outputFiles[0])
        lsq12xfm = self.inputFH.getLastXfm(self.targetFH)
        
        #Get registration parameters from nlin protocol, blur and register
        #Assume a SINGLE generation here. 
        self.nlinParams = mp.setOneGenMincANTSParams(self.fileRes, reg_protocol=self.nlin_protocol)
        for b in self.nlinParams.blurs:
            for j in b:
                #Note that blurs for ANTS params in an array of arrays. 
                if j != -1:            
                    self.p.addStage(ma.blur(self.targetFH, j, gradient=True))
                    self.p.addStage(ma.blur(self.inputFH, j, gradient=True))
                    
        sp = ma.mincANTS(self.inputFH,
                         self.targetFH,
                         defaultDir=self.defaultDir, 
                         blur=self.nlinParams.blurs[0],
                         gradient=self.nlinParams.gradient[0],
                         similarity_metric=self.nlinParams.similarityMetric[0],
                         weight=self.nlinParams.weight[0],
                         iterations=self.nlinParams.iterations[0],
                         radius_or_histo=self.nlinParams.radiusHisto[0],
                         transformation_model=self.nlinParams.transformationModel[0], 
                         regularization=self.nlinParams.regularization[0],
                         useMask=self.nlinParams.useMask[0])
        self.p.addStage(sp)
        nlinXfm = sp.outputFiles[0]
        #Reset last base volume to original input for future registrations.
        self.inputFH.setLastBasevol(setToOriginalInput=True)
        #Concatenate transforms to get final lsq12 + nlin. Register volume handles naming and setting of lastXfm
        output = self.inputFH.registerVolume(self.targetFH, "transforms")
        xc = ma.xfmConcat([lsq12xfm, nlinXfm], output, fh.logFromFile(self.inputFH.logDir, output))
        self.p.addStage(xc)
Example #8
0
class LongitudinalStatsConcatAndResample:
    """ For each subject:
        1. Calculate stats (displacement, absolute jacobians, relative jacobians) between i and i+1 time points 
        2. Calculate transform from subject to common space (nlinFH) and invert it. 
           For most subjects this will require some amount of transform concatenation. 
        3. Calculate the stats (displacement, absolute jacobians, relative jacobians) from common space
           to each timepoint.
    """
    def __init__(self, subjects, timePoint, nlinFH, statsKernels, commonName):
        
        self.subjects = subjects
        self.timePoint = timePoint
        self.nlinFH = nlinFH
        self.blurs = [] 
        self.setupBlurs(statsKernels)
        self.commonName = commonName
        
        self.p = Pipeline()
        
        self.buildPipeline()
    
    def setupBlurs(self, statsKernels):
        if isinstance(statsKernels, list):
            self.blurs = statsKernels
        elif isinstance(statsKernels, str):
            for i in statsKernels.split(","):
                self.blurs.append(float(i))
        else:
            print("Improper type of blurring kernels specified for stats calculation: " + str(statsKernels))
            sys.exit()
    
    def statsCalculation(self, inputFH, targetFH, xfm=None, useChainStats=True):
        """If useChainStats=True, calculate stats between input and target. 
           This happens for all i to i+1 calcs.
           
           If useChainStats=False, calculate stats in the standard way, from target to
           input, We do this, when we go from the common space to all others. """
        if useChainStats:
            stats = st.CalcChainStats(inputFH, targetFH, self.blurs)
        else:
            stats = st.CalcStats(inputFH, targetFH, self.blurs)
        self.p.addPipeline(stats.p)
        """If an xfm is specified, resample all to this common space"""
        if xfm:
            if not self.nlinFH:
                likeFH = targetFH
            else:
                likeFH = self.nlinFH
            res = resampleToCommon(xfm, inputFH, stats.statsGroup, self.blurs, likeFH)
            self.p.addPipeline(res)
    
    def statsAndConcat(self, s, i, count, beforeAvg=True):
        """Construct array to common space for this timepoint.
           This builds upon arrays from previous calls."""
        if beforeAvg:
            xfm = s[i].getLastXfm(s[i+1]) 
        else:
            xfm = s[i].getLastXfm(s[i-1])
        """Set this transform as last xfm from input to nlin and calculate nlin to s[i] stats"""
        if self.nlinFH:
            self.xfmToCommon.insert(0, xfm)
            """ Concat transforms to get xfmToCommon and calculate statistics 
                Note that inverted transform, which is what we want, is calculated in
                the statistics module. """
            xtc = createBaseName(s[i].transformsDir, s[i].basename + "_to_" + self.commonName + ".xfm")
            xc = ma.xfmConcat(self.xfmToCommon, xtc, fh.logFromFile(s[i].logDir, xtc))
            self.p.addStage(xc)
            # here in order to visually inspect the alignment with the common
            # time point, we should resample this subject:
            inputResampledToCommon = createBaseName(s[i].resampledDir, s[i].basename + "_to_" + self.commonName + ".mnc") 
            logToCommon = fh.logFromFile(s[i].logDir, inputResampledToCommon)
            resampleCmd = ma.mincresample(s[i],
                                          self.nlinFH,
                                          likeFile=self.nlinFH,
                                          transform=xtc,
                                          output=inputResampledToCommon,
                                          logFile=logToCommon,
                                          argArray=["-sinc"])
            self.p.addStage(resampleCmd)
            s[i].addAndSetXfmToUse(self.nlinFH, xtc)
            self.statsCalculation(s[i], self.nlinFH, xfm=None, useChainStats=False)
        else:
            xtc=None
        """Calculate i to i+1 stats for all but final timePoint"""
        if count - i > 1:
            self.statsCalculation(s[i], s[i+1], xfm=xtc, useChainStats=True)
        
    def buildPipeline(self):
        for subj in self.subjects:
            s = self.subjects[subj]
            count = len(s)
            """Wherever iterative model building was run, the indiv --> nlin xfm is stored
               in the group with the name "final". We need to use this group for to get the
               transform and do the stats calculation, and then reset to the current group.
               Calculate stats first from average to timepoint included in average"""

            if self.timePoint == -1:
                # This means that we used the last file for each of the subjects
                # to create the common average. This will be a variable time 
                # point, so we have to determine it for each of the input files
                timePointToUse = len(s) - 1
            else:
                timePointToUse = self.timePoint

            currGroup = s[timePointToUse].currentGroupIndex
            index = s[timePointToUse].getGroupIndex("final")
            xfmToNlin = s[timePointToUse].getLastXfm(self.nlinFH, groupIndex=index)
            
            if xfmToNlin:
                self.xfmToCommon = [xfmToNlin]
            else:
                self.xfmToCommon = []
            if self.nlinFH:
                s[timePointToUse].currentGroupIndex = index
                self.statsCalculation(s[timePointToUse], self.nlinFH, xfm=None, useChainStats=False)
                s[timePointToUse].currentGroupIndex = currGroup
            """Next: If timepoint included in average is NOT final timepoint, 
               also calculate i to i+1 stats."""
            if count - timePointToUse > 1:
                self.statsCalculation(s[timePointToUse], s[timePointToUse+1], xfm=xfmToNlin, useChainStats=True)
            if not timePointToUse - 1 < 0:
                """ Average happened at time point other than first time point. 
                    Loop over points prior to average."""
                for i in reversed(range(timePointToUse)): 
                    self.statsAndConcat(s, i, count, beforeAvg=True)
                         
            # Loop over points after average. If average is at first time point, this loop
            # will hit all time points (other than first). If average is at subsequent time 
            # point, it hits all time points not covered previously. 
            #
            # xfmToCommon (possibly) needs to be reset: if the average time point is not the first time point
            # then the array xfmToCommon now contains a list of transformations from the first time point to the
            # average. For instance if the average time point is time point 3, then xfmToCommon now contains:
            # [ time_0_to_time_1.xfm, time_1_to_time_2.xfm, time_2_to_average_at_time_point2.xfm ]
            if xfmToNlin:
                self.xfmToCommon = [xfmToNlin]
            else:
                self.xfmToCommon = []  
            for i in range(timePointToUse + 1, count):
                self.statsAndConcat(s, i, count, beforeAvg=False)
Example #9
0
class FullIterativeLSQ12Nlin:
    """Does a full iterative LSQ12 and NLIN. Basically iterative model building starting from LSQ6
       and without stats at the end. Designed to be called as part of a larger application. 
       Specifying an initModel is optional, all other arguments are mandatory."""
    def __init__(self, inputs, dirs, options, avgPrefix=None, initModel=None, fileResolution=None):
        self.inputs = inputs
        self.dirs = dirs
        self.options = options
        self.avgPrefix = avgPrefix
        self.initModel = initModel
        self.nlinFH = None
        self.providedResolution = fileResolution
        
        self.p = Pipeline()
        
        self.buildPipeline()
        
    def buildPipeline(self):
        lsq12LikeFH = None 
        resolutionForLSQ12 = None
        if self.initModel:
            lsq12LikeFH = self.initModel[0]
        elif self.options.lsq12_likeFile: 
            lsq12LikeFH = self.options.lsq12_likeFile 
        
        if lsq12LikeFH == None and self.options.lsq12_subject_matter == None and self.providedResolution == None:
            print("\nError: the FullIterativeLSQ12Nlin module was called without specifying either an initial model, nor an lsq12_subject_matter. Currently that means that the code can not determine the resolution at which the registrations should be run. Please specify one of the two. Exiting\n")
            sys.exit()
        
        if not (lsq12LikeFH == None):
            resolutionForLSQ12 = rf.returnFinestResolution(lsq12LikeFH)

        if resolutionForLSQ12 == None and self.providedResolution == None:
            print("\nError: the resolution at which the LSQ12 and the NLIN registration should be run could not be determined from either the initial model nor the LSQ12 like file. Please provide the fileResolution to the FullIterativeLSQ12Nlin module. Exiting\n")
            sys.exit()
        
        if resolutionForLSQ12 == None and self.providedResolution:
            resolutionForLSQ12 = self.providedResolution
        
        lsq12module = lsq12.FullLSQ12(self.inputs,
                                      self.dirs.lsq12Dir,
                                      queue_type=self.options.queue_type,
                                      likeFile=lsq12LikeFH,
                                      maxPairs=self.options.lsq12_max_pairs,
                                      lsq12_protocol=self.options.lsq12_protocol,
                                      subject_matter=self.options.lsq12_subject_matter,
                                      resolution=resolutionForLSQ12)
        lsq12module.iterate()
        self.p.addPipeline(lsq12module.p)
        self.lsq12Params = lsq12module.lsq12Params
        if lsq12module.lsq12AvgFH.getMask()== None:
            if self.initModel:
                lsq12module.lsq12AvgFH.setMask(self.initModel[0].getMask())
        if not self.avgPrefix:
            self.avgPrefix = self.options.pipeline_name
        # same as in MBM.py:
        # for now we can use the same resolution for the NLIN stages as we did for the 
        # LSQ12 stage. At some point we should look into the subject matter option...
        nlinModule = nlin.initializeAndRunNLIN(self.dirs.lsq12Dir,
                                               self.inputs,
                                               self.dirs.nlinDir,
                                               avgPrefix=self.avgPrefix, 
                                               createAvg=False,
                                               targetAvg=lsq12module.lsq12AvgFH,
                                               nlin_protocol=self.options.nlin_protocol,
                                               reg_method=self.options.reg_method,
                                               resolution=resolutionForLSQ12)
        self.p.addPipeline(nlinModule.p)
        self.nlinFH = nlinModule.nlinAverages[-1]
        self.nlinParams = nlinModule.nlinParams
        self.initialTarget = nlinModule.initialTarget
        # Now we need the full transform to go back to LSQ6 space
        for i in self.inputs:
            linXfm = lsq12module.lsq12AvgXfms[i]
            nlinXfm = i.getLastXfm(self.nlinFH)
            outXfm = st.createOutputFileName(i, nlinXfm, "transforms", "_with_additional.xfm")
            xc = ma.xfmConcat([linXfm, nlinXfm], outXfm, fh.logFromFile(i.logDir, outXfm))
            self.p.addStage(xc)
            i.addAndSetXfmToUse(self.nlinFH, outXfm)
Example #10
0
class HierarchicalMinctracc:
    """Default HierarchicalMinctracc currently does:
        1. A standard three stage LSQ12 alignment. (See defaults for LSQ12 module.)
        2. A six generation non-linear minctracc alignment. 
       To override these defaults, lsq12 and nlin protocols may be specified. """
    def __init__(self, 
                 inputFH, 
                 targetFH,
                 lsq12_protocol=None,
                 nlin_protocol=None,
                 includeLinear = True,
                 subject_matter = None,  
                 defaultDir="tmp"):
        
        self.p = Pipeline()
        self.inputFH = inputFH
        self.targetFH = targetFH
        self.lsq12_protocol = lsq12_protocol
        self.nlin_protocol = nlin_protocol
        self.includeLinear = includeLinear
        self.subject_matter = subject_matter
        self.defaultDir = defaultDir
        
        if ((self.lsq12_protocol == None and self.subject_matter==None) or self.nlin_protocol == None):
            # the resolution of the registration should be based on the target
            self.fileRes = rf.returnFinestResolution(self.targetFH)
        else:
            self.fileRes = None
        
        self.buildPipeline()
        
    def buildPipeline(self):
            
        # Do LSQ12 alignment prior to non-linear stages if desired
        if self.includeLinear: 
            self.lsq12Params = mp.setLSQ12MinctraccParams(self.fileRes,
                                            subject_matter=self.subject_matter,
                                            reg_protocol=self.lsq12_protocol)
            lsq12reg = lsq12.LSQ12(self.inputFH, 
                                   self.targetFH, 
                                   blurs=self.lsq12Params.blurs,
                                   step=self.lsq12Params.stepSize,
                                   gradient=self.lsq12Params.useGradient,
                                   simplex=self.lsq12Params.simplex,
                                   w_translations=self.lsq12Params.w_translations,
                                   defaultDir=self.defaultDir)
            self.p.addPipeline(lsq12reg.p)
        
        # create the nonlinear registrations
        self.nlinParams = mp.setNlinMinctraccParams(self.fileRes, reg_protocol=self.nlin_protocol)
        for b in self.nlinParams.blurs: 
            if b != -1:           
                self.p.addStage(ma.blur(self.inputFH, b, gradient=True))
                self.p.addStage(ma.blur(self.targetFH, b, gradient=True))
        for i in range(len(self.nlinParams.stepSize)):
            #For the final stage, make sure the output directory is transforms.
            if i == (len(self.nlinParams.stepSize) - 1):
                self.defaultDir = "transforms"
            nlinStage = ma.minctracc(self.inputFH, 
                                     self.targetFH,
                                     defaultDir=self.defaultDir,
                                     blur=self.nlinParams.blurs[i],
                                     gradient=self.nlinParams.useGradient[i],
                                     iterations=self.nlinParams.iterations[i],
                                     step=self.nlinParams.stepSize[i],
                                     w_translations=self.nlinParams.w_translations[i],
                                     simplex=self.nlinParams.simplex[i],
                                     memory=self.nlinParams.memory[i] if self.nlinParams.memory else None,
                                     optimization=self.nlinParams.optimization[i])
            self.p.addStage(nlinStage)
Example #11
0
class initializeAndRunNLIN(object):
    """Class to setup target average (if needed), 
       instantiate correct version of NLIN class,
       and run NLIN registration."""
    def __init__(self, 
                  targetOutputDir, #Output directory for files related to initial target (often _lsq12)
                  inputFiles, 
                  nlinDir, 
                  avgPrefix, #Prefix for nlin-1.mnc, ... nlin-k.mnc 
                  createAvg=True, #True=call mincAvg, False=targetAvg already exists
                  targetAvg=None, #Optional path to initial target - passing name does not guarantee existence
                  targetMask=None, #Optional path to mask for initial target
                  nlin_protocol=None,
                  reg_method=None):
        self.p = Pipeline()
        self.targetOutputDir = targetOutputDir
        self.inputFiles = inputFiles
        self.nlinDir = nlinDir
        self.avgPrefix = avgPrefix
        self.createAvg = createAvg
        self.targetAvg = targetAvg
        self.targetMask = targetMask
        self.nlin_protocol = nlin_protocol
        self.reg_method = reg_method
        
        # setup initialTarget (if needed) and initialize non-linear module
        self.setupTarget()
        self.initNlinModule()
        
        #iterate through non-linear registration and setup averages
        self.nlinModule.iterate()
        self.p.addPipeline(self.nlinModule.p)
        self.nlinAverages = self.nlinModule.nlinAverages
        self.nlinParams = self.nlinModule.nlinParams
        
    def setupTarget(self):
        if self.targetAvg:
            if isinstance(self.targetAvg, str): 
                self.initialTarget = RegistrationPipeFH(self.targetAvg, 
                                                        mask=self.targetMask, 
                                                        basedir=self.targetOutputDir)
                self.outputAvg = self.targetAvg
            elif isinstance(self.targetAvg, RegistrationPipeFH):
                self.initialTarget = self.targetAvg
                self.outputAvg = self.targetAvg.getLastBasevol()
                if not self.initialTarget.getMask():
                    if self.targetMask:
                        self.initialTarget.setMask(self.targetMask)
            else:
                print "You have passed a target average that is neither a string nor a file handler: " + str(self.targetAvg)
                print "Exiting..."
        else:
            self.targetAvg = abspath(self.targetOutputDir) + "/" + "initial-target.mnc" 
            self.initialTarget = RegistrationPipeFH(self.targetAvg, 
                                                    mask=self.targetMask, 
                                                    basedir=self.targetOutputDir)
            self.outputAvg = self.targetAvg
        if self.createAvg:
            avg = mincAverage(self.inputFiles, 
                              self.initialTarget, 
                              output=self.outputAvg,
                              defaultDir=self.targetOutputDir)
            self.p.addStage(avg)
            
    def initNlinModule(self):
        if self.reg_method=="mincANTS":
            self.nlinModule = NLINANTS(self.inputFiles, self.initialTarget, self.nlinDir, self.avgPrefix, self.nlin_protocol)
        elif self.reg_method=="minctracc":
            self.nlinModule = NLINminctracc(self.inputFiles, self.initialTarget, self.nlinDir, self.avgPrefix, self.nlin_protocol)
        else:
            logger.error("Incorrect registration method specified: " + self.reg_method)
            sys.exit()
Example #12
0
class HierarchicalMinctracc:
    """Default HierarchicalMinctracc currently does:
        1. A standard three stage LSQ12 alignment. (See defaults for LSQ12 module.)
        2. A six generation non-linear minctracc alignment. 
       To override these defaults, lsq12 and nlin protocols may be specified. """
    def __init__(self, 
                 inputFH, 
                 targetFH,
                 lsq12_protocol=None,
                 nlin_protocol=None,
                 includeLinear = True,
                 subject_matter = None,  
                 defaultDir="tmp"):
        
        self.p = Pipeline()
        self.inputFH = inputFH
        self.targetFH = targetFH
        self.lsq12_protocol = lsq12_protocol
        self.nlin_protocol = nlin_protocol
        self.includeLinear = includeLinear
        self.subject_matter = subject_matter
        self.defaultDir = defaultDir
        
        if ((self.lsq12_protocol == None and self.subject_matter==None) or self.nlin_protocol == None):
            self.fileRes = rf.returnFinestResolution(self.inputFH)
        else:
            self.fileRes = None
        
        self.buildPipeline()
        
    def buildPipeline(self):
            
        # Do LSQ12 alignment prior to non-linear stages if desired
        if self.includeLinear: 
            self.lsq12Params = mp.setLSQ12MinctraccParams(self.fileRes,
                                            subject_matter=self.subject_matter,
                                            reg_protocol=self.lsq12_protocol)
            lsq12reg = lsq12.LSQ12(self.inputFH, 
                                   self.targetFH, 
                                   blurs=self.lsq12Params.blurs,
                                   step=self.lsq12Params.stepSize,
                                   gradient=self.lsq12Params.useGradient,
                                   simplex=self.lsq12Params.simplex,
                                   w_translations=self.lsq12Params.w_translations,
                                   defaultDir=self.defaultDir)
            self.p.addPipeline(lsq12reg.p)
        
        # create the nonlinear registrations
        self.nlinParams = mp.setNlinMinctraccParams(self.fileRes, reg_protocol=self.nlin_protocol)
        for b in self.nlinParams.blurs: 
            if b != -1:           
                self.p.addStage(ma.blur(self.inputFH, b, gradient=True))
                self.p.addStage(ma.blur(self.targetFH, b, gradient=True))
        for i in range(len(self.nlinParams.stepSize)):
            #For the final stage, make sure the output directory is transforms.
            if i == (len(self.nlinParams.stepSize) - 1):
                self.defaultDir = "transforms"
            nlinStage = ma.minctracc(self.inputFH, 
                                     self.targetFH,
                                     defaultDir=self.defaultDir,
                                     blur=self.nlinParams.blurs[i],
                                     gradient=self.nlinParams.useGradient[i],
                                     iterations=self.nlinParams.iterations[i],
                                     step=self.nlinParams.stepSize[i],
                                     w_translations=self.nlinParams.w_translations[i],
                                     simplex=self.nlinParams.simplex[i],
                                     optimization=self.nlinParams.optimization[i])
            self.p.addStage(nlinStage)
Example #13
0
class initializeAndRunNLIN(object):
    """Class to setup target average (if needed), 
       instantiate correct version of NLIN class,
       and run NLIN registration."""
    def __init__(
            self,
            targetOutputDir,  #Output directory for files related to initial target (often _lsq12)
            inputFiles,
            nlinDir,
            avgPrefix,  #Prefix for nlin-1.mnc, ... nlin-k.mnc 
            createAvg=True,  #True=call mincAvg, False=targetAvg already exists
            targetAvg=None,  #Optional path to initial target - passing name does not guarantee existence
            targetMask=None,  #Optional path to mask for initial target
            nlin_protocol=None,
            reg_method=None):
        self.p = Pipeline()
        self.targetOutputDir = targetOutputDir
        self.inputFiles = inputFiles
        self.nlinDir = nlinDir
        self.avgPrefix = avgPrefix
        self.createAvg = createAvg
        self.targetAvg = targetAvg
        self.targetMask = targetMask
        self.nlin_protocol = nlin_protocol
        self.reg_method = reg_method

        # setup initialTarget (if needed) and initialize non-linear module
        self.setupTarget()
        self.initNlinModule()

        #iterate through non-linear registration and setup averages
        self.nlinModule.iterate()
        self.p.addPipeline(self.nlinModule.p)
        self.nlinAverages = self.nlinModule.nlinAverages
        self.nlinParams = self.nlinModule.nlinParams

    def setupTarget(self):
        if self.targetAvg:
            if isinstance(self.targetAvg, str):
                self.initialTarget = RegistrationPipeFH(
                    self.targetAvg,
                    mask=self.targetMask,
                    basedir=self.targetOutputDir)
                self.outputAvg = self.targetAvg
            elif isinstance(self.targetAvg, RegistrationPipeFH):
                self.initialTarget = self.targetAvg
                self.outputAvg = self.targetAvg.getLastBasevol()
                if not self.initialTarget.getMask():
                    if self.targetMask:
                        self.initialTarget.setMask(self.targetMask)
            else:
                print "You have passed a target average that is neither a string nor a file handler: " + str(
                    self.targetAvg)
                print "Exiting..."
        else:
            self.targetAvg = abspath(
                self.targetOutputDir) + "/" + "initial-target.mnc"
            self.initialTarget = RegistrationPipeFH(
                self.targetAvg,
                mask=self.targetMask,
                basedir=self.targetOutputDir)
            self.outputAvg = self.targetAvg
        if self.createAvg:
            avg = mincAverage(self.inputFiles,
                              self.initialTarget,
                              output=self.outputAvg,
                              defaultDir=self.targetOutputDir)
            self.p.addStage(avg)

    def initNlinModule(self):
        if self.reg_method == "mincANTS":
            self.nlinModule = NLINANTS(self.inputFiles, self.initialTarget,
                                       self.nlinDir, self.avgPrefix,
                                       self.nlin_protocol)
        elif self.reg_method == "minctracc":
            self.nlinModule = NLINminctracc(self.inputFiles,
                                            self.initialTarget, self.nlinDir,
                                            self.avgPrefix, self.nlin_protocol)
        else:
            logger.error("Incorrect registration method specified: " +
                         self.reg_method)
            sys.exit()
Example #14
0
class LSQ12ANTSNlin:
    """Class that runs a basic LSQ12 registration, followed by a single mincANTS call.
       Currently used in MAGeT, registration_chain and pairwise_nlin."""
    def __init__(self,
                 inputFH,
                 targetFH,
                 lsq12_protocol=None,
                 nlin_protocol=None,
                 subject_matter=None,
                 defaultDir="tmp"):
        
        self.p = Pipeline()
        self.inputFH = inputFH
        self.targetFH = targetFH
        self.lsq12_protocol = lsq12_protocol
        self.nlin_protocol = nlin_protocol
        self.subject_matter = subject_matter
        self.defaultDir = defaultDir
        
        if ((self.lsq12_protocol == None and self.subject_matter==None) or self.nlin_protocol == None):
            self.fileRes = rf.returnFinestResolution(self.inputFH)
        else:
            self.fileRes = None
        
        self.buildPipeline()    
    
    def buildPipeline(self):
        # Run lsq12 registration prior to non-linear
        self.lsq12Params = mp.setLSQ12MinctraccParams(self.fileRes, 
                                                      subject_matter=self.subject_matter,
                                                      reg_protocol=self.lsq12_protocol)
        lsq12reg = lsq12.LSQ12(self.inputFH, 
                               self.targetFH, 
                               blurs=self.lsq12Params.blurs,
                               step=self.lsq12Params.stepSize,
                               gradient=self.lsq12Params.useGradient,
                               simplex=self.lsq12Params.simplex,
                               w_translations=self.lsq12Params.w_translations,
                               defaultDir=self.defaultDir)
        self.p.addPipeline(lsq12reg.p)
        
        #Resample using final LSQ12 transform and reset last base volume. 
        res = ma.mincresample(self.inputFH, self.targetFH, likeFile=self.targetFH, argArray=["-sinc"])   
        self.p.addStage(res)
        self.inputFH.setLastBasevol(res.outputFiles[0])
        lsq12xfm = self.inputFH.getLastXfm(self.targetFH)
        
        #Get registration parameters from nlin protocol, blur and register
        #Assume a SINGLE generation here. 
        self.nlinParams = mp.setOneGenMincANTSParams(self.fileRes, reg_protocol=self.nlin_protocol)
        for b in self.nlinParams.blurs:
            for j in b:
                #Note that blurs for ANTS params in an array of arrays. 
                if j != -1:            
                    self.p.addStage(ma.blur(self.targetFH, j, gradient=True))
                    self.p.addStage(ma.blur(self.inputFH, j, gradient=True))
                    
        sp = ma.mincANTS(self.inputFH,
                         self.targetFH,
                         defaultDir=self.defaultDir, 
                         blur=self.nlinParams.blurs[0],
                         gradient=self.nlinParams.gradient[0],
                         similarity_metric=self.nlinParams.similarityMetric[0],
                         weight=self.nlinParams.weight[0],
                         iterations=self.nlinParams.iterations[0],
                         radius_or_histo=self.nlinParams.radiusHisto[0],
                         transformation_model=self.nlinParams.transformationModel[0], 
                         regularization=self.nlinParams.regularization[0],
                         useMask=self.nlinParams.useMask[0])
        self.p.addStage(sp)
        nlinXfm = sp.outputFiles[0]
        #Reset last base volume to original input for future registrations.
        self.inputFH.setLastBasevol(setToOriginalInput=True)
        #Concatenate transforms to get final lsq12 + nlin. Register volume handles naming and setting of lastXfm
        output = self.inputFH.registerVolume(self.targetFH, "transforms")
        xc = ma.xfmConcat([lsq12xfm, nlinXfm], output, fh.logFromFile(self.inputFH.logDir, output))
        self.p.addStage(xc)
Example #15
0
class LongitudinalStatsConcatAndResample:
    """ For each subject:
        1. Calculate stats (displacement, absolute jacobians, relative jacobians) between i and i+1 time points 
        2. Calculate transform from subject to common space (nlinFH) and invert it. 
           For most subjects this will require some amount of transform concatenation. 
        3. Calculate the stats (displacement, absolute jacobians, relative jacobians) from common space
           to each timepoint.
    """
    def __init__(self, subjects, timePoint, nlinFH, statsKernels, commonName):
        
        self.subjects = subjects
        self.timePoint = timePoint
        self.nlinFH = nlinFH
        self.blurs = [] 
        self.setupBlurs(statsKernels)
        self.commonName = commonName
        
        self.p = Pipeline()
        
        self.buildPipeline()
    
    def setupBlurs(self, statsKernels):
        if isinstance(statsKernels, list):
            self.blurs = statsKernels
        elif isinstance(statsKernels, str):
            for i in statsKernels.split(","):
                self.blurs.append(float(i))
        else:
            print "Improper type of blurring kernels specified for stats calculation: " + str(statsKernels)
            sys.exit()
    
    def statsCalculation(self, inputFH, targetFH, xfm=None, useChainStats=True):
        """If useChainStats=True, calculate stats between input and target. 
           This happens for all i to i+1 calcs.
           
           If useChainStats=False, calculate stats in the standard way, from target to
           input, We do this, when we go from the common space to all others. """
        if useChainStats:
            stats = st.CalcChainStats(inputFH, targetFH, self.blurs)
        else:
            stats = st.CalcStats(inputFH, targetFH, self.blurs)
        self.p.addPipeline(stats.p)
        """If an xfm is specified, resample all to this common space"""
        if xfm:
            if not self.nlinFH:
                likeFH = targetFH
            else:
                likeFH = self.nlinFH
            res = resampleToCommon(xfm, inputFH, stats.statsGroup, self.blurs, likeFH)
            self.p.addPipeline(res)
    
    def statsAndConcat(self, s, i, count, beforeAvg=True):
        """Construct array to common space for this timepoint.
           This builds upon arrays from previous calls."""
        if beforeAvg:
            xfm = s[i].getLastXfm(s[i+1]) 
        else:
            xfm = s[i].getLastXfm(s[i-1])
        """Set this transform as last xfm from input to nlin and calculate nlin to s[i] stats"""
        if self.nlinFH:
            self.xfmToCommon.insert(0, xfm)
            """ Concat transforms to get xfmToCommon and calculate statistics 
                Note that inverted transform, which is what we want, is calculated in
                the statistics module. """
            xtc = fh.createBaseName(s[i].transformsDir, s[i].basename + "_to_" + self.commonName + ".xfm")
            xc = ma.xfmConcat(self.xfmToCommon, xtc, fh.logFromFile(s[i].logDir, xtc))
            self.p.addStage(xc)
            s[i].addAndSetXfmToUse(self.nlinFH, xtc)
            self.statsCalculation(s[i], self.nlinFH, xfm=None, useChainStats=False)
        else:
            xtc=None
        """Calculate i to i+1 stats for all but final timePoint"""
        if count - i > 1:
            self.statsCalculation(s[i], s[i+1], xfm=xtc, useChainStats=True)
        
    def buildPipeline(self):
        for subj in self.subjects:
            s = self.subjects[subj]
            count = len(s)
            """Wherever iterative model building was run, the indiv --> nlin xfm is stored
               in the group with the name "final". We need to use this group for to get the
               transform and do the stats calculation, and then reset to the current group.
               Calculate stats first from average to timepoint included in average"""
               
            currGroup = s[self.timePoint].currentGroupIndex
            index = s[self.timePoint].getGroupIndex("final")
            xfmToNlin = s[self.timePoint].getLastXfm(self.nlinFH, groupIndex=index)
            
            if xfmToNlin:
                self.xfmToCommon = [xfmToNlin]
            else:
                self.xfmToCommon = []
            if self.nlinFH:
                s[self.timePoint].currentGroupIndex = index
                self.statsCalculation(s[self.timePoint], self.nlinFH, xfm=None, useChainStats=False)
                s[self.timePoint].currentGroupIndex = currGroup
            """Next: If timepoint included in average is NOT final timepoint, 
               also calculate i to i+1 stats."""
            if count - self.timePoint > 1:
                self.statsCalculation(s[self.timePoint], s[self.timePoint+1], xfm=xfmToNlin, useChainStats=True)
            if not self.timePoint - 1 < 0:
                """ Average happened at time point other than first time point. 
                    Loop over points prior to average."""
                for i in reversed(range(self.timePoint)): 
                    self.statsAndConcat(s, i, count, beforeAvg=True)
                         
            """ Loop over points after average. If average is at first time point, this loop
                will hit all time points (other than first). If average is at subsequent time 
                point, it hits all time points not covered previously. xfmToCommon needs to be reset."""
            if xfmToNlin:
                self.xfmToCommon = [xfmToNlin]
            else:
                self.xfmToCommon = []  
            for i in range(self.timePoint + 1, count):
                self.statsAndConcat(s, i, count, beforeAvg=False)
Example #16
0
class FullLSQ12(object):
    """
        This class takes an array of input file handlers along with an optionally specified 
        protocol and does 12-parameter alignment and averaging of all of the pairs. 
        
        Required arguments:
        inputArray = array of file handlers to be registered
        outputDir = an output directory to place the final average from this registration
       
        Optional arguments include: 
        --likeFile = a file handler that can be used as a likeFile for resampling
            each input into the final lsq12 space. If none is specified, the input
            will be used
        --maxPairs = maximum number of pairs to register. If this pair is specified, 
            then each subject will only be registered to a subset of the other subjects.
        --lsq2_protocol = an optional csv file to specify a protocol that overrides the defaults.
        --subject_matter = currently supports "mousebrain". If this is specified, the parameter for
        the minctracc registrations are set based on defaults for mouse brains instead of the file
        resolution. 
    """
    
    def __init__(self, inputArray, 
                 outputDir, 
                 likeFile=None, 
                 maxPairs=None, 
                 lsq12_protocol=None,
                 subject_matter=None):
        self.p = Pipeline()
        """Initial inputs should be an array of fileHandlers with lastBasevol in lsq12 space"""
        self.inputs = inputArray
        """Output directory should be _nlin """
        self.lsq12Dir = outputDir
        """likeFile for resampling"""
        self.likeFile=likeFile
        """Maximum number of pairs to calculate"""
        self.maxPairs = maxPairs
        """Final lsq12 average"""
        self.lsq12Avg = None
        """Final lsq12 average file handler (e.g. the file handler associated with lsq12Avg)"""
        self.lsq12AvgFH = None
        """ Dictionary of lsq12 average transforms, which will include one per input.
            Key is input file handler and value is string pointing to final average lsq12
            transform for that particular subject. 
            These xfms may be used subsequently for statistics calculations. """
        self.lsq12AvgXfms = {}
        # what sort of subject matter do we deal with?
        self.subject_matter = subject_matter
        
        """Create the blurring resolution from the file resolution"""
        try:
            self.fileRes = rf.getFinestResolution(self.inputs[0])
        except: 
            # if this fails (because file doesn't exist when pipeline is created) grab from
            # initial input volume, which should exist. 
            self.fileRes = rf.getFinestResolution(self.inputs[0].inputFileName)
        
        """ 
            Similarly to LSQ6 and NLIN modules, an optional SEMI-COLON delimited csv may 
            be specified to override the default registration protocol. An example protocol
            may be found in:
            
            Note that if no protocol is specified, then defaults will be used. 
            Based on the length of these parameter arrays, the number of generations is set. 
        """
        self.defaultParams()
        if lsq12_protocol:
            self.setParams(lsq12_protocol)
        self.generations = self.getGenerations() 
        
        # Create new lsq12 group for each input prior to registration
        for i in range(len(self.inputs)):
            self.inputs[i].newGroup(groupName="lsq12")
         
    
    def defaultParams(self):
        """ 
            Default minctracc parameters based on resolution of file, unless
            a particular subject matter was provided
        """
        
        blurfactors      = [       5,   10.0/3.0,         2.5]
        stepfactors      = [50.0/3.0,   25.0/3.0,         5.5]
        simplexfactors   = [      50,         25,    50.0/3.0]
        
        if(self.subject_matter == "mousebrain"):
            # the default for mouse brains should be:
            # blurs:   0.3   0.2   0.15
            # steps:   1     0.5   0.333
            # simplex: 3     1.5   1
            self.blurs =    [0.3, 0.2, 0.15]
            self.stepSize=  [1,   0.5, 1.0/3.0]
            self.simplex=   [3,   1.5, 1]
        else:
            self.blurs = [i * self.fileRes for i in blurfactors]
            self.stepSize=[i * self.fileRes for i in stepfactors]
            self.simplex=[i * self.fileRes for i in simplexfactors]
        
        self.useGradient=[False,True,False]
        
    def setParams(self, lsq12_protocol):
        """Set parameters from specified protocol"""
        
        """Read parameters into array from csv."""
        inputCsv = open(abspath(lsq12_protocol), 'rb')
        csvReader = csv.reader(inputCsv, delimiter=';', skipinitialspace=True)
        params = []
        for r in csvReader:
            params.append(r)
        """initialize arrays """
        self.blurs = []
        self.stepSize = []
        self.useGradient = []
        self.simplex = []

        """Parse through rows and assign appropriate values to each parameter array.
           Everything is read in as strings, but in some cases, must be converted to 
           floats, booleans or gradients. 
        """
        for p in params:
            if p[0]=="blur":
                """Blurs must be converted to floats."""
                for i in range(1,len(p)):
                    self.blurs.append(float(p[i]))
            elif p[0]=="step":
                """Steps are strings but must be converted to a float."""
                for i in range(1,len(p)):
                    self.stepSize.append(float(p[i]))
            elif p[0]=="gradient":
                """Gradients must be converted to bools."""
                for i in range(1,len(p)):
                    if p[i]=="True" or p[i]=="TRUE":
                        self.useGradient.append(True)  
                    elif p[i]=="False" or p[i]=="FALSE":
                        self.useGradient.append(False) 
            elif p[0]=="simplex":
                """Simplex must be converted to an int."""
                for i in range(1,len(p)):
                    self.simplex.append(int(p[i]))
            else:
                print "Improper parameter specified for minctracc protocol: " + str(p[0])
                print "Exiting..."
                sys.exit()
        
    def getGenerations(self):
        arrayLength = len(self.blurs)
        errorMsg = "Array lengths in lsq12 minctracc protocol do not match."
        if (len(self.stepSize) != arrayLength 
            or len(self.useGradient) != arrayLength
            or len(self.simplex) != arrayLength):
            print errorMsg
            raise
        else:
            return arrayLength 
        
    def iterate(self):
        if not self.maxPairs:
            xfmsToAvg = {}
            lsq12ResampledFiles = {}
            for inputFH in self.inputs:
                """Create an array of xfms, to compute an average lsq12 xfm for each input"""
                xfmsToAvg[inputFH] = []
                for targetFH in self.inputs:
                    if inputFH != targetFH:
                        lsq12 = LSQ12(inputFH,
                                      targetFH,
                                      self.blurs,
                                      self.stepSize,
                                      self.useGradient,
                                      self.simplex)
                        self.p.addPipeline(lsq12.p)
                        xfmsToAvg[inputFH].append(inputFH.getLastXfm(targetFH))
                
                """Create average xfm for inputFH using xfmsToAvg array"""
                cmd = ["xfmavg"]
                for i in range(len(xfmsToAvg[inputFH])):
                    cmd.append(InputFile(xfmsToAvg[inputFH][i]))
                avgXfmOutput = createBaseName(inputFH.transformsDir, inputFH.basename + "-avg-lsq12.xfm")
                cmd.append(OutputFile(avgXfmOutput))
                xfmavg = CmdStage(cmd)
                xfmavg.setLogFile(LogFile(logFromFile(inputFH.logDir, avgXfmOutput)))
                self.p.addStage(xfmavg)
                self.lsq12AvgXfms[inputFH] = avgXfmOutput
                """ resample brain and add to array for mincAveraging"""
                if not self.likeFile:
                    likeFile=inputFH
                else:
                    likeFile=self.likeFile
                rslOutput = createBaseName(inputFH.resampledDir, inputFH.basename + "-resampled-lsq12.mnc")
                res = ma.mincresample(inputFH, 
                                      inputFH,
                                      transform=avgXfmOutput, 
                                      likeFile=likeFile, 
                                      output=rslOutput,
                                      argArray=["-sinc"])   
                self.p.addStage(res)
                lsq12ResampledFiles[inputFH] = rslOutput
            """ After all registrations complete, setLastBasevol for each subject to be
                resampled file in lsq12 space. We can then call mincAverage on fileHandlers,
                as it will use the lastBasevol for each by default."""
            for inputFH in self.inputs:
                inputFH.setLastBasevol(lsq12ResampledFiles[inputFH])
            """ mincAverage all resampled brains and put in lsq12Directory""" 
            self.lsq12Avg = abspath(self.lsq12Dir) + "/" + basename(self.lsq12Dir) + "-pairs.mnc" 
            self.lsq12AvgFH = RegistrationPipeFH(self.lsq12Avg, basedir=self.lsq12Dir)
            avg = ma.mincAverage(self.inputs, 
                                 self.lsq12AvgFH, 
                                 output=self.lsq12Avg,
                                 defaultDir=self.lsq12Dir)
            self.p.addStage(avg)
        else:
            print "Registration using a specified number of max pairs not yet working. Check back soon!"
            sys.exit()
Example #17
0
class FullLSQ12(object):
    """
        This class takes an array of input file handlers along with an optionally specified 
        protocol and does 12-parameter alignment and averaging of all of the pairs. 
        
        Required arguments:
        inputArray = array of file handlers to be registered
        outputDir = an output directory to place the final average from this registration
       
        Optional arguments include: 
        --likeFile = a file handler that can be used as a likeFile for resampling
            each input into the final lsq12 space. If none is specified, the input
            will be used
        --maxPairs = maximum number of pairs to register. If this pair is specified, 
            then each subject will only be registered to a subset of the other subjects.
        --lsq2_protocol = an optional csv file to specify a protocol that overrides the defaults.
        --subject_matter = currently supports "mousebrain". If this is specified, the parameter for
        the minctracc registrations are set based on defaults for mouse brains instead of the file
        resolution. 
    """
    def __init__(self,
                 inputArray,
                 outputDir,
                 likeFile=None,
                 maxPairs=None,
                 lsq12_protocol=None,
                 subject_matter=None):
        self.p = Pipeline()
        """Initial inputs should be an array of fileHandlers with lastBasevol in lsq12 space"""
        self.inputs = inputArray
        """Output directory should be _nlin """
        self.lsq12Dir = outputDir
        """likeFile for resampling"""
        self.likeFile = likeFile
        """Maximum number of pairs to calculate"""
        self.maxPairs = maxPairs
        """Final lsq12 average"""
        self.lsq12Avg = None
        """Final lsq12 average file handler (e.g. the file handler associated with lsq12Avg)"""
        self.lsq12AvgFH = None
        """ Dictionary of lsq12 average transforms, which will include one per input.
            Key is input file handler and value is string pointing to final average lsq12
            transform for that particular subject. 
            These xfms may be used subsequently for statistics calculations. """
        self.lsq12AvgXfms = {}
        """Create the blurring resolution from the file resolution"""
        if (subject_matter == None and lsq12_protocol == None):
            self.fileRes = rf.returnFinestResolution(self.inputs[0])
        else:
            self.fileRes = None
        """"Set up parameter array"""
        self.lsq12Params = mp.setLSQ12MinctraccParams(
            self.fileRes,
            subject_matter=subject_matter,
            reg_protocol=lsq12_protocol)
        self.blurs = self.lsq12Params.blurs
        self.stepSize = self.lsq12Params.stepSize
        self.useGradient = self.lsq12Params.useGradient
        self.simplex = self.lsq12Params.simplex
        self.w_translations = self.lsq12Params.w_translations
        self.generations = self.lsq12Params.generations

        # Create new lsq12 group for each input prior to registration
        for i in range(len(self.inputs)):
            self.inputs[i].newGroup(groupName="lsq12")

    def iterate(self):
        if not self.maxPairs:
            xfmsToAvg = {}
            lsq12ResampledFiles = {}
            for inputFH in self.inputs:
                """Create an array of xfms, to compute an average lsq12 xfm for each input"""
                xfmsToAvg[inputFH] = []
                for targetFH in self.inputs:
                    if inputFH != targetFH:
                        lsq12 = LSQ12(inputFH,
                                      targetFH,
                                      blurs=self.blurs,
                                      step=self.stepSize,
                                      gradient=self.useGradient,
                                      simplex=self.simplex,
                                      w_translations=self.w_translations)
                        self.p.addPipeline(lsq12.p)
                        xfmsToAvg[inputFH].append(inputFH.getLastXfm(targetFH))
                """Create average xfm for inputFH using xfmsToAvg array"""
                cmd = ["xfmavg"]
                for i in range(len(xfmsToAvg[inputFH])):
                    cmd.append(InputFile(xfmsToAvg[inputFH][i]))
                avgXfmOutput = createBaseName(
                    inputFH.transformsDir, inputFH.basename + "-avg-lsq12.xfm")
                cmd.append(OutputFile(avgXfmOutput))
                xfmavg = CmdStage(cmd)
                xfmavg.setLogFile(
                    LogFile(logFromFile(inputFH.logDir, avgXfmOutput)))
                self.p.addStage(xfmavg)
                self.lsq12AvgXfms[inputFH] = avgXfmOutput
                """ resample brain and add to array for mincAveraging"""
                if not self.likeFile:
                    likeFile = inputFH
                else:
                    likeFile = self.likeFile
                rslOutput = createBaseName(
                    inputFH.resampledDir,
                    inputFH.basename + "-resampled-lsq12.mnc")
                res = ma.mincresample(inputFH,
                                      inputFH,
                                      transform=avgXfmOutput,
                                      likeFile=likeFile,
                                      output=rslOutput,
                                      argArray=["-sinc"])
                self.p.addStage(res)
                lsq12ResampledFiles[inputFH] = rslOutput
            """ After all registrations complete, setLastBasevol for each subject to be
                resampled file in lsq12 space. We can then call mincAverage on fileHandlers,
                as it will use the lastBasevol for each by default."""
            for inputFH in self.inputs:
                inputFH.setLastBasevol(lsq12ResampledFiles[inputFH])
            """ mincAverage all resampled brains and put in lsq12Directory"""
            self.lsq12Avg = abspath(self.lsq12Dir) + "/" + basename(
                self.lsq12Dir) + "-pairs.mnc"
            self.lsq12AvgFH = RegistrationPipeFH(self.lsq12Avg,
                                                 basedir=self.lsq12Dir)
            avg = ma.mincAverage(self.inputs,
                                 self.lsq12AvgFH,
                                 output=self.lsq12Avg,
                                 defaultDir=self.lsq12Dir)
            self.p.addStage(avg)
        else:
            print "Registration using a specified number of max pairs not yet working. Check back soon!"
            sys.exit()
Example #18
0
class FullIterativeLSQ12Nlin:
    """Does a full iterative LSQ12 and NLIN. Basically iterative model building starting from LSQ6
       and without stats at the end. Designed to be called as part of a larger application. 
       Specifying an initModel is optional, all other arguments are mandatory."""
    def __init__(self, inputs, dirs, options, avgPrefix=None, initModel=None):
        self.inputs = inputs
        self.dirs = dirs
        self.options = options
        self.avgPrefix = avgPrefix
        self.initModel = initModel
        self.nlinFH = None

        self.p = Pipeline()

        self.buildPipeline()

    def buildPipeline(self):
        lsq12LikeFH = None
        resolutionForLSQ12 = None
        if self.initModel:
            lsq12LikeFH = self.initModel[0]
        elif self.options.lsq12_likeFile:
            lsq12LikeFH = self.options.lsq12_likeFile

        if lsq12LikeFH == None and self.options.lsq12_subject_matter == None:
            print "\nError: the FullIterativeLSQ12Nlin module was called without specifying either an initial model, nor an lsq12_subject_matter. Currently that means that the code can not determine the resolution at which the registrations should be run. Please specify one of the two. Exiting\n"
            sys.exit()

        if not (lsq12LikeFH == None):
            resolutionForLSQ12 = rf.returnFinestResolution(lsq12LikeFH)

        lsq12module = lsq12.FullLSQ12(
            self.inputs,
            self.dirs.lsq12Dir,
            likeFile=lsq12LikeFH,
            maxPairs=self.options.lsq12_max_pairs,
            lsq12_protocol=self.options.lsq12_protocol,
            subject_matter=self.options.lsq12_subject_matter,
            resolution=resolutionForLSQ12)
        lsq12module.iterate()
        self.p.addPipeline(lsq12module.p)
        self.lsq12Params = lsq12module.lsq12Params
        if lsq12module.lsq12AvgFH.getMask() == None:
            if self.initModel:
                lsq12module.lsq12AvgFH.setMask(self.initModel[0].getMask())
        if not self.avgPrefix:
            self.avgPrefix = self.options.pipeline_name
        # same as in MBM.py:
        # for now we can use the same resolution for the NLIN stages as we did for the
        # LSQ12 stage. At some point we should look into the subject matter option...
        nlinModule = nlin.initializeAndRunNLIN(
            self.dirs.lsq12Dir,
            self.inputs,
            self.dirs.nlinDir,
            avgPrefix=self.avgPrefix,
            createAvg=False,
            targetAvg=lsq12module.lsq12AvgFH,
            nlin_protocol=self.options.nlin_protocol,
            reg_method=self.options.reg_method,
            resolution=resolutionForLSQ12)
        self.p.addPipeline(nlinModule.p)
        self.nlinFH = nlinModule.nlinAverages[-1]
        self.nlinParams = nlinModule.nlinParams
        self.initialTarget = nlinModule.initialTarget
        # Now we need the full transform to go back to LSQ6 space
        for i in self.inputs:
            linXfm = lsq12module.lsq12AvgXfms[i]
            nlinXfm = i.getLastXfm(self.nlinFH)
            outXfm = st.createOutputFileName(i, nlinXfm, "transforms",
                                             "_with_additional.xfm")
            xc = ma.xfmConcat([linXfm, nlinXfm], outXfm,
                              fh.logFromFile(i.logDir, outXfm))
            self.p.addStage(xc)
            i.addAndSetXfmToUse(self.nlinFH, outXfm)