def GeodesicShootingPlots(g, ginv, I0, It, cf): fig = plt.figure(3) plt.clf() fig.patch.set_facecolor('white') plt.subplot(2, 2, 1) CAvmCommon.MyGridPlot(g, every=cf.io.gridEvery, color='k', dim=cf.io.plotSliceDim, sliceIdx=cf.io.plotSlice, isVF=False) plt.axis('equal') plt.axis('off') plt.title('\phi') plt.subplot(2, 2, 2) CAvmCommon.MyGridPlot(ginv, every=cf.io.gridEvery, color='k', dim=cf.io.plotSliceDim, sliceIdx=cf.io.plotSlice, isVF=False) plt.axis('equal') plt.axis('off') plt.title('\phi^{-1}') plt.subplot(2, 2, 3) display.DispImage(I0, 'I0', newFig=False, dim=cf.io.plotSliceDim, sliceIdx=cf.io.plotSlice) plt.subplot(2, 2, 4) display.DispImage(It, 'I1', newFig=False, dim=cf.io.plotSliceDim, sliceIdx=cf.io.plotSlice) plt.draw() plt.show() if cf.io.outputPrefix != None: plt.savefig(cf.io.outputPrefix + 'shooting.pdf')
def MatchingGradient(p): # shoot the geodesic forward CAvmCommon.IntegrateGeodesic(p.m0,p.t,p.diffOp, \ p.m, p.g, p.ginv,\ p.scratchV1, p.scratchV2,p. scratchV3,\ p.checkpointstates, p.checkpointinds,\ Ninv=p.nInv, integMethod = p.integMethod, RK4=p.scratchV4,scratchG=p.scratchV5) endidx = p.checkpointinds.index(len(p.t)-1) # compute residual image ca.ApplyH(p.residualIm,p.I0,p.ginv) ca.Sub_I(p.residualIm, p.I1) # while we have residual, save the image energy IEnergy = ca.Sum2(p.residualIm)/(2*p.sigma*p.sigma*float(p.I0.nVox())) ca.DivC_I(p.residualIm, p.sigma*p.sigma) # gradient at measurement # integrate backward CAvmCommon.IntegrateAdjoints(p.Iadj,p.madj,\ p.I,p.m,p.Iadjtmp, p.madjtmp,p.scratchV1,\ p.scratchV2,p.scratchV3,\ p.I0,p.m0,\ p.t, p.checkpointstates, p.checkpointinds,\ [p.residualIm], [endidx],\ p.diffOp, p.integMethod, p.nInv, \ scratchV3=p.scratchV7, scratchV4=p.g,scratchV5=p.ginv,scratchV6=p.scratchV8, scratchV7=p.scratchV9, \ scratchV8=p.scratchV10,scratchV9=p.scratchV11,\ RK4=p.scratchV4, scratchG=p.scratchV5, scratchGinv=p.scratchV6) # compute gradient ca.Copy(p.scratchV1, p.m0) p.diffOp.applyInverseOperator(p.scratchV1) # while we have velocity, save the vector energy VEnergy = 0.5*ca.Dot(p.m0,p.scratchV1)/float(p.I0.nVox()) ca.Sub_I(p.scratchV1, p.madj) #p.diffOp.applyOperator(p.scratchV1) return (p.scratchV1, VEnergy, IEnergy)
def Matching(cf): if cf.compute.useCUDA and cf.compute.gpuID is not None: ca.SetCUDADevice(cf.compute.gpuID) # prepare output directory common.Mkdir_p(os.path.dirname(cf.io.outputPrefix)) # Output loaded config if cf.io.outputPrefix is not None: cfstr = Config.ConfigToYAML(MatchingConfigSpec, cf) with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f: f.write(cfstr) mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST I0 = common.LoadITKImage(cf.study.I0, mType) I1 = common.LoadITKImage(cf.study.I1, mType) #ca.DivC_I(I0,255.0) #ca.DivC_I(I1,255.0) grid = I0.grid() ca.ThreadMemoryManager.init(grid, mType, 1) #common.DebugHere() # TODO: need to work on these t = [x*1./cf.optim.nTimeSteps for x in range(cf.optim.nTimeSteps+1)] checkpointinds = range(1,len(t)) checkpointstates = [(ca.Field3D(grid,mType),ca.Field3D(grid,mType)) for idx in checkpointinds] p = MatchingVariables(I0,I1, cf.vectormomentum.sigma, t,checkpointinds, checkpointstates, cf.vectormomentum.diffOpParams[0], cf.vectormomentum.diffOpParams[1], cf.vectormomentum.diffOpParams[2], cf.optim.Niter, cf.optim.stepSize, cf.optim.maxPert, cf.optim.nTimeSteps, integMethod = cf.optim.integMethod, optMethod=cf.optim.method, nInv=cf.optim.NIterForInverse,plotEvery=cf.io.plotEvery, plotSlice = cf.io.plotSlice, quiverEvery = cf.io.quiverEvery, outputPrefix = cf.io.outputPrefix) RunMatching(p) # write output if cf.io.outputPrefix is not None: # reset all variables by shooting once, may have been overwritten CAvmCommon.IntegrateGeodesic(p.m0,p.t,p.diffOp,\ p.m, p.g, p.ginv,\ p.scratchV1, p.scratchV2,p. scratchV3,\ p.checkpointstates, p.checkpointinds,\ Ninv=p.nInv, integMethod = p.integMethod) common.SaveITKField(p.m0, cf.io.outputPrefix+"m0.mhd") common.SaveITKField(p.ginv, cf.io.outputPrefix+"phiinv.mhd") common.SaveITKField(p.g, cf.io.outputPrefix+"phi.mhd")
def geodesic_shooting_diffOp(moving, target, m0, steps, mType, config): grid = moving.grid() m0.setGrid(grid) ca.ThreadMemoryManager.init(grid, mType, 1) if mType == ca.MEM_HOST: diffOp = ca.FluidKernelFFTCPU() else: diffOp = ca.FluidKernelFFTGPU() diffOp.setAlpha(config['deformation_params']['diffOpParams'][0]) diffOp.setBeta(config['deformation_params']['diffOpParams'][1]) diffOp.setGamma(config['deformation_params']['diffOpParams'][2]) diffOp.setGrid(grid) g = ca.Field3D(grid, mType) ginv = ca.Field3D(grid, mType) mt = ca.Field3D(grid, mType) It = ca.Image3D(grid, mType) It_inv = ca.Image3D(grid, mType) if (steps <= 0): time_steps = config['deformation_params']['timeSteps']; else: time_steps = steps; t = [x*1./time_steps for x in range(time_steps+1)] checkpointinds = range(1,len(t)) checkpointstates = [(ca.Field3D(grid,mType),ca.Field3D(grid,mType)) for idx in checkpointinds] scratchV1 = ca.Field3D(grid,mType) scratchV2 = ca.Field3D(grid,mType) scratchV3 = ca.Field3D(grid,mType) CAvmCommon.IntegrateGeodesic(m0,t,diffOp, mt, g, ginv,\ scratchV1,scratchV2,scratchV3,\ keepstates=checkpointstates,keepinds=checkpointinds, Ninv=config['deformation_params']['NIterForInverse'], integMethod = config['deformation_params']['integMethod']) ca.ApplyH(It,moving,ginv) ca.ApplyH(It_inv,target,g) output={'I1':It, 'I1_inv': It_inv, "phiinv":ginv} return output
def MatchingPlots(p): """ Do some summary plots for image matching """ #reset all variables by shooting once, may have been overwritten CAvmCommon.IntegrateGeodesic(p.m0,p.t,p.diffOp,\ p.m, p.g, p.ginv,\ p.scratchV1, p.scratchV2,p. scratchV3,\ p.checkpointstates, p.checkpointinds,\ Ninv=p.nInv, integMethod = p.integMethod) # plot the images fig = plt.figure('images') plt.clf() fig.patch.set_facecolor('white') plt.subplot(2, 2, 1) display.DispImage(p.I0, 'I0', newFig=False, cmap='gray', sliceIdx=p.plotSlice) plt.subplot(2, 2, 2) indx_of_last_tp = p.checkpointinds.index(len(p.t) - 1) (g, ginv) = p.checkpointstates[indx_of_last_tp] ca.ApplyH(p.I, p.I0, ginv) display.DispImage(p.I, '\phi.I0', newFig=False, cmap='gray', sliceIdx=p.plotSlice) plt.subplot(2, 2, 3) display.DispImage(p.I1, 'I1', newFig=False, cmap='gray', sliceIdx=p.plotSlice) plt.subplot(2, 2, 4) ca.ApplyH(p.I, p.I1, g) display.DispImage(p.I, '\phi^{-1}.I1', newFig=False, cmap='gray', sliceIdx=p.plotSlice) plt.draw() plt.show() if p.outputPrefix != None: plt.savefig(p.outputPrefix + 'images.pdf') fig = plt.figure('def') plt.clf() fig.patch.set_facecolor('white') plt.subplot(2, 2, 1) display.GridPlot(ginv, every=p.quiverEvery, color='k', sliceIdx=p.plotSlice, isVF=False) plt.axis('equal') plt.axis('off') plt.title('\phi^{-1}') plt.subplot(2, 2, 2) display.GridPlot(g, every=p.quiverEvery, color='k', sliceIdx=p.plotSlice, isVF=False) plt.axis('equal') plt.axis('off') plt.title('\phi') plt.subplot(2, 2, 3) ca.JacDetH(p.I, ginv) #p.I used as scratch variable to compute jacobian display.DispImage(p.I, '|D\phi^{-1}|', newFig=False, sliceIdx=p.plotSlice) plt.subplot(2, 2, 4) ca.MulC_I(p.residualIm, p.sigma * p.sigma) display.DispImage(p.residualIm, '\phi.I0-I1', newFig=False, sliceIdx=p.plotSlice) plt.colorbar() plt.draw() plt.show() if p.outputPrefix != None: plt.savefig(p.outputPrefix + 'def.pdf') fig = plt.figure('energy') fig.patch.set_facecolor('white') TE = [sum(x) for x in p.Energy] VE = [row[0] for row in p.Energy] IE = [row[1] for row in p.Energy] plt.subplot(1, 3, 1) plt.plot(TE) plt.title('Total Energy') plt.hold(False) plt.subplot(1, 3, 2) plt.plot(VE) plt.title('Vector Energy') plt.hold(False) plt.subplot(1, 3, 3) plt.plot(IE) plt.title('Image Energy') plt.hold(False) plt.draw() plt.show() if p.outputPrefix != None: plt.savefig(p.outputPrefix + 'energy.pdf')
def WarpGradient(p, t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts): # shoot the geodesic forward CAvmCommon.IntegrateGeodesic(p.m0,t,p.diffOp, \ p.m, p.g, p.ginv,\ p.scratchV1, p.scratchV2,p. scratchV3,\ cpstates, cpinds,\ Ninv=p.nInv, integMethod = p.integMethod, RK4=p.scratchV4,scratchG=p.scratchV5) IEnergy = 0.0 # compute residuals for each measurement timepoint along with computing energy for i in range(len(Imsmts)): if msmtinds[i] != -1: (g, ginv) = cpstates[msmtinds[i]] ca.ApplyH(gradAtMsmts[i], p.I0, ginv) ca.Sub_I(gradAtMsmts[i], Imsmts[i]) # while we have residual, save the image energy IEnergy += ca.Sum2( gradAtMsmts[i]) / (2 * p.sigma * p.sigma * float(p.I0.nVox())) ca.DivC_I(gradAtMsmts[i], p.sigma * p.sigma) # gradient at measurement elif msmtinds[i] == -1: ca.Copy(gradAtMsmts[i], p.I0) ca.Sub_I(gradAtMsmts[i], Imsmts[i]) # while we have residual, save the image energy IEnergy += ca.Sum2( gradAtMsmts[i]) / (2 * p.sigma * p.sigma * float(p.I0.nVox())) ca.DivC_I(gradAtMsmts[i], p.sigma * p.sigma) # gradient at measurement # integrate backward CAvmCommon.IntegrateAdjoints(p.Iadj,p.madj,\ p.I,p.m,p.Iadjtmp, p.madjtmp,p.scratchV1,\ p.scratchV2,p.scratchV3,\ p.I0,p.m0,\ t, cpstates, cpinds,\ gradAtMsmts,msmtinds,\ p.diffOp,\ p.integMethod, p.nInv, \ scratchV3=p.scratchV7, scratchV4=p.g,scratchV5=p.ginv,scratchV6=p.scratchV8, scratchV7=p.scratchV9, \ scratchV8=p.scratchV10,scratchV9=p.scratchV11,\ RK4=p.scratchV4, scratchG=p.scratchV5, scratchGinv=p.scratchV6,\ scratchI = p.scratchI1) # compute gradient ca.Copy(p.scratchV1, p.m0) p.diffOp.applyInverseOperator(p.scratchV1) # while we have velocity, save the vector energy VEnergy = 0.5 * ca.Dot(p.m0, p.scratchV1) / float(p.I0.nVox()) ca.Sub_I(p.scratchV1, p.madj) #p.diffOp.applyOperator(p.scratchV1) # compute closed from terms for image update # p.Iadjtmp and p.I will be used as scratch images scratchI = p.scratchI1 #reference assigned imOnes = p.I #reference assigned ca.SetMem(imOnes, 1.0) ca.SetMem(p.sumSplatI, 0.0) ca.SetMem(p.sumJac, 0.0) #common.DebugHere() for i in range(len(Imsmts)): # TODO: check these indexings for cases when timepoint 0 # is not checkpointed if msmtinds[i] != -1: (g, ginv) = cpstates[msmtinds[i]] CAvmCommon.SplatSafe(scratchI, ginv, Imsmts[i]) ca.Add_I(p.sumSplatI, scratchI) CAvmCommon.SplatSafe(scratchI, ginv, imOnes) ca.Add_I(p.sumJac, scratchI) elif msmtinds[i] == -1: ca.Add_I(p.sumSplatI, Imsmts[i]) ca.Add_I(p.sumJac, imOnes) return (p.scratchV1, p.sumJac, p.sumSplatI, VEnergy, IEnergy)
def GeodesicShooting(cf): # prepare output directory common.Mkdir_p(os.path.dirname(cf.io.outputPrefix)) # Output loaded config if cf.io.outputPrefix is not None: cfstr = Config.ConfigToYAML(GeodesicShootingConfigSpec, cf) with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f: f.write(cfstr) mType = ca.MEM_DEVICE if cf.useCUDA else ca.MEM_HOST #common.DebugHere() I0 = common.LoadITKImage(cf.study.I0, mType) m0 = common.LoadITKField(cf.study.m0, mType) grid = I0.grid() ca.ThreadMemoryManager.init(grid, mType, 1) # set up diffOp if mType == ca.MEM_HOST: diffOp = ca.FluidKernelFFTCPU() else: diffOp = ca.FluidKernelFFTGPU() diffOp.setAlpha(cf.diffOpParams[0]) diffOp.setBeta(cf.diffOpParams[1]) diffOp.setGamma(cf.diffOpParams[2]) diffOp.setGrid(grid) g = ca.Field3D(grid, mType) ginv = ca.Field3D(grid, mType) mt = ca.Field3D(grid, mType) It = ca.Image3D(grid, mType) t = [ x * 1. / cf.integration.nTimeSteps for x in range(cf.integration.nTimeSteps + 1) ] checkpointinds = range(1, len(t)) checkpointstates = [(ca.Field3D(grid, mType), ca.Field3D(grid, mType)) for idx in checkpointinds] scratchV1 = ca.Field3D(grid, mType) scratchV2 = ca.Field3D(grid, mType) scratchV3 = ca.Field3D(grid, mType) # scale momenta to shoot cf.study.scaleMomenta = float(cf.study.scaleMomenta) if abs(cf.study.scaleMomenta) > 0.000000: ca.MulC_I(m0, float(cf.study.scaleMomenta)) CAvmCommon.IntegrateGeodesic(m0,t,diffOp, mt, g, ginv,\ scratchV1,scratchV2,scratchV3,\ keepstates=checkpointstates,keepinds=checkpointinds, Ninv=cf.integration.NIterForInverse, integMethod = cf.integration.integMethod) else: ca.Copy(It, I0) ca.Copy(mt, m0) ca.SetToIdentity(ginv) ca.SetToIdentity(g) # write output if cf.io.outputPrefix is not None: # scale back shotmomenta before writing if abs(cf.study.scaleMomenta) > 0.000000: ca.ApplyH(It, I0, ginv) ca.CoAd(mt, ginv, m0) ca.DivC_I(mt, float(cf.study.scaleMomenta)) common.SaveITKImage(It, cf.io.outputPrefix + "I1.mhd") common.SaveITKField(mt, cf.io.outputPrefix + "m1.mhd") common.SaveITKField(ginv, cf.io.outputPrefix + "phiinv.mhd") common.SaveITKField(g, cf.io.outputPrefix + "phi.mhd") GeodesicShootingPlots(g, ginv, I0, It, cf) if cf.io.saveFrames: SaveFrames(checkpointstates, checkpointinds, I0, It, m0, mt, cf)
def SaveFrames(checkpointstates, checkpointinds, I0, It, m0, mt, cf): momentathresh = 0.00002 common.Mkdir_p(os.path.dirname(cf.io.outputPrefix) + '/frames/') image_idx = 0 fig = plt.figure(1, frameon=False) plt.clf() display.DispImage(I0, '', newFig=False, cmap='gray', dim=cf.io.plotSliceDim, sliceIdx=cf.io.plotSlice) plt.draw() outfilename = cf.io.outputPrefix + '/frames/I' + str(image_idx).zfill( 5) + '.png' fig.set_size_inches(4, 4) plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100) fig = plt.figure(2, frameon=False) plt.clf() temp = ca.Field3D(I0.grid(), I0.memType()) ca.SetToIdentity(temp) common.DebugHere() CAvmCommon.MyGridPlot(temp, every=cf.io.gridEvery, color='k', dim=cf.io.plotSliceDim, sliceIdx=cf.io.plotSlice, isVF=False, plotBase=False) #fig.patch.set_alpha(0) #fig.patch.set_visible(False) a = fig.gca() #a.set_frame_on(False) a.set_xticks([]) a.set_yticks([]) plt.axis('tight') plt.axis('image') plt.axis('off') plt.draw() fig.set_size_inches(4, 4) outfilename = cf.io.outputPrefix + '/frames/invdef' + str(image_idx).zfill( 5) + '.png' plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100) fig = plt.figure(3, frameon=False) plt.clf() CAvmCommon.MyGridPlot(temp, every=cf.io.gridEvery, color='k', dim=cf.io.plotSliceDim, sliceIdx=cf.io.plotSlice, isVF=False, plotBase=False) #fig.patch.set_alpha(0) #fig.patch.set_visible(False) a = fig.gca() #a.set_frame_on(False) a.set_xticks([]) a.set_yticks([]) plt.axis('tight') plt.axis('image') plt.axis('off') plt.draw() fig.set_size_inches(4, 4) outfilename = cf.io.outputPrefix + '/frames/def' + str(image_idx).zfill( 5) + '.png' plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100) fig = plt.figure(4, frameon=False) plt.clf() display.DispImage(I0, '', newFig=False, cmap='gray', dim=cf.io.plotSliceDim, sliceIdx=cf.io.plotSlice) plt.hold('True') CAvmCommon.MyQuiver(m0, dim=cf.io.plotSliceDim, sliceIdx=cf.io.plotSlice, every=cf.io.quiverEvery, thresh=momentathresh, scaleArrows=0.25, arrowCol='r', lineWidth=0.5, width=0.005) plt.draw() plt.hold('False') outfilename = cf.io.outputPrefix + '/frames/m' + str(image_idx).zfill( 5) + '.png' fig.set_size_inches(4, 4) plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100) for i in range(len(checkpointinds)): image_idx = image_idx + 1 ca.ApplyH(It, I0, checkpointstates[i][1]) fig = plt.figure(1, frameon=False) plt.clf() display.DispImage(It, '', newFig=False, cmap='gray', dim=cf.io.plotSliceDim, sliceIdx=cf.io.plotSlice) plt.draw() outfilename = cf.io.outputPrefix + '/frames/I' + str(image_idx).zfill( 5) + '.png' fig.set_size_inches(4, 4) plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100) fig = plt.figure(2, frameon=False) plt.clf() CAvmCommon.MyGridPlot(checkpointstates[i][1], every=cf.io.gridEvery, color='k', dim=cf.io.plotSliceDim, sliceIdx=cf.io.plotSlice, isVF=False, plotBase=False) #fig.patch.set_alpha(0) #fig.patch.set_visible(False) a = fig.gca() #a.set_frame_on(False) a.set_xticks([]) a.set_yticks([]) plt.axis('tight') plt.axis('image') plt.axis('off') plt.draw() outfilename = cf.io.outputPrefix + '/frames/invdef' + str( image_idx).zfill(5) + '.png' fig.set_size_inches(4, 4) plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100) fig = plt.figure(3, frameon=False) plt.clf() CAvmCommon.MyGridPlot(checkpointstates[i][0], every=cf.io.gridEvery, color='k', dim=cf.io.plotSliceDim, sliceIdx=cf.io.plotSlice, isVF=False, plotBase=False) #fig.patch.set_alpha(0) #fig.patch.set_visible(False) a = fig.gca() #a.set_frame_on(False) a.set_xticks([]) a.set_yticks([]) plt.axis('tight') plt.axis('image') plt.axis('off') plt.draw() outfilename = cf.io.outputPrefix + '/frames/def' + str( image_idx).zfill(5) + '.png' fig.set_size_inches(4, 4) plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100) ca.CoAd(mt, checkpointstates[i][1], m0) fig = plt.figure(4, frameon=False) plt.clf() display.DispImage(It, '', newFig=False, cmap='gray', dim=cf.io.plotSliceDim, sliceIdx=cf.io.plotSlice) plt.hold('True') CAvmCommon.MyQuiver(mt, dim=cf.io.plotSliceDim, sliceIdx=cf.io.plotSlice, every=cf.io.quiverEvery, thresh=momentathresh, scaleArrows=0.40, arrowCol='r', lineWidth=0.5, width=0.005) plt.draw() plt.hold('False') outfilename = cf.io.outputPrefix + '/frames/m' + str(image_idx).zfill( 5) + '.png' fig.set_size_inches(4, 4) plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)