def RigidReg( Is, It, theta_step=.0001, t_step=.01, a_step=0, maxIter=350, plot=True, origin=None, theta=0, # only applies for 2D t=None, # only applies for 2D Ain=np.matrix(np.identity(3))): Idef = ca.Image3D(It.grid(), It.memType()) gradIdef = ca.Field3D(It.grid(), It.memType()) h = ca.Field3D(It.grid(), It.memType()) ca.SetToIdentity(h) x = ca.Image3D(It.grid(), It.memType()) y = ca.Image3D(It.grid(), It.memType()) DX = ca.Image3D(It.grid(), It.memType()) DY = ca.Image3D(It.grid(), It.memType()) diff = ca.Image3D(It.grid(), It.memType()) scratchI = ca.Image3D(It.grid(), It.memType()) ca.Copy(x, h, 0) ca.Copy(y, h, 1) if origin is None: origin = [(Is.grid().size().x + 1) / 2.0, (Is.grid().size().y + 1) / 2.0, (Is.grid().size().z + 1) / 2.0] x -= origin[0] y -= origin[1] numel = It.size().x * It.size().y * It.size().z immin, immax = ca.MinMax(It) imrng = max(immax - immin, .01) t_step /= numel * imrng theta_step /= numel * imrng a_step /= numel * imrng energy = [] a = 1 if cc.Is3D(Is): if theta: print "theta is not utilized in 3D registration" z = ca.Image3D(It.grid(), It.memType()) DZ = ca.Image3D(It.grid(), It.memType()) ca.Copy(z, h, 2) z -= origin[2] A = np.matrix(np.identity(4)) cc.ApplyAffineReal(Idef, Is, A) # cc.ApplyAffine(Idef, Is, A, origin) t = [0, 0, 0] for i in xrange(maxIter): ca.Sub(diff, Idef, It) ca.Gradient(gradIdef, Idef) ca.Copy(DX, gradIdef, 0) ca.Copy(DY, gradIdef, 1) ca.Copy(DZ, gradIdef, 2) # take gradient step for the translation ca.Mul(scratchI, DX, diff) t[0] += t_step * ca.Sum(scratchI) ca.Mul(scratchI, DY, diff) t[1] += t_step * ca.Sum(scratchI) ca.Mul(scratchI, DZ, diff) t[2] += t_step * ca.Sum(scratchI) A[0, 3] = t[0] A[1, 3] = t[1] A[2, 3] = t[2] if a_step > 0: DX *= x DY *= y DZ *= z DZ += DX DZ += DY DZ *= diff d_a = a_step * ca.Sum(DZ) a_prev = a a += d_a # multiplying by a/a_prev is equivalent to adding (a-aprev) A = A * np.matrix([[a / a_prev, 0, 0, 0], [ 0, a / a_prev, 0, 0 ], [0, 0, a / a_prev, 0], [0, 0, 0, 1]]) # Z rotation ca.Copy(DX, gradIdef, 0) ca.Copy(DY, gradIdef, 1) DX *= y ca.Neg_I(DX) DY *= x ca.Add(scratchI, DX, DY) scratchI *= diff theta = -theta_step * ca.Sum(scratchI) # % Recalculate A A = A * np.matrix( [[np.cos(theta), np.sin(theta), 0, 0], [-np.sin(theta), np.cos(theta), 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) # Y rotation ca.Copy(DX, gradIdef, 0) ca.Copy(DZ, gradIdef, 2) DX *= z ca.Neg_I(DX) DZ *= x ca.Add(scratchI, DX, DZ) scratchI *= diff theta = -theta_step * ca.Sum(scratchI) # % Recalculate A A = A * np.matrix( [[np.cos(theta), 0, np.sin(theta), 0], [0, 1, 0, 0], [-np.sin(theta), 0, np.cos(theta), 0], [0, 0, 0, 1]]) # X rotation ca.Copy(DY, gradIdef, 1) ca.Copy(DZ, gradIdef, 2) DY *= z ca.Neg_I(DY) DZ *= y ca.Add(scratchI, DY, DZ) scratchI *= diff theta = -theta_step * ca.Sum(scratchI) # Recalculate A A = A * np.matrix( [[1, 0, 0, 0], [0, np.cos(theta), np.sin(theta), 0], [0, -np.sin(theta), np.cos(theta), 0], [0, 0, 0, 1]]) cc.ApplyAffineReal(Idef, Is, A) # cc.ApplyAffine(Idef, Is, A, origin) # % display Energy (and other figures) at the end energy.append(ca.Sum2(diff)) if (i == maxIter - 1) or (i > 75 and abs(energy[-1] - energy[-50]) < immax): cd.DispImage(diff, title='Difference Image', colorbar=True) plt.figure() plt.plot(energy) cd.DispImage(Idef, title='Deformed Image') break elif cc.Is2D(Is): # theta = 0 if t is None: t = [0, 0] # A = np.array([[a*np.cos(theta), np.sin(theta), t[0]], # [-np.sin(theta), a*np.cos(theta), t[1]], # [0, 0, 1]]) A = np.copy(Ain) cc.ApplyAffineReal(Idef, Is, A) # ca.Copy(Idef, Is) for i in xrange(1, maxIter): # [FX,FY] = gradient(Idef) ca.Sub(diff, Idef, It) ca.Gradient(gradIdef, Idef) ca.Copy(DX, gradIdef, 0) ca.Copy(DY, gradIdef, 1) # take gradient step for the translation ca.Mul(scratchI, DX, diff) t[0] += t_step * ca.Sum(scratchI) ca.Mul(scratchI, DY, diff) t[1] += t_step * ca.Sum(scratchI) # take gradient step for the rotation theta if a_step > 0: # d/da DX *= x DY *= y DY += DX DY *= diff d_a = a_step * ca.Sum(DY) a += d_a # d/dtheta ca.Copy(DX, gradIdef, 0) ca.Copy(DY, gradIdef, 1) DX *= y ca.Neg_I(DX) DY *= x ca.Add(scratchI, DX, DY) scratchI *= diff d_theta = theta_step * ca.Sum(scratchI) theta -= d_theta # Recalculate A, Idef A = np.matrix([[a * np.cos(theta), np.sin(theta), t[0]], [-np.sin(theta), a * np.cos(theta), t[1]], [0, 0, 1]]) A = Ain * A cc.ApplyAffineReal(Idef, Is, A) # cc.ApplyAffine(Idef, Is, A, origin) # % display Energy (and other figures) at the end energy.append(ca.Sum2(diff)) if (i == maxIter - 1) or (i > 75 and abs(energy[-1] - energy[-50]) < immax): if i == maxIter - 1: print "not converged in ", maxIter, " Iterations" if plot: cd.DispImage(diff, title='Difference Image', colorbar=True) plt.figure() plt.plot(energy) cd.DispImage(Idef, title='Deformed Image') break return A
def ElastReg(I0Orig, I1Orig, scales=[1], nIters=[1000], maxPert=[0.2], fluidParams=[0.1, 0.1, 0.001], VFC=0.2, Mask=None, plotEvery=100): mType = I0Orig.memType() origGrid = I0Orig.grid() # allocate vars I0 = ca.Image3D(origGrid, mType) I1 = ca.Image3D(origGrid, mType) u = ca.Field3D(origGrid, mType) Idef = ca.Image3D(origGrid, mType) diff = ca.Image3D(origGrid, mType) gI = ca.Field3D(origGrid, mType) gU = ca.Field3D(origGrid, mType) scratchI = ca.Image3D(origGrid, mType) scratchV = ca.Field3D(origGrid, mType) # mask if Mask != None: MaskOrig = Mask.copy() # allocate diffOp if mType == ca.MEM_HOST: diffOp = ca.FluidKernelFFTCPU() else: diffOp = ca.FluidKernelFFTGPU() # initialize some vars nScales = len(scales) scaleManager = ca.MultiscaleManager(origGrid) for s in scales: scaleManager.addScaleLevel(s) # Initalize the thread memory manager (needed for resampler) # num pools is 2 (images) + 2*3 (fields) ca.ThreadMemoryManager.init(origGrid, mType, 8) if mType == ca.MEM_HOST: resampler = ca.MultiscaleResamplerGaussCPU(origGrid) else: resampler = ca.MultiscaleResamplerGaussGPU(origGrid) def setScale(scale): global curGrid scaleManager.set(scale) curGrid = scaleManager.getCurGrid() # since this is only 2D: curGrid.spacing().z = 1.0 resampler.setScaleLevel(scaleManager) diffOp.setAlpha(fluidParams[0]) diffOp.setBeta(fluidParams[1]) diffOp.setGamma(fluidParams[2]) diffOp.setGrid(curGrid) # downsample images I0.setGrid(curGrid) I1.setGrid(curGrid) if scaleManager.isLastScale(): ca.Copy(I0, I0Orig) ca.Copy(I1, I1Orig) else: resampler.downsampleImage(I0, I0Orig) resampler.downsampleImage(I1, I1Orig) if Mask != None: if scaleManager.isLastScale(): Mask.setGrid(curGrid) ca.Copy(Mask, MaskOrig) else: resampler.downsampleImage(Mask, MaskOrig) # initialize / upsample deformation if scaleManager.isFirstScale(): u.setGrid(curGrid) ca.SetMem(u, 0.0) else: resampler.updateVField(u) # set grids gI.setGrid(curGrid) Idef.setGrid(curGrid) diff.setGrid(curGrid) gU.setGrid(curGrid) scratchI.setGrid(curGrid) scratchV.setGrid(curGrid) # end function energy = [[] for _ in xrange(3)] for scale in range(len(scales)): setScale(scale) ustep = None # update gradient ca.Gradient(gI, I0) for it in range(nIters[scale]): print 'iter %d' % it # compute deformed image ca.ApplyV(Idef, I0, u, 1.0) # update u ca.Sub(diff, I1, Idef) if Mask != None: ca.Mul_I(diff, Mask) ca.ApplyV(scratchV, gI, u, ca.BACKGROUND_STRATEGY_CLAMP) ca.Mul_I(scratchV, diff) diffOp.applyInverseOperator(gU, scratchV) vfcEn = VFC * ca.Dot(scratchV, gU) # why is this negative necessary? ca.MulC_I(gU, -1.0) # u = u*(1-VFC*ustep) + (-2.0*ustep)*gU # MulC_Add_MulC_I(u, (1-VFC*ustep), # gU, 2.0*ustep) # u = u - ustep*(VFC*u + 2.0*gU) ca.MulC_I(gU, 2.0) # subtract average if gamma is zero (result of nullspace # of L for K(L(u))) if fluidParams[2] == 0: av = ca.SumComp(u) av /= scratchI.nVox() ca.SubC(scratchV, u, av) # continue computing gradient ca.MulC(scratchV, u, VFC) ca.Add_I(gU, scratchV) ca.Magnitude(scratchI, gU) gradmax = ca.Max(scratchI) if ustep is None or ustep * gradmax > maxPert: ustep = maxPert[scale] / gradmax print 'step is %f' % ustep ca.MulC_I(gU, ustep) # apply gradient ca.Sub_I(u, gU) # compute energy energy[0].append(ca.Sum2(diff)) diffOp.applyOperator(scratchV, u) energy[1].append(0.5 * VFC * ca.Dot(u, scratchV)) energy[2].append(energy[0][-1]+\ energy[1][-1]) if plotEvery > 0 and \ ((it+1) % plotEvery == 0 or \ (scale == nScales-1 and it == nIters[scale]-1)): print 'plotting' clrlist = ['r', 'g', 'b', 'm', 'c', 'y', 'k'] plt.figure('energy') for i in range(len(energy)): plt.plot(energy[i], clrlist[i]) if i == 0: plt.hold(True) plt.hold(False) plt.draw() plt.figure('results') plt.clf() plt.subplot(3, 2, 1) display.DispImage(I0, 'I0', newFig=False) plt.subplot(3, 2, 2) display.DispImage(I1, 'I1', newFig=False) plt.subplot(3, 2, 3) display.DispImage(Idef, 'def', newFig=False) plt.subplot(3, 2, 4) display.DispImage(diff, 'diff', newFig=False) plt.colorbar() plt.subplot(3, 2, 5) display.GridPlot(u, every=4) plt.subplot(3, 2, 6) display.JacDetPlot(u) plt.colorbar() plt.draw() plt.show() # end plot # end iteration # end scale return (Idef, u, energy)