def MatchingGradient(p): # shoot the geodesic forward CAvmCommon.IntegrateGeodesic(p.m0,p.t,p.diffOp, \ p.m, p.g, p.ginv,\ p.scratchV1, p.scratchV2,p. scratchV3,\ p.checkpointstates, p.checkpointinds,\ Ninv=p.nInv, integMethod = p.integMethod, RK4=p.scratchV4,scratchG=p.scratchV5) endidx = p.checkpointinds.index(len(p.t)-1) # compute residual image ca.ApplyH(p.residualIm,p.I0,p.ginv) ca.Sub_I(p.residualIm, p.I1) # while we have residual, save the image energy IEnergy = ca.Sum2(p.residualIm)/(2*p.sigma*p.sigma*float(p.I0.nVox())) ca.DivC_I(p.residualIm, p.sigma*p.sigma) # gradient at measurement # integrate backward CAvmCommon.IntegrateAdjoints(p.Iadj,p.madj,\ p.I,p.m,p.Iadjtmp, p.madjtmp,p.scratchV1,\ p.scratchV2,p.scratchV3,\ p.I0,p.m0,\ p.t, p.checkpointstates, p.checkpointinds,\ [p.residualIm], [endidx],\ p.diffOp, p.integMethod, p.nInv, \ scratchV3=p.scratchV7, scratchV4=p.g,scratchV5=p.ginv,scratchV6=p.scratchV8, scratchV7=p.scratchV9, \ scratchV8=p.scratchV10,scratchV9=p.scratchV11,\ RK4=p.scratchV4, scratchG=p.scratchV5, scratchGinv=p.scratchV6) # compute gradient ca.Copy(p.scratchV1, p.m0) p.diffOp.applyInverseOperator(p.scratchV1) # while we have velocity, save the vector energy VEnergy = 0.5*ca.Dot(p.m0,p.scratchV1)/float(p.I0.nVox()) ca.Sub_I(p.scratchV1, p.madj) #p.diffOp.applyOperator(p.scratchV1) return (p.scratchV1, VEnergy, IEnergy)
def WarpGradient(p, t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts): # shoot the geodesic forward CAvmCommon.IntegrateGeodesic(p.m0,t,p.diffOp, \ p.m, p.g, p.ginv,\ p.scratchV1, p.scratchV2,p. scratchV3,\ cpstates, cpinds,\ Ninv=p.nInv, integMethod = p.integMethod, RK4=p.scratchV4,scratchG=p.scratchV5) IEnergy = 0.0 # compute residuals for each measurement timepoint along with computing energy for i in range(len(Imsmts)): if msmtinds[i] != -1: (g, ginv) = cpstates[msmtinds[i]] ca.ApplyH(gradAtMsmts[i], p.I0, ginv) ca.Sub_I(gradAtMsmts[i], Imsmts[i]) # while we have residual, save the image energy IEnergy += ca.Sum2( gradAtMsmts[i]) / (2 * p.sigma * p.sigma * float(p.I0.nVox())) ca.DivC_I(gradAtMsmts[i], p.sigma * p.sigma) # gradient at measurement elif msmtinds[i] == -1: ca.Copy(gradAtMsmts[i], p.I0) ca.Sub_I(gradAtMsmts[i], Imsmts[i]) # while we have residual, save the image energy IEnergy += ca.Sum2( gradAtMsmts[i]) / (2 * p.sigma * p.sigma * float(p.I0.nVox())) ca.DivC_I(gradAtMsmts[i], p.sigma * p.sigma) # gradient at measurement # integrate backward CAvmCommon.IntegrateAdjoints(p.Iadj,p.madj,\ p.I,p.m,p.Iadjtmp, p.madjtmp,p.scratchV1,\ p.scratchV2,p.scratchV3,\ p.I0,p.m0,\ t, cpstates, cpinds,\ gradAtMsmts,msmtinds,\ p.diffOp,\ p.integMethod, p.nInv, \ scratchV3=p.scratchV7, scratchV4=p.g,scratchV5=p.ginv,scratchV6=p.scratchV8, scratchV7=p.scratchV9, \ scratchV8=p.scratchV10,scratchV9=p.scratchV11,\ RK4=p.scratchV4, scratchG=p.scratchV5, scratchGinv=p.scratchV6,\ scratchI = p.scratchI1) # compute gradient ca.Copy(p.scratchV1, p.m0) p.diffOp.applyInverseOperator(p.scratchV1) # while we have velocity, save the vector energy VEnergy = 0.5 * ca.Dot(p.m0, p.scratchV1) / float(p.I0.nVox()) ca.Sub_I(p.scratchV1, p.madj) #p.diffOp.applyOperator(p.scratchV1) # compute closed from terms for image update # p.Iadjtmp and p.I will be used as scratch images scratchI = p.scratchI1 #reference assigned imOnes = p.I #reference assigned ca.SetMem(imOnes, 1.0) ca.SetMem(p.sumSplatI, 0.0) ca.SetMem(p.sumJac, 0.0) #common.DebugHere() for i in range(len(Imsmts)): # TODO: check these indexings for cases when timepoint 0 # is not checkpointed if msmtinds[i] != -1: (g, ginv) = cpstates[msmtinds[i]] CAvmCommon.SplatSafe(scratchI, ginv, Imsmts[i]) ca.Add_I(p.sumSplatI, scratchI) CAvmCommon.SplatSafe(scratchI, ginv, imOnes) ca.Add_I(p.sumJac, scratchI) elif msmtinds[i] == -1: ca.Add_I(p.sumSplatI, Imsmts[i]) ca.Add_I(p.sumJac, imOnes) return (p.scratchV1, p.sumJac, p.sumSplatI, VEnergy, IEnergy)
def GeodesicShooting(cf): # prepare output directory common.Mkdir_p(os.path.dirname(cf.io.outputPrefix)) # Output loaded config if cf.io.outputPrefix is not None: cfstr = Config.ConfigToYAML(GeodesicShootingConfigSpec, cf) with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f: f.write(cfstr) mType = ca.MEM_DEVICE if cf.useCUDA else ca.MEM_HOST #common.DebugHere() I0 = common.LoadITKImage(cf.study.I0, mType) m0 = common.LoadITKField(cf.study.m0, mType) grid = I0.grid() ca.ThreadMemoryManager.init(grid, mType, 1) # set up diffOp if mType == ca.MEM_HOST: diffOp = ca.FluidKernelFFTCPU() else: diffOp = ca.FluidKernelFFTGPU() diffOp.setAlpha(cf.diffOpParams[0]) diffOp.setBeta(cf.diffOpParams[1]) diffOp.setGamma(cf.diffOpParams[2]) diffOp.setGrid(grid) g = ca.Field3D(grid, mType) ginv = ca.Field3D(grid, mType) mt = ca.Field3D(grid, mType) It = ca.Image3D(grid, mType) t = [ x * 1. / cf.integration.nTimeSteps for x in range(cf.integration.nTimeSteps + 1) ] checkpointinds = range(1, len(t)) checkpointstates = [(ca.Field3D(grid, mType), ca.Field3D(grid, mType)) for idx in checkpointinds] scratchV1 = ca.Field3D(grid, mType) scratchV2 = ca.Field3D(grid, mType) scratchV3 = ca.Field3D(grid, mType) # scale momenta to shoot cf.study.scaleMomenta = float(cf.study.scaleMomenta) if abs(cf.study.scaleMomenta) > 0.000000: ca.MulC_I(m0, float(cf.study.scaleMomenta)) CAvmCommon.IntegrateGeodesic(m0,t,diffOp, mt, g, ginv,\ scratchV1,scratchV2,scratchV3,\ keepstates=checkpointstates,keepinds=checkpointinds, Ninv=cf.integration.NIterForInverse, integMethod = cf.integration.integMethod) else: ca.Copy(It, I0) ca.Copy(mt, m0) ca.SetToIdentity(ginv) ca.SetToIdentity(g) # write output if cf.io.outputPrefix is not None: # scale back shotmomenta before writing if abs(cf.study.scaleMomenta) > 0.000000: ca.ApplyH(It, I0, ginv) ca.CoAd(mt, ginv, m0) ca.DivC_I(mt, float(cf.study.scaleMomenta)) common.SaveITKImage(It, cf.io.outputPrefix + "I1.mhd") common.SaveITKField(mt, cf.io.outputPrefix + "m1.mhd") common.SaveITKField(ginv, cf.io.outputPrefix + "phiinv.mhd") common.SaveITKField(g, cf.io.outputPrefix + "phi.mhd") GeodesicShootingPlots(g, ginv, I0, It, cf) if cf.io.saveFrames: SaveFrames(checkpointstates, checkpointinds, I0, It, m0, mt, cf)