Exemple #1
0
def plotResults(I, Data, fig, cmap='gray', rng=[0, 1]):
    plt.figure(fig)
    plt.subplot(1, 3, 1)
    display.DispImage(Data, 'Orig', cmap=cmap, \
                               newFig=False, rng=rng, t=False)
    plt.subplot(1, 3, 2)
    display.DispImage(I, 'Denoised', cmap=cmap, \
                               newFig=False, rng=rng, t=False)
    Sub(scratchI, I, Data)
    plt.subplot(1, 3, 3)
    display.DispImage(scratchI, 'diff', cmap=cmap, \
                               newFig=False, rng=None, t=False)
    plt.draw()
    plt.show()
def GeodesicShootingPlots(g, ginv, I0, It, cf):
    fig = plt.figure(3)
    plt.clf()
    fig.patch.set_facecolor('white')
    plt.subplot(2, 2, 1)
    CAvmCommon.MyGridPlot(g,
                          every=cf.io.gridEvery,
                          color='k',
                          dim=cf.io.plotSliceDim,
                          sliceIdx=cf.io.plotSlice,
                          isVF=False)
    plt.axis('equal')
    plt.axis('off')
    plt.title('\phi')

    plt.subplot(2, 2, 2)
    CAvmCommon.MyGridPlot(ginv,
                          every=cf.io.gridEvery,
                          color='k',
                          dim=cf.io.plotSliceDim,
                          sliceIdx=cf.io.plotSlice,
                          isVF=False)
    plt.axis('equal')
    plt.axis('off')
    plt.title('\phi^{-1}')

    plt.subplot(2, 2, 3)
    display.DispImage(I0,
                      'I0',
                      newFig=False,
                      dim=cf.io.plotSliceDim,
                      sliceIdx=cf.io.plotSlice)
    plt.subplot(2, 2, 4)
    display.DispImage(It,
                      'I1',
                      newFig=False,
                      dim=cf.io.plotSliceDim,
                      sliceIdx=cf.io.plotSlice)
    plt.draw()
    plt.show()
    if cf.io.outputPrefix != None:
        plt.savefig(cf.io.outputPrefix + 'shooting.pdf')
Exemple #3
0
def MatchingPlots(p):
    """
    Do some summary plots for image matching
    """
    #reset all variables by shooting once, may have been overwritten
    CAvmCommon.IntegrateGeodesic(p.m0,p.t,p.diffOp,\
                          p.m, p.g, p.ginv,\
                          p.scratchV1, p.scratchV2,p. scratchV3,\
                          p.checkpointstates, p.checkpointinds,\
                          Ninv=p.nInv, integMethod = p.integMethod)

    # plot the images
    fig = plt.figure('images')
    plt.clf()
    fig.patch.set_facecolor('white')
    plt.subplot(2, 2, 1)
    display.DispImage(p.I0,
                      'I0',
                      newFig=False,
                      cmap='gray',
                      sliceIdx=p.plotSlice)
    plt.subplot(2, 2, 2)
    indx_of_last_tp = p.checkpointinds.index(len(p.t) - 1)
    (g, ginv) = p.checkpointstates[indx_of_last_tp]
    ca.ApplyH(p.I, p.I0, ginv)
    display.DispImage(p.I,
                      '\phi.I0',
                      newFig=False,
                      cmap='gray',
                      sliceIdx=p.plotSlice)
    plt.subplot(2, 2, 3)
    display.DispImage(p.I1,
                      'I1',
                      newFig=False,
                      cmap='gray',
                      sliceIdx=p.plotSlice)
    plt.subplot(2, 2, 4)
    ca.ApplyH(p.I, p.I1, g)
    display.DispImage(p.I,
                      '\phi^{-1}.I1',
                      newFig=False,
                      cmap='gray',
                      sliceIdx=p.plotSlice)

    plt.draw()
    plt.show()
    if p.outputPrefix != None: plt.savefig(p.outputPrefix + 'images.pdf')
    fig = plt.figure('def')
    plt.clf()
    fig.patch.set_facecolor('white')
    plt.subplot(2, 2, 1)
    display.GridPlot(ginv,
                     every=p.quiverEvery,
                     color='k',
                     sliceIdx=p.plotSlice,
                     isVF=False)
    plt.axis('equal')
    plt.axis('off')
    plt.title('\phi^{-1}')
    plt.subplot(2, 2, 2)
    display.GridPlot(g,
                     every=p.quiverEvery,
                     color='k',
                     sliceIdx=p.plotSlice,
                     isVF=False)
    plt.axis('equal')
    plt.axis('off')
    plt.title('\phi')
    plt.subplot(2, 2, 3)
    ca.JacDetH(p.I, ginv)  #p.I used as scratch variable to compute jacobian
    display.DispImage(p.I, '|D\phi^{-1}|', newFig=False, sliceIdx=p.plotSlice)
    plt.subplot(2, 2, 4)
    ca.MulC_I(p.residualIm, p.sigma * p.sigma)
    display.DispImage(p.residualIm,
                      '\phi.I0-I1',
                      newFig=False,
                      sliceIdx=p.plotSlice)
    plt.colorbar()
    plt.draw()
    plt.show()
    if p.outputPrefix != None: plt.savefig(p.outputPrefix + 'def.pdf')
    fig = plt.figure('energy')
    fig.patch.set_facecolor('white')

    TE = [sum(x) for x in p.Energy]
    VE = [row[0] for row in p.Energy]
    IE = [row[1] for row in p.Energy]
    plt.subplot(1, 3, 1)
    plt.plot(TE)
    plt.title('Total Energy')
    plt.hold(False)
    plt.subplot(1, 3, 2)
    plt.plot(VE)
    plt.title('Vector Energy')
    plt.hold(False)
    plt.subplot(1, 3, 3)
    plt.plot(IE)
    plt.title('Image Energy')
    plt.hold(False)
    plt.draw()
    plt.show()
    if p.outputPrefix != None: plt.savefig(p.outputPrefix + 'energy.pdf')
def MatchingImageMomentaPlots(cf,
                              geodesicState,
                              tDiscGeodesic,
                              EnergyHistory,
                              m0,
                              J1,
                              n1,
                              writeOutput=True):
    """
    Do some summary plots for MatchingImageMomenta
    """

    #ENERGY
    fig = plt.figure(1)
    plt.clf()
    fig.patch.set_facecolor('white')

    TE = [row[0] for row in EnergyHistory]
    VE = [row[1] for row in EnergyHistory]
    IE = [row[2] for row in EnergyHistory]
    ME = [row[3] for row in EnergyHistory]
    plt.subplot(2, 2, 1)
    plt.plot(TE)
    plt.title('Total Energy')
    plt.hold(False)
    plt.subplot(2, 2, 2)
    plt.plot(VE)
    plt.title('Vector Energy')
    plt.hold(False)
    plt.subplot(2, 2, 3)
    plt.plot(IE)
    plt.title('Image Match Energy')
    plt.hold(False)
    plt.subplot(2, 2, 4)
    plt.plot(ME)
    plt.title('Momenta Match Energy')
    plt.hold(False)
    plt.draw()
    plt.show()
    if cf.io.outputPrefix != None and writeOutput:
        plt.savefig(cf.io.outputPrefix + 'energy.pdf')

    # GEODESIC INITIAL CONDITIONS and RHO and RHO inv
    CAvmHGMCommon.HGMIntegrateGeodesic(geodesicState.p0, geodesicState.s,
                                       geodesicState.diffOp, geodesicState.p,
                                       geodesicState.rho, geodesicState.rhoinv,
                                       tDiscGeodesic, geodesicState.Ninv,
                                       geodesicState.integMethod)

    fig = plt.figure(2)
    plt.clf()
    fig.patch.set_facecolor('white')

    plt.subplot(2, 2, 1)
    display.DispImage(geodesicState.J0,
                      'J0',
                      newFig=False,
                      sliceIdx=cf.io.plotSlice)
    plt.subplot(2, 2, 2)
    ca.ApplyH(geodesicState.J, geodesicState.J0, geodesicState.rhoinv)
    display.DispImage(geodesicState.J,
                      'J1',
                      newFig=False,
                      sliceIdx=cf.io.plotSlice)

    plt.subplot(2, 2, 3)
    display.GridPlot(geodesicState.rhoinv,
                     every=cf.io.quiverEvery,
                     color='k',
                     sliceIdx=cf.io.plotSlice,
                     isVF=False)
    plt.axis('equal')
    plt.axis('off')
    plt.title('rho^{-1}')
    plt.subplot(2, 2, 4)
    display.GridPlot(geodesicState.rho,
                     every=cf.io.quiverEvery,
                     color='k',
                     sliceIdx=cf.io.plotSlice,
                     isVF=False)
    plt.axis('equal')
    plt.axis('off')
    plt.title('rho')
    if cf.io.outputPrefix != None and writeOutput:
        plt.savefig(cf.io.outputPrefix + 'def.pdf')

    # MATCHING DIFFERENCE IMAGES
    grid = geodesicState.J0.grid()
    mType = geodesicState.J0.memType()
    imdiff = ca.ManagedImage3D(grid, mType)

    # Image matching
    ca.Copy(imdiff, geodesicState.J)
    ca.Sub_I(imdiff, J1)
    fig = plt.figure(3)
    plt.clf()
    fig.patch.set_facecolor('white')

    plt.subplot(1, 3, 1)
    display.DispImage(geodesicState.J0,
                      'Source J0',
                      newFig=False,
                      sliceIdx=cf.io.plotSlice)
    plt.colorbar()

    plt.subplot(1, 3, 2)
    display.DispImage(J1, 'Target J1', newFig=False, sliceIdx=cf.io.plotSlice)
    plt.colorbar()

    plt.subplot(1, 3, 3)
    display.DispImage(imdiff,
                      'rho.J0-J1',
                      newFig=False,
                      sliceIdx=cf.io.plotSlice)
    plt.colorbar()
    if cf.io.outputPrefix != None and writeOutput:
        plt.savefig(cf.io.outputPrefix + 'diffImage.pdf')

    # Momenta matching
    if mType == ca.MEM_DEVICE:
        scratchV1 = ca.Field3D(grid, mType)
        scratchV2 = ca.Field3D(grid, mType)
        scratchV3 = ca.Field3D(grid, mType)
    else:
        scratchV1 = ca.ManagedField3D(grid, mType)
        scratchV2 = ca.ManagedField3D(grid, mType)
        scratchV3 = ca.ManagedField3D(grid, mType)

    fig = plt.figure(4)
    plt.clf()
    fig.patch.set_facecolor('white')
    ca.Copy(scratchV1, m0)
    scratchV1.toType(ca.MEM_HOST)
    m0_x, m0_y, m0_z = scratchV1.asnp()
    plt.subplot(2, 3, 1)
    plt.imshow(np.squeeze(m0_x))
    plt.colorbar()
    plt.title('X: Source m0 ')
    plt.subplot(2, 3, 4)
    plt.imshow(np.squeeze(m0_y))
    plt.colorbar()
    plt.title('Y: Source m0')

    ca.Copy(scratchV2, n1)
    scratchV2.toType(ca.MEM_HOST)
    n1_x, n1_y, n1_z = scratchV2.asnp()
    plt.subplot(2, 3, 2)
    plt.imshow(np.squeeze(n1_x))
    plt.colorbar()
    plt.title('X: Target n1')
    plt.subplot(2, 3, 5)
    plt.imshow(np.squeeze(n1_y))
    plt.colorbar()
    plt.title('Y: Target n1')

    ca.CoAd(scratchV3, geodesicState.rhoinv, m0)
    ca.Sub_I(scratchV3, n1)
    scratchV3.toType(ca.MEM_HOST)
    diff_x, diff_y, diff_z = scratchV3.asnp()
    plt.subplot(2, 3, 3)
    plt.imshow(np.squeeze(diff_x))
    plt.colorbar()
    plt.title('X: rho.m0-n1')
    plt.subplot(2, 3, 6)
    plt.imshow(np.squeeze(diff_y))
    plt.colorbar()
    plt.title('Y: rho.m0-n1')

    if cf.io.outputPrefix != None and writeOutput:
        plt.savefig(cf.io.outputPrefix + 'diffMomenta.pdf')

    del scratchV1, scratchV2, scratchV3
    del imdiff
    plt.ion()

    initMax = 5
    randArrSmall = (np.random.rand(10, 10) * initMax).astype(int)
    imSmall = common.ImFromNPArr(randArrSmall)
    imLargeNN = ca.Image3D(50, 50, 1)
    imLargeLinear = ca.Image3D(50, 50, 1)
    imLargeCubic = ca.Image3D(50, 50, 1)
    ca.Resample(imLargeNN, imSmall, ca.BACKGROUND_STRATEGY_CLAMP, ca.INTERP_NN)
    ca.Resample(imLargeLinear, imSmall, ca.BACKGROUND_STRATEGY_CLAMP,
                ca.INTERP_LINEAR)
    ca.Resample(imLargeCubic, imSmall, ca.BACKGROUND_STRATEGY_CLAMP,
                ca.INTERP_CUBIC)
    plt.figure('interp test')
    plt.subplot(2, 3, 1)
    display.DispImage(imLargeNN, 'NN', newFig=False)
    plt.subplot(2, 3, 2)
    display.DispImage(imLargeLinear, 'Linear', newFig=False)
    plt.subplot(2, 3, 3)
    display.DispImage(imLargeCubic, 'Cubic', newFig=False)
    plt.subplot(2, 3, 5)
    display.DispImage(imSmall, 'small', newFig=False)
    plt.show()

    h = common.WavyDef([50, 50],
                       nWaves=1,
                       waveAmp=10,
                       waveDim=0,
                       mType=ca.MEM_HOST,
                       deformation=True)
    imDefNN = imLargeNN.copy()
Exemple #6
0
def GreedyReg\
        (I0Orig, \
         I1Orig, \
         scales = [1], \
         nIters = [1000], \
         ustep = [0.25], \
         fluidParams = [0.1, 0.1, 0.001], \
         plotEvery = 100):

    mType = I0Orig.memType()
    origGrid = I0Orig.grid()

    # allocate vars
    I0 = Image3D(origGrid, mType)
    I1 = Image3D(origGrid, mType)
    h = Field3D(origGrid, mType)
    Idef = Image3D(origGrid, mType)
    diff = Image3D(origGrid, mType)
    gI = Field3D(origGrid, mType)
    gU = Field3D(origGrid, mType)
    scratchI = Image3D(origGrid, mType)
    scratchV = Field3D(origGrid, mType)

    # allocate diffOp
    if mType == MEM_HOST:
        diffOp = FluidKernelFFTCPU()
    else:
        diffOp = FluidKernelFFTGPU()

    # initialize some vars
    zerovec = Vec3Df(0.0, 0.0, 0.0)

    nScales = len(scales)
    scaleManager = MultiscaleManager(origGrid)
    for s in scales:
        scaleManager.addScaleLevel(s)

    # Initalize the thread memory manager (needed for resampler)
    # num pools is 2 (images) + 2*3 (fields)
    ThreadMemoryManager.init(origGrid, mType, 8)

    if mType == MEM_HOST:
        resampler = MultiscaleResamplerGaussCPU(origGrid)
    else:
        resampler = MultiscaleResamplerGaussGPU(origGrid)

    def setScale(scale):
        global curGrid

        scaleManager.set(scale)
        curGrid = scaleManager.getCurGrid()
        # since this is only 2D:
        curGrid.spacing().z = 1.0;

        resampler.setScaleLevel(scaleManager)

        diffOp.setAlpha(fluidParams[0])
        diffOp.setBeta(fluidParams[1])
        diffOp.setGamma(fluidParams[2])
        diffOp.setGrid(curGrid)

        # downsample images
        I0.setGrid(curGrid)
        I1.setGrid(curGrid)
        if scaleManager.isLastScale():
            Copy(I0, I0Orig)
            Copy(I1, I1Orig)
        else:
            resampler.downsampleImage(I0,I0Orig)
            resampler.downsampleImage(I1,I1Orig)

        # initialize / upsample deformation
        if scaleManager.isFirstScale():
            h.setGrid(curGrid)
            SetToIdentity(h)
        else:
            resampler.updateHField(h)

        # set grids
        gI.setGrid(curGrid)
        Idef.setGrid(curGrid)
        diff.setGrid(curGrid)
        gU.setGrid(curGrid)
        scratchI.setGrid(curGrid)
        scratchV.setGrid(curGrid)
    # end function

    energy = [[] for _ in xrange(3)]

    for scale in range(len(scales)):

        setScale(scale)

        for it in range(nIters[scale]):
            print 'iter %d'%it

            # compute deformed image
            ApplyH(Idef, I0, h)

            # update gradient
            Gradient(gI, Idef)

            # update u
            Sub(diff, I1, Idef)

            gI *= diff

            diffOp.applyInverseOperator(gU, gI)

            gU *= ustep[scale]
            # ApplyV(scratchV, h, gU, BACKGROUND_STRATEGY_PARTIAL_ID)
            ComposeHV(scratchV, h, gU)
            h.swap(scratchV)

            # compute energy
            energy[0].append(Sum2(diff))

            if it % plotEvery == 0 or it == nIters[scale]-1:
                clrlist = ['r','g','b','m','c','y','k']
                plt.figure('energy')
                for i in range(len(energy)):
                    plt.plot(energy[i],clrlist[i])
                    if i == 0:
                        plt.hold(True)
                plt.hold(False)
                plt.draw()

                plt.figure('results')
                plt.clf()
                plt.subplot(3,2,1)
                display.DispImage(I0, 'I0', newFig=False)
                plt.subplot(3,2,2)
                display.DispImage(I1, 'I1', newFig=False)
                plt.subplot(3,2,3)
                display.DispImage(Idef, 'def', newFig=False)
                plt.subplot(3,2,4)
                display.DispImage(diff, 'diff', newFig=False)
                plt.colorbar()
                plt.subplot(3,2,5)
                display.GridPlot(h, every=4, isVF=False)
                plt.draw()
                plt.show()

            # end plot
        # end iteration
    # end scale
    return (Idef, h, energy)
def SaveFrames(checkpointstates, checkpointinds, I0, It, m0, mt, cf):
    momentathresh = 0.00002
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix) + '/frames/')
    image_idx = 0
    fig = plt.figure(1, frameon=False)
    plt.clf()
    display.DispImage(I0,
                      '',
                      newFig=False,
                      cmap='gray',
                      dim=cf.io.plotSliceDim,
                      sliceIdx=cf.io.plotSlice)
    plt.draw()
    outfilename = cf.io.outputPrefix + '/frames/I' + str(image_idx).zfill(
        5) + '.png'
    fig.set_size_inches(4, 4)
    plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

    fig = plt.figure(2, frameon=False)
    plt.clf()
    temp = ca.Field3D(I0.grid(), I0.memType())
    ca.SetToIdentity(temp)
    common.DebugHere()
    CAvmCommon.MyGridPlot(temp,
                          every=cf.io.gridEvery,
                          color='k',
                          dim=cf.io.plotSliceDim,
                          sliceIdx=cf.io.plotSlice,
                          isVF=False,
                          plotBase=False)
    #fig.patch.set_alpha(0)
    #fig.patch.set_visible(False)
    a = fig.gca()
    #a.set_frame_on(False)
    a.set_xticks([])
    a.set_yticks([])
    plt.axis('tight')
    plt.axis('image')
    plt.axis('off')
    plt.draw()
    fig.set_size_inches(4, 4)
    outfilename = cf.io.outputPrefix + '/frames/invdef' + str(image_idx).zfill(
        5) + '.png'
    plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

    fig = plt.figure(3, frameon=False)
    plt.clf()
    CAvmCommon.MyGridPlot(temp,
                          every=cf.io.gridEvery,
                          color='k',
                          dim=cf.io.plotSliceDim,
                          sliceIdx=cf.io.plotSlice,
                          isVF=False,
                          plotBase=False)
    #fig.patch.set_alpha(0)
    #fig.patch.set_visible(False)
    a = fig.gca()
    #a.set_frame_on(False)
    a.set_xticks([])
    a.set_yticks([])
    plt.axis('tight')
    plt.axis('image')
    plt.axis('off')
    plt.draw()
    fig.set_size_inches(4, 4)
    outfilename = cf.io.outputPrefix + '/frames/def' + str(image_idx).zfill(
        5) + '.png'
    plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

    fig = plt.figure(4, frameon=False)
    plt.clf()
    display.DispImage(I0,
                      '',
                      newFig=False,
                      cmap='gray',
                      dim=cf.io.plotSliceDim,
                      sliceIdx=cf.io.plotSlice)
    plt.hold('True')
    CAvmCommon.MyQuiver(m0,
                        dim=cf.io.plotSliceDim,
                        sliceIdx=cf.io.plotSlice,
                        every=cf.io.quiverEvery,
                        thresh=momentathresh,
                        scaleArrows=0.25,
                        arrowCol='r',
                        lineWidth=0.5,
                        width=0.005)
    plt.draw()

    plt.hold('False')

    outfilename = cf.io.outputPrefix + '/frames/m' + str(image_idx).zfill(
        5) + '.png'
    fig.set_size_inches(4, 4)
    plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

    for i in range(len(checkpointinds)):
        image_idx = image_idx + 1
        ca.ApplyH(It, I0, checkpointstates[i][1])
        fig = plt.figure(1, frameon=False)
        plt.clf()
        display.DispImage(It,
                          '',
                          newFig=False,
                          cmap='gray',
                          dim=cf.io.plotSliceDim,
                          sliceIdx=cf.io.plotSlice)
        plt.draw()
        outfilename = cf.io.outputPrefix + '/frames/I' + str(image_idx).zfill(
            5) + '.png'
        fig.set_size_inches(4, 4)
        plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

        fig = plt.figure(2, frameon=False)
        plt.clf()
        CAvmCommon.MyGridPlot(checkpointstates[i][1],
                              every=cf.io.gridEvery,
                              color='k',
                              dim=cf.io.plotSliceDim,
                              sliceIdx=cf.io.plotSlice,
                              isVF=False,
                              plotBase=False)
        #fig.patch.set_alpha(0)
        #fig.patch.set_visible(False)
        a = fig.gca()
        #a.set_frame_on(False)
        a.set_xticks([])
        a.set_yticks([])
        plt.axis('tight')
        plt.axis('image')
        plt.axis('off')
        plt.draw()
        outfilename = cf.io.outputPrefix + '/frames/invdef' + str(
            image_idx).zfill(5) + '.png'
        fig.set_size_inches(4, 4)
        plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

        fig = plt.figure(3, frameon=False)
        plt.clf()
        CAvmCommon.MyGridPlot(checkpointstates[i][0],
                              every=cf.io.gridEvery,
                              color='k',
                              dim=cf.io.plotSliceDim,
                              sliceIdx=cf.io.plotSlice,
                              isVF=False,
                              plotBase=False)
        #fig.patch.set_alpha(0)
        #fig.patch.set_visible(False)
        a = fig.gca()
        #a.set_frame_on(False)
        a.set_xticks([])
        a.set_yticks([])
        plt.axis('tight')
        plt.axis('image')
        plt.axis('off')
        plt.draw()
        outfilename = cf.io.outputPrefix + '/frames/def' + str(
            image_idx).zfill(5) + '.png'
        fig.set_size_inches(4, 4)
        plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

        ca.CoAd(mt, checkpointstates[i][1], m0)
        fig = plt.figure(4, frameon=False)
        plt.clf()
        display.DispImage(It,
                          '',
                          newFig=False,
                          cmap='gray',
                          dim=cf.io.plotSliceDim,
                          sliceIdx=cf.io.plotSlice)
        plt.hold('True')
        CAvmCommon.MyQuiver(mt,
                            dim=cf.io.plotSliceDim,
                            sliceIdx=cf.io.plotSlice,
                            every=cf.io.quiverEvery,
                            thresh=momentathresh,
                            scaleArrows=0.40,
                            arrowCol='r',
                            lineWidth=0.5,
                            width=0.005)
        plt.draw()
        plt.hold('False')
        outfilename = cf.io.outputPrefix + '/frames/m' + str(image_idx).zfill(
            5) + '.png'
        fig.set_size_inches(4, 4)
        plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)
Exemple #8
0
def ElastReg(I0Orig,
             I1Orig,
             scales=[1],
             nIters=[1000],
             maxPert=[0.2],
             fluidParams=[0.1, 0.1, 0.001],
             VFC=0.2,
             Mask=None,
             plotEvery=100):

    mType = I0Orig.memType()
    origGrid = I0Orig.grid()

    # allocate vars
    I0 = ca.Image3D(origGrid, mType)
    I1 = ca.Image3D(origGrid, mType)
    u = ca.Field3D(origGrid, mType)
    Idef = ca.Image3D(origGrid, mType)
    diff = ca.Image3D(origGrid, mType)
    gI = ca.Field3D(origGrid, mType)
    gU = ca.Field3D(origGrid, mType)
    scratchI = ca.Image3D(origGrid, mType)
    scratchV = ca.Field3D(origGrid, mType)

    # mask
    if Mask != None:
        MaskOrig = Mask.copy()

    # allocate diffOp
    if mType == ca.MEM_HOST:
        diffOp = ca.FluidKernelFFTCPU()
    else:
        diffOp = ca.FluidKernelFFTGPU()

    # initialize some vars
    nScales = len(scales)
    scaleManager = ca.MultiscaleManager(origGrid)
    for s in scales:
        scaleManager.addScaleLevel(s)

    # Initalize the thread memory manager (needed for resampler)
    # num pools is 2 (images) + 2*3 (fields)
    ca.ThreadMemoryManager.init(origGrid, mType, 8)

    if mType == ca.MEM_HOST:
        resampler = ca.MultiscaleResamplerGaussCPU(origGrid)
    else:
        resampler = ca.MultiscaleResamplerGaussGPU(origGrid)

    def setScale(scale):
        global curGrid

        scaleManager.set(scale)
        curGrid = scaleManager.getCurGrid()
        # since this is only 2D:
        curGrid.spacing().z = 1.0

        resampler.setScaleLevel(scaleManager)

        diffOp.setAlpha(fluidParams[0])
        diffOp.setBeta(fluidParams[1])
        diffOp.setGamma(fluidParams[2])
        diffOp.setGrid(curGrid)

        # downsample images
        I0.setGrid(curGrid)
        I1.setGrid(curGrid)
        if scaleManager.isLastScale():
            ca.Copy(I0, I0Orig)
            ca.Copy(I1, I1Orig)
        else:
            resampler.downsampleImage(I0, I0Orig)
            resampler.downsampleImage(I1, I1Orig)

        if Mask != None:
            if scaleManager.isLastScale():
                Mask.setGrid(curGrid)
                ca.Copy(Mask, MaskOrig)
            else:
                resampler.downsampleImage(Mask, MaskOrig)

        # initialize / upsample deformation
        if scaleManager.isFirstScale():
            u.setGrid(curGrid)
            ca.SetMem(u, 0.0)
        else:
            resampler.updateVField(u)

        # set grids
        gI.setGrid(curGrid)
        Idef.setGrid(curGrid)
        diff.setGrid(curGrid)
        gU.setGrid(curGrid)
        scratchI.setGrid(curGrid)
        scratchV.setGrid(curGrid)

    # end function

    energy = [[] for _ in xrange(3)]

    for scale in range(len(scales)):

        setScale(scale)
        ustep = None
        # update gradient
        ca.Gradient(gI, I0)

        for it in range(nIters[scale]):
            print 'iter %d' % it

            # compute deformed image
            ca.ApplyV(Idef, I0, u, 1.0)

            # update u
            ca.Sub(diff, I1, Idef)

            if Mask != None:
                ca.Mul_I(diff, Mask)

            ca.ApplyV(scratchV, gI, u, ca.BACKGROUND_STRATEGY_CLAMP)
            ca.Mul_I(scratchV, diff)

            diffOp.applyInverseOperator(gU, scratchV)

            vfcEn = VFC * ca.Dot(scratchV, gU)

            # why is this negative necessary?
            ca.MulC_I(gU, -1.0)

            # u =  u*(1-VFC*ustep) + (-2.0*ustep)*gU
            # MulC_Add_MulC_I(u, (1-VFC*ustep),
            #                        gU, 2.0*ustep)

            # u =  u - ustep*(VFC*u + 2.0*gU)
            ca.MulC_I(gU, 2.0)

            # subtract average if gamma is zero (result of nullspace
            # of L for K(L(u)))
            if fluidParams[2] == 0:
                av = ca.SumComp(u)
                av /= scratchI.nVox()
                ca.SubC(scratchV, u, av)
            # continue computing gradient
            ca.MulC(scratchV, u, VFC)
            ca.Add_I(gU, scratchV)

            ca.Magnitude(scratchI, gU)
            gradmax = ca.Max(scratchI)
            if ustep is None or ustep * gradmax > maxPert:
                ustep = maxPert[scale] / gradmax
                print 'step is %f' % ustep

            ca.MulC_I(gU, ustep)
            # apply gradient
            ca.Sub_I(u, gU)

            # compute energy
            energy[0].append(ca.Sum2(diff))
            diffOp.applyOperator(scratchV, u)
            energy[1].append(0.5 * VFC * ca.Dot(u, scratchV))
            energy[2].append(energy[0][-1]+\
                             energy[1][-1])

            if plotEvery > 0 and \
                   ((it+1) % plotEvery == 0 or \
                    (scale == nScales-1 and it == nIters[scale]-1)):
                print 'plotting'
                clrlist = ['r', 'g', 'b', 'm', 'c', 'y', 'k']
                plt.figure('energy')
                for i in range(len(energy)):
                    plt.plot(energy[i], clrlist[i])
                    if i == 0:
                        plt.hold(True)
                plt.hold(False)
                plt.draw()

                plt.figure('results')
                plt.clf()
                plt.subplot(3, 2, 1)
                display.DispImage(I0, 'I0', newFig=False)
                plt.subplot(3, 2, 2)
                display.DispImage(I1, 'I1', newFig=False)
                plt.subplot(3, 2, 3)
                display.DispImage(Idef, 'def', newFig=False)
                plt.subplot(3, 2, 4)
                display.DispImage(diff, 'diff', newFig=False)
                plt.colorbar()
                plt.subplot(3, 2, 5)
                display.GridPlot(u, every=4)
                plt.subplot(3, 2, 6)
                display.JacDetPlot(u)
                plt.colorbar()
                plt.draw()
                plt.show()

            # end plot
        # end iteration
    # end scale
    return (Idef, u, energy)
Exemple #9
0
def HGMPlots(cf,
             groupState,
             tDiscGroup,
             residualState,
             tDiscResidual,
             index_individual,
             writeOutput=True):
    """
    Do some summary plots for HGM
    """

    #ENERGY
    fig = plt.figure(1)
    plt.clf()
    fig.patch.set_facecolor('white')

    TE = [sum(x) for x in groupState.EnergyHistory]
    VE = [row[0] for row in groupState.EnergyHistory]
    IE = [row[1] for row in groupState.EnergyHistory]
    SE = [row[2] for row in groupState.EnergyHistory]
    TE = TE[1:]
    VE = VE[1:]
    IE = IE[1:]
    SE = SE[1:]
    plt.subplot(2, 2, 1)
    plt.plot(TE)
    plt.title('Total Energy')
    plt.hold(False)
    plt.subplot(2, 2, 2)
    plt.plot(VE)
    plt.title('Vector Energy')
    plt.hold(False)
    plt.subplot(2, 2, 3)
    plt.plot(IE)
    plt.title('Intercept Energy')
    plt.hold(False)
    plt.subplot(2, 2, 4)
    plt.plot(SE)
    plt.title('Slope Energy')
    plt.hold(False)
    plt.draw()
    plt.show()
    if cf.io.outputPrefix != None and writeOutput:
        plt.savefig(cf.io.outputPrefix + 'energy.pdf')

    # GROUP INITIAL CONDITIONS and PSI and PSI inv
    # shoot group geodesic forward
    CAvmHGMCommon.HGMIntegrateGeodesic(groupState.m0, groupState.t,
                                       groupState.diffOp, groupState.m,
                                       groupState.g, groupState.ginv,
                                       tDiscGroup, groupState.Ninv,
                                       groupState.integMethod)

    fig = plt.figure(2)
    plt.clf()
    fig.patch.set_facecolor('white')

    plt.subplot(2, 2, 1)
    display.DispImage(groupState.I0,
                      'I0',
                      newFig=False,
                      sliceIdx=cf.io.plotSlice)

    plt.subplot(2, 2, 2)
    ca.ApplyH(groupState.I, groupState.I0, groupState.ginv)
    display.DispImage(groupState.I,
                      'I1',
                      newFig=False,
                      sliceIdx=cf.io.plotSlice)

    plt.subplot(2, 2, 3)
    display.GridPlot(groupState.ginv,
                     every=cf.io.quiverEvery,
                     color='k',
                     sliceIdx=cf.io.plotSlice,
                     isVF=False)
    plt.axis('equal')
    plt.axis('off')
    plt.title('psi^{-1}')
    plt.subplot(2, 2, 4)
    display.GridPlot(groupState.g,
                     every=cf.io.quiverEvery,
                     color='k',
                     sliceIdx=cf.io.plotSlice,
                     isVF=False)
    plt.axis('equal')
    plt.axis('off')
    plt.title('psi')
    if cf.io.outputPrefix != None and writeOutput:
        plt.savefig(cf.io.outputPrefix + 'groupdef.pdf')

    # RESIDUAL INITIAL CONDITIONS and RHO and RHO inv
    ca.ApplyH(groupState.I, groupState.I0, groupState.ginv)
    residualState.J0 = groupState.I
    residualState.p0 = tDiscGroup[index_individual].p0
    CAvmHGMCommon.HGMIntegrateGeodesic(residualState.p0, residualState.s,
                                       residualState.diffOp, residualState.p,
                                       residualState.rho, residualState.rhoinv,
                                       tDiscResidual, residualState.Ninv,
                                       residualState.integMethod)

    fig = plt.figure(3)
    plt.clf()
    fig.patch.set_facecolor('white')

    plt.subplot(2, 2, 1)
    display.DispImage(residualState.J0,
                      'J0',
                      newFig=False,
                      sliceIdx=cf.io.plotSlice)
    plt.subplot(2, 2, 2)
    ca.ApplyH(residualState.J, residualState.J0, residualState.rhoinv)
    display.DispImage(residualState.J,
                      'J1',
                      newFig=False,
                      sliceIdx=cf.io.plotSlice)

    plt.subplot(2, 2, 3)
    display.GridPlot(residualState.rhoinv,
                     every=cf.io.quiverEvery,
                     color='k',
                     sliceIdx=cf.io.plotSlice,
                     isVF=False)
    plt.axis('equal')
    plt.axis('off')
    plt.title('rho^{-1}')
    plt.subplot(2, 2, 4)
    display.GridPlot(residualState.rho,
                     every=cf.io.quiverEvery,
                     color='k',
                     sliceIdx=cf.io.plotSlice,
                     isVF=False)
    plt.axis('equal')
    plt.axis('off')
    plt.title('rho')
    if cf.io.outputPrefix != None and writeOutput:
        plt.savefig(cf.io.outputPrefix + 'resdef.pdf')

    # MATCHING DIFFERENCE IMAGES
    grid = groupState.I0.grid()
    mType = groupState.I0.memType()
    imdiff = ca.ManagedImage3D(grid, mType)
    vecdiff = ca.ManagedField3D(grid, mType)

    # Intercept matching
    ca.Copy(imdiff, residualState.J)
    ca.Sub_I(imdiff, tDiscGroup[index_individual].J)
    fig = plt.figure(4)
    plt.clf()
    fig.patch.set_facecolor('white')

    plt.subplot(1, 3, 1)
    display.DispImage(residualState.J0,
                      'Source J0',
                      newFig=False,
                      sliceIdx=cf.io.plotSlice)
    plt.colorbar()

    plt.subplot(1, 3, 2)
    display.DispImage(tDiscGroup[index_individual].J,
                      'Target J1',
                      newFig=False,
                      sliceIdx=cf.io.plotSlice)
    plt.colorbar()

    plt.subplot(1, 3, 3)
    display.DispImage(imdiff,
                      'rho.J0-J1',
                      newFig=False,
                      sliceIdx=cf.io.plotSlice)
    plt.colorbar()
    if cf.io.outputPrefix != None and writeOutput:
        plt.savefig(cf.io.outputPrefix + 'diffintercept.pdf')

    # Slope matching
    '''
    ca.CoAd(groupState.m,groupState.ginv,groupState.m0)
    ca.CoAd(vecdiff,residualState.rhoinv,groupState.m)
    n0 = ca.Field3D(grid, ca.MEM_HOST)
    n1 = ca.Field3D(grid, ca.MEM_HOST)
    ca.Copy(n0,groupState.m)
    ca.Copy(n1,tDiscGroup[index_individual].n)    
    ca.Sub_I(vecdiff, tDiscGroup[index_individual].n)
    vecdiff.toType(ca.MEM_HOST)
    n0_x, n0_y, n0_z = n0.asnp()
    n1_x, n1_y, n1_z = n1.asnp()
    diff_x, diff_y, diff_z = vecdiff.asnp()

    fig = plt.figure(5)
    plt.clf()
    fig.patch.set_facecolor('white')

    plt.subplot(2,3,1)
    plt.imshow(np.squeeze(n0_x)); plt.colorbar(); plt.title('X: Source n0 ')

    plt.subplot(2,3,2)
    plt.imshow(np.squeeze(n1_x)); plt.colorbar(); plt.title('X: Target n1')

    plt.subplot(2,3,3)
    plt.imshow(np.squeeze(diff_x)); plt.colorbar(); plt.title('X: rho.n0-n1')

    plt.subplot(2,3,4)
    plt.imshow(np.squeeze(n0_y)); plt.colorbar(); plt.title('Y: Source n0')

    plt.subplot(2,3,5)
    plt.imshow(np.squeeze(n1_y)); plt.colorbar(); plt.title('Y: Target n1')

    plt.subplot(2,3,6)
    plt.imshow(np.squeeze(diff_y)); plt.colorbar(); plt.title('Y: rho.n0-n1')

    if cf.io.outputPrefix != None and writeOutput: plt.savefig(cf.io.outputPrefix+'diffslope.pdf')
    '''
    del imdiff
    del vecdiff
Exemple #10
0
def ElastReg\
        (I0, \
             I1, \
             nIters = 1000, \
             ustep = 0.25, \
             fluidParams = [0.1, 0.1, 0.001], \
             VFC = 0.2, \
             Mask = None, \
             plotEvery = 100):

    mType = I0.memType()
    grid = I0.grid()
 
    # allocate vars
    u = Field3D(grid, mType)
    SetMem(u,0.0)
    Idef = Image3D(grid, mType)
    diff = Image3D(grid, mType)
    gI = Field3D(grid, mType)
    gU = Field3D(grid, mType)
    scratchI = Image3D(grid, mType)
    scratchV = Field3D(grid, mType)

    # allocate diffOp 
    if mType == MEM_HOST:
        diffOp = FluidKernelFFTCPU()
    else:
        diffOp = FluidKernelFFTGPU()
    diffOp.setAlpha(fluidParams[0])
    diffOp.setBeta(fluidParams[1])
    diffOp.setGamma(fluidParams[2])
    diffOp.setGrid(grid)

    energy = [[] for _ in xrange(3)]

    # update gradient
    Gradient(gI, I0)

    for it in range(nIters):
        print 'iter %d'%it

        # compute deformed image
        ApplyV(Idef, I0, u, 1.0)

        # update u
        Sub(diff, I1, Idef)

        if Mask != None:
            Mul_I(diff, Mask)

        ApplyV(scratchV, gI, u, BACKGROUND_STRATEGY_CLAMP)
        Mul_I(scratchV, diff)

        diffOp.applyInverseOperator(gU, scratchV)

        # for computing energy
        diffOp.applyOperator(scratchV, u)
        
        # u =  u*(1-VFC*ustep) + (-2.0*ustep)*gU
        MulC_Add_MulC_I(u, (1-VFC*ustep),
                               gU, 2.0*ustep)

        # compute energy
        energy[0].append(Sum2(diff))
        energy[1].append(0.5*VFC*Dot(scratchV, u))
        energy[2].append(energy[0][-1]+\
                         energy[1][-1])

        if plotEvery > 0 and \
               ((it+1) % plotEvery == 0 or \
                it == nIters-1):
            print 'plotting'
            clrlist = ['r','g','b','m','c','y','k']
            plt.figure('energy')
            for i in range(len(energy)):
                plt.plot(energy[i],clrlist[i])
                if i == 0:
                    plt.hold(True)
            plt.hold(False)
            plt.draw()

            plt.figure('results')
            plt.clf()
            plt.subplot(3,2,1)
            display.DispImage(I0, 'I0', newFig=False)
            plt.subplot(3,2,2)
            display.DispImage(I1, 'I1', newFig=False)
            plt.subplot(3,2,3)
            display.DispImage(Idef, 'def', newFig=False)
            plt.subplot(3,2,4)
            display.DispImage(diff, 'diff', newFig=False)
            plt.colorbar()
            plt.subplot(3,2,5)
            display.GridPlot(u, every=4)
            plt.draw()
            plt.show()

        # end plot
    # end iteration
    return (Idef, u, energy)