def test_SplatWorldAdjoint(self, disp=False): # Random input images reg = common.RandImage(self.sz, nSig=1.0, gSig=0.0, mType=ca.MEM_HOST, sp=self.imSp) # adjust hJ's grid so that voxels don't align perfectly spfactor = 0.77382 small = common.RandImage(self.sz, nSig=1.0, gSig=0.0, mType=ca.MEM_HOST, sp=self.imSp * spfactor) tmpsmall = ca.Image3D(small.grid(), ca.MEM_HOST) ca.SetMem(tmpsmall, 0) # compute < I(Phi(x)), J(x) > ca.ResampleWorld(tmpsmall, reg) smallIP = ca.Dot(tmpsmall, small) # compute < I(y), |DPhi^{-1}(y)| J(Phi^{-1}(y)) > tmpreg = ca.Image3D(reg.grid(), ca.MEM_HOST) ca.SetMem(tmpreg, 0) ca.SplatWorld(tmpreg, small) regIP = ca.Dot(tmpreg, reg) #print "a=%f b=%f" % (phiIdotJ, IdotphiJ) self.assertLess(abs(smallIP - regIP), 2e-6)
def test_SplatAdjoint(self, disp=False): hI = common.RandImage(self.sz, nSig=1.0, gSig=0.0, mType=ca.MEM_HOST, sp=self.imSp) hJ = common.RandImage(self.sz, nSig=1.0, gSig=0.0, mType=ca.MEM_HOST, sp=self.imSp) hPhi = common.RandField(self.sz, nSig=5.0, gSig=4.0, mType=ca.MEM_HOST, sp=self.imSp) tmp = ca.Image3D(self.grid, ca.MEM_HOST) # compute < I(Phi(x)), J(x) > ca.ApplyH(tmp, hI, hPhi) phiIdotJ = ca.Dot(tmp, hJ) # compute < I(y), |DPhi^{-1}(y)| J(Phi^{-1}(y)) > ca.Splat(tmp, hPhi, hJ) IdotphiJ = ca.Dot(tmp, hI) #print "a=%f b=%f" % (phiIdotJ, IdotphiJ) self.assertLess(abs(phiIdotJ - IdotphiJ), 2e-6)
def test_GroupAdjointAction(self, disp=False): hM = common.RandField(self.sz, nSig=5.0, gSig=4.0, mType=ca.MEM_HOST, sp=self.imSp) hV = common.RandField(self.sz, nSig=5.0, gSig=4.0, mType=ca.MEM_HOST, sp=self.imSp) hPhi = common.RandField(self.sz, nSig=5.0, gSig=4.0, mType=ca.MEM_HOST, sp=self.imSp) tmp = ca.Field3D(self.grid, ca.MEM_HOST) # compute < m, Ad_\phi v > ca.Ad(tmp, hPhi, hV) rhs = ca.Dot(tmp, hM) # compute < Ad^*_\phi m, v > ca.CoAd(tmp, hPhi, hM) lhs = ca.Dot(tmp, hV) #print "a=%f b=%f" % (rhs, lhs) self.assertLess(abs(rhs - lhs), 2e-6)
def HGMComputeJumpGradientsFromResidual(gradIntercept, gradSlope, residualState, mt, J1, n1): interceptEnergy = 0.0 slopeEnergy = 0.0 grid = residualState.J0.grid() mType = residualState.J0.memType() imdiff = ca.ManagedImage3D(grid, mType) vecdiff = ca.ManagedField3D(grid, mType) # 1. Gradient for Intercept ApplyH(imdiff, residualState.J0, residualState.rhoinv) Sub_I(imdiff, J1) iEnergy = 0.5 * ca.Sum2(imdiff) / ( float(residualState.p0.nVox()) * residualState.Sigma * residualState.Sigma * residualState.SigmaIntercept * residualState.SigmaIntercept) # save for use in intercept energy term MulC_I( imdiff, 1.0 / (residualState.SigmaIntercept * residualState.SigmaIntercept * residualState.Sigma * residualState.Sigma)) #Splat(gradIntercept, residualState.rhoinv, imdiff,False) SplatSafe(gradIntercept, residualState.rhoinv, imdiff) # 2. Gradient for Slope CoAd(residualState.p, residualState.rhoinv, mt) Sub_I(residualState.p, n1) Copy(vecdiff, residualState.p) # save for use in slope energy term residualState.diffOp.applyInverseOperator(residualState.p) # apply Ad and get the result in gradSlope Ad(gradSlope, residualState.rhoinv, residualState.p) MulC_I(gradSlope, 1.0 / (residualState.SigmaSlope * residualState.SigmaSlope)) # energy computation here # slope term slopeEnergy = ca.Dot(residualState.p, vecdiff) / ( float(residualState.p0.nVox()) * residualState.SigmaSlope * residualState.SigmaSlope) # intercept term. p is used as scratch variable Copy(residualState.p, residualState.p0) residualState.diffOp.applyInverseOperator(residualState.p) pEnergy = ca.Dot(residualState.p0, residualState.p) / ( float(residualState.p0.nVox()) * residualState.SigmaIntercept * residualState.SigmaIntercept) return (pEnergy, iEnergy, slopeEnergy)
def MatchingImageMomentaComputeEnergy(geodesicState, m0, J1, n1): vecEnergy = 0.0 imageMatchEnergy = 0.0 momentaMatchEnergy = 0.0 grid = geodesicState.J0.grid() mType = geodesicState.J0.memType() imdiff = ca.ManagedImage3D(grid, mType) vecdiff = ca.ManagedField3D(grid, mType) # image match energy ca.ApplyH(imdiff, geodesicState.J0, geodesicState.rhoinv) ca.Sub_I(imdiff, J1) imageMatchEnergy = 0.5 * ca.Sum2(imdiff) / ( float(geodesicState.p0.nVox()) * geodesicState.Sigma * geodesicState.Sigma * geodesicState.SigmaIntercept * geodesicState.SigmaIntercept) # save for use in intercept energy term # momenta match energy ca.CoAd(geodesicState.p, geodesicState.rhoinv, m0) ca.Sub_I(geodesicState.p, n1) ca.Copy(vecdiff, geodesicState.p) # save for use in slope energy term geodesicState.diffOp.applyInverseOperator(geodesicState.p) momentaMatchEnergy = ca.Dot(vecdiff, geodesicState.p) / ( float(geodesicState.p0.nVox()) * geodesicState.SigmaSlope * geodesicState.SigmaSlope) # vector energy. p is used as scratch variable ca.Copy(geodesicState.p, geodesicState.p0) geodesicState.diffOp.applyInverseOperator(geodesicState.p) vecEnergy = 0.5 * ca.Dot(geodesicState.p0, geodesicState.p) / ( float(geodesicState.p0.nVox()) * geodesicState.SigmaIntercept * geodesicState.SigmaIntercept) return (vecEnergy, imageMatchEnergy, momentaMatchEnergy)
def MatchingGradient(p): # shoot the geodesic forward CAvmCommon.IntegrateGeodesic(p.m0,p.t,p.diffOp, \ p.m, p.g, p.ginv,\ p.scratchV1, p.scratchV2,p. scratchV3,\ p.checkpointstates, p.checkpointinds,\ Ninv=p.nInv, integMethod = p.integMethod, RK4=p.scratchV4,scratchG=p.scratchV5) endidx = p.checkpointinds.index(len(p.t)-1) # compute residual image ca.ApplyH(p.residualIm,p.I0,p.ginv) ca.Sub_I(p.residualIm, p.I1) # while we have residual, save the image energy IEnergy = ca.Sum2(p.residualIm)/(2*p.sigma*p.sigma*float(p.I0.nVox())) ca.DivC_I(p.residualIm, p.sigma*p.sigma) # gradient at measurement # integrate backward CAvmCommon.IntegrateAdjoints(p.Iadj,p.madj,\ p.I,p.m,p.Iadjtmp, p.madjtmp,p.scratchV1,\ p.scratchV2,p.scratchV3,\ p.I0,p.m0,\ p.t, p.checkpointstates, p.checkpointinds,\ [p.residualIm], [endidx],\ p.diffOp, p.integMethod, p.nInv, \ scratchV3=p.scratchV7, scratchV4=p.g,scratchV5=p.ginv,scratchV6=p.scratchV8, scratchV7=p.scratchV9, \ scratchV8=p.scratchV10,scratchV9=p.scratchV11,\ RK4=p.scratchV4, scratchG=p.scratchV5, scratchGinv=p.scratchV6) # compute gradient ca.Copy(p.scratchV1, p.m0) p.diffOp.applyInverseOperator(p.scratchV1) # while we have velocity, save the vector energy VEnergy = 0.5*ca.Dot(p.m0,p.scratchV1)/float(p.I0.nVox()) ca.Sub_I(p.scratchV1, p.madj) #p.diffOp.applyOperator(p.scratchV1) return (p.scratchV1, VEnergy, IEnergy)
def MatchingImageMomentaWriteOuput(cf, geodesicState, EnergyHistory, m0, n1): grid = geodesicState.J0.grid() mType = geodesicState.J0.memType() # save momenta for the gedoesic common.SaveITKField(geodesicState.p0, cf.io.outputPrefix + "p0.mhd") # save matched momenta for the geodesic if cf.vectormomentum.matchImOnly: m0 = common.LoadITKField(cf.study.m, mType) ca.CoAd(geodesicState.p, geodesicState.rhoinv, m0) common.SaveITKField(geodesicState.p, cf.io.outputPrefix + "m1.mhd") # momenta match energy if cf.vectormomentum.matchImOnly: vecdiff = ca.ManagedField3D(grid, mType) ca.Sub_I(geodesicState.p, n1) ca.Copy(vecdiff, geodesicState.p) geodesicState.diffOp.applyInverseOperator(geodesicState.p) momentaMatchEnergy = ca.Dot(vecdiff, geodesicState.p) / ( float(geodesicState.p0.nVox()) * geodesicState.SigmaSlope * geodesicState.SigmaSlope) # save energy energyFilename = cf.io.outputPrefix + "testMomentaMatchEnergy.csv" with open(energyFilename, 'w') as f: print >> f, momentaMatchEnergy # save matched image for the geodesic tempim = ca.ManagedImage3D(grid, mType) ca.ApplyH(tempim, geodesicState.J0, geodesicState.rhoinv) common.SaveITKImage(tempim, cf.io.outputPrefix + "I1.mhd") # save energy energyFilename = cf.io.outputPrefix + "energy.csv" MatchingImageMomentaWriteEnergyHistoryToFile(EnergyHistory, energyFilename)
def WarpGradient(p, t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts): # shoot the geodesic forward CAvmCommon.IntegrateGeodesic(p.m0,t,p.diffOp, \ p.m, p.g, p.ginv,\ p.scratchV1, p.scratchV2,p. scratchV3,\ cpstates, cpinds,\ Ninv=p.nInv, integMethod = p.integMethod, RK4=p.scratchV4,scratchG=p.scratchV5) IEnergy = 0.0 # compute residuals for each measurement timepoint along with computing energy for i in range(len(Imsmts)): if msmtinds[i] != -1: (g, ginv) = cpstates[msmtinds[i]] ca.ApplyH(gradAtMsmts[i], p.I0, ginv) ca.Sub_I(gradAtMsmts[i], Imsmts[i]) # while we have residual, save the image energy IEnergy += ca.Sum2( gradAtMsmts[i]) / (2 * p.sigma * p.sigma * float(p.I0.nVox())) ca.DivC_I(gradAtMsmts[i], p.sigma * p.sigma) # gradient at measurement elif msmtinds[i] == -1: ca.Copy(gradAtMsmts[i], p.I0) ca.Sub_I(gradAtMsmts[i], Imsmts[i]) # while we have residual, save the image energy IEnergy += ca.Sum2( gradAtMsmts[i]) / (2 * p.sigma * p.sigma * float(p.I0.nVox())) ca.DivC_I(gradAtMsmts[i], p.sigma * p.sigma) # gradient at measurement # integrate backward CAvmCommon.IntegrateAdjoints(p.Iadj,p.madj,\ p.I,p.m,p.Iadjtmp, p.madjtmp,p.scratchV1,\ p.scratchV2,p.scratchV3,\ p.I0,p.m0,\ t, cpstates, cpinds,\ gradAtMsmts,msmtinds,\ p.diffOp,\ p.integMethod, p.nInv, \ scratchV3=p.scratchV7, scratchV4=p.g,scratchV5=p.ginv,scratchV6=p.scratchV8, scratchV7=p.scratchV9, \ scratchV8=p.scratchV10,scratchV9=p.scratchV11,\ RK4=p.scratchV4, scratchG=p.scratchV5, scratchGinv=p.scratchV6,\ scratchI = p.scratchI1) # compute gradient ca.Copy(p.scratchV1, p.m0) p.diffOp.applyInverseOperator(p.scratchV1) # while we have velocity, save the vector energy VEnergy = 0.5 * ca.Dot(p.m0, p.scratchV1) / float(p.I0.nVox()) ca.Sub_I(p.scratchV1, p.madj) #p.diffOp.applyOperator(p.scratchV1) # compute closed from terms for image update # p.Iadjtmp and p.I will be used as scratch images scratchI = p.scratchI1 #reference assigned imOnes = p.I #reference assigned ca.SetMem(imOnes, 1.0) ca.SetMem(p.sumSplatI, 0.0) ca.SetMem(p.sumJac, 0.0) #common.DebugHere() for i in range(len(Imsmts)): # TODO: check these indexings for cases when timepoint 0 # is not checkpointed if msmtinds[i] != -1: (g, ginv) = cpstates[msmtinds[i]] CAvmCommon.SplatSafe(scratchI, ginv, Imsmts[i]) ca.Add_I(p.sumSplatI, scratchI) CAvmCommon.SplatSafe(scratchI, ginv, imOnes) ca.Add_I(p.sumJac, scratchI) elif msmtinds[i] == -1: ca.Add_I(p.sumSplatI, Imsmts[i]) ca.Add_I(p.sumJac, imOnes) return (p.scratchV1, p.sumJac, p.sumSplatI, VEnergy, IEnergy)
def ElastReg(I0Orig, I1Orig, scales=[1], nIters=[1000], maxPert=[0.2], fluidParams=[0.1, 0.1, 0.001], VFC=0.2, Mask=None, plotEvery=100): mType = I0Orig.memType() origGrid = I0Orig.grid() # allocate vars I0 = ca.Image3D(origGrid, mType) I1 = ca.Image3D(origGrid, mType) u = ca.Field3D(origGrid, mType) Idef = ca.Image3D(origGrid, mType) diff = ca.Image3D(origGrid, mType) gI = ca.Field3D(origGrid, mType) gU = ca.Field3D(origGrid, mType) scratchI = ca.Image3D(origGrid, mType) scratchV = ca.Field3D(origGrid, mType) # mask if Mask != None: MaskOrig = Mask.copy() # allocate diffOp if mType == ca.MEM_HOST: diffOp = ca.FluidKernelFFTCPU() else: diffOp = ca.FluidKernelFFTGPU() # initialize some vars nScales = len(scales) scaleManager = ca.MultiscaleManager(origGrid) for s in scales: scaleManager.addScaleLevel(s) # Initalize the thread memory manager (needed for resampler) # num pools is 2 (images) + 2*3 (fields) ca.ThreadMemoryManager.init(origGrid, mType, 8) if mType == ca.MEM_HOST: resampler = ca.MultiscaleResamplerGaussCPU(origGrid) else: resampler = ca.MultiscaleResamplerGaussGPU(origGrid) def setScale(scale): global curGrid scaleManager.set(scale) curGrid = scaleManager.getCurGrid() # since this is only 2D: curGrid.spacing().z = 1.0 resampler.setScaleLevel(scaleManager) diffOp.setAlpha(fluidParams[0]) diffOp.setBeta(fluidParams[1]) diffOp.setGamma(fluidParams[2]) diffOp.setGrid(curGrid) # downsample images I0.setGrid(curGrid) I1.setGrid(curGrid) if scaleManager.isLastScale(): ca.Copy(I0, I0Orig) ca.Copy(I1, I1Orig) else: resampler.downsampleImage(I0, I0Orig) resampler.downsampleImage(I1, I1Orig) if Mask != None: if scaleManager.isLastScale(): Mask.setGrid(curGrid) ca.Copy(Mask, MaskOrig) else: resampler.downsampleImage(Mask, MaskOrig) # initialize / upsample deformation if scaleManager.isFirstScale(): u.setGrid(curGrid) ca.SetMem(u, 0.0) else: resampler.updateVField(u) # set grids gI.setGrid(curGrid) Idef.setGrid(curGrid) diff.setGrid(curGrid) gU.setGrid(curGrid) scratchI.setGrid(curGrid) scratchV.setGrid(curGrid) # end function energy = [[] for _ in xrange(3)] for scale in range(len(scales)): setScale(scale) ustep = None # update gradient ca.Gradient(gI, I0) for it in range(nIters[scale]): print 'iter %d' % it # compute deformed image ca.ApplyV(Idef, I0, u, 1.0) # update u ca.Sub(diff, I1, Idef) if Mask != None: ca.Mul_I(diff, Mask) ca.ApplyV(scratchV, gI, u, ca.BACKGROUND_STRATEGY_CLAMP) ca.Mul_I(scratchV, diff) diffOp.applyInverseOperator(gU, scratchV) vfcEn = VFC * ca.Dot(scratchV, gU) # why is this negative necessary? ca.MulC_I(gU, -1.0) # u = u*(1-VFC*ustep) + (-2.0*ustep)*gU # MulC_Add_MulC_I(u, (1-VFC*ustep), # gU, 2.0*ustep) # u = u - ustep*(VFC*u + 2.0*gU) ca.MulC_I(gU, 2.0) # subtract average if gamma is zero (result of nullspace # of L for K(L(u))) if fluidParams[2] == 0: av = ca.SumComp(u) av /= scratchI.nVox() ca.SubC(scratchV, u, av) # continue computing gradient ca.MulC(scratchV, u, VFC) ca.Add_I(gU, scratchV) ca.Magnitude(scratchI, gU) gradmax = ca.Max(scratchI) if ustep is None or ustep * gradmax > maxPert: ustep = maxPert[scale] / gradmax print 'step is %f' % ustep ca.MulC_I(gU, ustep) # apply gradient ca.Sub_I(u, gU) # compute energy energy[0].append(ca.Sum2(diff)) diffOp.applyOperator(scratchV, u) energy[1].append(0.5 * VFC * ca.Dot(u, scratchV)) energy[2].append(energy[0][-1]+\ energy[1][-1]) if plotEvery > 0 and \ ((it+1) % plotEvery == 0 or \ (scale == nScales-1 and it == nIters[scale]-1)): print 'plotting' clrlist = ['r', 'g', 'b', 'm', 'c', 'y', 'k'] plt.figure('energy') for i in range(len(energy)): plt.plot(energy[i], clrlist[i]) if i == 0: plt.hold(True) plt.hold(False) plt.draw() plt.figure('results') plt.clf() plt.subplot(3, 2, 1) display.DispImage(I0, 'I0', newFig=False) plt.subplot(3, 2, 2) display.DispImage(I1, 'I1', newFig=False) plt.subplot(3, 2, 3) display.DispImage(Idef, 'def', newFig=False) plt.subplot(3, 2, 4) display.DispImage(diff, 'diff', newFig=False) plt.colorbar() plt.subplot(3, 2, 5) display.GridPlot(u, every=4) plt.subplot(3, 2, 6) display.JacDetPlot(u) plt.colorbar() plt.draw() plt.show() # end plot # end iteration # end scale return (Idef, u, energy)