Ejemplo n.º 1
0
def geodesic_shooting_diffOp(moving, target, m0, steps, mType, config):
	grid = moving.grid()
	m0.setGrid(grid)
	ca.ThreadMemoryManager.init(grid, mType, 1)
	if mType == ca.MEM_HOST:
		diffOp = ca.FluidKernelFFTCPU()
	else:
		diffOp = ca.FluidKernelFFTGPU()
	diffOp.setAlpha(config['deformation_params']['diffOpParams'][0])
	diffOp.setBeta(config['deformation_params']['diffOpParams'][1])
	diffOp.setGamma(config['deformation_params']['diffOpParams'][2])
	diffOp.setGrid(grid)

	g = ca.Field3D(grid, mType)
	ginv = ca.Field3D(grid, mType)
	mt = ca.Field3D(grid, mType)
	It = ca.Image3D(grid, mType)
	It_inv = ca.Image3D(grid, mType)

	if (steps <= 0):
		time_steps = config['deformation_params']['timeSteps'];
	else:
		time_steps = steps;

	t = [x*1./time_steps for x in range(time_steps+1)]
	checkpointinds = range(1,len(t))
	checkpointstates =  [(ca.Field3D(grid,mType),ca.Field3D(grid,mType)) for idx in checkpointinds]

	scratchV1 = ca.Field3D(grid,mType)
	scratchV2 = ca.Field3D(grid,mType)
	scratchV3 = ca.Field3D(grid,mType)    

	CAvmCommon.IntegrateGeodesic(m0,t,diffOp, mt, g, ginv,\
								 scratchV1,scratchV2,scratchV3,\
                                 keepstates=checkpointstates,keepinds=checkpointinds,
                                 Ninv=config['deformation_params']['NIterForInverse'], integMethod = config['deformation_params']['integMethod'])
	ca.ApplyH(It,moving,ginv)
	ca.ApplyH(It_inv,target,g)

	output={'I1':It, 'I1_inv': It_inv, "phiinv":ginv}
	return output
Ejemplo n.º 2
0
    def __init__(self,
                 I0,
                 I1,
                 sigma,
                 t,
                 cpinds,
                 cpstates,
                 alpha,
                 beta,
                 gamma,
                 nIter,
                 stepSize,
                 maxPert,
                 nTimeSteps,
                 integMethod='RK4',
                 optMethod='FIXEDGD',
                 nInv=10,
                 plotEvery=None,
                 plotSlice=None,
                 quiverEvery=None,
                 outputPrefix='./'):
        """
        Initialize everything with the size and type given
        """
        print I0
        self.I0 = I0
        self.I1 = I1
        self.grid = I0.grid()
        print(self.grid)
        self.memtype = I0.memType()

        # matching param
        self.sigma = sigma

        # initial conditions
        self.m0 = ca.Field3D(self.grid, self.memtype)
        ca.SetMem(self.m0, 0.0)

        # state variables
        self.g = ca.Field3D(self.grid, self.memtype)
        self.ginv = ca.Field3D(self.grid, self.memtype)
        self.m = ca.Field3D(self.grid, self.memtype)
        self.I = ca.Image3D(self.grid, self.memtype)
        self.residualIm = ca.Image3D(self.grid, self.memtype)

        # adjoint variables
        self.madj = ca.Field3D(self.grid, self.memtype)
        self.Iadj = ca.Image3D(self.grid, self.memtype)
        self.madjtmp = ca.Field3D(self.grid, self.memtype)
        self.Iadjtmp = ca.Image3D(self.grid, self.memtype)

        # time array
        self.t = t

        # checkpointing variables
        self.checkpointinds = cpinds
        self.checkpointstates = cpstates

        # set up diffOp
        if self.memtype == ca.MEM_HOST:
            self.diffOp = ca.FluidKernelFFTCPU()
        else:
            self.diffOp = ca.FluidKernelFFTGPU()
        self.diffOp.setAlpha(alpha)
        self.diffOp.setBeta(beta)
        self.diffOp.setGamma(gamma)
        self.diffOp.setGrid(self.grid)

        # energy
        self.Energy = None

        # optimization stuff
        self.nIter = nIter
        self.stepSize = stepSize
        self.maxPert = maxPert
        self.nTimeSteps = nTimeSteps
        self.optMethod = optMethod
        self.integMethod = integMethod
        self.nInv = nInv  # for interative update to inverse deformation

        # plotting variables
        self.plotEvery = plotEvery
        self.plotSlice = plotSlice
        self.quiverEvery = quiverEvery
        self.outputPrefix = outputPrefix
        # scratch variables
        self.scratchV1 = ca.Field3D(self.grid, self.memtype)
        self.scratchV2 = ca.Field3D(self.grid, self.memtype)
        self.scratchV3 = ca.Field3D(self.grid, self.memtype)

        self.scratchV4 = ca.Field3D(self.grid, self.memtype)
        self.scratchV5 = ca.Field3D(self.grid, self.memtype)
        self.scratchV6 = ca.Field3D(self.grid, self.memtype)
        self.scratchV7 = ca.Field3D(self.grid, self.memtype)
        self.scratchV8 = ca.Field3D(self.grid, self.memtype)
        self.scratchV9 = ca.Field3D(self.grid, self.memtype)
        self.scratchV10 = ca.Field3D(self.grid, self.memtype)
        self.scratchV11 = ca.Field3D(self.grid, self.memtype)
    def __init__(self,
                 grid,
                 mType,
                 alpha,
                 beta,
                 gamma,
                 nInv,
                 sigma,
                 StepSize,
                 integMethod='EULER'):
        """
        Initialize everything with the size and type given
        """
        self.grid = grid
        self.memtype = mType

        # initial conditions
        self.I0 = None  # this is a reference that always points to the atlas image
        self.m0 = None  # this is a reference that gets assigned to momenta for an individual each time

        # state variables
        self.g = ca.Field3D(self.grid, self.memtype)
        self.ginv = ca.Field3D(self.grid, self.memtype)
        self.m = ca.Field3D(self.grid, self.memtype)
        self.I = ca.Image3D(self.grid, self.memtype)

        # adjoint variables
        self.madj = ca.Field3D(self.grid, self.memtype)
        self.Iadj = ca.Image3D(self.grid, self.memtype)
        self.madjtmp = ca.Field3D(self.grid, self.memtype)
        self.Iadjtmp = ca.Image3D(self.grid, self.memtype)

        # image variables for closed-form template update
        self.sumSplatI = ca.Image3D(self.grid, self.memtype)
        self.sumJac = ca.Image3D(self.grid, self.memtype)

        # set up diffOp
        if self.memtype == ca.MEM_HOST:
            self.diffOp = ca.FluidKernelFFTCPU()
        else:
            self.diffOp = ca.FluidKernelFFTGPU()
        self.diffOp.setAlpha(alpha)
        self.diffOp.setBeta(beta)
        self.diffOp.setGamma(gamma)
        self.diffOp.setGrid(self.grid)

        # some extras
        self.nInv = nInv  # for interative update to inverse deformation
        self.integMethod = integMethod
        self.sigma = sigma
        self.stepSize = StepSize

        # TODO: scratch variables to be changed to using managed memory
        self.scratchV1 = ca.Field3D(self.grid, self.memtype)
        self.scratchV2 = ca.Field3D(self.grid, self.memtype)
        self.scratchV3 = ca.Field3D(self.grid, self.memtype)
        self.scratchV4 = ca.Field3D(self.grid, self.memtype)
        self.scratchV5 = ca.Field3D(self.grid, self.memtype)
        self.scratchV6 = ca.Field3D(self.grid, self.memtype)
        self.scratchV7 = ca.Field3D(self.grid, self.memtype)
        self.scratchV8 = ca.Field3D(self.grid, self.memtype)
        self.scratchV9 = ca.Field3D(self.grid, self.memtype)
        self.scratchV10 = ca.Field3D(self.grid, self.memtype)
        self.scratchV11 = ca.Field3D(self.grid, self.memtype)
        self.scratchI1 = ca.Image3D(
            self.grid,
            self.memtype)  #only used  for geodesic regression with RK4
def GeodesicShooting(cf):

    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # Output loaded config
    if cf.io.outputPrefix is not None:
        cfstr = Config.ConfigToYAML(GeodesicShootingConfigSpec, cf)
        with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
            f.write(cfstr)

    mType = ca.MEM_DEVICE if cf.useCUDA else ca.MEM_HOST
    #common.DebugHere()
    I0 = common.LoadITKImage(cf.study.I0, mType)
    m0 = common.LoadITKField(cf.study.m0, mType)
    grid = I0.grid()

    ca.ThreadMemoryManager.init(grid, mType, 1)
    # set up diffOp
    if mType == ca.MEM_HOST:
        diffOp = ca.FluidKernelFFTCPU()
    else:
        diffOp = ca.FluidKernelFFTGPU()
    diffOp.setAlpha(cf.diffOpParams[0])
    diffOp.setBeta(cf.diffOpParams[1])
    diffOp.setGamma(cf.diffOpParams[2])
    diffOp.setGrid(grid)

    g = ca.Field3D(grid, mType)
    ginv = ca.Field3D(grid, mType)
    mt = ca.Field3D(grid, mType)
    It = ca.Image3D(grid, mType)
    t = [
        x * 1. / cf.integration.nTimeSteps
        for x in range(cf.integration.nTimeSteps + 1)
    ]
    checkpointinds = range(1, len(t))
    checkpointstates = [(ca.Field3D(grid, mType), ca.Field3D(grid, mType))
                        for idx in checkpointinds]

    scratchV1 = ca.Field3D(grid, mType)
    scratchV2 = ca.Field3D(grid, mType)
    scratchV3 = ca.Field3D(grid, mType)
    # scale momenta to shoot
    cf.study.scaleMomenta = float(cf.study.scaleMomenta)
    if abs(cf.study.scaleMomenta) > 0.000000:
        ca.MulC_I(m0, float(cf.study.scaleMomenta))
        CAvmCommon.IntegrateGeodesic(m0,t,diffOp, mt, g, ginv,\
                                     scratchV1,scratchV2,scratchV3,\
                                     keepstates=checkpointstates,keepinds=checkpointinds,
                                     Ninv=cf.integration.NIterForInverse, integMethod = cf.integration.integMethod)
    else:
        ca.Copy(It, I0)
        ca.Copy(mt, m0)
        ca.SetToIdentity(ginv)
        ca.SetToIdentity(g)

    # write output
    if cf.io.outputPrefix is not None:
        # scale back shotmomenta before writing
        if abs(cf.study.scaleMomenta) > 0.000000:
            ca.ApplyH(It, I0, ginv)
            ca.CoAd(mt, ginv, m0)
            ca.DivC_I(mt, float(cf.study.scaleMomenta))

        common.SaveITKImage(It, cf.io.outputPrefix + "I1.mhd")
        common.SaveITKField(mt, cf.io.outputPrefix + "m1.mhd")
        common.SaveITKField(ginv, cf.io.outputPrefix + "phiinv.mhd")
        common.SaveITKField(g, cf.io.outputPrefix + "phi.mhd")
        GeodesicShootingPlots(g, ginv, I0, It, cf)
        if cf.io.saveFrames:
            SaveFrames(checkpointstates, checkpointinds, I0, It, m0, mt, cf)
Ejemplo n.º 5
0
def ElastReg(I0Orig,
             I1Orig,
             scales=[1],
             nIters=[1000],
             maxPert=[0.2],
             fluidParams=[0.1, 0.1, 0.001],
             VFC=0.2,
             Mask=None,
             plotEvery=100):

    mType = I0Orig.memType()
    origGrid = I0Orig.grid()

    # allocate vars
    I0 = ca.Image3D(origGrid, mType)
    I1 = ca.Image3D(origGrid, mType)
    u = ca.Field3D(origGrid, mType)
    Idef = ca.Image3D(origGrid, mType)
    diff = ca.Image3D(origGrid, mType)
    gI = ca.Field3D(origGrid, mType)
    gU = ca.Field3D(origGrid, mType)
    scratchI = ca.Image3D(origGrid, mType)
    scratchV = ca.Field3D(origGrid, mType)

    # mask
    if Mask != None:
        MaskOrig = Mask.copy()

    # allocate diffOp
    if mType == ca.MEM_HOST:
        diffOp = ca.FluidKernelFFTCPU()
    else:
        diffOp = ca.FluidKernelFFTGPU()

    # initialize some vars
    nScales = len(scales)
    scaleManager = ca.MultiscaleManager(origGrid)
    for s in scales:
        scaleManager.addScaleLevel(s)

    # Initalize the thread memory manager (needed for resampler)
    # num pools is 2 (images) + 2*3 (fields)
    ca.ThreadMemoryManager.init(origGrid, mType, 8)

    if mType == ca.MEM_HOST:
        resampler = ca.MultiscaleResamplerGaussCPU(origGrid)
    else:
        resampler = ca.MultiscaleResamplerGaussGPU(origGrid)

    def setScale(scale):
        global curGrid

        scaleManager.set(scale)
        curGrid = scaleManager.getCurGrid()
        # since this is only 2D:
        curGrid.spacing().z = 1.0

        resampler.setScaleLevel(scaleManager)

        diffOp.setAlpha(fluidParams[0])
        diffOp.setBeta(fluidParams[1])
        diffOp.setGamma(fluidParams[2])
        diffOp.setGrid(curGrid)

        # downsample images
        I0.setGrid(curGrid)
        I1.setGrid(curGrid)
        if scaleManager.isLastScale():
            ca.Copy(I0, I0Orig)
            ca.Copy(I1, I1Orig)
        else:
            resampler.downsampleImage(I0, I0Orig)
            resampler.downsampleImage(I1, I1Orig)

        if Mask != None:
            if scaleManager.isLastScale():
                Mask.setGrid(curGrid)
                ca.Copy(Mask, MaskOrig)
            else:
                resampler.downsampleImage(Mask, MaskOrig)

        # initialize / upsample deformation
        if scaleManager.isFirstScale():
            u.setGrid(curGrid)
            ca.SetMem(u, 0.0)
        else:
            resampler.updateVField(u)

        # set grids
        gI.setGrid(curGrid)
        Idef.setGrid(curGrid)
        diff.setGrid(curGrid)
        gU.setGrid(curGrid)
        scratchI.setGrid(curGrid)
        scratchV.setGrid(curGrid)

    # end function

    energy = [[] for _ in xrange(3)]

    for scale in range(len(scales)):

        setScale(scale)
        ustep = None
        # update gradient
        ca.Gradient(gI, I0)

        for it in range(nIters[scale]):
            print 'iter %d' % it

            # compute deformed image
            ca.ApplyV(Idef, I0, u, 1.0)

            # update u
            ca.Sub(diff, I1, Idef)

            if Mask != None:
                ca.Mul_I(diff, Mask)

            ca.ApplyV(scratchV, gI, u, ca.BACKGROUND_STRATEGY_CLAMP)
            ca.Mul_I(scratchV, diff)

            diffOp.applyInverseOperator(gU, scratchV)

            vfcEn = VFC * ca.Dot(scratchV, gU)

            # why is this negative necessary?
            ca.MulC_I(gU, -1.0)

            # u =  u*(1-VFC*ustep) + (-2.0*ustep)*gU
            # MulC_Add_MulC_I(u, (1-VFC*ustep),
            #                        gU, 2.0*ustep)

            # u =  u - ustep*(VFC*u + 2.0*gU)
            ca.MulC_I(gU, 2.0)

            # subtract average if gamma is zero (result of nullspace
            # of L for K(L(u)))
            if fluidParams[2] == 0:
                av = ca.SumComp(u)
                av /= scratchI.nVox()
                ca.SubC(scratchV, u, av)
            # continue computing gradient
            ca.MulC(scratchV, u, VFC)
            ca.Add_I(gU, scratchV)

            ca.Magnitude(scratchI, gU)
            gradmax = ca.Max(scratchI)
            if ustep is None or ustep * gradmax > maxPert:
                ustep = maxPert[scale] / gradmax
                print 'step is %f' % ustep

            ca.MulC_I(gU, ustep)
            # apply gradient
            ca.Sub_I(u, gU)

            # compute energy
            energy[0].append(ca.Sum2(diff))
            diffOp.applyOperator(scratchV, u)
            energy[1].append(0.5 * VFC * ca.Dot(u, scratchV))
            energy[2].append(energy[0][-1]+\
                             energy[1][-1])

            if plotEvery > 0 and \
                   ((it+1) % plotEvery == 0 or \
                    (scale == nScales-1 and it == nIters[scale]-1)):
                print 'plotting'
                clrlist = ['r', 'g', 'b', 'm', 'c', 'y', 'k']
                plt.figure('energy')
                for i in range(len(energy)):
                    plt.plot(energy[i], clrlist[i])
                    if i == 0:
                        plt.hold(True)
                plt.hold(False)
                plt.draw()

                plt.figure('results')
                plt.clf()
                plt.subplot(3, 2, 1)
                display.DispImage(I0, 'I0', newFig=False)
                plt.subplot(3, 2, 2)
                display.DispImage(I1, 'I1', newFig=False)
                plt.subplot(3, 2, 3)
                display.DispImage(Idef, 'def', newFig=False)
                plt.subplot(3, 2, 4)
                display.DispImage(diff, 'diff', newFig=False)
                plt.colorbar()
                plt.subplot(3, 2, 5)
                display.GridPlot(u, every=4)
                plt.subplot(3, 2, 6)
                display.JacDetPlot(u)
                plt.colorbar()
                plt.draw()
                plt.show()

            # end plot
        # end iteration
    # end scale
    return (Idef, u, energy)