def gather_file(args): if args.momentum: file = common.AsNPCopy(common.LoadITKField(args.files[0], ca.MEM_HOST)) else: #file = common.AsNPCopy(common.LoadITKImage(args.files[0], ca.MEM_HOST)) file = nib.load(args.files[0]).get_data() all_size = (len(args.files), ) + (15, 15, 15, 1, 3) #file.shape; data = torch.zeros(all_size) for i in range(0, len(args.files)): if args.momentum: cur_slice = torch.from_numpy( common.AsNPCopy(common.LoadITKField(args.files[i], ca.MEM_HOST))) else: cur_slice = nib.load(args.files[i]).get_data() if cur_slice.size == 3375: cur_slice = np.zeros([15, 15, 15, 1, 3]) cur_slice = torch.from_numpy(cur_slice) data[i] = cur_slice if args.momentum: # transpose the dataset to fit the training format data = data.numpy() data = np.transpose(data, [0, 4, 1, 2, 3]) data = torch.from_numpy(data) torch.save(data, args.output)
def MatchingImageMomentaWriteOuput(cf, geodesicState, EnergyHistory, m0, n1): grid = geodesicState.J0.grid() mType = geodesicState.J0.memType() # save momenta for the gedoesic common.SaveITKField(geodesicState.p0, cf.io.outputPrefix + "p0.mhd") # save matched momenta for the geodesic if cf.vectormomentum.matchImOnly: m0 = common.LoadITKField(cf.study.m, mType) ca.CoAd(geodesicState.p, geodesicState.rhoinv, m0) common.SaveITKField(geodesicState.p, cf.io.outputPrefix + "m1.mhd") # momenta match energy if cf.vectormomentum.matchImOnly: vecdiff = ca.ManagedField3D(grid, mType) ca.Sub_I(geodesicState.p, n1) ca.Copy(vecdiff, geodesicState.p) geodesicState.diffOp.applyInverseOperator(geodesicState.p) momentaMatchEnergy = ca.Dot(vecdiff, geodesicState.p) / ( float(geodesicState.p0.nVox()) * geodesicState.SigmaSlope * geodesicState.SigmaSlope) # save energy energyFilename = cf.io.outputPrefix + "testMomentaMatchEnergy.csv" with open(energyFilename, 'w') as f: print >> f, momentaMatchEnergy # save matched image for the geodesic tempim = ca.ManagedImage3D(grid, mType) ca.ApplyH(tempim, geodesicState.J0, geodesicState.rhoinv) common.SaveITKImage(tempim, cf.io.outputPrefix + "I1.mhd") # save energy energyFilename = cf.io.outputPrefix + "energy.csv" MatchingImageMomentaWriteEnergyHistoryToFile(EnergyHistory, energyFilename)
def main(): secNum = sys.argv[1] mkyNum = sys.argv[2] region = str(sys.argv[3]) # channel = sys.argv[3] ext = 'M{0}/section_{1}/{2}/'.format(mkyNum, secNum, region) ss_dir = '/home/sci/blakez/korenbergNAS/3D_database/Working/Microscopic/side_light_microscope/' conf_dir = '/home/sci/blakez/korenbergNAS/3D_database/Working/Microscopic/confocal/' memT = ca.MEM_DEVICE try: with open( ss_dir + 'src_registration/M{0}/section_{1}/M{0}_01_section_{1}_regions.txt' .format(mkyNum, secNum), 'r') as f: region_dict = json.load(f) f.close() except IOError: region_dict = {} region_dict[region] = {} region_dict['size'] = map( int, raw_input("What is the size of the full resolution image x,y? "). split(',')) region_dict[region]['bbx'] = map( int, raw_input( "What are the x indicies of the bounding box (Matlab Format x_start,x_stop? " ).split(',')) region_dict[region]['bby'] = map( int, raw_input( "What are the y indicies of the bounding box (Matlab Format y_start,y_stop? " ).split(',')) if region not in region_dict: region_dict[region] = {} region_dict[region]['bbx'] = map( int, raw_input( "What are the x indicies of the bounding box (Matlab Format x_start,x_stop? " ).split(',')) region_dict[region]['bby'] = map( int, raw_input( "What are the y indicies of the bounding box (Matlab Format y_start,y_stop? " ).split(',')) img_region = common.LoadITKImage( ss_dir + 'src_registration/M{0}/section_{1}/M{0}_01_section_{1}_{2}.tiff'. format(mkyNum, secNum, region), ca.MEM_HOST) ssiSrc = common.LoadITKImage( ss_dir + 'src_registration/M{0}/section_{1}/frag0/M{0}_01_ssi_section_{1}_frag0.nrrd' .format(mkyNum, secNum), ca.MEM_HOST) bfi_df = common.LoadITKField( ss_dir + 'Blockface_registered/M{0}/section_{1}/frag0/M{0}_01_ssi_section_{1}_frag0_to_bfi_real.mha' .format(mkyNum, secNum), ca.MEM_DEVICE) # Figure out the same region in the low resolution image: There is a transpose from here to matlab so dimensions are flipped low_sz = ssiSrc.size().tolist() yrng_raw = [(low_sz[1] * region_dict[region]['bbx'][0]) / np.float(region_dict['size'][0]), (low_sz[1] * region_dict[region]['bbx'][1]) / np.float(region_dict['size'][0])] xrng_raw = [(low_sz[0] * region_dict[region]['bby'][0]) / np.float(region_dict['size'][1]), (low_sz[0] * region_dict[region]['bby'][1]) / np.float(region_dict['size'][1])] yrng = [np.int(np.floor(yrng_raw[0])), np.int(np.ceil(yrng_raw[1]))] xrng = [np.int(np.floor(xrng_raw[0])), np.int(np.ceil(xrng_raw[1]))] low_sub = cc.SubVol(ssiSrc, xrng, yrng) # Figure out the grid for the sub region in relation to the sidescape originout = [ ssiSrc.origin().x + ssiSrc.spacing().x * xrng[0], ssiSrc.origin().y + ssiSrc.spacing().y * yrng[0], 0 ] spacingout = [ (low_sub.size().x * ssiSrc.spacing().x) / (img_region.size().x), (low_sub.size().y * ssiSrc.spacing().y) / (img_region.size().y), 1 ] gridout = cc.MakeGrid(img_region.size().tolist(), spacingout, originout) img_region.setGrid(gridout) only_sub = np.zeros(ssiSrc.size().tolist()[0:2]) only_sub[xrng[0]:xrng[1], yrng[0]:yrng[1]] = np.squeeze(low_sub.asnp()) only_sub = common.ImFromNPArr(only_sub) only_sub.setGrid(ssiSrc.grid()) # Deform the only sub region to only_sub.toType(ca.MEM_DEVICE) def_sub = ca.Image3D(bfi_df.grid(), bfi_df.memType()) cc.ApplyHReal(def_sub, only_sub, bfi_df) def_sub.toType(ca.MEM_HOST) # Now have to find the bounding box in the deformation space (bfi space) if 'deformation_bbx' not in region_dict[region]: bb_def = np.squeeze(pp.LandmarkPicker([np.squeeze(def_sub.asnp())])) bb_def_y = [bb_def[0][0], bb_def[1][0]] bb_def_x = [bb_def[0][1], bb_def[1][1]] region_dict[region]['deformation_bbx'] = bb_def_x region_dict[region]['deformation_bby'] = bb_def_y with open( ss_dir + 'src_registration/M{0}/section_{1}/M{0}_01_section_{1}_regions.txt' .format(mkyNum, secNum), 'w') as f: json.dump(region_dict, f) f.close() # Now need to extract the region and create a deformation and image that have the same resolution as the img_region deform_sub = cc.SubVol(bfi_df, region_dict[region]['deformation_bbx'], region_dict[region]['deformation_bby']) common.DebugHere() sizeout = [ int( np.ceil((deform_sub.size().x * deform_sub.spacing().x) / img_region.spacing().x)), int( np.ceil((deform_sub.size().y * deform_sub.spacing().y) / img_region.spacing().y)), 1 ] region_grid = cc.MakeGrid(sizeout, img_region.spacing().tolist(), deform_sub.origin().tolist()) def_im_region = ca.Image3D(region_grid, deform_sub.memType()) up_deformation = ca.Field3D(region_grid, deform_sub.memType()) img_region.toType(ca.MEM_DEVICE) cc.ResampleWorld(up_deformation, deform_sub, ca.BACKGROUND_STRATEGY_PARTIAL_ZERO) cc.ApplyHReal(def_im_region, img_region, up_deformation) ss_out = ss_dir + 'Blockface_registered/M{0}/section_{1}/{2}/'.format( mkyNum, secNum, region) if not pth.exists(pth.expanduser(ss_out)): os.mkdir(pth.expanduser(ss_out)) common.SaveITKImage( def_im_region, pth.expanduser(ss_out) + 'M{0}_01_section_{1}_{2}_def_to_bfi.nrrd'.format( mkyNum, secNum, region)) common.SaveITKImage( def_im_region, pth.expanduser(ss_out) + 'M{0}_01_section_{1}_{2}_def_to_bfi.tiff'.format( mkyNum, secNum, region)) del img_region, def_im_region, ssiSrc, deform_sub # Now apply the same deformation to the confocal images conf_grid = cc.LoadGrid( conf_dir + 'sidelight_registered/M{0}/section_{1}/{2}/affine_registration_grid.txt' .format(mkyNum, secNum, region)) cf_out = conf_dir + 'blockface_registered/M{0}/section_{1}/{2}/'.format( mkyNum, secNum, region) # confocal.toType(ca.MEM_DEVICE) # def_conf = ca.Image3D(region_grid, deform_sub.memType()) # cc.ApplyHReal(def_conf, confocal, up_deformation) for channel in range(0, 4): z_stack = [] num_slices = len( glob.glob(conf_dir + 'sidelight_registered/M{0}/section_{1}/{3}/Ch{2}/*.tiff'. format(mkyNum, secNum, channel, region))) for z in range(0, num_slices): src_im = common.LoadITKImage( conf_dir + 'sidelight_registered/M{0}/section_{1}/{3}/Ch{2}/M{0}_01_section_{1}_LGN_RHS_Ch{2}_conf_aff_sidelight_z{4}.tiff' .format(mkyNum, secNum, channel, region, str(z).zfill(2))) src_im.setGrid( cc.MakeGrid( ca.Vec3Di(conf_grid.size().x, conf_grid.size().y, 1), conf_grid.spacing(), conf_grid.origin())) src_im.toType(ca.MEM_DEVICE) def_im = ca.Image3D(region_grid, ca.MEM_DEVICE) cc.ApplyHReal(def_im, src_im, up_deformation) def_im.toType(ca.MEM_HOST) common.SaveITKImage( def_im, cf_out + 'Ch{2}/M{0}_01_section_{1}_{3}_Ch{2}_conf_def_blockface_z{4}.tiff' .format(mkyNum, secNum, channel, region, str(z).zfill(2))) if z == 0: common.SaveITKImage( def_im, cf_out + 'Ch{2}/M{0}_01_section_{1}_{3}_Ch{2}_conf_def_blockface_z{4}.nrrd' .format(mkyNum, secNum, channel, region, str(z).zfill(2))) z_stack.append(def_im) print('==> Done with Ch {0}: {1}/{2}'.format( channel, z, num_slices - 1)) stacked = cc.Imlist_to_Im(z_stack) stacked.setSpacing( ca.Vec3Df(region_grid.spacing().x, region_grid.spacing().y, conf_grid.spacing().z)) common.SaveITKImage( stacked, cf_out + 'Ch{2}/M{0}_01_section_{1}_{3}_Ch{2}_conf_def_blockface_stack.nrrd' .format(mkyNum, secNum, channel, region)) if channel == 0: cc.WriteGrid( stacked.grid(), cf_out + 'deformed_registration_grid.txt'.format( mkyNum, secNum, region))
def MatchingImageMomenta(cf): """Runs matching for image momenta pair.""" if cf.compute.useCUDA and cf.compute.gpuID is not None: ca.SetCUDADevice(cf.compute.gpuID) common.DebugHere() # prepare output directory common.Mkdir_p(os.path.dirname(cf.io.outputPrefix)) # Output loaded config if cf.io.outputPrefix is not None: cfstr = Config.ConfigToYAML(MatchingImageMomentaConfigSpec, cf) with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f: f.write(cfstr) # mem type is determined by whether or not we're using CUDA mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST # load data in memory I0 = common.LoadITKImage(cf.study.I, mType) m0 = common.LoadITKField(cf.study.m, mType) J1 = common.LoadITKImage(cf.study.J, mType) n1 = common.LoadITKField(cf.study.n, mType) # get imGrid from data imGrid = I0.grid() # create time array with checkpointing info for this geodesic to be estimated (s, scratchInd, rCpinds) = CAvmHGM.HGMSetUpTimeArray(cf.optim.nTimeSteps, [1.0], 0.001) tDiscGeodesic = CAvmHGMCommon.HGMSetupTimeDiscretizationResidual( s, rCpinds, imGrid, mType) # create the state variable for geodesic that is going to hold all info p0 = ca.Field3D(imGrid, mType) geodesicState = CAvmHGMCommon.HGMResidualState( I0, p0, imGrid, mType, cf.vectormomentum.diffOpParams[0], cf.vectormomentum.diffOpParams[1], cf.vectormomentum.diffOpParams[2], s, cf.optim.NIterForInverse, 1.0, cf.vectormomentum.sigmaM, cf.vectormomentum.sigmaI, cf.optim.stepSize, integMethod=cf.optim.integMethod) # initialize with zero ca.SetMem(geodesicState.p0, 0.0) # start up the memory manager for scratch variables ca.ThreadMemoryManager.init(imGrid, mType, 0) EnergyHistory = [] # run the loop for it in range(cf.optim.Niter): # shoot the geodesic forward CAvmHGMCommon.HGMIntegrateGeodesic(geodesicState.p0, geodesicState.s, geodesicState.diffOp, geodesicState.p, geodesicState.rho, geodesicState.rhoinv, tDiscGeodesic, geodesicState.Ninv, geodesicState.integMethod) # integrate the geodesic backward CAvmHGMCommon.HGMIntegrateAdjointsResidual(geodesicState, tDiscGeodesic, m0, J1, n1) # TODO: verify it should just be log map/simple image matching when sigmaM=\infty # gradient descent step for geodesic.p0 CAvmHGMCommon.HGMTakeGradientStepResidual(geodesicState) # compute and print energy (VEnergy, IEnergy, MEnergy) = MatchingImageMomentaComputeEnergy(geodesicState, m0, J1, n1) EnergyHistory.append( [VEnergy + IEnergy + MEnergy, VEnergy, IEnergy, MEnergy]) print "Iter", it, "of", cf.optim.Niter, ":", VEnergy + IEnergy + MEnergy, '(Total) = ', VEnergy, '(Vector) + ', IEnergy, '(Image Match) + ', MEnergy, '(Momenta Match)' # plots if cf.io.plotEvery > 0 and (((it + 1) % cf.io.plotEvery == 0) or (it == cf.optim.Niter - 1)): MatchingImageMomentaPlots(cf, geodesicState, tDiscGeodesic, EnergyHistory, m0, J1, n1, writeOutput=True) # write output MatchingImageMomentaWriteOuput(cf, geodesicState)
def GeodesicShooting(cf): # prepare output directory common.Mkdir_p(os.path.dirname(cf.io.outputPrefix)) # Output loaded config if cf.io.outputPrefix is not None: cfstr = Config.ConfigToYAML(GeodesicShootingConfigSpec, cf) with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f: f.write(cfstr) mType = ca.MEM_DEVICE if cf.useCUDA else ca.MEM_HOST #common.DebugHere() I0 = common.LoadITKImage(cf.study.I0, mType) m0 = common.LoadITKField(cf.study.m0, mType) grid = I0.grid() ca.ThreadMemoryManager.init(grid, mType, 1) # set up diffOp if mType == ca.MEM_HOST: diffOp = ca.FluidKernelFFTCPU() else: diffOp = ca.FluidKernelFFTGPU() diffOp.setAlpha(cf.diffOpParams[0]) diffOp.setBeta(cf.diffOpParams[1]) diffOp.setGamma(cf.diffOpParams[2]) diffOp.setGrid(grid) g = ca.Field3D(grid, mType) ginv = ca.Field3D(grid, mType) mt = ca.Field3D(grid, mType) It = ca.Image3D(grid, mType) t = [ x * 1. / cf.integration.nTimeSteps for x in range(cf.integration.nTimeSteps + 1) ] checkpointinds = range(1, len(t)) checkpointstates = [(ca.Field3D(grid, mType), ca.Field3D(grid, mType)) for idx in checkpointinds] scratchV1 = ca.Field3D(grid, mType) scratchV2 = ca.Field3D(grid, mType) scratchV3 = ca.Field3D(grid, mType) # scale momenta to shoot cf.study.scaleMomenta = float(cf.study.scaleMomenta) if abs(cf.study.scaleMomenta) > 0.000000: ca.MulC_I(m0, float(cf.study.scaleMomenta)) CAvmCommon.IntegrateGeodesic(m0,t,diffOp, mt, g, ginv,\ scratchV1,scratchV2,scratchV3,\ keepstates=checkpointstates,keepinds=checkpointinds, Ninv=cf.integration.NIterForInverse, integMethod = cf.integration.integMethod) else: ca.Copy(It, I0) ca.Copy(mt, m0) ca.SetToIdentity(ginv) ca.SetToIdentity(g) # write output if cf.io.outputPrefix is not None: # scale back shotmomenta before writing if abs(cf.study.scaleMomenta) > 0.000000: ca.ApplyH(It, I0, ginv) ca.CoAd(mt, ginv, m0) ca.DivC_I(mt, float(cf.study.scaleMomenta)) common.SaveITKImage(It, cf.io.outputPrefix + "I1.mhd") common.SaveITKField(mt, cf.io.outputPrefix + "m1.mhd") common.SaveITKField(ginv, cf.io.outputPrefix + "phiinv.mhd") common.SaveITKField(g, cf.io.outputPrefix + "phi.mhd") GeodesicShootingPlots(g, ginv, I0, It, cf) if cf.io.saveFrames: SaveFrames(checkpointstates, checkpointinds, I0, It, m0, mt, cf)
def BuildHGM(cf): """Worker for running Hierarchical Geodesic Model (HGM) n for group geodesic estimation on a subset of individuals. Runs HGM on this subset sequentially. The variations retuned are summed up to get update for all individuals""" size = Compute.GetMPIInfo()['size'] rank = Compute.GetMPIInfo()['rank'] name = Compute.GetMPIInfo()['name'] localRank = Compute.GetMPIInfo()['local_rank'] nodename = socket.gethostname() # prepare output directory common.Mkdir_p(os.path.dirname(cf.io.outputPrefix)) # just one reporter process on each node isReporter = rank == 0 cf.study.numSubjects = len(cf.study.subjectIntercepts) if isReporter: # Output loaded config if cf.io.outputPrefix is not None: cfstr = Config.ConfigToYAML(HGMConfigSpec, cf) with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f: f.write(cfstr) #common.DebugHere() # if MPI check if processes are greater than number of subjects. it is okay if there are more subjects than processes if cf.compute.useMPI and (cf.study.numSubjects < cf.compute.numProcesses): raise Exception("Please don't use more processes " + "than total number of individuals") # subdivide data, create subsets for this thread to work on nodeSubjectIds = cf.study.subjectIds[rank::cf.compute.numProcesses] nodeIntercepts = cf.study.subjectIntercepts[rank::cf.compute.numProcesses] nodeSlopes = cf.study.subjectSlopes[rank::cf.compute.numProcesses] nodeBaselineTimes = cf.study.subjectBaselineTimes[rank::cf.compute. numProcesses] sys.stdout.write( "This is process %d of %d with name: %s on machinename: %s and local rank: %d.\nnodeIntercepts: %s\n nodeSlopes: %s\n nodeBaselineTimes: %s\n" % (rank, size, name, nodename, localRank, nodeIntercepts, nodeSlopes, nodeBaselineTimes)) # mem type is determined by whether or not we're using CUDA mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST # load data in memory # load intercepts J = [ common.LoadITKImage(f, mType) if isinstance(f, str) else f for f in nodeIntercepts ] # load slopes n = [ common.LoadITKField(f, mType) if isinstance(f, str) else f for f in nodeSlopes ] # get imGrid from data imGrid = J[0].grid() # create time array with checkpointing info for group geodesic (t, Jind, gCpinds) = HGMSetUpTimeArray(cf.optim.nTimeStepsGroup, nodeBaselineTimes, 0.0000001) tdiscGroup = CAvmHGMCommon.HGMSetupTimeDiscretizationGroup( t, J, n, Jind, gCpinds, mType, nodeSubjectIds) # create time array with checkpointing info for residual geodesic (s, scratchInd, rCpinds) = HGMSetUpTimeArray(cf.optim.nTimeStepsResidual, [1.0], 0.0000001) tdiscResidual = CAvmHGMCommon.HGMSetupTimeDiscretizationResidual( s, rCpinds, imGrid, mType) # create group state and residual state groupState = CAvmHGMCommon.HGMGroupState( imGrid, mType, cf.vectormomentum.diffOpParamsGroup[0], cf.vectormomentum.diffOpParamsGroup[1], cf.vectormomentum.diffOpParamsGroup[2], t, cf.optim.NIterForInverse, cf.vectormomentum.varIntercept, cf.vectormomentum.varSlope, cf.vectormomentum.varInterceptReg, cf.optim.stepSizeGroup, integMethod=cf.optim.integMethodGroup) #ca.Copy(groupState.I0, common.LoadITKImage('/usr/sci/projects/ADNI/nikhil/software/vectormomentumtest/TestData/FlowerData/Longitudinal/GroupGeodesic/I0.mhd', mType)) # note that residual state is treated a scratch variable in this algorithm and reused for computing residual geodesics of multiple individual residualState = CAvmHGMCommon.HGMResidualState( None, None, imGrid, mType, cf.vectormomentum.diffOpParamsResidual[0], cf.vectormomentum.diffOpParamsResidual[1], cf.vectormomentum.diffOpParamsResidual[2], s, cf.optim.NIterForInverse, cf.vectormomentum.varIntercept, cf.vectormomentum.varSlope, cf.vectormomentum.varInterceptReg, cf.optim.stepSizeResidual, integMethod=cf.optim.integMethodResidual) # start up the memory manager for scratch variables ca.ThreadMemoryManager.init(imGrid, mType, 0) # need some host memory in np array format for MPI reductions if cf.compute.useMPI: mpiImageBuff = None if mType == ca.MEM_HOST else ca.Image3D( imGrid, ca.MEM_HOST) mpiFieldBuff = None if mType == ca.MEM_HOST else ca.Field3D( imGrid, ca.MEM_HOST) for i in range(len(groupState.t) - 1, -1, -1): if tdiscGroup[i].J is not None: indx_last_individual = i break ''' # initial template image ca.SetMem(groupState.I0, 0.0) tmp = ca.ManagedImage3D(imGrid, mType) for tdisc in tdiscGroup: if tdisc.J is not None: ca.Copy(tmp, tdisc.J) groupState.I0 += tmp del tmp if cf.compute.useMPI: Compute.Reduce(groupState.I0, mpiImageBuff) # divide by total num subjects groupState.I0 /= cf.study.numSubjects ''' # run the loop for it in range(cf.optim.Niter): # compute HGM variation for group HGMGroupVariation(groupState, tdiscGroup, residualState, tdiscResidual, cf.io.outputPrefix, rank, it) common.CheckCUDAError("Error after HGM iteration") # compute gradient for momenta (m is used as scratch) # if there are multiple nodes we'll need to sum across processes now if cf.compute.useMPI: # do an MPI sum Compute.Reduce(groupState.sumSplatI, mpiImageBuff) Compute.Reduce(groupState.sumJac, mpiImageBuff) Compute.Reduce(groupState.madj, mpiFieldBuff) # also sum up energies of other nodes # intercept Eintercept = np.array([groupState.EnergyHistory[-1][1]]) mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, Eintercept, op=mpi4py.MPI.SUM) groupState.EnergyHistory[-1][1] = Eintercept[0] Eslope = np.array([groupState.EnergyHistory[-1][2]]) mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, Eslope, op=mpi4py.MPI.SUM) groupState.EnergyHistory[-1][2] = Eslope[0] ca.Copy(groupState.m, groupState.m0) groupState.diffOp.applyInverseOperator(groupState.m) ca.Sub_I(groupState.m, groupState.madj) #groupState.diffOp.applyOperator(groupState.m) # now take gradient step in momenta for group if cf.optim.method == 'FIXEDGD': # take fixed stepsize gradient step ca.Add_MulC_I(groupState.m0, groupState.m, -cf.optim.stepSizeGroup) else: raise Exception("Unknown optimization scheme: " + cf.optim.method) # end if # now divide to get the new base image for group ca.Div(groupState.I0, groupState.sumSplatI, groupState.sumJac) # keep track of energy in this iteration if isReporter and cf.io.plotEvery > 0 and (( (it + 1) % cf.io.plotEvery == 0) or (it == cf.optim.Niter - 1)): HGMPlots(cf, groupState, tdiscGroup, residualState, tdiscResidual, indx_last_individual, writeOutput=True) if isReporter: (VEnergy, IEnergy, SEnergy) = groupState.EnergyHistory[-1] print datetime.datetime.now().time( ), " Iter", it, "of", cf.optim.Niter, ":", VEnergy + IEnergy + SEnergy, '(Total) = ', VEnergy, '(Vector) + ', IEnergy, '(Intercept) + ', SEnergy, '(Slope)' # write output images and fields HGMWriteOutput(cf, groupState, tdiscGroup, isReporter)
def BuildGeoReg(cf): """Worker for running geodesic estimation on a subset of individuals """ #common.DebugHere() localRank = Compute.GetMPIInfo()['local_rank'] rank = Compute.GetMPIInfo()['rank'] # prepare output directory common.Mkdir_p(os.path.dirname(cf.io.outputPrefix)) # just one reporter process on each node isReporter = rank == 0 # load filenames and times for all subjects (subjectsIds, subjectsImagePaths, subjectsTimes) = GeoRegLoadSubjectsDetails(cf.study.subjectFile) cf.study.numSubjects = len(subjectsIds) if isReporter: # Output loaded config if cf.io.outputPrefix is not None: cfstr = Config.ConfigToYAML(GeoRegConfigSpec, cf) with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f: f.write(cfstr) # if MPI check if processes are greater than number of subjects. it is okay if there are more subjects than processes if cf.compute.useMPI and (len(subjectsIds) < cf.compute.numProcesses): raise Exception("Please don't use more processes " + "than total number of individuals") nodeSubjectsIds = subjectsIds[rank::cf.compute.numProcesses] nodeSubjectsImagePaths = subjectsImagePaths[rank::cf.compute.numProcesses] nodeSubjectsTimes = subjectsTimes[rank::cf.compute.numProcesses] numLocalSubjects = len(nodeSubjectsImagePaths) if cf.study.initializationsFile is not None: (subjectsInitialImages, subjectsInitialMomenta) = GeoRegLoadSubjectsInitializations( cf.study.initializationsFile) nodeSubjectsInitialImages = subjectsInitialImages[rank::cf.compute. numProcesses] nodeSubjectsInitialMomenta = subjectsInitialMomenta[rank::cf.compute. numProcesses] print 'rank:', rank, ', localRank:', localRank, ', numberSubjects/TotalSubjects:', len( nodeSubjectsImagePaths ), '/', cf.study.numSubjects, ', nodeSubjectsImagePaths:', nodeSubjectsImagePaths, ', nodeSubjectsTimes:', nodeSubjectsTimes # mem type is determined by whether or not we're using CUDA mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST # setting gpuid should be handled in gpu # if using GPU set device based on local rank #if cf.compute.useCUDA: # ca.SetCUDADevice(localRank) # get image size information dummyImToGetGridInfo = common.LoadITKImage(nodeSubjectsImagePaths[0][0], mType) imGrid = dummyImToGetGridInfo.grid() if cf.study.setUnitSpacing: imGrid.setSpacing(ca.Vec3Df(1.0, 1.0, 1.0)) if cf.study.setZeroOrigin: imGrid.setOrigin(ca.Vec3Df(0, 0, 0)) #del dummyImToGetGridInfo; # start up the memory manager for scratch variables ca.ThreadMemoryManager.init(imGrid, mType, 0) # allocate memory p = GeoRegVariables(imGrid, mType, cf.vectormomentum.diffOpParams[0], cf.vectormomentum.diffOpParams[1], cf.vectormomentum.diffOpParams[2], cf.optim.NIterForInverse, cf.vectormomentum.sigma, cf.optim.stepSize, integMethod=cf.optim.integMethod) # for each individual run geodesic regression for each subject for i in range(numLocalSubjects): # initializations for this subject if cf.study.initializationsFile is not None: # assuming the initializations are already preprocessed, in terms of intensities, origin and voxel scalings. p.I0 = common.LoadITKImage(nodeSubjectsInitialImages[i], mType) p.m0 = common.LoadITKField(nodeSubjectsInitialMomenta[i], mType) else: ca.SetMem(p.m0, 0.0) ca.SetMem(p.I0, 0.0) # allocate memory specific to this subject in steps a, b and c # a. create time array with checkpointing info for regression geodesic, allocate checkpoint memory (t, msmtinds, cpinds) = GeoRegSetUpTimeArray(cf.optim.nTimeSteps, nodeSubjectsTimes[i], 0.001) cpstates = [(ca.Field3D(imGrid, mType), ca.Field3D(imGrid, mType)) for idx in cpinds] # b. allocate gradAtMeasurements of the length of msmtindex for storing residuals gradAtMsmts = [ca.Image3D(imGrid, mType) for idx in msmtinds] # c. load timepoint images for this subject Imsmts = [ common.LoadITKImage(f, mType) if isinstance(f, str) else f for f in nodeSubjectsImagePaths[i] ] # reset stepsize if adaptive stepsize changed it inside p.stepSize = cf.optim.stepSize # preprocessimages GeoRegPreprocessInput(nodeSubjectsIds[i], cf, p, t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts) # run regression for this subject # REMEMBER # msmtinds index into cpinds # gradAtMsmts is parallel to msmtinds # cpinds index into t EnergyHistory = RunGeoReg(nodeSubjectsIds[i], cf, p, t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts) # write output images and fields for this subject # TODO: BEWARE There are hardcoded numbers inside preprocessing code specific for ADNI/OASIS brain data. GeoRegWriteOuput(nodeSubjectsIds[i], cf, p, t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts, EnergyHistory) # clean up memory specific to this subject del t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts