def BuildAtlas(cf):
    """Worker for running Atlas construction on a subset of individuals.
    Runs Atlas on this subset sequentially. The variations retuned are
    summed up to get update for all individuals
    """

    localRank = Compute.GetMPIInfo()['local_rank']
    rank = Compute.GetMPIInfo()['rank']

    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # just one reporter process on each node
    isReporter = rank == 0
    cf.study.numSubjects = len(cf.study.subjectImages)

    if isReporter:
        # Output loaded config
        if cf.io.outputPrefix is not None:
            cfstr = Config.ConfigToYAML(AtlasConfigSpec, cf)
            with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
                f.write(cfstr)
    #common.DebugHere()

    # if MPI check if processes are greater than number of subjects. it is okay if there are more subjects than processes

    if cf.compute.useMPI and (cf.study.numSubjects < cf.compute.numProcesses):
        raise Exception("Please don't use more processes " +
                        "than total number of individuals")

    # subdivide data, create subsets for this thread to work on
    nodeSubjectIds = cf.study.subjectIds[rank::cf.compute.numProcesses]
    nodeImages = cf.study.subjectImages[rank::cf.compute.numProcesses]
    nodeWeights = cf.study.subjectWeights[rank::cf.compute.numProcesses]

    numLocalSubjects = len(nodeImages)
    print 'rank:', rank, ', localRank:', localRank, ', nodeImages:', nodeImages, ', nodeWeights:', nodeWeights

    # mem type is determined by whether or not we're using CUDA
    mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST

    # load data in memory
    # load intercepts
    J_array = [
        common.LoadITKImage(f, mType) if isinstance(f, str) else f
        for f in nodeImages
    ]

    # get imGrid from data
    imGrid = J_array[0].grid()

    # atlas image
    atlas = ca.Image3D(imGrid, mType)

    # allocate memory to store only the initial momenta for each individual in this thread
    m_array = [ca.Field3D(imGrid, mType) for i in range(numLocalSubjects)]

    # allocate only one copy of scratch memory to be reused for each local individual in this thread in loop
    p = WarpVariables(imGrid,
                      mType,
                      cf.vectormomentum.diffOpParams[0],
                      cf.vectormomentum.diffOpParams[1],
                      cf.vectormomentum.diffOpParams[2],
                      cf.optim.NIterForInverse,
                      cf.vectormomentum.sigma,
                      cf.optim.stepSize,
                      integMethod=cf.optim.integMethod)

    # memory to accumulate numerators and denominators for atlas from
    # local individuals which will be summed across MPI threads
    sumSplatI = ca.Image3D(imGrid, mType)
    sumJac = ca.Image3D(imGrid, mType)

    # start up the memory manager for scratch variables
    ca.ThreadMemoryManager.init(imGrid, mType, 0)

    # need some host memory in np array format for MPI reductions
    if cf.compute.useMPI:
        mpiImageBuff = None if mType == ca.MEM_HOST else ca.Image3D(
            imGrid, ca.MEM_HOST)

    t = [
        x * 1. / (cf.optim.nTimeSteps) for x in range(cf.optim.nTimeSteps + 1)
    ]
    cpinds = range(1, len(t))
    msmtinds = [
        len(t) - 2
    ]  # since t=0 is not in cpinds, thats just identity deformation so not checkpointed
    cpstates = [(ca.Field3D(imGrid, mType), ca.Field3D(imGrid, mType))
                for idx in cpinds]
    gradAtMsmts = [ca.Image3D(imGrid, mType) for idx in msmtinds]

    EnergyHistory = []

    # TODO: better initializations
    # initialize atlas image with zeros.
    ca.SetMem(atlas, 0.0)
    # initialize momenta with zeros

    for m0_individual in m_array:
        ca.SetMem(m0_individual, 0.0)
    '''
    # initial template image
    ca.SetMem(groupState.I0, 0.0)
    tmp = ca.ManagedImage3D(imGrid, mType)

    for tdisc in tdiscGroup:
        if tdisc.J is not None:
            ca.Copy(tmp, tdisc.J)
            groupState.I0 += tmp
    del tmp
    if cf.compute.useMPI:
        Compute.Reduce(groupState.I0, mpiImageBuff)
    
    # divide by total num subjects
    groupState.I0 /= cf.study.numSubjects
    '''

    # preprocessinput

    # assign atlas reference to p.I0. This reference will not change.
    p.I0 = atlas

    # run the loop
    for it in range(cf.optim.Niter):
        # run one iteration of warp for each individual and update
        # their own initial momenta and also accumulate SplatI and Jac
        ca.SetMem(sumSplatI, 0.0)
        ca.SetMem(sumJac, 0.0)
        TotalVEnergy = np.array([0.0])
        TotalIEnergy = np.array([0.0])

        for itsub in range(numLocalSubjects):
            # initializations for this subject, this only assigns
            # reference to image variables
            p.m0 = m_array[itsub]
            Imsmts = [J_array[itsub]]

            # run warp iteration
            VEnergy, IEnergy = RunWarpIteration(nodeSubjectIds[itsub], cf, p,
                                                t, Imsmts, cpinds, cpstates,
                                                msmtinds, gradAtMsmts, it)

            # gather relevant results
            ca.Add_I(sumSplatI, p.sumSplatI)
            ca.Add_I(sumJac, p.sumJac)
            TotalVEnergy[0] += VEnergy
            TotalIEnergy[0] += IEnergy

        # if there are multiple nodes we'll need to sum across processes now
        if cf.compute.useMPI:
            # do an MPI sum
            Compute.Reduce(sumSplatI, mpiImageBuff)
            Compute.Reduce(sumJac, mpiImageBuff)

            # also sum up energies of other nodes
            mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE,
                                            TotalVEnergy,
                                            op=mpi4py.MPI.SUM)
            mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE,
                                            TotalIEnergy,
                                            op=mpi4py.MPI.SUM)

        EnergyHistory.append([TotalVEnergy[0], TotalIEnergy[0]])

        # now divide to get the new atlas image
        ca.Div(atlas, sumSplatI, sumJac)

        # keep track of energy in this iteration
        if isReporter and cf.io.plotEvery > 0 and ((
            (it + 1) % cf.io.plotEvery == 0) or (it == cf.optim.Niter - 1)):
            # plots
            AtlasPlots(cf, p, atlas, m_array, EnergyHistory)

        if isReporter:
            # print out energy
            (VEnergy, IEnergy) = EnergyHistory[-1]
            print "Iter", it, "of", cf.optim.Niter, ":", VEnergy + IEnergy, '(Total) = ', VEnergy, '(Vector) + ', IEnergy, '(Image)'

    # write output images and fields
    AtlasWriteOutput(cf, atlas, m_array, nodeSubjectIds, isReporter)
def WarpGradient(p, t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts):

    # shoot the geodesic forward
    CAvmCommon.IntegrateGeodesic(p.m0,t,p.diffOp, \
                                 p.m, p.g, p.ginv,\
                                 p.scratchV1, p.scratchV2,p. scratchV3,\
                                 cpstates, cpinds,\
                                 Ninv=p.nInv, integMethod = p.integMethod, RK4=p.scratchV4,scratchG=p.scratchV5)

    IEnergy = 0.0
    # compute residuals for each measurement timepoint along with computing energy
    for i in range(len(Imsmts)):
        if msmtinds[i] != -1:
            (g, ginv) = cpstates[msmtinds[i]]
            ca.ApplyH(gradAtMsmts[i], p.I0, ginv)
            ca.Sub_I(gradAtMsmts[i], Imsmts[i])
            # while we have residual, save the image energy
            IEnergy += ca.Sum2(
                gradAtMsmts[i]) / (2 * p.sigma * p.sigma * float(p.I0.nVox()))
            ca.DivC_I(gradAtMsmts[i],
                      p.sigma * p.sigma)  # gradient at measurement
        elif msmtinds[i] == -1:
            ca.Copy(gradAtMsmts[i], p.I0)
            ca.Sub_I(gradAtMsmts[i], Imsmts[i])
            # while we have residual, save the image energy
            IEnergy += ca.Sum2(
                gradAtMsmts[i]) / (2 * p.sigma * p.sigma * float(p.I0.nVox()))
            ca.DivC_I(gradAtMsmts[i],
                      p.sigma * p.sigma)  # gradient at measurement

    # integrate backward
    CAvmCommon.IntegrateAdjoints(p.Iadj,p.madj,\
                                 p.I,p.m,p.Iadjtmp, p.madjtmp,p.scratchV1,\
                                 p.scratchV2,p.scratchV3,\
                                 p.I0,p.m0,\
                                 t, cpstates, cpinds,\
                                 gradAtMsmts,msmtinds,\
                                 p.diffOp,\
                                 p.integMethod, p.nInv, \
                                 scratchV3=p.scratchV7, scratchV4=p.g,scratchV5=p.ginv,scratchV6=p.scratchV8, scratchV7=p.scratchV9, \
                                 scratchV8=p.scratchV10,scratchV9=p.scratchV11,\
                                 RK4=p.scratchV4, scratchG=p.scratchV5, scratchGinv=p.scratchV6,\
                                 scratchI = p.scratchI1)

    # compute gradient
    ca.Copy(p.scratchV1, p.m0)
    p.diffOp.applyInverseOperator(p.scratchV1)
    # while we have velocity, save the vector energy
    VEnergy = 0.5 * ca.Dot(p.m0, p.scratchV1) / float(p.I0.nVox())

    ca.Sub_I(p.scratchV1, p.madj)
    #p.diffOp.applyOperator(p.scratchV1)

    # compute closed from terms for image update
    # p.Iadjtmp and p.I will be used as scratch images
    scratchI = p.scratchI1  #reference assigned
    imOnes = p.I  #reference assigned
    ca.SetMem(imOnes, 1.0)
    ca.SetMem(p.sumSplatI, 0.0)
    ca.SetMem(p.sumJac, 0.0)
    #common.DebugHere()
    for i in range(len(Imsmts)):
        # TODO: check these indexings for cases when timepoint 0
        # is not checkpointed
        if msmtinds[i] != -1:
            (g, ginv) = cpstates[msmtinds[i]]
            CAvmCommon.SplatSafe(scratchI, ginv, Imsmts[i])
            ca.Add_I(p.sumSplatI, scratchI)
            CAvmCommon.SplatSafe(scratchI, ginv, imOnes)
            ca.Add_I(p.sumJac, scratchI)
        elif msmtinds[i] == -1:
            ca.Add_I(p.sumSplatI, Imsmts[i])
            ca.Add_I(p.sumJac, imOnes)
    return (p.scratchV1, p.sumJac, p.sumSplatI, VEnergy, IEnergy)
Ejemplo n.º 3
0
def ElastReg(I0Orig,
             I1Orig,
             scales=[1],
             nIters=[1000],
             maxPert=[0.2],
             fluidParams=[0.1, 0.1, 0.001],
             VFC=0.2,
             Mask=None,
             plotEvery=100):

    mType = I0Orig.memType()
    origGrid = I0Orig.grid()

    # allocate vars
    I0 = ca.Image3D(origGrid, mType)
    I1 = ca.Image3D(origGrid, mType)
    u = ca.Field3D(origGrid, mType)
    Idef = ca.Image3D(origGrid, mType)
    diff = ca.Image3D(origGrid, mType)
    gI = ca.Field3D(origGrid, mType)
    gU = ca.Field3D(origGrid, mType)
    scratchI = ca.Image3D(origGrid, mType)
    scratchV = ca.Field3D(origGrid, mType)

    # mask
    if Mask != None:
        MaskOrig = Mask.copy()

    # allocate diffOp
    if mType == ca.MEM_HOST:
        diffOp = ca.FluidKernelFFTCPU()
    else:
        diffOp = ca.FluidKernelFFTGPU()

    # initialize some vars
    nScales = len(scales)
    scaleManager = ca.MultiscaleManager(origGrid)
    for s in scales:
        scaleManager.addScaleLevel(s)

    # Initalize the thread memory manager (needed for resampler)
    # num pools is 2 (images) + 2*3 (fields)
    ca.ThreadMemoryManager.init(origGrid, mType, 8)

    if mType == ca.MEM_HOST:
        resampler = ca.MultiscaleResamplerGaussCPU(origGrid)
    else:
        resampler = ca.MultiscaleResamplerGaussGPU(origGrid)

    def setScale(scale):
        global curGrid

        scaleManager.set(scale)
        curGrid = scaleManager.getCurGrid()
        # since this is only 2D:
        curGrid.spacing().z = 1.0

        resampler.setScaleLevel(scaleManager)

        diffOp.setAlpha(fluidParams[0])
        diffOp.setBeta(fluidParams[1])
        diffOp.setGamma(fluidParams[2])
        diffOp.setGrid(curGrid)

        # downsample images
        I0.setGrid(curGrid)
        I1.setGrid(curGrid)
        if scaleManager.isLastScale():
            ca.Copy(I0, I0Orig)
            ca.Copy(I1, I1Orig)
        else:
            resampler.downsampleImage(I0, I0Orig)
            resampler.downsampleImage(I1, I1Orig)

        if Mask != None:
            if scaleManager.isLastScale():
                Mask.setGrid(curGrid)
                ca.Copy(Mask, MaskOrig)
            else:
                resampler.downsampleImage(Mask, MaskOrig)

        # initialize / upsample deformation
        if scaleManager.isFirstScale():
            u.setGrid(curGrid)
            ca.SetMem(u, 0.0)
        else:
            resampler.updateVField(u)

        # set grids
        gI.setGrid(curGrid)
        Idef.setGrid(curGrid)
        diff.setGrid(curGrid)
        gU.setGrid(curGrid)
        scratchI.setGrid(curGrid)
        scratchV.setGrid(curGrid)

    # end function

    energy = [[] for _ in xrange(3)]

    for scale in range(len(scales)):

        setScale(scale)
        ustep = None
        # update gradient
        ca.Gradient(gI, I0)

        for it in range(nIters[scale]):
            print 'iter %d' % it

            # compute deformed image
            ca.ApplyV(Idef, I0, u, 1.0)

            # update u
            ca.Sub(diff, I1, Idef)

            if Mask != None:
                ca.Mul_I(diff, Mask)

            ca.ApplyV(scratchV, gI, u, ca.BACKGROUND_STRATEGY_CLAMP)
            ca.Mul_I(scratchV, diff)

            diffOp.applyInverseOperator(gU, scratchV)

            vfcEn = VFC * ca.Dot(scratchV, gU)

            # why is this negative necessary?
            ca.MulC_I(gU, -1.0)

            # u =  u*(1-VFC*ustep) + (-2.0*ustep)*gU
            # MulC_Add_MulC_I(u, (1-VFC*ustep),
            #                        gU, 2.0*ustep)

            # u =  u - ustep*(VFC*u + 2.0*gU)
            ca.MulC_I(gU, 2.0)

            # subtract average if gamma is zero (result of nullspace
            # of L for K(L(u)))
            if fluidParams[2] == 0:
                av = ca.SumComp(u)
                av /= scratchI.nVox()
                ca.SubC(scratchV, u, av)
            # continue computing gradient
            ca.MulC(scratchV, u, VFC)
            ca.Add_I(gU, scratchV)

            ca.Magnitude(scratchI, gU)
            gradmax = ca.Max(scratchI)
            if ustep is None or ustep * gradmax > maxPert:
                ustep = maxPert[scale] / gradmax
                print 'step is %f' % ustep

            ca.MulC_I(gU, ustep)
            # apply gradient
            ca.Sub_I(u, gU)

            # compute energy
            energy[0].append(ca.Sum2(diff))
            diffOp.applyOperator(scratchV, u)
            energy[1].append(0.5 * VFC * ca.Dot(u, scratchV))
            energy[2].append(energy[0][-1]+\
                             energy[1][-1])

            if plotEvery > 0 and \
                   ((it+1) % plotEvery == 0 or \
                    (scale == nScales-1 and it == nIters[scale]-1)):
                print 'plotting'
                clrlist = ['r', 'g', 'b', 'm', 'c', 'y', 'k']
                plt.figure('energy')
                for i in range(len(energy)):
                    plt.plot(energy[i], clrlist[i])
                    if i == 0:
                        plt.hold(True)
                plt.hold(False)
                plt.draw()

                plt.figure('results')
                plt.clf()
                plt.subplot(3, 2, 1)
                display.DispImage(I0, 'I0', newFig=False)
                plt.subplot(3, 2, 2)
                display.DispImage(I1, 'I1', newFig=False)
                plt.subplot(3, 2, 3)
                display.DispImage(Idef, 'def', newFig=False)
                plt.subplot(3, 2, 4)
                display.DispImage(diff, 'diff', newFig=False)
                plt.colorbar()
                plt.subplot(3, 2, 5)
                display.GridPlot(u, every=4)
                plt.subplot(3, 2, 6)
                display.JacDetPlot(u)
                plt.colorbar()
                plt.draw()
                plt.show()

            # end plot
        # end iteration
    # end scale
    return (Idef, u, energy)