Ejemplo n.º 1
0
 def test_SplatWorldAdjoint(self, disp=False):
     # Random input images
     reg = common.RandImage(self.sz,
                            nSig=1.0,
                            gSig=0.0,
                            mType=ca.MEM_HOST,
                            sp=self.imSp)
     # adjust hJ's grid so that voxels don't align perfectly
     spfactor = 0.77382
     small = common.RandImage(self.sz,
                              nSig=1.0,
                              gSig=0.0,
                              mType=ca.MEM_HOST,
                              sp=self.imSp * spfactor)
     tmpsmall = ca.Image3D(small.grid(), ca.MEM_HOST)
     ca.SetMem(tmpsmall, 0)
     # compute < I(Phi(x)), J(x) >
     ca.ResampleWorld(tmpsmall, reg)
     smallIP = ca.Dot(tmpsmall, small)
     # compute < I(y), |DPhi^{-1}(y)| J(Phi^{-1}(y)) >
     tmpreg = ca.Image3D(reg.grid(), ca.MEM_HOST)
     ca.SetMem(tmpreg, 0)
     ca.SplatWorld(tmpreg, small)
     regIP = ca.Dot(tmpreg, reg)
     #print "a=%f b=%f" % (phiIdotJ, IdotphiJ)
     self.assertLess(abs(smallIP - regIP), 2e-6)
Ejemplo n.º 2
0
 def test_SplatAdjoint(self, disp=False):
     hI = common.RandImage(self.sz,
                           nSig=1.0,
                           gSig=0.0,
                           mType=ca.MEM_HOST,
                           sp=self.imSp)
     hJ = common.RandImage(self.sz,
                           nSig=1.0,
                           gSig=0.0,
                           mType=ca.MEM_HOST,
                           sp=self.imSp)
     hPhi = common.RandField(self.sz,
                             nSig=5.0,
                             gSig=4.0,
                             mType=ca.MEM_HOST,
                             sp=self.imSp)
     tmp = ca.Image3D(self.grid, ca.MEM_HOST)
     # compute < I(Phi(x)), J(x) >
     ca.ApplyH(tmp, hI, hPhi)
     phiIdotJ = ca.Dot(tmp, hJ)
     # compute < I(y), |DPhi^{-1}(y)| J(Phi^{-1}(y)) >
     ca.Splat(tmp, hPhi, hJ)
     IdotphiJ = ca.Dot(tmp, hI)
     #print "a=%f b=%f" % (phiIdotJ, IdotphiJ)
     self.assertLess(abs(phiIdotJ - IdotphiJ), 2e-6)
Ejemplo n.º 3
0
 def test_GroupAdjointAction(self, disp=False):
     hM = common.RandField(self.sz,
                           nSig=5.0,
                           gSig=4.0,
                           mType=ca.MEM_HOST,
                           sp=self.imSp)
     hV = common.RandField(self.sz,
                           nSig=5.0,
                           gSig=4.0,
                           mType=ca.MEM_HOST,
                           sp=self.imSp)
     hPhi = common.RandField(self.sz,
                             nSig=5.0,
                             gSig=4.0,
                             mType=ca.MEM_HOST,
                             sp=self.imSp)
     tmp = ca.Field3D(self.grid, ca.MEM_HOST)
     # compute < m, Ad_\phi v >
     ca.Ad(tmp, hPhi, hV)
     rhs = ca.Dot(tmp, hM)
     # compute < Ad^*_\phi m,  v >
     ca.CoAd(tmp, hPhi, hM)
     lhs = ca.Dot(tmp, hV)
     #print "a=%f b=%f" % (rhs, lhs)
     self.assertLess(abs(rhs - lhs), 2e-6)
Ejemplo n.º 4
0
def HGMComputeJumpGradientsFromResidual(gradIntercept, gradSlope,
                                        residualState, mt, J1, n1):
    interceptEnergy = 0.0
    slopeEnergy = 0.0

    grid = residualState.J0.grid()
    mType = residualState.J0.memType()

    imdiff = ca.ManagedImage3D(grid, mType)
    vecdiff = ca.ManagedField3D(grid, mType)

    # 1. Gradient for Intercept
    ApplyH(imdiff, residualState.J0, residualState.rhoinv)
    Sub_I(imdiff, J1)
    iEnergy = 0.5 * ca.Sum2(imdiff) / (
        float(residualState.p0.nVox()) * residualState.Sigma *
        residualState.Sigma * residualState.SigmaIntercept *
        residualState.SigmaIntercept)  # save for use in intercept energy term
    MulC_I(
        imdiff,
        1.0 / (residualState.SigmaIntercept * residualState.SigmaIntercept *
               residualState.Sigma * residualState.Sigma))
    #Splat(gradIntercept, residualState.rhoinv, imdiff,False)
    SplatSafe(gradIntercept, residualState.rhoinv, imdiff)

    # 2. Gradient for Slope
    CoAd(residualState.p, residualState.rhoinv, mt)
    Sub_I(residualState.p, n1)
    Copy(vecdiff, residualState.p)  # save for use in slope energy term
    residualState.diffOp.applyInverseOperator(residualState.p)
    # apply Ad and get the result in gradSlope
    Ad(gradSlope, residualState.rhoinv, residualState.p)
    MulC_I(gradSlope,
           1.0 / (residualState.SigmaSlope * residualState.SigmaSlope))
    # energy computation here
    # slope term
    slopeEnergy = ca.Dot(residualState.p, vecdiff) / (
        float(residualState.p0.nVox()) * residualState.SigmaSlope *
        residualState.SigmaSlope)

    # intercept term. p is used as scratch variable
    Copy(residualState.p, residualState.p0)
    residualState.diffOp.applyInverseOperator(residualState.p)
    pEnergy = ca.Dot(residualState.p0, residualState.p) / (
        float(residualState.p0.nVox()) * residualState.SigmaIntercept *
        residualState.SigmaIntercept)

    return (pEnergy, iEnergy, slopeEnergy)
def MatchingImageMomentaComputeEnergy(geodesicState, m0, J1, n1):
    vecEnergy = 0.0
    imageMatchEnergy = 0.0
    momentaMatchEnergy = 0.0

    grid = geodesicState.J0.grid()
    mType = geodesicState.J0.memType()

    imdiff = ca.ManagedImage3D(grid, mType)
    vecdiff = ca.ManagedField3D(grid, mType)

    # image match energy
    ca.ApplyH(imdiff, geodesicState.J0, geodesicState.rhoinv)
    ca.Sub_I(imdiff, J1)
    imageMatchEnergy = 0.5 * ca.Sum2(imdiff) / (
        float(geodesicState.p0.nVox()) * geodesicState.Sigma *
        geodesicState.Sigma * geodesicState.SigmaIntercept *
        geodesicState.SigmaIntercept)  # save for use in intercept energy term

    # momenta match energy
    ca.CoAd(geodesicState.p, geodesicState.rhoinv, m0)
    ca.Sub_I(geodesicState.p, n1)
    ca.Copy(vecdiff, geodesicState.p)  # save for use in slope energy term
    geodesicState.diffOp.applyInverseOperator(geodesicState.p)
    momentaMatchEnergy = ca.Dot(vecdiff, geodesicState.p) / (
        float(geodesicState.p0.nVox()) * geodesicState.SigmaSlope *
        geodesicState.SigmaSlope)

    # vector energy. p is used as scratch variable
    ca.Copy(geodesicState.p, geodesicState.p0)
    geodesicState.diffOp.applyInverseOperator(geodesicState.p)
    vecEnergy = 0.5 * ca.Dot(geodesicState.p0, geodesicState.p) / (
        float(geodesicState.p0.nVox()) * geodesicState.SigmaIntercept *
        geodesicState.SigmaIntercept)

    return (vecEnergy, imageMatchEnergy, momentaMatchEnergy)
Ejemplo n.º 6
0
def MatchingGradient(p):
    # shoot the geodesic forward    
    CAvmCommon.IntegrateGeodesic(p.m0,p.t,p.diffOp, \
                      p.m, p.g, p.ginv,\
                      p.scratchV1, p.scratchV2,p. scratchV3,\
                      p.checkpointstates, p.checkpointinds,\
                      Ninv=p.nInv, integMethod = p.integMethod, RK4=p.scratchV4,scratchG=p.scratchV5)


    endidx = p.checkpointinds.index(len(p.t)-1)
    # compute residual image
    ca.ApplyH(p.residualIm,p.I0,p.ginv)
    ca.Sub_I(p.residualIm, p.I1)
    # while we have residual, save the image energy
    IEnergy = ca.Sum2(p.residualIm)/(2*p.sigma*p.sigma*float(p.I0.nVox()))
    
    ca.DivC_I(p.residualIm, p.sigma*p.sigma) # gradient at measurement
    
    # integrate backward
    CAvmCommon.IntegrateAdjoints(p.Iadj,p.madj,\
                      p.I,p.m,p.Iadjtmp, p.madjtmp,p.scratchV1,\
                      p.scratchV2,p.scratchV3,\
                      p.I0,p.m0,\
                      p.t, p.checkpointstates, p.checkpointinds,\
                      [p.residualIm], [endidx],\
                      p.diffOp,
                      p.integMethod, p.nInv, \
                      scratchV3=p.scratchV7, scratchV4=p.g,scratchV5=p.ginv,scratchV6=p.scratchV8, scratchV7=p.scratchV9, \
                      scratchV8=p.scratchV10,scratchV9=p.scratchV11,\
                      RK4=p.scratchV4, scratchG=p.scratchV5, scratchGinv=p.scratchV6)
                      
    
    # compute gradient
    ca.Copy(p.scratchV1, p.m0)
    p.diffOp.applyInverseOperator(p.scratchV1)
    # while we have velocity, save the vector energy
    VEnergy = 0.5*ca.Dot(p.m0,p.scratchV1)/float(p.I0.nVox())

    ca.Sub_I(p.scratchV1, p.madj)
    #p.diffOp.applyOperator(p.scratchV1)
    return (p.scratchV1, VEnergy, IEnergy)
def MatchingImageMomentaWriteOuput(cf, geodesicState, EnergyHistory, m0, n1):
    grid = geodesicState.J0.grid()
    mType = geodesicState.J0.memType()

    # save momenta for the gedoesic
    common.SaveITKField(geodesicState.p0, cf.io.outputPrefix + "p0.mhd")

    # save matched momenta for the geodesic
    if cf.vectormomentum.matchImOnly:
        m0 = common.LoadITKField(cf.study.m, mType)

    ca.CoAd(geodesicState.p, geodesicState.rhoinv, m0)
    common.SaveITKField(geodesicState.p, cf.io.outputPrefix + "m1.mhd")

    # momenta match energy
    if cf.vectormomentum.matchImOnly:
        vecdiff = ca.ManagedField3D(grid, mType)
        ca.Sub_I(geodesicState.p, n1)
        ca.Copy(vecdiff, geodesicState.p)
        geodesicState.diffOp.applyInverseOperator(geodesicState.p)
        momentaMatchEnergy = ca.Dot(vecdiff, geodesicState.p) / (
            float(geodesicState.p0.nVox()) * geodesicState.SigmaSlope *
            geodesicState.SigmaSlope)
        # save energy
        energyFilename = cf.io.outputPrefix + "testMomentaMatchEnergy.csv"
        with open(energyFilename, 'w') as f:
            print >> f, momentaMatchEnergy

    # save matched image for the geodesic
    tempim = ca.ManagedImage3D(grid, mType)
    ca.ApplyH(tempim, geodesicState.J0, geodesicState.rhoinv)
    common.SaveITKImage(tempim, cf.io.outputPrefix + "I1.mhd")

    # save energy
    energyFilename = cf.io.outputPrefix + "energy.csv"
    MatchingImageMomentaWriteEnergyHistoryToFile(EnergyHistory, energyFilename)
def WarpGradient(p, t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts):

    # shoot the geodesic forward
    CAvmCommon.IntegrateGeodesic(p.m0,t,p.diffOp, \
                                 p.m, p.g, p.ginv,\
                                 p.scratchV1, p.scratchV2,p. scratchV3,\
                                 cpstates, cpinds,\
                                 Ninv=p.nInv, integMethod = p.integMethod, RK4=p.scratchV4,scratchG=p.scratchV5)

    IEnergy = 0.0
    # compute residuals for each measurement timepoint along with computing energy
    for i in range(len(Imsmts)):
        if msmtinds[i] != -1:
            (g, ginv) = cpstates[msmtinds[i]]
            ca.ApplyH(gradAtMsmts[i], p.I0, ginv)
            ca.Sub_I(gradAtMsmts[i], Imsmts[i])
            # while we have residual, save the image energy
            IEnergy += ca.Sum2(
                gradAtMsmts[i]) / (2 * p.sigma * p.sigma * float(p.I0.nVox()))
            ca.DivC_I(gradAtMsmts[i],
                      p.sigma * p.sigma)  # gradient at measurement
        elif msmtinds[i] == -1:
            ca.Copy(gradAtMsmts[i], p.I0)
            ca.Sub_I(gradAtMsmts[i], Imsmts[i])
            # while we have residual, save the image energy
            IEnergy += ca.Sum2(
                gradAtMsmts[i]) / (2 * p.sigma * p.sigma * float(p.I0.nVox()))
            ca.DivC_I(gradAtMsmts[i],
                      p.sigma * p.sigma)  # gradient at measurement

    # integrate backward
    CAvmCommon.IntegrateAdjoints(p.Iadj,p.madj,\
                                 p.I,p.m,p.Iadjtmp, p.madjtmp,p.scratchV1,\
                                 p.scratchV2,p.scratchV3,\
                                 p.I0,p.m0,\
                                 t, cpstates, cpinds,\
                                 gradAtMsmts,msmtinds,\
                                 p.diffOp,\
                                 p.integMethod, p.nInv, \
                                 scratchV3=p.scratchV7, scratchV4=p.g,scratchV5=p.ginv,scratchV6=p.scratchV8, scratchV7=p.scratchV9, \
                                 scratchV8=p.scratchV10,scratchV9=p.scratchV11,\
                                 RK4=p.scratchV4, scratchG=p.scratchV5, scratchGinv=p.scratchV6,\
                                 scratchI = p.scratchI1)

    # compute gradient
    ca.Copy(p.scratchV1, p.m0)
    p.diffOp.applyInverseOperator(p.scratchV1)
    # while we have velocity, save the vector energy
    VEnergy = 0.5 * ca.Dot(p.m0, p.scratchV1) / float(p.I0.nVox())

    ca.Sub_I(p.scratchV1, p.madj)
    #p.diffOp.applyOperator(p.scratchV1)

    # compute closed from terms for image update
    # p.Iadjtmp and p.I will be used as scratch images
    scratchI = p.scratchI1  #reference assigned
    imOnes = p.I  #reference assigned
    ca.SetMem(imOnes, 1.0)
    ca.SetMem(p.sumSplatI, 0.0)
    ca.SetMem(p.sumJac, 0.0)
    #common.DebugHere()
    for i in range(len(Imsmts)):
        # TODO: check these indexings for cases when timepoint 0
        # is not checkpointed
        if msmtinds[i] != -1:
            (g, ginv) = cpstates[msmtinds[i]]
            CAvmCommon.SplatSafe(scratchI, ginv, Imsmts[i])
            ca.Add_I(p.sumSplatI, scratchI)
            CAvmCommon.SplatSafe(scratchI, ginv, imOnes)
            ca.Add_I(p.sumJac, scratchI)
        elif msmtinds[i] == -1:
            ca.Add_I(p.sumSplatI, Imsmts[i])
            ca.Add_I(p.sumJac, imOnes)
    return (p.scratchV1, p.sumJac, p.sumSplatI, VEnergy, IEnergy)
Ejemplo n.º 9
0
def ElastReg(I0Orig,
             I1Orig,
             scales=[1],
             nIters=[1000],
             maxPert=[0.2],
             fluidParams=[0.1, 0.1, 0.001],
             VFC=0.2,
             Mask=None,
             plotEvery=100):

    mType = I0Orig.memType()
    origGrid = I0Orig.grid()

    # allocate vars
    I0 = ca.Image3D(origGrid, mType)
    I1 = ca.Image3D(origGrid, mType)
    u = ca.Field3D(origGrid, mType)
    Idef = ca.Image3D(origGrid, mType)
    diff = ca.Image3D(origGrid, mType)
    gI = ca.Field3D(origGrid, mType)
    gU = ca.Field3D(origGrid, mType)
    scratchI = ca.Image3D(origGrid, mType)
    scratchV = ca.Field3D(origGrid, mType)

    # mask
    if Mask != None:
        MaskOrig = Mask.copy()

    # allocate diffOp
    if mType == ca.MEM_HOST:
        diffOp = ca.FluidKernelFFTCPU()
    else:
        diffOp = ca.FluidKernelFFTGPU()

    # initialize some vars
    nScales = len(scales)
    scaleManager = ca.MultiscaleManager(origGrid)
    for s in scales:
        scaleManager.addScaleLevel(s)

    # Initalize the thread memory manager (needed for resampler)
    # num pools is 2 (images) + 2*3 (fields)
    ca.ThreadMemoryManager.init(origGrid, mType, 8)

    if mType == ca.MEM_HOST:
        resampler = ca.MultiscaleResamplerGaussCPU(origGrid)
    else:
        resampler = ca.MultiscaleResamplerGaussGPU(origGrid)

    def setScale(scale):
        global curGrid

        scaleManager.set(scale)
        curGrid = scaleManager.getCurGrid()
        # since this is only 2D:
        curGrid.spacing().z = 1.0

        resampler.setScaleLevel(scaleManager)

        diffOp.setAlpha(fluidParams[0])
        diffOp.setBeta(fluidParams[1])
        diffOp.setGamma(fluidParams[2])
        diffOp.setGrid(curGrid)

        # downsample images
        I0.setGrid(curGrid)
        I1.setGrid(curGrid)
        if scaleManager.isLastScale():
            ca.Copy(I0, I0Orig)
            ca.Copy(I1, I1Orig)
        else:
            resampler.downsampleImage(I0, I0Orig)
            resampler.downsampleImage(I1, I1Orig)

        if Mask != None:
            if scaleManager.isLastScale():
                Mask.setGrid(curGrid)
                ca.Copy(Mask, MaskOrig)
            else:
                resampler.downsampleImage(Mask, MaskOrig)

        # initialize / upsample deformation
        if scaleManager.isFirstScale():
            u.setGrid(curGrid)
            ca.SetMem(u, 0.0)
        else:
            resampler.updateVField(u)

        # set grids
        gI.setGrid(curGrid)
        Idef.setGrid(curGrid)
        diff.setGrid(curGrid)
        gU.setGrid(curGrid)
        scratchI.setGrid(curGrid)
        scratchV.setGrid(curGrid)

    # end function

    energy = [[] for _ in xrange(3)]

    for scale in range(len(scales)):

        setScale(scale)
        ustep = None
        # update gradient
        ca.Gradient(gI, I0)

        for it in range(nIters[scale]):
            print 'iter %d' % it

            # compute deformed image
            ca.ApplyV(Idef, I0, u, 1.0)

            # update u
            ca.Sub(diff, I1, Idef)

            if Mask != None:
                ca.Mul_I(diff, Mask)

            ca.ApplyV(scratchV, gI, u, ca.BACKGROUND_STRATEGY_CLAMP)
            ca.Mul_I(scratchV, diff)

            diffOp.applyInverseOperator(gU, scratchV)

            vfcEn = VFC * ca.Dot(scratchV, gU)

            # why is this negative necessary?
            ca.MulC_I(gU, -1.0)

            # u =  u*(1-VFC*ustep) + (-2.0*ustep)*gU
            # MulC_Add_MulC_I(u, (1-VFC*ustep),
            #                        gU, 2.0*ustep)

            # u =  u - ustep*(VFC*u + 2.0*gU)
            ca.MulC_I(gU, 2.0)

            # subtract average if gamma is zero (result of nullspace
            # of L for K(L(u)))
            if fluidParams[2] == 0:
                av = ca.SumComp(u)
                av /= scratchI.nVox()
                ca.SubC(scratchV, u, av)
            # continue computing gradient
            ca.MulC(scratchV, u, VFC)
            ca.Add_I(gU, scratchV)

            ca.Magnitude(scratchI, gU)
            gradmax = ca.Max(scratchI)
            if ustep is None or ustep * gradmax > maxPert:
                ustep = maxPert[scale] / gradmax
                print 'step is %f' % ustep

            ca.MulC_I(gU, ustep)
            # apply gradient
            ca.Sub_I(u, gU)

            # compute energy
            energy[0].append(ca.Sum2(diff))
            diffOp.applyOperator(scratchV, u)
            energy[1].append(0.5 * VFC * ca.Dot(u, scratchV))
            energy[2].append(energy[0][-1]+\
                             energy[1][-1])

            if plotEvery > 0 and \
                   ((it+1) % plotEvery == 0 or \
                    (scale == nScales-1 and it == nIters[scale]-1)):
                print 'plotting'
                clrlist = ['r', 'g', 'b', 'm', 'c', 'y', 'k']
                plt.figure('energy')
                for i in range(len(energy)):
                    plt.plot(energy[i], clrlist[i])
                    if i == 0:
                        plt.hold(True)
                plt.hold(False)
                plt.draw()

                plt.figure('results')
                plt.clf()
                plt.subplot(3, 2, 1)
                display.DispImage(I0, 'I0', newFig=False)
                plt.subplot(3, 2, 2)
                display.DispImage(I1, 'I1', newFig=False)
                plt.subplot(3, 2, 3)
                display.DispImage(Idef, 'def', newFig=False)
                plt.subplot(3, 2, 4)
                display.DispImage(diff, 'diff', newFig=False)
                plt.colorbar()
                plt.subplot(3, 2, 5)
                display.GridPlot(u, every=4)
                plt.subplot(3, 2, 6)
                display.JacDetPlot(u)
                plt.colorbar()
                plt.draw()
                plt.show()

            # end plot
        # end iteration
    # end scale
    return (Idef, u, energy)