class FullIterativeLSQ12Nlin: """Does a full iterative LSQ12 and NLIN. Basically iterative model building starting from LSQ6 and without stats at the end. Designed to be called as part of a larger application. Specifying an initModel is optional, all other arguments are mandatory.""" def __init__(self, inputs, dirs, options, avgPrefix=None, initModel=None): self.inputs = inputs self.dirs = dirs self.options = options self.avgPrefix = avgPrefix self.initModel = initModel self.nlinFH = None self.p = Pipeline() self.buildPipeline() def buildPipeline(self): lsq12LikeFH = None if self.initModel: lsq12LikeFH = self.initModel[0] elif self.options.lsq12_likeFile: lsq12LikeFH = self.options.lsq12_likeFile lsq12module = lsq12.FullLSQ12(self.inputs, self.dirs.lsq12Dir, likeFile=lsq12LikeFH, maxPairs=self.options.lsq12_max_pairs, lsq12_protocol=self.options.lsq12_protocol, subject_matter=self.options.lsq12_subject_matter) lsq12module.iterate() self.p.addPipeline(lsq12module.p) self.lsq12Params = lsq12module.lsq12Params if lsq12module.lsq12AvgFH.getMask()== None: if self.initModel: lsq12module.lsq12AvgFH.setMask(self.initModel[0].getMask()) if not self.avgPrefix: self.avgPrefix = self.options.pipeline_name nlinModule = nlin.initializeAndRunNLIN(self.dirs.lsq12Dir, self.inputs, self.dirs.nlinDir, avgPrefix=self.avgPrefix, createAvg=False, targetAvg=lsq12module.lsq12AvgFH, nlin_protocol=self.options.nlin_protocol, reg_method=self.options.reg_method) self.p.addPipeline(nlinModule.p) self.nlinFH = nlinModule.nlinAverages[-1] self.nlinParams = nlinModule.nlinParams self.initialTarget = nlinModule.initialTarget # Now we need the full transform to go back to LSQ6 space for i in self.inputs: linXfm = lsq12module.lsq12AvgXfms[i] nlinXfm = i.getLastXfm(self.nlinFH) outXfm = st.createOutputFileName(i, nlinXfm, "transforms", "_with_additional.xfm") xc = ma.xfmConcat([linXfm, nlinXfm], outXfm, fh.logFromFile(i.logDir, outXfm)) self.p.addStage(xc) i.addAndSetXfmToUse(self.nlinFH, outXfm)
class HierarchicalMinctracc: """Default HierarchicalMinctracc currently does: 1. 2 lsq12 stages with a blur of 0.25 2. 5 nlin stages with a blur of 0.25 3. 1 nlin stage with no blur""" def __init__(self, inputPipeFH, templatePipeFH, steps=[1,0.5,0.5,0.2,0.2,0.1], blurs=[0.25,0.25,0.25,0.25,0.25, -1], gradients=[False, False, True, False, True, False], iterations=[60,60,60,10,10,4], simplexes=[3,3,3,1.5,1.5,1], w_translations=0.2, linearparams = {'type' : "lsq12", 'simplex' : 1, 'step' : 1}, defaultDir="tmp"): self.p = Pipeline() for b in blurs: #MF TODO: -1 case is also handled in blur. Need here for addStage. #Fix this redundancy and/or better design? if b != -1: tblur = ma.blur(templatePipeFH, b, gradient=True) iblur = ma.blur(inputPipeFH, b, gradient=True) self.p.addStage(tblur) self.p.addStage(iblur) # Do standard LSQ12 alignment prior to non-linear stages lsq12reg = lsq12.LSQ12(inputPipeFH, templatePipeFH, defaultDir=defaultDir) self.p.addPipeline(lsq12reg.p) # create the nonlinear registrations for i in range(len(steps)): """For the final stage, make sure the output directory is transforms.""" if i == (len(steps) - 1): defaultDir = "transforms" nlinStage = ma.minctracc(inputPipeFH, templatePipeFH, defaultDir=defaultDir, blur=blurs[i], gradient=gradients[i], iterations=iterations[i], step=steps[i], similarity=0.8, w_translations=w_translations, simplex=simplexes[i]) self.p.addStage(nlinStage)
def MAGeTRegister(inputFH, templateFH, regMethod, name="initial", createMask=False, lsq12_protocol=None, nlin_protocol=None): p = Pipeline() if createMask: defaultDir = "tmp" else: defaultDir = "transforms" if regMethod == "minctracc": sp = HierarchicalMinctracc(inputFH, templateFH, lsq12_protocol=lsq12_protocol, nlin_protocol=nlin_protocol, defaultDir=defaultDir) p.addPipeline(sp.p) elif regMethod == "mincANTS": register = LSQ12ANTSNlin(inputFH, templateFH, lsq12_protocol=lsq12_protocol, nlin_protocol=nlin_protocol, defaultDir=defaultDir) p.addPipeline(register.p) rp = LabelAndFileResampling(inputFH, templateFH, name=name, createMask=createMask) p.addPipeline(rp.p) return (p)
def MAGeTRegister(inputFH, templateFH, regMethod, name="initial", createMask=False, lsq12_protocol=None, nlin_protocol=None): p = Pipeline() if createMask: defaultDir="tmp" else: defaultDir="transforms" if regMethod == "minctracc": sp = HierarchicalMinctracc(inputFH, templateFH, lsq12_protocol=lsq12_protocol, nlin_protocol=nlin_protocol, defaultDir=defaultDir) p.addPipeline(sp.p) elif regMethod == "mincANTS": register = LSQ12ANTSNlin(inputFH, templateFH, lsq12_protocol=lsq12_protocol, nlin_protocol=nlin_protocol, defaultDir=defaultDir) p.addPipeline(register.p) rp = LabelAndFileResampling(inputFH, templateFH, name=name, createMask=createMask) p.addPipeline(rp.p) return(p)
def MAGeTMask(atlases, inputs, numAtlases, regMethod, lsq12_protocol=None, nlin_protocol=None): """ Masking algorithm is as follows: 1. Run HierarchicalMinctracc or mincANTS with mask=True, using masks instead of labels. 2. Do voxel voting to find the best mask. (Or, if single atlas, use that transform) 3. mincMath to multiply original input by mask to get _masked.mnc file (This is done for both atlases and inputs, though for atlases, voxel voting is not required.) 4. Replace lastBasevol with masked version, since once we have created mask, we no longer care about unmasked version. 5. Clear out labels arrays, which were used to keep track of masks, as we want to re-set them for actual labels. Note: All data will be placed in a newly created masking directory to keep it separate from data generated during actual MAGeT. """ p = Pipeline() for atlasFH in atlases: maskDirectoryStructure(atlasFH, masking=True) for inputFH in inputs: maskDirectoryStructure(inputFH, masking=True) for atlasFH in atlases: sp = MAGeTRegister(inputFH, atlasFH, regMethod, name="initial", createMask=True, lsq12_protocol=lsq12_protocol, nlin_protocol=nlin_protocol) p.addPipeline(sp) """ Prior to final masking, set log and tmp directories as they were.""" for atlasFH in atlases: """Retrieve labels for use in new group. Assume only one""" labels = atlasFH.returnLabels(True) maskDirectoryStructure(atlasFH, masking=False) mp = maskFiles(atlasFH, True) p.addPipeline(mp) atlasFH.newGroup() atlasFH.addLabels(labels[0], inputLabel=True) for inputFH in inputs: maskDirectoryStructure(inputFH, masking=False) mp = maskFiles(inputFH, False, numAtlases) p.addPipeline(mp) # this will remove the "inputLabels"; labels that # come directly from the atlas library inputFH.clearLabels(True) # this will remove the "labels"; second generation # labels. I.e. labels from labels from the atlas library inputFH.clearLabels(False) inputFH.newGroup() return (p)
def MAGeTMask(atlases, inputs, numAtlases, regMethod, lsq12_protocol=None, nlin_protocol=None): """ Masking algorithm is as follows: 1. Run HierarchicalMinctracc or mincANTS with mask=True, using masks instead of labels. 2. Do voxel voting to find the best mask. (Or, if single atlas, use that transform) 3. mincMath to multiply original input by mask to get _masked.mnc file (This is done for both atlases and inputs, though for atlases, voxel voting is not required.) 4. Replace lastBasevol with masked version, since once we have created mask, we no longer care about unmasked version. 5. Clear out labels arrays, which were used to keep track of masks, as we want to re-set them for actual labels. Note: All data will be placed in a newly created masking directory to keep it separate from data generated during actual MAGeT. """ p = Pipeline() for atlasFH in atlases: maskDirectoryStructure(atlasFH, masking=True) for inputFH in inputs: maskDirectoryStructure(inputFH, masking=True) for atlasFH in atlases: sp = MAGeTRegister(inputFH, atlasFH, regMethod, name="initial", createMask=True, lsq12_protocol=lsq12_protocol, nlin_protocol=nlin_protocol) p.addPipeline(sp) """ Prior to final masking, set log and tmp directories as they were.""" for atlasFH in atlases: """Retrieve labels for use in new group. Assume only one""" labels = atlasFH.returnLabels(True) maskDirectoryStructure(atlasFH, masking=False) mp = maskFiles(atlasFH, True) p.addPipeline(mp) atlasFH.newGroup() atlasFH.addLabels(labels[0], inputLabel=True) for inputFH in inputs: maskDirectoryStructure(inputFH, masking=False) mp = maskFiles(inputFH, False, numAtlases) p.addPipeline(mp) # this will remove the "inputLabels"; labels that # come directly from the atlas library inputFH.clearLabels(True) # this will remove the "labels"; second generation # labels. I.e. labels from labels from the atlas library inputFH.clearLabels(False) inputFH.newGroup() return(p)
class LSQ12ANTSNlin: """Class that runs a basic LSQ12 registration, followed by a single mincANTS call. Currently used in MAGeT, registration_chain and pairwise_nlin.""" def __init__(self, inputFH, targetFH, lsq12_protocol=None, nlin_protocol=None, subject_matter=None, defaultDir="tmp"): self.p = Pipeline() self.inputFH = inputFH self.targetFH = targetFH self.lsq12_protocol = lsq12_protocol self.nlin_protocol = nlin_protocol self.subject_matter = subject_matter self.defaultDir = defaultDir if ((self.lsq12_protocol == None and self.subject_matter==None) or self.nlin_protocol == None): # always base the resolution to be used on the target for the registrations self.fileRes = rf.returnFinestResolution(self.targetFH) else: self.fileRes = None self.buildPipeline() def buildPipeline(self): # Run lsq12 registration prior to non-linear self.lsq12Params = mp.setLSQ12MinctraccParams(self.fileRes, subject_matter=self.subject_matter, reg_protocol=self.lsq12_protocol) lsq12reg = lsq12.LSQ12(self.inputFH, self.targetFH, blurs=self.lsq12Params.blurs, step=self.lsq12Params.stepSize, gradient=self.lsq12Params.useGradient, simplex=self.lsq12Params.simplex, w_translations=self.lsq12Params.w_translations, defaultDir=self.defaultDir) self.p.addPipeline(lsq12reg.p) #Resample using final LSQ12 transform and reset last base volume. res = ma.mincresample(self.inputFH, self.targetFH, likeFile=self.targetFH, argArray=["-sinc"]) self.p.addStage(res) self.inputFH.setLastBasevol(res.outputFiles[0]) lsq12xfm = self.inputFH.getLastXfm(self.targetFH) #Get registration parameters from nlin protocol, blur and register #Assume a SINGLE generation here. self.nlinParams = mp.setOneGenMincANTSParams(self.fileRes, reg_protocol=self.nlin_protocol) for b in self.nlinParams.blurs: for j in b: #Note that blurs for ANTS params in an array of arrays. if j != -1: self.p.addStage(ma.blur(self.targetFH, j, gradient=True)) self.p.addStage(ma.blur(self.inputFH, j, gradient=True)) sp = ma.mincANTS(self.inputFH, self.targetFH, defaultDir=self.defaultDir, blur=self.nlinParams.blurs[0], gradient=self.nlinParams.gradient[0], similarity_metric=self.nlinParams.similarityMetric[0], weight=self.nlinParams.weight[0], iterations=self.nlinParams.iterations[0], radius_or_histo=self.nlinParams.radiusHisto[0], transformation_model=self.nlinParams.transformationModel[0], regularization=self.nlinParams.regularization[0], useMask=self.nlinParams.useMask[0]) self.p.addStage(sp) nlinXfm = sp.outputFiles[0] #Reset last base volume to original input for future registrations. self.inputFH.setLastBasevol(setToOriginalInput=True) #Concatenate transforms to get final lsq12 + nlin. Register volume handles naming and setting of lastXfm output = self.inputFH.registerVolume(self.targetFH, "transforms") xc = ma.xfmConcat([lsq12xfm, nlinXfm], output, fh.logFromFile(self.inputFH.logDir, output)) self.p.addStage(xc)
class LongitudinalStatsConcatAndResample: """ For each subject: 1. Calculate stats (displacement, absolute jacobians, relative jacobians) between i and i+1 time points 2. Calculate transform from subject to common space (nlinFH) and invert it. For most subjects this will require some amount of transform concatenation. 3. Calculate the stats (displacement, absolute jacobians, relative jacobians) from common space to each timepoint. """ def __init__(self, subjects, timePoint, nlinFH, statsKernels, commonName): self.subjects = subjects self.timePoint = timePoint self.nlinFH = nlinFH self.blurs = [] self.setupBlurs(statsKernels) self.commonName = commonName self.p = Pipeline() self.buildPipeline() def setupBlurs(self, statsKernels): if isinstance(statsKernels, list): self.blurs = statsKernels elif isinstance(statsKernels, str): for i in statsKernels.split(","): self.blurs.append(float(i)) else: print("Improper type of blurring kernels specified for stats calculation: " + str(statsKernels)) sys.exit() def statsCalculation(self, inputFH, targetFH, xfm=None, useChainStats=True): """If useChainStats=True, calculate stats between input and target. This happens for all i to i+1 calcs. If useChainStats=False, calculate stats in the standard way, from target to input, We do this, when we go from the common space to all others. """ if useChainStats: stats = st.CalcChainStats(inputFH, targetFH, self.blurs) else: stats = st.CalcStats(inputFH, targetFH, self.blurs) self.p.addPipeline(stats.p) """If an xfm is specified, resample all to this common space""" if xfm: if not self.nlinFH: likeFH = targetFH else: likeFH = self.nlinFH res = resampleToCommon(xfm, inputFH, stats.statsGroup, self.blurs, likeFH) self.p.addPipeline(res) def statsAndConcat(self, s, i, count, beforeAvg=True): """Construct array to common space for this timepoint. This builds upon arrays from previous calls.""" if beforeAvg: xfm = s[i].getLastXfm(s[i+1]) else: xfm = s[i].getLastXfm(s[i-1]) """Set this transform as last xfm from input to nlin and calculate nlin to s[i] stats""" if self.nlinFH: self.xfmToCommon.insert(0, xfm) """ Concat transforms to get xfmToCommon and calculate statistics Note that inverted transform, which is what we want, is calculated in the statistics module. """ xtc = createBaseName(s[i].transformsDir, s[i].basename + "_to_" + self.commonName + ".xfm") xc = ma.xfmConcat(self.xfmToCommon, xtc, fh.logFromFile(s[i].logDir, xtc)) self.p.addStage(xc) # here in order to visually inspect the alignment with the common # time point, we should resample this subject: inputResampledToCommon = createBaseName(s[i].resampledDir, s[i].basename + "_to_" + self.commonName + ".mnc") logToCommon = fh.logFromFile(s[i].logDir, inputResampledToCommon) resampleCmd = ma.mincresample(s[i], self.nlinFH, likeFile=self.nlinFH, transform=xtc, output=inputResampledToCommon, logFile=logToCommon, argArray=["-sinc"]) self.p.addStage(resampleCmd) s[i].addAndSetXfmToUse(self.nlinFH, xtc) self.statsCalculation(s[i], self.nlinFH, xfm=None, useChainStats=False) else: xtc=None """Calculate i to i+1 stats for all but final timePoint""" if count - i > 1: self.statsCalculation(s[i], s[i+1], xfm=xtc, useChainStats=True) def buildPipeline(self): for subj in self.subjects: s = self.subjects[subj] count = len(s) """Wherever iterative model building was run, the indiv --> nlin xfm is stored in the group with the name "final". We need to use this group for to get the transform and do the stats calculation, and then reset to the current group. Calculate stats first from average to timepoint included in average""" if self.timePoint == -1: # This means that we used the last file for each of the subjects # to create the common average. This will be a variable time # point, so we have to determine it for each of the input files timePointToUse = len(s) - 1 else: timePointToUse = self.timePoint currGroup = s[timePointToUse].currentGroupIndex index = s[timePointToUse].getGroupIndex("final") xfmToNlin = s[timePointToUse].getLastXfm(self.nlinFH, groupIndex=index) if xfmToNlin: self.xfmToCommon = [xfmToNlin] else: self.xfmToCommon = [] if self.nlinFH: s[timePointToUse].currentGroupIndex = index self.statsCalculation(s[timePointToUse], self.nlinFH, xfm=None, useChainStats=False) s[timePointToUse].currentGroupIndex = currGroup """Next: If timepoint included in average is NOT final timepoint, also calculate i to i+1 stats.""" if count - timePointToUse > 1: self.statsCalculation(s[timePointToUse], s[timePointToUse+1], xfm=xfmToNlin, useChainStats=True) if not timePointToUse - 1 < 0: """ Average happened at time point other than first time point. Loop over points prior to average.""" for i in reversed(range(timePointToUse)): self.statsAndConcat(s, i, count, beforeAvg=True) # Loop over points after average. If average is at first time point, this loop # will hit all time points (other than first). If average is at subsequent time # point, it hits all time points not covered previously. # # xfmToCommon (possibly) needs to be reset: if the average time point is not the first time point # then the array xfmToCommon now contains a list of transformations from the first time point to the # average. For instance if the average time point is time point 3, then xfmToCommon now contains: # [ time_0_to_time_1.xfm, time_1_to_time_2.xfm, time_2_to_average_at_time_point2.xfm ] if xfmToNlin: self.xfmToCommon = [xfmToNlin] else: self.xfmToCommon = [] for i in range(timePointToUse + 1, count): self.statsAndConcat(s, i, count, beforeAvg=False)
class FullIterativeLSQ12Nlin: """Does a full iterative LSQ12 and NLIN. Basically iterative model building starting from LSQ6 and without stats at the end. Designed to be called as part of a larger application. Specifying an initModel is optional, all other arguments are mandatory.""" def __init__(self, inputs, dirs, options, avgPrefix=None, initModel=None, fileResolution=None): self.inputs = inputs self.dirs = dirs self.options = options self.avgPrefix = avgPrefix self.initModel = initModel self.nlinFH = None self.providedResolution = fileResolution self.p = Pipeline() self.buildPipeline() def buildPipeline(self): lsq12LikeFH = None resolutionForLSQ12 = None if self.initModel: lsq12LikeFH = self.initModel[0] elif self.options.lsq12_likeFile: lsq12LikeFH = self.options.lsq12_likeFile if lsq12LikeFH == None and self.options.lsq12_subject_matter == None and self.providedResolution == None: print("\nError: the FullIterativeLSQ12Nlin module was called without specifying either an initial model, nor an lsq12_subject_matter. Currently that means that the code can not determine the resolution at which the registrations should be run. Please specify one of the two. Exiting\n") sys.exit() if not (lsq12LikeFH == None): resolutionForLSQ12 = rf.returnFinestResolution(lsq12LikeFH) if resolutionForLSQ12 == None and self.providedResolution == None: print("\nError: the resolution at which the LSQ12 and the NLIN registration should be run could not be determined from either the initial model nor the LSQ12 like file. Please provide the fileResolution to the FullIterativeLSQ12Nlin module. Exiting\n") sys.exit() if resolutionForLSQ12 == None and self.providedResolution: resolutionForLSQ12 = self.providedResolution lsq12module = lsq12.FullLSQ12(self.inputs, self.dirs.lsq12Dir, queue_type=self.options.queue_type, likeFile=lsq12LikeFH, maxPairs=self.options.lsq12_max_pairs, lsq12_protocol=self.options.lsq12_protocol, subject_matter=self.options.lsq12_subject_matter, resolution=resolutionForLSQ12) lsq12module.iterate() self.p.addPipeline(lsq12module.p) self.lsq12Params = lsq12module.lsq12Params if lsq12module.lsq12AvgFH.getMask()== None: if self.initModel: lsq12module.lsq12AvgFH.setMask(self.initModel[0].getMask()) if not self.avgPrefix: self.avgPrefix = self.options.pipeline_name # same as in MBM.py: # for now we can use the same resolution for the NLIN stages as we did for the # LSQ12 stage. At some point we should look into the subject matter option... nlinModule = nlin.initializeAndRunNLIN(self.dirs.lsq12Dir, self.inputs, self.dirs.nlinDir, avgPrefix=self.avgPrefix, createAvg=False, targetAvg=lsq12module.lsq12AvgFH, nlin_protocol=self.options.nlin_protocol, reg_method=self.options.reg_method, resolution=resolutionForLSQ12) self.p.addPipeline(nlinModule.p) self.nlinFH = nlinModule.nlinAverages[-1] self.nlinParams = nlinModule.nlinParams self.initialTarget = nlinModule.initialTarget # Now we need the full transform to go back to LSQ6 space for i in self.inputs: linXfm = lsq12module.lsq12AvgXfms[i] nlinXfm = i.getLastXfm(self.nlinFH) outXfm = st.createOutputFileName(i, nlinXfm, "transforms", "_with_additional.xfm") xc = ma.xfmConcat([linXfm, nlinXfm], outXfm, fh.logFromFile(i.logDir, outXfm)) self.p.addStage(xc) i.addAndSetXfmToUse(self.nlinFH, outXfm)
class HierarchicalMinctracc: """Default HierarchicalMinctracc currently does: 1. A standard three stage LSQ12 alignment. (See defaults for LSQ12 module.) 2. A six generation non-linear minctracc alignment. To override these defaults, lsq12 and nlin protocols may be specified. """ def __init__(self, inputFH, targetFH, lsq12_protocol=None, nlin_protocol=None, includeLinear = True, subject_matter = None, defaultDir="tmp"): self.p = Pipeline() self.inputFH = inputFH self.targetFH = targetFH self.lsq12_protocol = lsq12_protocol self.nlin_protocol = nlin_protocol self.includeLinear = includeLinear self.subject_matter = subject_matter self.defaultDir = defaultDir if ((self.lsq12_protocol == None and self.subject_matter==None) or self.nlin_protocol == None): # the resolution of the registration should be based on the target self.fileRes = rf.returnFinestResolution(self.targetFH) else: self.fileRes = None self.buildPipeline() def buildPipeline(self): # Do LSQ12 alignment prior to non-linear stages if desired if self.includeLinear: self.lsq12Params = mp.setLSQ12MinctraccParams(self.fileRes, subject_matter=self.subject_matter, reg_protocol=self.lsq12_protocol) lsq12reg = lsq12.LSQ12(self.inputFH, self.targetFH, blurs=self.lsq12Params.blurs, step=self.lsq12Params.stepSize, gradient=self.lsq12Params.useGradient, simplex=self.lsq12Params.simplex, w_translations=self.lsq12Params.w_translations, defaultDir=self.defaultDir) self.p.addPipeline(lsq12reg.p) # create the nonlinear registrations self.nlinParams = mp.setNlinMinctraccParams(self.fileRes, reg_protocol=self.nlin_protocol) for b in self.nlinParams.blurs: if b != -1: self.p.addStage(ma.blur(self.inputFH, b, gradient=True)) self.p.addStage(ma.blur(self.targetFH, b, gradient=True)) for i in range(len(self.nlinParams.stepSize)): #For the final stage, make sure the output directory is transforms. if i == (len(self.nlinParams.stepSize) - 1): self.defaultDir = "transforms" nlinStage = ma.minctracc(self.inputFH, self.targetFH, defaultDir=self.defaultDir, blur=self.nlinParams.blurs[i], gradient=self.nlinParams.useGradient[i], iterations=self.nlinParams.iterations[i], step=self.nlinParams.stepSize[i], w_translations=self.nlinParams.w_translations[i], simplex=self.nlinParams.simplex[i], memory=self.nlinParams.memory[i] if self.nlinParams.memory else None, optimization=self.nlinParams.optimization[i]) self.p.addStage(nlinStage)
class initializeAndRunNLIN(object): """Class to setup target average (if needed), instantiate correct version of NLIN class, and run NLIN registration.""" def __init__(self, targetOutputDir, #Output directory for files related to initial target (often _lsq12) inputFiles, nlinDir, avgPrefix, #Prefix for nlin-1.mnc, ... nlin-k.mnc createAvg=True, #True=call mincAvg, False=targetAvg already exists targetAvg=None, #Optional path to initial target - passing name does not guarantee existence targetMask=None, #Optional path to mask for initial target nlin_protocol=None, reg_method=None): self.p = Pipeline() self.targetOutputDir = targetOutputDir self.inputFiles = inputFiles self.nlinDir = nlinDir self.avgPrefix = avgPrefix self.createAvg = createAvg self.targetAvg = targetAvg self.targetMask = targetMask self.nlin_protocol = nlin_protocol self.reg_method = reg_method # setup initialTarget (if needed) and initialize non-linear module self.setupTarget() self.initNlinModule() #iterate through non-linear registration and setup averages self.nlinModule.iterate() self.p.addPipeline(self.nlinModule.p) self.nlinAverages = self.nlinModule.nlinAverages self.nlinParams = self.nlinModule.nlinParams def setupTarget(self): if self.targetAvg: if isinstance(self.targetAvg, str): self.initialTarget = RegistrationPipeFH(self.targetAvg, mask=self.targetMask, basedir=self.targetOutputDir) self.outputAvg = self.targetAvg elif isinstance(self.targetAvg, RegistrationPipeFH): self.initialTarget = self.targetAvg self.outputAvg = self.targetAvg.getLastBasevol() if not self.initialTarget.getMask(): if self.targetMask: self.initialTarget.setMask(self.targetMask) else: print "You have passed a target average that is neither a string nor a file handler: " + str(self.targetAvg) print "Exiting..." else: self.targetAvg = abspath(self.targetOutputDir) + "/" + "initial-target.mnc" self.initialTarget = RegistrationPipeFH(self.targetAvg, mask=self.targetMask, basedir=self.targetOutputDir) self.outputAvg = self.targetAvg if self.createAvg: avg = mincAverage(self.inputFiles, self.initialTarget, output=self.outputAvg, defaultDir=self.targetOutputDir) self.p.addStage(avg) def initNlinModule(self): if self.reg_method=="mincANTS": self.nlinModule = NLINANTS(self.inputFiles, self.initialTarget, self.nlinDir, self.avgPrefix, self.nlin_protocol) elif self.reg_method=="minctracc": self.nlinModule = NLINminctracc(self.inputFiles, self.initialTarget, self.nlinDir, self.avgPrefix, self.nlin_protocol) else: logger.error("Incorrect registration method specified: " + self.reg_method) sys.exit()
class HierarchicalMinctracc: """Default HierarchicalMinctracc currently does: 1. A standard three stage LSQ12 alignment. (See defaults for LSQ12 module.) 2. A six generation non-linear minctracc alignment. To override these defaults, lsq12 and nlin protocols may be specified. """ def __init__(self, inputFH, targetFH, lsq12_protocol=None, nlin_protocol=None, includeLinear = True, subject_matter = None, defaultDir="tmp"): self.p = Pipeline() self.inputFH = inputFH self.targetFH = targetFH self.lsq12_protocol = lsq12_protocol self.nlin_protocol = nlin_protocol self.includeLinear = includeLinear self.subject_matter = subject_matter self.defaultDir = defaultDir if ((self.lsq12_protocol == None and self.subject_matter==None) or self.nlin_protocol == None): self.fileRes = rf.returnFinestResolution(self.inputFH) else: self.fileRes = None self.buildPipeline() def buildPipeline(self): # Do LSQ12 alignment prior to non-linear stages if desired if self.includeLinear: self.lsq12Params = mp.setLSQ12MinctraccParams(self.fileRes, subject_matter=self.subject_matter, reg_protocol=self.lsq12_protocol) lsq12reg = lsq12.LSQ12(self.inputFH, self.targetFH, blurs=self.lsq12Params.blurs, step=self.lsq12Params.stepSize, gradient=self.lsq12Params.useGradient, simplex=self.lsq12Params.simplex, w_translations=self.lsq12Params.w_translations, defaultDir=self.defaultDir) self.p.addPipeline(lsq12reg.p) # create the nonlinear registrations self.nlinParams = mp.setNlinMinctraccParams(self.fileRes, reg_protocol=self.nlin_protocol) for b in self.nlinParams.blurs: if b != -1: self.p.addStage(ma.blur(self.inputFH, b, gradient=True)) self.p.addStage(ma.blur(self.targetFH, b, gradient=True)) for i in range(len(self.nlinParams.stepSize)): #For the final stage, make sure the output directory is transforms. if i == (len(self.nlinParams.stepSize) - 1): self.defaultDir = "transforms" nlinStage = ma.minctracc(self.inputFH, self.targetFH, defaultDir=self.defaultDir, blur=self.nlinParams.blurs[i], gradient=self.nlinParams.useGradient[i], iterations=self.nlinParams.iterations[i], step=self.nlinParams.stepSize[i], w_translations=self.nlinParams.w_translations[i], simplex=self.nlinParams.simplex[i], optimization=self.nlinParams.optimization[i]) self.p.addStage(nlinStage)
class initializeAndRunNLIN(object): """Class to setup target average (if needed), instantiate correct version of NLIN class, and run NLIN registration.""" def __init__( self, targetOutputDir, #Output directory for files related to initial target (often _lsq12) inputFiles, nlinDir, avgPrefix, #Prefix for nlin-1.mnc, ... nlin-k.mnc createAvg=True, #True=call mincAvg, False=targetAvg already exists targetAvg=None, #Optional path to initial target - passing name does not guarantee existence targetMask=None, #Optional path to mask for initial target nlin_protocol=None, reg_method=None): self.p = Pipeline() self.targetOutputDir = targetOutputDir self.inputFiles = inputFiles self.nlinDir = nlinDir self.avgPrefix = avgPrefix self.createAvg = createAvg self.targetAvg = targetAvg self.targetMask = targetMask self.nlin_protocol = nlin_protocol self.reg_method = reg_method # setup initialTarget (if needed) and initialize non-linear module self.setupTarget() self.initNlinModule() #iterate through non-linear registration and setup averages self.nlinModule.iterate() self.p.addPipeline(self.nlinModule.p) self.nlinAverages = self.nlinModule.nlinAverages self.nlinParams = self.nlinModule.nlinParams def setupTarget(self): if self.targetAvg: if isinstance(self.targetAvg, str): self.initialTarget = RegistrationPipeFH( self.targetAvg, mask=self.targetMask, basedir=self.targetOutputDir) self.outputAvg = self.targetAvg elif isinstance(self.targetAvg, RegistrationPipeFH): self.initialTarget = self.targetAvg self.outputAvg = self.targetAvg.getLastBasevol() if not self.initialTarget.getMask(): if self.targetMask: self.initialTarget.setMask(self.targetMask) else: print "You have passed a target average that is neither a string nor a file handler: " + str( self.targetAvg) print "Exiting..." else: self.targetAvg = abspath( self.targetOutputDir) + "/" + "initial-target.mnc" self.initialTarget = RegistrationPipeFH( self.targetAvg, mask=self.targetMask, basedir=self.targetOutputDir) self.outputAvg = self.targetAvg if self.createAvg: avg = mincAverage(self.inputFiles, self.initialTarget, output=self.outputAvg, defaultDir=self.targetOutputDir) self.p.addStage(avg) def initNlinModule(self): if self.reg_method == "mincANTS": self.nlinModule = NLINANTS(self.inputFiles, self.initialTarget, self.nlinDir, self.avgPrefix, self.nlin_protocol) elif self.reg_method == "minctracc": self.nlinModule = NLINminctracc(self.inputFiles, self.initialTarget, self.nlinDir, self.avgPrefix, self.nlin_protocol) else: logger.error("Incorrect registration method specified: " + self.reg_method) sys.exit()
class LSQ12ANTSNlin: """Class that runs a basic LSQ12 registration, followed by a single mincANTS call. Currently used in MAGeT, registration_chain and pairwise_nlin.""" def __init__(self, inputFH, targetFH, lsq12_protocol=None, nlin_protocol=None, subject_matter=None, defaultDir="tmp"): self.p = Pipeline() self.inputFH = inputFH self.targetFH = targetFH self.lsq12_protocol = lsq12_protocol self.nlin_protocol = nlin_protocol self.subject_matter = subject_matter self.defaultDir = defaultDir if ((self.lsq12_protocol == None and self.subject_matter==None) or self.nlin_protocol == None): self.fileRes = rf.returnFinestResolution(self.inputFH) else: self.fileRes = None self.buildPipeline() def buildPipeline(self): # Run lsq12 registration prior to non-linear self.lsq12Params = mp.setLSQ12MinctraccParams(self.fileRes, subject_matter=self.subject_matter, reg_protocol=self.lsq12_protocol) lsq12reg = lsq12.LSQ12(self.inputFH, self.targetFH, blurs=self.lsq12Params.blurs, step=self.lsq12Params.stepSize, gradient=self.lsq12Params.useGradient, simplex=self.lsq12Params.simplex, w_translations=self.lsq12Params.w_translations, defaultDir=self.defaultDir) self.p.addPipeline(lsq12reg.p) #Resample using final LSQ12 transform and reset last base volume. res = ma.mincresample(self.inputFH, self.targetFH, likeFile=self.targetFH, argArray=["-sinc"]) self.p.addStage(res) self.inputFH.setLastBasevol(res.outputFiles[0]) lsq12xfm = self.inputFH.getLastXfm(self.targetFH) #Get registration parameters from nlin protocol, blur and register #Assume a SINGLE generation here. self.nlinParams = mp.setOneGenMincANTSParams(self.fileRes, reg_protocol=self.nlin_protocol) for b in self.nlinParams.blurs: for j in b: #Note that blurs for ANTS params in an array of arrays. if j != -1: self.p.addStage(ma.blur(self.targetFH, j, gradient=True)) self.p.addStage(ma.blur(self.inputFH, j, gradient=True)) sp = ma.mincANTS(self.inputFH, self.targetFH, defaultDir=self.defaultDir, blur=self.nlinParams.blurs[0], gradient=self.nlinParams.gradient[0], similarity_metric=self.nlinParams.similarityMetric[0], weight=self.nlinParams.weight[0], iterations=self.nlinParams.iterations[0], radius_or_histo=self.nlinParams.radiusHisto[0], transformation_model=self.nlinParams.transformationModel[0], regularization=self.nlinParams.regularization[0], useMask=self.nlinParams.useMask[0]) self.p.addStage(sp) nlinXfm = sp.outputFiles[0] #Reset last base volume to original input for future registrations. self.inputFH.setLastBasevol(setToOriginalInput=True) #Concatenate transforms to get final lsq12 + nlin. Register volume handles naming and setting of lastXfm output = self.inputFH.registerVolume(self.targetFH, "transforms") xc = ma.xfmConcat([lsq12xfm, nlinXfm], output, fh.logFromFile(self.inputFH.logDir, output)) self.p.addStage(xc)
class LongitudinalStatsConcatAndResample: """ For each subject: 1. Calculate stats (displacement, absolute jacobians, relative jacobians) between i and i+1 time points 2. Calculate transform from subject to common space (nlinFH) and invert it. For most subjects this will require some amount of transform concatenation. 3. Calculate the stats (displacement, absolute jacobians, relative jacobians) from common space to each timepoint. """ def __init__(self, subjects, timePoint, nlinFH, statsKernels, commonName): self.subjects = subjects self.timePoint = timePoint self.nlinFH = nlinFH self.blurs = [] self.setupBlurs(statsKernels) self.commonName = commonName self.p = Pipeline() self.buildPipeline() def setupBlurs(self, statsKernels): if isinstance(statsKernels, list): self.blurs = statsKernels elif isinstance(statsKernels, str): for i in statsKernels.split(","): self.blurs.append(float(i)) else: print "Improper type of blurring kernels specified for stats calculation: " + str(statsKernels) sys.exit() def statsCalculation(self, inputFH, targetFH, xfm=None, useChainStats=True): """If useChainStats=True, calculate stats between input and target. This happens for all i to i+1 calcs. If useChainStats=False, calculate stats in the standard way, from target to input, We do this, when we go from the common space to all others. """ if useChainStats: stats = st.CalcChainStats(inputFH, targetFH, self.blurs) else: stats = st.CalcStats(inputFH, targetFH, self.blurs) self.p.addPipeline(stats.p) """If an xfm is specified, resample all to this common space""" if xfm: if not self.nlinFH: likeFH = targetFH else: likeFH = self.nlinFH res = resampleToCommon(xfm, inputFH, stats.statsGroup, self.blurs, likeFH) self.p.addPipeline(res) def statsAndConcat(self, s, i, count, beforeAvg=True): """Construct array to common space for this timepoint. This builds upon arrays from previous calls.""" if beforeAvg: xfm = s[i].getLastXfm(s[i+1]) else: xfm = s[i].getLastXfm(s[i-1]) """Set this transform as last xfm from input to nlin and calculate nlin to s[i] stats""" if self.nlinFH: self.xfmToCommon.insert(0, xfm) """ Concat transforms to get xfmToCommon and calculate statistics Note that inverted transform, which is what we want, is calculated in the statistics module. """ xtc = fh.createBaseName(s[i].transformsDir, s[i].basename + "_to_" + self.commonName + ".xfm") xc = ma.xfmConcat(self.xfmToCommon, xtc, fh.logFromFile(s[i].logDir, xtc)) self.p.addStage(xc) s[i].addAndSetXfmToUse(self.nlinFH, xtc) self.statsCalculation(s[i], self.nlinFH, xfm=None, useChainStats=False) else: xtc=None """Calculate i to i+1 stats for all but final timePoint""" if count - i > 1: self.statsCalculation(s[i], s[i+1], xfm=xtc, useChainStats=True) def buildPipeline(self): for subj in self.subjects: s = self.subjects[subj] count = len(s) """Wherever iterative model building was run, the indiv --> nlin xfm is stored in the group with the name "final". We need to use this group for to get the transform and do the stats calculation, and then reset to the current group. Calculate stats first from average to timepoint included in average""" currGroup = s[self.timePoint].currentGroupIndex index = s[self.timePoint].getGroupIndex("final") xfmToNlin = s[self.timePoint].getLastXfm(self.nlinFH, groupIndex=index) if xfmToNlin: self.xfmToCommon = [xfmToNlin] else: self.xfmToCommon = [] if self.nlinFH: s[self.timePoint].currentGroupIndex = index self.statsCalculation(s[self.timePoint], self.nlinFH, xfm=None, useChainStats=False) s[self.timePoint].currentGroupIndex = currGroup """Next: If timepoint included in average is NOT final timepoint, also calculate i to i+1 stats.""" if count - self.timePoint > 1: self.statsCalculation(s[self.timePoint], s[self.timePoint+1], xfm=xfmToNlin, useChainStats=True) if not self.timePoint - 1 < 0: """ Average happened at time point other than first time point. Loop over points prior to average.""" for i in reversed(range(self.timePoint)): self.statsAndConcat(s, i, count, beforeAvg=True) """ Loop over points after average. If average is at first time point, this loop will hit all time points (other than first). If average is at subsequent time point, it hits all time points not covered previously. xfmToCommon needs to be reset.""" if xfmToNlin: self.xfmToCommon = [xfmToNlin] else: self.xfmToCommon = [] for i in range(self.timePoint + 1, count): self.statsAndConcat(s, i, count, beforeAvg=False)
class FullLSQ12(object): """ This class takes an array of input file handlers along with an optionally specified protocol and does 12-parameter alignment and averaging of all of the pairs. Required arguments: inputArray = array of file handlers to be registered outputDir = an output directory to place the final average from this registration Optional arguments include: --likeFile = a file handler that can be used as a likeFile for resampling each input into the final lsq12 space. If none is specified, the input will be used --maxPairs = maximum number of pairs to register. If this pair is specified, then each subject will only be registered to a subset of the other subjects. --lsq2_protocol = an optional csv file to specify a protocol that overrides the defaults. --subject_matter = currently supports "mousebrain". If this is specified, the parameter for the minctracc registrations are set based on defaults for mouse brains instead of the file resolution. """ def __init__(self, inputArray, outputDir, likeFile=None, maxPairs=None, lsq12_protocol=None, subject_matter=None): self.p = Pipeline() """Initial inputs should be an array of fileHandlers with lastBasevol in lsq12 space""" self.inputs = inputArray """Output directory should be _nlin """ self.lsq12Dir = outputDir """likeFile for resampling""" self.likeFile=likeFile """Maximum number of pairs to calculate""" self.maxPairs = maxPairs """Final lsq12 average""" self.lsq12Avg = None """Final lsq12 average file handler (e.g. the file handler associated with lsq12Avg)""" self.lsq12AvgFH = None """ Dictionary of lsq12 average transforms, which will include one per input. Key is input file handler and value is string pointing to final average lsq12 transform for that particular subject. These xfms may be used subsequently for statistics calculations. """ self.lsq12AvgXfms = {} # what sort of subject matter do we deal with? self.subject_matter = subject_matter """Create the blurring resolution from the file resolution""" try: self.fileRes = rf.getFinestResolution(self.inputs[0]) except: # if this fails (because file doesn't exist when pipeline is created) grab from # initial input volume, which should exist. self.fileRes = rf.getFinestResolution(self.inputs[0].inputFileName) """ Similarly to LSQ6 and NLIN modules, an optional SEMI-COLON delimited csv may be specified to override the default registration protocol. An example protocol may be found in: Note that if no protocol is specified, then defaults will be used. Based on the length of these parameter arrays, the number of generations is set. """ self.defaultParams() if lsq12_protocol: self.setParams(lsq12_protocol) self.generations = self.getGenerations() # Create new lsq12 group for each input prior to registration for i in range(len(self.inputs)): self.inputs[i].newGroup(groupName="lsq12") def defaultParams(self): """ Default minctracc parameters based on resolution of file, unless a particular subject matter was provided """ blurfactors = [ 5, 10.0/3.0, 2.5] stepfactors = [50.0/3.0, 25.0/3.0, 5.5] simplexfactors = [ 50, 25, 50.0/3.0] if(self.subject_matter == "mousebrain"): # the default for mouse brains should be: # blurs: 0.3 0.2 0.15 # steps: 1 0.5 0.333 # simplex: 3 1.5 1 self.blurs = [0.3, 0.2, 0.15] self.stepSize= [1, 0.5, 1.0/3.0] self.simplex= [3, 1.5, 1] else: self.blurs = [i * self.fileRes for i in blurfactors] self.stepSize=[i * self.fileRes for i in stepfactors] self.simplex=[i * self.fileRes for i in simplexfactors] self.useGradient=[False,True,False] def setParams(self, lsq12_protocol): """Set parameters from specified protocol""" """Read parameters into array from csv.""" inputCsv = open(abspath(lsq12_protocol), 'rb') csvReader = csv.reader(inputCsv, delimiter=';', skipinitialspace=True) params = [] for r in csvReader: params.append(r) """initialize arrays """ self.blurs = [] self.stepSize = [] self.useGradient = [] self.simplex = [] """Parse through rows and assign appropriate values to each parameter array. Everything is read in as strings, but in some cases, must be converted to floats, booleans or gradients. """ for p in params: if p[0]=="blur": """Blurs must be converted to floats.""" for i in range(1,len(p)): self.blurs.append(float(p[i])) elif p[0]=="step": """Steps are strings but must be converted to a float.""" for i in range(1,len(p)): self.stepSize.append(float(p[i])) elif p[0]=="gradient": """Gradients must be converted to bools.""" for i in range(1,len(p)): if p[i]=="True" or p[i]=="TRUE": self.useGradient.append(True) elif p[i]=="False" or p[i]=="FALSE": self.useGradient.append(False) elif p[0]=="simplex": """Simplex must be converted to an int.""" for i in range(1,len(p)): self.simplex.append(int(p[i])) else: print "Improper parameter specified for minctracc protocol: " + str(p[0]) print "Exiting..." sys.exit() def getGenerations(self): arrayLength = len(self.blurs) errorMsg = "Array lengths in lsq12 minctracc protocol do not match." if (len(self.stepSize) != arrayLength or len(self.useGradient) != arrayLength or len(self.simplex) != arrayLength): print errorMsg raise else: return arrayLength def iterate(self): if not self.maxPairs: xfmsToAvg = {} lsq12ResampledFiles = {} for inputFH in self.inputs: """Create an array of xfms, to compute an average lsq12 xfm for each input""" xfmsToAvg[inputFH] = [] for targetFH in self.inputs: if inputFH != targetFH: lsq12 = LSQ12(inputFH, targetFH, self.blurs, self.stepSize, self.useGradient, self.simplex) self.p.addPipeline(lsq12.p) xfmsToAvg[inputFH].append(inputFH.getLastXfm(targetFH)) """Create average xfm for inputFH using xfmsToAvg array""" cmd = ["xfmavg"] for i in range(len(xfmsToAvg[inputFH])): cmd.append(InputFile(xfmsToAvg[inputFH][i])) avgXfmOutput = createBaseName(inputFH.transformsDir, inputFH.basename + "-avg-lsq12.xfm") cmd.append(OutputFile(avgXfmOutput)) xfmavg = CmdStage(cmd) xfmavg.setLogFile(LogFile(logFromFile(inputFH.logDir, avgXfmOutput))) self.p.addStage(xfmavg) self.lsq12AvgXfms[inputFH] = avgXfmOutput """ resample brain and add to array for mincAveraging""" if not self.likeFile: likeFile=inputFH else: likeFile=self.likeFile rslOutput = createBaseName(inputFH.resampledDir, inputFH.basename + "-resampled-lsq12.mnc") res = ma.mincresample(inputFH, inputFH, transform=avgXfmOutput, likeFile=likeFile, output=rslOutput, argArray=["-sinc"]) self.p.addStage(res) lsq12ResampledFiles[inputFH] = rslOutput """ After all registrations complete, setLastBasevol for each subject to be resampled file in lsq12 space. We can then call mincAverage on fileHandlers, as it will use the lastBasevol for each by default.""" for inputFH in self.inputs: inputFH.setLastBasevol(lsq12ResampledFiles[inputFH]) """ mincAverage all resampled brains and put in lsq12Directory""" self.lsq12Avg = abspath(self.lsq12Dir) + "/" + basename(self.lsq12Dir) + "-pairs.mnc" self.lsq12AvgFH = RegistrationPipeFH(self.lsq12Avg, basedir=self.lsq12Dir) avg = ma.mincAverage(self.inputs, self.lsq12AvgFH, output=self.lsq12Avg, defaultDir=self.lsq12Dir) self.p.addStage(avg) else: print "Registration using a specified number of max pairs not yet working. Check back soon!" sys.exit()
class FullLSQ12(object): """ This class takes an array of input file handlers along with an optionally specified protocol and does 12-parameter alignment and averaging of all of the pairs. Required arguments: inputArray = array of file handlers to be registered outputDir = an output directory to place the final average from this registration Optional arguments include: --likeFile = a file handler that can be used as a likeFile for resampling each input into the final lsq12 space. If none is specified, the input will be used --maxPairs = maximum number of pairs to register. If this pair is specified, then each subject will only be registered to a subset of the other subjects. --lsq2_protocol = an optional csv file to specify a protocol that overrides the defaults. --subject_matter = currently supports "mousebrain". If this is specified, the parameter for the minctracc registrations are set based on defaults for mouse brains instead of the file resolution. """ def __init__(self, inputArray, outputDir, likeFile=None, maxPairs=None, lsq12_protocol=None, subject_matter=None): self.p = Pipeline() """Initial inputs should be an array of fileHandlers with lastBasevol in lsq12 space""" self.inputs = inputArray """Output directory should be _nlin """ self.lsq12Dir = outputDir """likeFile for resampling""" self.likeFile = likeFile """Maximum number of pairs to calculate""" self.maxPairs = maxPairs """Final lsq12 average""" self.lsq12Avg = None """Final lsq12 average file handler (e.g. the file handler associated with lsq12Avg)""" self.lsq12AvgFH = None """ Dictionary of lsq12 average transforms, which will include one per input. Key is input file handler and value is string pointing to final average lsq12 transform for that particular subject. These xfms may be used subsequently for statistics calculations. """ self.lsq12AvgXfms = {} """Create the blurring resolution from the file resolution""" if (subject_matter == None and lsq12_protocol == None): self.fileRes = rf.returnFinestResolution(self.inputs[0]) else: self.fileRes = None """"Set up parameter array""" self.lsq12Params = mp.setLSQ12MinctraccParams( self.fileRes, subject_matter=subject_matter, reg_protocol=lsq12_protocol) self.blurs = self.lsq12Params.blurs self.stepSize = self.lsq12Params.stepSize self.useGradient = self.lsq12Params.useGradient self.simplex = self.lsq12Params.simplex self.w_translations = self.lsq12Params.w_translations self.generations = self.lsq12Params.generations # Create new lsq12 group for each input prior to registration for i in range(len(self.inputs)): self.inputs[i].newGroup(groupName="lsq12") def iterate(self): if not self.maxPairs: xfmsToAvg = {} lsq12ResampledFiles = {} for inputFH in self.inputs: """Create an array of xfms, to compute an average lsq12 xfm for each input""" xfmsToAvg[inputFH] = [] for targetFH in self.inputs: if inputFH != targetFH: lsq12 = LSQ12(inputFH, targetFH, blurs=self.blurs, step=self.stepSize, gradient=self.useGradient, simplex=self.simplex, w_translations=self.w_translations) self.p.addPipeline(lsq12.p) xfmsToAvg[inputFH].append(inputFH.getLastXfm(targetFH)) """Create average xfm for inputFH using xfmsToAvg array""" cmd = ["xfmavg"] for i in range(len(xfmsToAvg[inputFH])): cmd.append(InputFile(xfmsToAvg[inputFH][i])) avgXfmOutput = createBaseName( inputFH.transformsDir, inputFH.basename + "-avg-lsq12.xfm") cmd.append(OutputFile(avgXfmOutput)) xfmavg = CmdStage(cmd) xfmavg.setLogFile( LogFile(logFromFile(inputFH.logDir, avgXfmOutput))) self.p.addStage(xfmavg) self.lsq12AvgXfms[inputFH] = avgXfmOutput """ resample brain and add to array for mincAveraging""" if not self.likeFile: likeFile = inputFH else: likeFile = self.likeFile rslOutput = createBaseName( inputFH.resampledDir, inputFH.basename + "-resampled-lsq12.mnc") res = ma.mincresample(inputFH, inputFH, transform=avgXfmOutput, likeFile=likeFile, output=rslOutput, argArray=["-sinc"]) self.p.addStage(res) lsq12ResampledFiles[inputFH] = rslOutput """ After all registrations complete, setLastBasevol for each subject to be resampled file in lsq12 space. We can then call mincAverage on fileHandlers, as it will use the lastBasevol for each by default.""" for inputFH in self.inputs: inputFH.setLastBasevol(lsq12ResampledFiles[inputFH]) """ mincAverage all resampled brains and put in lsq12Directory""" self.lsq12Avg = abspath(self.lsq12Dir) + "/" + basename( self.lsq12Dir) + "-pairs.mnc" self.lsq12AvgFH = RegistrationPipeFH(self.lsq12Avg, basedir=self.lsq12Dir) avg = ma.mincAverage(self.inputs, self.lsq12AvgFH, output=self.lsq12Avg, defaultDir=self.lsq12Dir) self.p.addStage(avg) else: print "Registration using a specified number of max pairs not yet working. Check back soon!" sys.exit()
class FullIterativeLSQ12Nlin: """Does a full iterative LSQ12 and NLIN. Basically iterative model building starting from LSQ6 and without stats at the end. Designed to be called as part of a larger application. Specifying an initModel is optional, all other arguments are mandatory.""" def __init__(self, inputs, dirs, options, avgPrefix=None, initModel=None): self.inputs = inputs self.dirs = dirs self.options = options self.avgPrefix = avgPrefix self.initModel = initModel self.nlinFH = None self.p = Pipeline() self.buildPipeline() def buildPipeline(self): lsq12LikeFH = None resolutionForLSQ12 = None if self.initModel: lsq12LikeFH = self.initModel[0] elif self.options.lsq12_likeFile: lsq12LikeFH = self.options.lsq12_likeFile if lsq12LikeFH == None and self.options.lsq12_subject_matter == None: print "\nError: the FullIterativeLSQ12Nlin module was called without specifying either an initial model, nor an lsq12_subject_matter. Currently that means that the code can not determine the resolution at which the registrations should be run. Please specify one of the two. Exiting\n" sys.exit() if not (lsq12LikeFH == None): resolutionForLSQ12 = rf.returnFinestResolution(lsq12LikeFH) lsq12module = lsq12.FullLSQ12( self.inputs, self.dirs.lsq12Dir, likeFile=lsq12LikeFH, maxPairs=self.options.lsq12_max_pairs, lsq12_protocol=self.options.lsq12_protocol, subject_matter=self.options.lsq12_subject_matter, resolution=resolutionForLSQ12) lsq12module.iterate() self.p.addPipeline(lsq12module.p) self.lsq12Params = lsq12module.lsq12Params if lsq12module.lsq12AvgFH.getMask() == None: if self.initModel: lsq12module.lsq12AvgFH.setMask(self.initModel[0].getMask()) if not self.avgPrefix: self.avgPrefix = self.options.pipeline_name # same as in MBM.py: # for now we can use the same resolution for the NLIN stages as we did for the # LSQ12 stage. At some point we should look into the subject matter option... nlinModule = nlin.initializeAndRunNLIN( self.dirs.lsq12Dir, self.inputs, self.dirs.nlinDir, avgPrefix=self.avgPrefix, createAvg=False, targetAvg=lsq12module.lsq12AvgFH, nlin_protocol=self.options.nlin_protocol, reg_method=self.options.reg_method, resolution=resolutionForLSQ12) self.p.addPipeline(nlinModule.p) self.nlinFH = nlinModule.nlinAverages[-1] self.nlinParams = nlinModule.nlinParams self.initialTarget = nlinModule.initialTarget # Now we need the full transform to go back to LSQ6 space for i in self.inputs: linXfm = lsq12module.lsq12AvgXfms[i] nlinXfm = i.getLastXfm(self.nlinFH) outXfm = st.createOutputFileName(i, nlinXfm, "transforms", "_with_additional.xfm") xc = ma.xfmConcat([linXfm, nlinXfm], outXfm, fh.logFromFile(i.logDir, outXfm)) self.p.addStage(xc) i.addAndSetXfmToUse(self.nlinFH, outXfm)