Ejemplo n.º 1
0
def MatchingGradient(p):
    # shoot the geodesic forward    
    CAvmCommon.IntegrateGeodesic(p.m0,p.t,p.diffOp, \
                      p.m, p.g, p.ginv,\
                      p.scratchV1, p.scratchV2,p. scratchV3,\
                      p.checkpointstates, p.checkpointinds,\
                      Ninv=p.nInv, integMethod = p.integMethod, RK4=p.scratchV4,scratchG=p.scratchV5)


    endidx = p.checkpointinds.index(len(p.t)-1)
    # compute residual image
    ca.ApplyH(p.residualIm,p.I0,p.ginv)
    ca.Sub_I(p.residualIm, p.I1)
    # while we have residual, save the image energy
    IEnergy = ca.Sum2(p.residualIm)/(2*p.sigma*p.sigma*float(p.I0.nVox()))
    
    ca.DivC_I(p.residualIm, p.sigma*p.sigma) # gradient at measurement
    
    # integrate backward
    CAvmCommon.IntegrateAdjoints(p.Iadj,p.madj,\
                      p.I,p.m,p.Iadjtmp, p.madjtmp,p.scratchV1,\
                      p.scratchV2,p.scratchV3,\
                      p.I0,p.m0,\
                      p.t, p.checkpointstates, p.checkpointinds,\
                      [p.residualIm], [endidx],\
                      p.diffOp,
                      p.integMethod, p.nInv, \
                      scratchV3=p.scratchV7, scratchV4=p.g,scratchV5=p.ginv,scratchV6=p.scratchV8, scratchV7=p.scratchV9, \
                      scratchV8=p.scratchV10,scratchV9=p.scratchV11,\
                      RK4=p.scratchV4, scratchG=p.scratchV5, scratchGinv=p.scratchV6)
                      
    
    # compute gradient
    ca.Copy(p.scratchV1, p.m0)
    p.diffOp.applyInverseOperator(p.scratchV1)
    # while we have velocity, save the vector energy
    VEnergy = 0.5*ca.Dot(p.m0,p.scratchV1)/float(p.I0.nVox())

    ca.Sub_I(p.scratchV1, p.madj)
    #p.diffOp.applyOperator(p.scratchV1)
    return (p.scratchV1, VEnergy, IEnergy)
def WarpGradient(p, t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts):

    # shoot the geodesic forward
    CAvmCommon.IntegrateGeodesic(p.m0,t,p.diffOp, \
                                 p.m, p.g, p.ginv,\
                                 p.scratchV1, p.scratchV2,p. scratchV3,\
                                 cpstates, cpinds,\
                                 Ninv=p.nInv, integMethod = p.integMethod, RK4=p.scratchV4,scratchG=p.scratchV5)

    IEnergy = 0.0
    # compute residuals for each measurement timepoint along with computing energy
    for i in range(len(Imsmts)):
        if msmtinds[i] != -1:
            (g, ginv) = cpstates[msmtinds[i]]
            ca.ApplyH(gradAtMsmts[i], p.I0, ginv)
            ca.Sub_I(gradAtMsmts[i], Imsmts[i])
            # while we have residual, save the image energy
            IEnergy += ca.Sum2(
                gradAtMsmts[i]) / (2 * p.sigma * p.sigma * float(p.I0.nVox()))
            ca.DivC_I(gradAtMsmts[i],
                      p.sigma * p.sigma)  # gradient at measurement
        elif msmtinds[i] == -1:
            ca.Copy(gradAtMsmts[i], p.I0)
            ca.Sub_I(gradAtMsmts[i], Imsmts[i])
            # while we have residual, save the image energy
            IEnergy += ca.Sum2(
                gradAtMsmts[i]) / (2 * p.sigma * p.sigma * float(p.I0.nVox()))
            ca.DivC_I(gradAtMsmts[i],
                      p.sigma * p.sigma)  # gradient at measurement

    # integrate backward
    CAvmCommon.IntegrateAdjoints(p.Iadj,p.madj,\
                                 p.I,p.m,p.Iadjtmp, p.madjtmp,p.scratchV1,\
                                 p.scratchV2,p.scratchV3,\
                                 p.I0,p.m0,\
                                 t, cpstates, cpinds,\
                                 gradAtMsmts,msmtinds,\
                                 p.diffOp,\
                                 p.integMethod, p.nInv, \
                                 scratchV3=p.scratchV7, scratchV4=p.g,scratchV5=p.ginv,scratchV6=p.scratchV8, scratchV7=p.scratchV9, \
                                 scratchV8=p.scratchV10,scratchV9=p.scratchV11,\
                                 RK4=p.scratchV4, scratchG=p.scratchV5, scratchGinv=p.scratchV6,\
                                 scratchI = p.scratchI1)

    # compute gradient
    ca.Copy(p.scratchV1, p.m0)
    p.diffOp.applyInverseOperator(p.scratchV1)
    # while we have velocity, save the vector energy
    VEnergy = 0.5 * ca.Dot(p.m0, p.scratchV1) / float(p.I0.nVox())

    ca.Sub_I(p.scratchV1, p.madj)
    #p.diffOp.applyOperator(p.scratchV1)

    # compute closed from terms for image update
    # p.Iadjtmp and p.I will be used as scratch images
    scratchI = p.scratchI1  #reference assigned
    imOnes = p.I  #reference assigned
    ca.SetMem(imOnes, 1.0)
    ca.SetMem(p.sumSplatI, 0.0)
    ca.SetMem(p.sumJac, 0.0)
    #common.DebugHere()
    for i in range(len(Imsmts)):
        # TODO: check these indexings for cases when timepoint 0
        # is not checkpointed
        if msmtinds[i] != -1:
            (g, ginv) = cpstates[msmtinds[i]]
            CAvmCommon.SplatSafe(scratchI, ginv, Imsmts[i])
            ca.Add_I(p.sumSplatI, scratchI)
            CAvmCommon.SplatSafe(scratchI, ginv, imOnes)
            ca.Add_I(p.sumJac, scratchI)
        elif msmtinds[i] == -1:
            ca.Add_I(p.sumSplatI, Imsmts[i])
            ca.Add_I(p.sumJac, imOnes)
    return (p.scratchV1, p.sumJac, p.sumSplatI, VEnergy, IEnergy)
def GeodesicShooting(cf):

    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # Output loaded config
    if cf.io.outputPrefix is not None:
        cfstr = Config.ConfigToYAML(GeodesicShootingConfigSpec, cf)
        with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
            f.write(cfstr)

    mType = ca.MEM_DEVICE if cf.useCUDA else ca.MEM_HOST
    #common.DebugHere()
    I0 = common.LoadITKImage(cf.study.I0, mType)
    m0 = common.LoadITKField(cf.study.m0, mType)
    grid = I0.grid()

    ca.ThreadMemoryManager.init(grid, mType, 1)
    # set up diffOp
    if mType == ca.MEM_HOST:
        diffOp = ca.FluidKernelFFTCPU()
    else:
        diffOp = ca.FluidKernelFFTGPU()
    diffOp.setAlpha(cf.diffOpParams[0])
    diffOp.setBeta(cf.diffOpParams[1])
    diffOp.setGamma(cf.diffOpParams[2])
    diffOp.setGrid(grid)

    g = ca.Field3D(grid, mType)
    ginv = ca.Field3D(grid, mType)
    mt = ca.Field3D(grid, mType)
    It = ca.Image3D(grid, mType)
    t = [
        x * 1. / cf.integration.nTimeSteps
        for x in range(cf.integration.nTimeSteps + 1)
    ]
    checkpointinds = range(1, len(t))
    checkpointstates = [(ca.Field3D(grid, mType), ca.Field3D(grid, mType))
                        for idx in checkpointinds]

    scratchV1 = ca.Field3D(grid, mType)
    scratchV2 = ca.Field3D(grid, mType)
    scratchV3 = ca.Field3D(grid, mType)
    # scale momenta to shoot
    cf.study.scaleMomenta = float(cf.study.scaleMomenta)
    if abs(cf.study.scaleMomenta) > 0.000000:
        ca.MulC_I(m0, float(cf.study.scaleMomenta))
        CAvmCommon.IntegrateGeodesic(m0,t,diffOp, mt, g, ginv,\
                                     scratchV1,scratchV2,scratchV3,\
                                     keepstates=checkpointstates,keepinds=checkpointinds,
                                     Ninv=cf.integration.NIterForInverse, integMethod = cf.integration.integMethod)
    else:
        ca.Copy(It, I0)
        ca.Copy(mt, m0)
        ca.SetToIdentity(ginv)
        ca.SetToIdentity(g)

    # write output
    if cf.io.outputPrefix is not None:
        # scale back shotmomenta before writing
        if abs(cf.study.scaleMomenta) > 0.000000:
            ca.ApplyH(It, I0, ginv)
            ca.CoAd(mt, ginv, m0)
            ca.DivC_I(mt, float(cf.study.scaleMomenta))

        common.SaveITKImage(It, cf.io.outputPrefix + "I1.mhd")
        common.SaveITKField(mt, cf.io.outputPrefix + "m1.mhd")
        common.SaveITKField(ginv, cf.io.outputPrefix + "phiinv.mhd")
        common.SaveITKField(g, cf.io.outputPrefix + "phi.mhd")
        GeodesicShootingPlots(g, ginv, I0, It, cf)
        if cf.io.saveFrames:
            SaveFrames(checkpointstates, checkpointinds, I0, It, m0, mt, cf)