def BuildHGM(cf): """Worker for running Hierarchical Geodesic Model (HGM) n for group geodesic estimation on a subset of individuals. Runs HGM on this subset sequentially. The variations retuned are summed up to get update for all individuals""" size = Compute.GetMPIInfo()['size'] rank = Compute.GetMPIInfo()['rank'] name = Compute.GetMPIInfo()['name'] localRank = Compute.GetMPIInfo()['local_rank'] nodename = socket.gethostname() # prepare output directory common.Mkdir_p(os.path.dirname(cf.io.outputPrefix)) # just one reporter process on each node isReporter = rank == 0 cf.study.numSubjects = len(cf.study.subjectIntercepts) if isReporter: # Output loaded config if cf.io.outputPrefix is not None: cfstr = Config.ConfigToYAML(HGMConfigSpec, cf) with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f: f.write(cfstr) #common.DebugHere() # if MPI check if processes are greater than number of subjects. it is okay if there are more subjects than processes if cf.compute.useMPI and (cf.study.numSubjects < cf.compute.numProcesses): raise Exception("Please don't use more processes " + "than total number of individuals") # subdivide data, create subsets for this thread to work on nodeSubjectIds = cf.study.subjectIds[rank::cf.compute.numProcesses] nodeIntercepts = cf.study.subjectIntercepts[rank::cf.compute.numProcesses] nodeSlopes = cf.study.subjectSlopes[rank::cf.compute.numProcesses] nodeBaselineTimes = cf.study.subjectBaselineTimes[rank::cf.compute. numProcesses] sys.stdout.write( "This is process %d of %d with name: %s on machinename: %s and local rank: %d.\nnodeIntercepts: %s\n nodeSlopes: %s\n nodeBaselineTimes: %s\n" % (rank, size, name, nodename, localRank, nodeIntercepts, nodeSlopes, nodeBaselineTimes)) # mem type is determined by whether or not we're using CUDA mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST # load data in memory # load intercepts J = [ common.LoadITKImage(f, mType) if isinstance(f, str) else f for f in nodeIntercepts ] # load slopes n = [ common.LoadITKField(f, mType) if isinstance(f, str) else f for f in nodeSlopes ] # get imGrid from data imGrid = J[0].grid() # create time array with checkpointing info for group geodesic (t, Jind, gCpinds) = HGMSetUpTimeArray(cf.optim.nTimeStepsGroup, nodeBaselineTimes, 0.0000001) tdiscGroup = CAvmHGMCommon.HGMSetupTimeDiscretizationGroup( t, J, n, Jind, gCpinds, mType, nodeSubjectIds) # create time array with checkpointing info for residual geodesic (s, scratchInd, rCpinds) = HGMSetUpTimeArray(cf.optim.nTimeStepsResidual, [1.0], 0.0000001) tdiscResidual = CAvmHGMCommon.HGMSetupTimeDiscretizationResidual( s, rCpinds, imGrid, mType) # create group state and residual state groupState = CAvmHGMCommon.HGMGroupState( imGrid, mType, cf.vectormomentum.diffOpParamsGroup[0], cf.vectormomentum.diffOpParamsGroup[1], cf.vectormomentum.diffOpParamsGroup[2], t, cf.optim.NIterForInverse, cf.vectormomentum.varIntercept, cf.vectormomentum.varSlope, cf.vectormomentum.varInterceptReg, cf.optim.stepSizeGroup, integMethod=cf.optim.integMethodGroup) #ca.Copy(groupState.I0, common.LoadITKImage('/usr/sci/projects/ADNI/nikhil/software/vectormomentumtest/TestData/FlowerData/Longitudinal/GroupGeodesic/I0.mhd', mType)) # note that residual state is treated a scratch variable in this algorithm and reused for computing residual geodesics of multiple individual residualState = CAvmHGMCommon.HGMResidualState( None, None, imGrid, mType, cf.vectormomentum.diffOpParamsResidual[0], cf.vectormomentum.diffOpParamsResidual[1], cf.vectormomentum.diffOpParamsResidual[2], s, cf.optim.NIterForInverse, cf.vectormomentum.varIntercept, cf.vectormomentum.varSlope, cf.vectormomentum.varInterceptReg, cf.optim.stepSizeResidual, integMethod=cf.optim.integMethodResidual) # start up the memory manager for scratch variables ca.ThreadMemoryManager.init(imGrid, mType, 0) # need some host memory in np array format for MPI reductions if cf.compute.useMPI: mpiImageBuff = None if mType == ca.MEM_HOST else ca.Image3D( imGrid, ca.MEM_HOST) mpiFieldBuff = None if mType == ca.MEM_HOST else ca.Field3D( imGrid, ca.MEM_HOST) for i in range(len(groupState.t) - 1, -1, -1): if tdiscGroup[i].J is not None: indx_last_individual = i break ''' # initial template image ca.SetMem(groupState.I0, 0.0) tmp = ca.ManagedImage3D(imGrid, mType) for tdisc in tdiscGroup: if tdisc.J is not None: ca.Copy(tmp, tdisc.J) groupState.I0 += tmp del tmp if cf.compute.useMPI: Compute.Reduce(groupState.I0, mpiImageBuff) # divide by total num subjects groupState.I0 /= cf.study.numSubjects ''' # run the loop for it in range(cf.optim.Niter): # compute HGM variation for group HGMGroupVariation(groupState, tdiscGroup, residualState, tdiscResidual, cf.io.outputPrefix, rank, it) common.CheckCUDAError("Error after HGM iteration") # compute gradient for momenta (m is used as scratch) # if there are multiple nodes we'll need to sum across processes now if cf.compute.useMPI: # do an MPI sum Compute.Reduce(groupState.sumSplatI, mpiImageBuff) Compute.Reduce(groupState.sumJac, mpiImageBuff) Compute.Reduce(groupState.madj, mpiFieldBuff) # also sum up energies of other nodes # intercept Eintercept = np.array([groupState.EnergyHistory[-1][1]]) mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, Eintercept, op=mpi4py.MPI.SUM) groupState.EnergyHistory[-1][1] = Eintercept[0] Eslope = np.array([groupState.EnergyHistory[-1][2]]) mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, Eslope, op=mpi4py.MPI.SUM) groupState.EnergyHistory[-1][2] = Eslope[0] ca.Copy(groupState.m, groupState.m0) groupState.diffOp.applyInverseOperator(groupState.m) ca.Sub_I(groupState.m, groupState.madj) #groupState.diffOp.applyOperator(groupState.m) # now take gradient step in momenta for group if cf.optim.method == 'FIXEDGD': # take fixed stepsize gradient step ca.Add_MulC_I(groupState.m0, groupState.m, -cf.optim.stepSizeGroup) else: raise Exception("Unknown optimization scheme: " + cf.optim.method) # end if # now divide to get the new base image for group ca.Div(groupState.I0, groupState.sumSplatI, groupState.sumJac) # keep track of energy in this iteration if isReporter and cf.io.plotEvery > 0 and (( (it + 1) % cf.io.plotEvery == 0) or (it == cf.optim.Niter - 1)): HGMPlots(cf, groupState, tdiscGroup, residualState, tdiscResidual, indx_last_individual, writeOutput=True) if isReporter: (VEnergy, IEnergy, SEnergy) = groupState.EnergyHistory[-1] print datetime.datetime.now().time( ), " Iter", it, "of", cf.optim.Niter, ":", VEnergy + IEnergy + SEnergy, '(Total) = ', VEnergy, '(Vector) + ', IEnergy, '(Intercept) + ', SEnergy, '(Slope)' # write output images and fields HGMWriteOutput(cf, groupState, tdiscGroup, isReporter)
def BuildAtlas(cf): """Worker for running Atlas construction on a subset of individuals. Runs Atlas on this subset sequentially. The variations retuned are summed up to get update for all individuals """ localRank = Compute.GetMPIInfo()['local_rank'] rank = Compute.GetMPIInfo()['rank'] # prepare output directory common.Mkdir_p(os.path.dirname(cf.io.outputPrefix)) # just one reporter process on each node isReporter = rank == 0 cf.study.numSubjects = len(cf.study.subjectImages) if isReporter: # Output loaded config if cf.io.outputPrefix is not None: cfstr = Config.ConfigToYAML(AtlasConfigSpec, cf) with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f: f.write(cfstr) #common.DebugHere() # if MPI check if processes are greater than number of subjects. it is okay if there are more subjects than processes if cf.compute.useMPI and (cf.study.numSubjects < cf.compute.numProcesses): raise Exception("Please don't use more processes " + "than total number of individuals") # subdivide data, create subsets for this thread to work on nodeSubjectIds = cf.study.subjectIds[rank::cf.compute.numProcesses] nodeImages = cf.study.subjectImages[rank::cf.compute.numProcesses] nodeWeights = cf.study.subjectWeights[rank::cf.compute.numProcesses] numLocalSubjects = len(nodeImages) print 'rank:', rank, ', localRank:', localRank, ', nodeImages:', nodeImages, ', nodeWeights:', nodeWeights # mem type is determined by whether or not we're using CUDA mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST # load data in memory # load intercepts J_array = [ common.LoadITKImage(f, mType) if isinstance(f, str) else f for f in nodeImages ] # get imGrid from data imGrid = J_array[0].grid() # atlas image atlas = ca.Image3D(imGrid, mType) # allocate memory to store only the initial momenta for each individual in this thread m_array = [ca.Field3D(imGrid, mType) for i in range(numLocalSubjects)] # allocate only one copy of scratch memory to be reused for each local individual in this thread in loop p = WarpVariables(imGrid, mType, cf.vectormomentum.diffOpParams[0], cf.vectormomentum.diffOpParams[1], cf.vectormomentum.diffOpParams[2], cf.optim.NIterForInverse, cf.vectormomentum.sigma, cf.optim.stepSize, integMethod=cf.optim.integMethod) # memory to accumulate numerators and denominators for atlas from # local individuals which will be summed across MPI threads sumSplatI = ca.Image3D(imGrid, mType) sumJac = ca.Image3D(imGrid, mType) # start up the memory manager for scratch variables ca.ThreadMemoryManager.init(imGrid, mType, 0) # need some host memory in np array format for MPI reductions if cf.compute.useMPI: mpiImageBuff = None if mType == ca.MEM_HOST else ca.Image3D( imGrid, ca.MEM_HOST) t = [ x * 1. / (cf.optim.nTimeSteps) for x in range(cf.optim.nTimeSteps + 1) ] cpinds = range(1, len(t)) msmtinds = [ len(t) - 2 ] # since t=0 is not in cpinds, thats just identity deformation so not checkpointed cpstates = [(ca.Field3D(imGrid, mType), ca.Field3D(imGrid, mType)) for idx in cpinds] gradAtMsmts = [ca.Image3D(imGrid, mType) for idx in msmtinds] EnergyHistory = [] # TODO: better initializations # initialize atlas image with zeros. ca.SetMem(atlas, 0.0) # initialize momenta with zeros for m0_individual in m_array: ca.SetMem(m0_individual, 0.0) ''' # initial template image ca.SetMem(groupState.I0, 0.0) tmp = ca.ManagedImage3D(imGrid, mType) for tdisc in tdiscGroup: if tdisc.J is not None: ca.Copy(tmp, tdisc.J) groupState.I0 += tmp del tmp if cf.compute.useMPI: Compute.Reduce(groupState.I0, mpiImageBuff) # divide by total num subjects groupState.I0 /= cf.study.numSubjects ''' # preprocessinput # assign atlas reference to p.I0. This reference will not change. p.I0 = atlas # run the loop for it in range(cf.optim.Niter): # run one iteration of warp for each individual and update # their own initial momenta and also accumulate SplatI and Jac ca.SetMem(sumSplatI, 0.0) ca.SetMem(sumJac, 0.0) TotalVEnergy = np.array([0.0]) TotalIEnergy = np.array([0.0]) for itsub in range(numLocalSubjects): # initializations for this subject, this only assigns # reference to image variables p.m0 = m_array[itsub] Imsmts = [J_array[itsub]] # run warp iteration VEnergy, IEnergy = RunWarpIteration(nodeSubjectIds[itsub], cf, p, t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts, it) # gather relevant results ca.Add_I(sumSplatI, p.sumSplatI) ca.Add_I(sumJac, p.sumJac) TotalVEnergy[0] += VEnergy TotalIEnergy[0] += IEnergy # if there are multiple nodes we'll need to sum across processes now if cf.compute.useMPI: # do an MPI sum Compute.Reduce(sumSplatI, mpiImageBuff) Compute.Reduce(sumJac, mpiImageBuff) # also sum up energies of other nodes mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, TotalVEnergy, op=mpi4py.MPI.SUM) mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, TotalIEnergy, op=mpi4py.MPI.SUM) EnergyHistory.append([TotalVEnergy[0], TotalIEnergy[0]]) # now divide to get the new atlas image ca.Div(atlas, sumSplatI, sumJac) # keep track of energy in this iteration if isReporter and cf.io.plotEvery > 0 and (( (it + 1) % cf.io.plotEvery == 0) or (it == cf.optim.Niter - 1)): # plots AtlasPlots(cf, p, atlas, m_array, EnergyHistory) if isReporter: # print out energy (VEnergy, IEnergy) = EnergyHistory[-1] print "Iter", it, "of", cf.optim.Niter, ":", VEnergy + IEnergy, '(Total) = ', VEnergy, '(Vector) + ', IEnergy, '(Image)' # write output images and fields AtlasWriteOutput(cf, atlas, m_array, nodeSubjectIds, isReporter)
def GeoRegIteration(subid, cf, p, t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts, EnergyHistory, it): # compute gradient for regression (grad_m, sumJac, sumSplatI, VEnergy, IEnergy) = GeoRegGradient(p, t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts) # do energy related stuff for printing and bookkeeping #if it>0: EnergyHistory.append([VEnergy + IEnergy, VEnergy, IEnergy]) print VEnergy + IEnergy, '(Total) = ', VEnergy, '(Vector)+', IEnergy, '(Image)' # plot some stuff if cf.io.plotEvery > 0 and (((it + 1) % cf.io.plotEvery) == 0 or it == cf.optim.Niter - 1): GeoRegPlots(subid, cf, p, t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts, EnergyHistory) # end if if cf.optim.method == 'FIXEDGD': # automatic stepsize selection in the first three steps if it == 1: # TODO: BEWARE There are hardcoded numbers here for 2D and 3D #first find max absolute value across voxels in gradient temp = ca.Field3D(grad_m.grid(), ca.MEM_HOST) ca.Copy(temp, grad_m) temp_x, temp_y, temp_z = temp.asnp() temp1 = np.square(temp_x.flatten()) + np.square( temp_y.flatten()) + np.square(temp_z.flatten()) medianval = np.median(temp1[temp1 > 0.0000000001]) del temp, temp1, temp_x, temp_y, temp_z #2D images for 2000 iters #p.stepSize = float(0.000000002*medianval) #3D images for 2000 iters p.stepSize = float(0.000002 * medianval) print 'rank:', Compute.GetMPIInfo( )['rank'], ', localRank:', Compute.GetMPIInfo( )['local_rank'], 'subid: ', subid, ' Selecting initial step size in the beginning to be ', str( p.stepSize) if it > 3: totalEnergyDiff = EnergyHistory[-1][0] - EnergyHistory[-2][0] if totalEnergyDiff > 0.0: if cf.optim.maxPert is not None: print 'rank:', Compute.GetMPIInfo( )['rank'], ', localRank:', Compute.GetMPIInfo( )['local_rank'], 'subid: ', subid, ' Reducing stepsize for gradient descent by ', str( cf.optim.maxPert * 100), '%. The new step size is ', str( p.stepSize * (1 - cf.optim.maxPert)) p.stepSize = p.stepSize * (1 - cf.optim.maxPert) # take gradient descent step ca.Add_MulC_I(p.m0, grad_m, -p.stepSize) else: raise Exception("Unknown optimization scheme: " + cf.optim.optMethod) # end if # now divide to get new base image ca.Div(p.I0, sumSplatI, sumJac) return (EnergyHistory)