Ejemplo n.º 1
0
 def test_GroupAdjointAction(self, disp=False):
     hM = common.RandField(self.sz,
                           nSig=5.0,
                           gSig=4.0,
                           mType=ca.MEM_HOST,
                           sp=self.imSp)
     hV = common.RandField(self.sz,
                           nSig=5.0,
                           gSig=4.0,
                           mType=ca.MEM_HOST,
                           sp=self.imSp)
     hPhi = common.RandField(self.sz,
                             nSig=5.0,
                             gSig=4.0,
                             mType=ca.MEM_HOST,
                             sp=self.imSp)
     tmp = ca.Field3D(self.grid, ca.MEM_HOST)
     # compute < m, Ad_\phi v >
     ca.Ad(tmp, hPhi, hV)
     rhs = ca.Dot(tmp, hM)
     # compute < Ad^*_\phi m,  v >
     ca.CoAd(tmp, hPhi, hM)
     lhs = ca.Dot(tmp, hV)
     #print "a=%f b=%f" % (rhs, lhs)
     self.assertLess(abs(rhs - lhs), 2e-6)
Ejemplo n.º 2
0
def Matching(cf):

    if cf.compute.useCUDA and cf.compute.gpuID is not None:
        ca.SetCUDADevice(cf.compute.gpuID)

    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # Output loaded config
    if cf.io.outputPrefix is not None:
        cfstr = Config.ConfigToYAML(MatchingConfigSpec, cf)
        with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
            f.write(cfstr)

    mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST


    I0 = common.LoadITKImage(cf.study.I0, mType)
    I1 = common.LoadITKImage(cf.study.I1, mType)
    #ca.DivC_I(I0,255.0)
    #ca.DivC_I(I1,255.0)
    grid = I0.grid()

    ca.ThreadMemoryManager.init(grid, mType, 1)
    
    #common.DebugHere()
    # TODO: need to work on these
    t = [x*1./cf.optim.nTimeSteps for x in range(cf.optim.nTimeSteps+1)]
    checkpointinds = range(1,len(t))
    checkpointstates =  [(ca.Field3D(grid,mType),ca.Field3D(grid,mType)) for idx in checkpointinds]

    p = MatchingVariables(I0,I1, cf.vectormomentum.sigma, t,checkpointinds, checkpointstates, cf.vectormomentum.diffOpParams[0], cf.vectormomentum.diffOpParams[1], cf.vectormomentum.diffOpParams[2], cf.optim.Niter, cf.optim.stepSize, cf.optim.maxPert, cf.optim.nTimeSteps, integMethod = cf.optim.integMethod, optMethod=cf.optim.method, nInv=cf.optim.NIterForInverse,plotEvery=cf.io.plotEvery, plotSlice = cf.io.plotSlice, quiverEvery = cf.io.quiverEvery, outputPrefix = cf.io.outputPrefix)

    RunMatching(p)

    # write output
    if cf.io.outputPrefix is not None: 
        # reset all variables by shooting once, may have been overwritten
        CAvmCommon.IntegrateGeodesic(p.m0,p.t,p.diffOp,\
                          p.m, p.g, p.ginv,\
                          p.scratchV1, p.scratchV2,p. scratchV3,\
                          p.checkpointstates, p.checkpointinds,\
                          Ninv=p.nInv, integMethod = p.integMethod)
        common.SaveITKField(p.m0, cf.io.outputPrefix+"m0.mhd")
        common.SaveITKField(p.ginv, cf.io.outputPrefix+"phiinv.mhd")
        common.SaveITKField(p.g, cf.io.outputPrefix+"phi.mhd")
Ejemplo n.º 3
0
def ComposeDef(V,
               t,
               asVField=False,
               inverse=False,
               scratchV1=None,
               scratchV2=None):
    """
    Takes an array of Field3Ds and returns a Field3Ds
    containting the vectors Composed to non-integer time t
    """
    vlen = len(V)
    grid = V[0].grid()
    mType = V[0].memType()
    # just clamp to final time
    if t > vlen:
        t = vlen
    t_int = int(math.floor(t))
    t_frac = t - t_int
    h = core.Field3D(grid, mType)
    if scratchV1 is None:
        scratchV1 = core.Field3D(grid, mType)
    core.SetToIdentity(h)
    for s in range(t_int):
        if inverse:
            core.ComposeHVInv(scratchV1, h, V[s])
        else:
            core.ComposeVH(scratchV1, V[s], h)
        h.swap(scratchV1)
    if t_frac != 0.0:
        if scratchV2 is None:
            scratchV2 = core.Field3D(grid, mType)
        core.Copy(scratchV2, V[t_int])
        core.MulC_I(scratchV2, core.Vec3Df(t_frac, t_frac, t_frac))
        if inverse:
            core.ComposeHVInv(scratchV1, h, scratchV2)
        else:
            core.ComposeVH(scratchV1, scratchV2, h)
        core.Copy(h, scratchV1)
    if asVField:
        core.SetToIdentity(scratchV1)
        core.Sub_I(h, scratchV1)
    return h
Ejemplo n.º 4
0
def DefSeriesIter(I, V, t, func, args):
    grid = I.grid()
    mType = I.memType()
    tlen = len(t)
    h = core.Field3D(grid, mType)
    IDef = core.Image3D(grid, mType)
    core.SetToIdentity(h)
    scratchV1 = core.Field3D(grid, mType)
    scratchV2 = core.Field3D(grid, mType)
    rtnarr = []
    for tidx in range(tlen):
        curt = t[tidx]
        h = common.ComposeDef(V, curt, inverse=True,
                              asVField=False,
                              scratchV1=scratchV1,
                              scratchV2=scratchV2)
        core.ApplyH(IDef, I, h)
        r = func(IDef, curt, *args)
        rtnarr.append(r)
    return rtnarr
Ejemplo n.º 5
0
def geodesic_shooting_diffOp(moving, target, m0, steps, mType, config):
	grid = moving.grid()
	m0.setGrid(grid)
	ca.ThreadMemoryManager.init(grid, mType, 1)
	if mType == ca.MEM_HOST:
		diffOp = ca.FluidKernelFFTCPU()
	else:
		diffOp = ca.FluidKernelFFTGPU()
	diffOp.setAlpha(config['deformation_params']['diffOpParams'][0])
	diffOp.setBeta(config['deformation_params']['diffOpParams'][1])
	diffOp.setGamma(config['deformation_params']['diffOpParams'][2])
	diffOp.setGrid(grid)

	g = ca.Field3D(grid, mType)
	ginv = ca.Field3D(grid, mType)
	mt = ca.Field3D(grid, mType)
	It = ca.Image3D(grid, mType)
	It_inv = ca.Image3D(grid, mType)

	if (steps <= 0):
		time_steps = config['deformation_params']['timeSteps'];
	else:
		time_steps = steps;

	t = [x*1./time_steps for x in range(time_steps+1)]
	checkpointinds = range(1,len(t))
	checkpointstates =  [(ca.Field3D(grid,mType),ca.Field3D(grid,mType)) for idx in checkpointinds]

	scratchV1 = ca.Field3D(grid,mType)
	scratchV2 = ca.Field3D(grid,mType)
	scratchV3 = ca.Field3D(grid,mType)    

	CAvmCommon.IntegrateGeodesic(m0,t,diffOp, mt, g, ginv,\
								 scratchV1,scratchV2,scratchV3,\
                                 keepstates=checkpointstates,keepinds=checkpointinds,
                                 Ninv=config['deformation_params']['NIterForInverse'], integMethod = config['deformation_params']['integMethod'])
	ca.ApplyH(It,moving,ginv)
	ca.ApplyH(It_inv,target,g)

	output={'I1':It, 'I1_inv': It_inv, "phiinv":ginv}
	return output
Ejemplo n.º 6
0
def FieldFromNPArr(arr,
                   mType=core.MEM_HOST,
                   sp=core.Vec3Df(1.0, 1.0, 1.0),
                   orig=core.Vec3Df(0.0, 0.0, 0.0)):
    """
    Expects X-by-Y-by-Z-by-3 or X-by-Y-by-2 array
    """
    arrx, arry, arrz = _ComponentArraysFromField(arr)
    fSz = core.Vec3Di()
    fSz.fromlist(arrx.shape)

    grid = core.GridInfo(fSz, sp, orig)
    f = core.Field3D(grid, core.MEM_HOST)
    (fArrX, fArrY, fArrZ) = f.asnp()
    fArrX[:, :, :] = arrx
    fArrY[:, :, :] = arry
    fArrZ[:, :, :] = arrz
    f.toType(mType)

    return f
Ejemplo n.º 7
0
def ExtractSliceVF(vf, sliceIdx=None, dim='z'):
    """
    Given a Field3D 'vf', extract the given slice along the given dimension.
    If sliceIdx is None, extract the middle slice.
    """
    sz = vf.size().tolist()
    # take mid slice if none specified
    if sliceIdx is None:
        sliceIdx = sz[DIMMAP[dim]] // 2
    roiStart = core.Vec3Di(0, 0, 0)
    roiSize = core.Vec3Di(sz[0], sz[1], sz[2])
    roiStart.set(DIMMAP[dim], sliceIdx)
    roiSize.set(DIMMAP[dim], 1)
    roiGrid = core.GridInfo(vf.grid().size(),
                            vf.grid().spacing(),
                            vf.grid().origin())
    roiGrid.setSize(roiSize)
    sliceVF = core.Field3D(roiGrid, vf.memType())
    core.SubVol(sliceVF, vf, roiStart)
    return sliceVF
Ejemplo n.º 8
0
def AsNPCopy(indata):

    origType = indata.memType()
    indata.toType(core.MEM_HOST)

    if IsImage3D(indata):
        out = indata.asnp().copy()
    elif IsField3D(indata):
        core.Field3D(indata.grid(), core.MEM_HOST)
        out_x, out_y, out_z = indata.asnp()
        sz = indata.grid().size().tolist()
        sz.append(3)
        out = np.zeros(sz)
        out[:, :, :, 0] = out_x
        out[:, :, :, 1] = out_y
        out[:, :, :, 2] = out_z
    else:
        raise Exception('Expected `out` to be Image3D or Field3D')

    indata.toType(origType)

    return out
Ejemplo n.º 9
0
def ApplyAffine(Iout, Im, A, bg=ca.BACKGROUND_STRATEGY_PARTIAL_ZERO):
    '''Applies an Affine matrix A to an image Im using the Image3D
    grid (size, spacing, origin) of the two images (Input and Output)

    '''
    # algorithm outline:  Create a temporary large grid, then perform
    # real affine transforms here, then crop to be the size of the out grid
    ca.SetMem(Iout, 0.0)

    A = np.matrix(A)

    bigsize = [max(Iout.grid().size().x, Im.grid().size().x),
               max(Iout.grid().size().y, Im.grid().size().y),
               max(Iout.grid().size().z, Im.grid().size().z)]
    idgrid = ca.GridInfo(ca.Vec3Di(bigsize[0], bigsize[1], bigsize[2]),
                         ca.Vec3Df(1, 1, 1),
                         ca.Vec3Df(0, 0, 0))
    # newgrid = Iout.grid()       # not a true copy!!!!!
    newgrid = ca.GridInfo(Iout.grid().size(),
                          Iout.grid().spacing(),
                          Iout.grid().origin())

    mType = Iout.memType()
    Imbig = cc.PadImage(Im, bigsize)
    h = ca.Field3D(idgrid, mType)
    ca.SetToIdentity(h)
    if isinstance(Im, ca.Field3D):
        Ioutbig = ca.Field3D(idgrid, mType)
    else:
        Ioutbig = ca.Image3D(idgrid, mType)

    # note:  x_real' = A*x_real; x_real' given (input grid)
    # solution: x_real = A^-1 * x_real
    # where x_real = x_index*spacing + origin
    # and x_real' = x_index'*spacing' + origin'
    # x_index' is really given, as is both spacings/origins
    # and we plug in the solution for x_index' into applyH

    if A.shape[1] == 3:          # 2D affine matrix
        x = ca.Image3D(idgrid, mType)
        y = ca.Image3D(idgrid, mType)
        xnew = ca.Image3D(idgrid, mType)
        ynew = ca.Image3D(idgrid, mType)
        ca.Copy(x, h, 0)
        ca.Copy(y, h, 1)

        # convert x,y to world coordinates
        x *= Iout.grid().spacing().x
        y *= Iout.grid().spacing().y
        x += Iout.grid().origin().x
        y += Iout.grid().origin().y

        # Matrix Multiply (All in real coords)
        Ainv = A.I
        ca.MulC_Add_MulC(xnew, x, Ainv[0, 0], y, Ainv[0, 1])
        ca.MulC_Add_MulC(ynew, x, Ainv[1, 0], y, Ainv[1, 1])
        xnew += (Ainv[0, 2])
        ynew += (Ainv[1, 2])     # xnew and ynew are now in real coords

        # convert back to index coordinates
        xnew -= Im.grid().origin().x
        ynew -= Im.grid().origin().y
        xnew /= Im.grid().spacing().x
        ynew /= Im.grid().spacing().y

        ca.SetToZero(h)
        ca.Copy(h, xnew, 0)
        ca.Copy(h, ynew, 1)

    elif A.shape[1] == 4:         # 3D affine matrix
        x = ca.Image3D(idgrid, mType)
        y = ca.Image3D(idgrid, mType)
        z = ca.Image3D(idgrid, mType)
        xnew = ca.Image3D(idgrid, mType)
        ynew = ca.Image3D(idgrid, mType)
        znew = ca.Image3D(idgrid, mType)
        ca.Copy(x, h, 0)
        ca.Copy(y, h, 1)
        ca.Copy(z, h, 2)

        x *= Iout.grid().spacing().x
        y *= Iout.grid().spacing().y
        z *= Iout.grid().spacing().z
        x += Iout.grid().origin().x
        y += Iout.grid().origin().y
        z += Iout.grid().origin().z

        # Matrix Multiply (All in real coords)
        Ainv = A.I
        ca.MulC_Add_MulC(xnew, x, Ainv[0, 0], y, Ainv[0, 1])
        ca.Add_MulC_I(xnew, z, Ainv[0, 2])
        xnew += (Ainv[0, 3])
        ca.MulC_Add_MulC(ynew, x, Ainv[1, 0], y, Ainv[1, 1])
        ca.Add_MulC_I(ynew, z, Ainv[1, 2])
        ynew += (Ainv[1, 3])
        ca.MulC_Add_MulC(znew, x, Ainv[2, 0], y, Ainv[2, 1])
        ca.Add_MulC_I(znew, z, Ainv[2, 2])
        znew += (Ainv[2, 3])

        # convert to index coordinates
        xnew -= Im.grid().origin().x
        ynew -= Im.grid().origin().y
        znew -= Im.grid().origin().z
        xnew /= Im.grid().spacing().x
        ynew /= Im.grid().spacing().y
        znew /= Im.grid().spacing().z

        ca.Copy(h, xnew, 0)
        ca.Copy(h, ynew, 1)
        ca.Copy(h, znew, 2)

    Imbig.setGrid(idgrid)

    ca.ApplyH(Ioutbig, Imbig, h, bg)
    # crop Ioutbig -> Iout
    ca.SubVol(Iout, Ioutbig, ca.Vec3Di(0, 0, 0))
    Iout.setGrid(newgrid)   # change back
Ejemplo n.º 10
0
SaveRGB = False
SaveVE = False
SaveBW = True
SaveVF = False
debug = False

# load BFI slice and MRI_as_BFI slices
BFI3D = cc.LoadMHA(dir_bf + 'block' + str(block) + 'as_MRI_bw_256.mha',
                   ca.MEM_HOST)  # B/W
MRI3D = cc.LoadMHA(dir_mri + 'T2Seg.mha', ca.MEM_HOST)

# Initialize Deformed 3D Volumes
if SaveRGB:
    BFI_color3D = cc.LoadMHA(dir_bf + 'block' + str(block) + '_reg_rgb.mha',
                             ca.MEM_HOST)
    BFIDef3D_RGB = ca.Field3D(BFI3D.grid(), BFI3D.memType())
    ca.SetMem(BFIDef3D_RGB, 0.0)
if SaveVE:
    BFIDef3D_VE = ca.Image3D(BFI3D.grid(), BFI3D.memType())
    ca.SetMem(BFIDef3D_VE, 0.0)
if SaveBW:
    BFIDef3D_BW = ca.Image3D(BFI3D.grid(), BFI3D.memType())
    ca.SetMem(BFIDef3D_BW, 0.0)
if SaveVF:
    BFIDef3D_VF = ca.Image3D(BFI3D.grid(), BFI3D.memType())
    ca.SetMem(BFIDef3D_VF, 0.0)

# Initialize the 2D slices
BFI = common.ExtractSliceIm(BFI3D, 0)
BFI.toType(ca.MEM_DEVICE)
BFI.setOrigin(ca.Vec3Df(0, 0, 0))
Ejemplo n.º 11
0
def GeoRegIteration(subid, cf, p, t, Imsmts, cpinds, cpstates, msmtinds,
                    gradAtMsmts, EnergyHistory, it):
    # compute gradient for regression
    (grad_m, sumJac, sumSplatI, VEnergy,
     IEnergy) = GeoRegGradient(p, t, Imsmts, cpinds, cpstates, msmtinds,
                               gradAtMsmts)

    # do energy related stuff for printing and bookkeeping
    #if it>0:
    EnergyHistory.append([VEnergy + IEnergy, VEnergy, IEnergy])
    print VEnergy + IEnergy, '(Total) = ', VEnergy, '(Vector)+', IEnergy, '(Image)'
    # plot some stuff
    if cf.io.plotEvery > 0 and (((it + 1) % cf.io.plotEvery) == 0
                                or it == cf.optim.Niter - 1):
        GeoRegPlots(subid, cf, p, t, Imsmts, cpinds, cpstates, msmtinds,
                    gradAtMsmts, EnergyHistory)
    # end if

    if cf.optim.method == 'FIXEDGD':
        # automatic stepsize selection in the first three steps
        if it == 1:
            # TODO: BEWARE There are hardcoded numbers here for 2D and 3D
            #first find max absolute value across voxels in gradient
            temp = ca.Field3D(grad_m.grid(), ca.MEM_HOST)
            ca.Copy(temp, grad_m)
            temp_x, temp_y, temp_z = temp.asnp()
            temp1 = np.square(temp_x.flatten()) + np.square(
                temp_y.flatten()) + np.square(temp_z.flatten())
            medianval = np.median(temp1[temp1 > 0.0000000001])
            del temp, temp1, temp_x, temp_y, temp_z
            #2D images for 2000 iters
            #p.stepSize = float(0.000000002*medianval)
            #3D images for 2000 iters
            p.stepSize = float(0.000002 * medianval)

            print 'rank:', Compute.GetMPIInfo(
            )['rank'], ', localRank:', Compute.GetMPIInfo(
            )['local_rank'], 'subid: ', subid, ' Selecting initial step size in the beginning to be ', str(
                p.stepSize)

        if it > 3:
            totalEnergyDiff = EnergyHistory[-1][0] - EnergyHistory[-2][0]
            if totalEnergyDiff > 0.0:
                if cf.optim.maxPert is not None:
                    print 'rank:', Compute.GetMPIInfo(
                    )['rank'], ', localRank:', Compute.GetMPIInfo(
                    )['local_rank'], 'subid: ', subid, ' Reducing stepsize for gradient descent by ', str(
                        cf.optim.maxPert *
                        100), '%. The new step size is ', str(
                            p.stepSize * (1 - cf.optim.maxPert))
                    p.stepSize = p.stepSize * (1 - cf.optim.maxPert)
        # take gradient descent step
        ca.Add_MulC_I(p.m0, grad_m, -p.stepSize)
    else:
        raise Exception("Unknown optimization scheme: " + cf.optim.optMethod)
    # end if

    # now divide to get new base image
    ca.Div(p.I0, sumSplatI, sumJac)

    return (EnergyHistory)
Ejemplo n.º 12
0
imagedir = '/home/sci/crottman/korenberg/results/' + reg_type + '/'
if reg_type in ['landmark', 've_reg']:
    fname_end = '_as_MRI_' + col + '.mha'
else:
    fname_end = '_as_MRI_' + col + '_' + str(sz) + '.mha'

if sz >= 512:
    mType = ca.MEM_HOST
else:
    mType = ca.MEM_DEVICE

if reg_type is not 'best':
    grid = cc.MakeGrid([sz, sz, sz], [256.0 / sz, 256.0 / sz, 256.0 / sz],
                       'center')
    blocks = ca.Field3D(grid, mType)
    ca.SetMem(blocks, 0.0)
    weights = blocks.copy()
    ca.SetMem(weights, 0.0)

    for i in xrange(1, 5):
        fname = imagedir + 'block' + str(i) + fname_end
        try:
            blk = cc.LoadMHA(fname, mType)
        except IOError:
            print 'Warning... block ' + str(i) + ' does not exist'
            continue
        blocks += blk
        weight3 = blk.copy()
        try:
            weight = cc.LoadMHA(imagedir +
def GeodesicShooting(cf):

    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # Output loaded config
    if cf.io.outputPrefix is not None:
        cfstr = Config.ConfigToYAML(GeodesicShootingConfigSpec, cf)
        with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
            f.write(cfstr)

    mType = ca.MEM_DEVICE if cf.useCUDA else ca.MEM_HOST
    #common.DebugHere()
    I0 = common.LoadITKImage(cf.study.I0, mType)
    m0 = common.LoadITKField(cf.study.m0, mType)
    grid = I0.grid()

    ca.ThreadMemoryManager.init(grid, mType, 1)
    # set up diffOp
    if mType == ca.MEM_HOST:
        diffOp = ca.FluidKernelFFTCPU()
    else:
        diffOp = ca.FluidKernelFFTGPU()
    diffOp.setAlpha(cf.diffOpParams[0])
    diffOp.setBeta(cf.diffOpParams[1])
    diffOp.setGamma(cf.diffOpParams[2])
    diffOp.setGrid(grid)

    g = ca.Field3D(grid, mType)
    ginv = ca.Field3D(grid, mType)
    mt = ca.Field3D(grid, mType)
    It = ca.Image3D(grid, mType)
    t = [
        x * 1. / cf.integration.nTimeSteps
        for x in range(cf.integration.nTimeSteps + 1)
    ]
    checkpointinds = range(1, len(t))
    checkpointstates = [(ca.Field3D(grid, mType), ca.Field3D(grid, mType))
                        for idx in checkpointinds]

    scratchV1 = ca.Field3D(grid, mType)
    scratchV2 = ca.Field3D(grid, mType)
    scratchV3 = ca.Field3D(grid, mType)
    # scale momenta to shoot
    cf.study.scaleMomenta = float(cf.study.scaleMomenta)
    if abs(cf.study.scaleMomenta) > 0.000000:
        ca.MulC_I(m0, float(cf.study.scaleMomenta))
        CAvmCommon.IntegrateGeodesic(m0,t,diffOp, mt, g, ginv,\
                                     scratchV1,scratchV2,scratchV3,\
                                     keepstates=checkpointstates,keepinds=checkpointinds,
                                     Ninv=cf.integration.NIterForInverse, integMethod = cf.integration.integMethod)
    else:
        ca.Copy(It, I0)
        ca.Copy(mt, m0)
        ca.SetToIdentity(ginv)
        ca.SetToIdentity(g)

    # write output
    if cf.io.outputPrefix is not None:
        # scale back shotmomenta before writing
        if abs(cf.study.scaleMomenta) > 0.000000:
            ca.ApplyH(It, I0, ginv)
            ca.CoAd(mt, ginv, m0)
            ca.DivC_I(mt, float(cf.study.scaleMomenta))

        common.SaveITKImage(It, cf.io.outputPrefix + "I1.mhd")
        common.SaveITKField(mt, cf.io.outputPrefix + "m1.mhd")
        common.SaveITKField(ginv, cf.io.outputPrefix + "phiinv.mhd")
        common.SaveITKField(g, cf.io.outputPrefix + "phi.mhd")
        GeodesicShootingPlots(g, ginv, I0, It, cf)
        if cf.io.saveFrames:
            SaveFrames(checkpointstates, checkpointinds, I0, It, m0, mt, cf)
def SaveFrames(checkpointstates, checkpointinds, I0, It, m0, mt, cf):
    momentathresh = 0.00002
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix) + '/frames/')
    image_idx = 0
    fig = plt.figure(1, frameon=False)
    plt.clf()
    display.DispImage(I0,
                      '',
                      newFig=False,
                      cmap='gray',
                      dim=cf.io.plotSliceDim,
                      sliceIdx=cf.io.plotSlice)
    plt.draw()
    outfilename = cf.io.outputPrefix + '/frames/I' + str(image_idx).zfill(
        5) + '.png'
    fig.set_size_inches(4, 4)
    plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

    fig = plt.figure(2, frameon=False)
    plt.clf()
    temp = ca.Field3D(I0.grid(), I0.memType())
    ca.SetToIdentity(temp)
    common.DebugHere()
    CAvmCommon.MyGridPlot(temp,
                          every=cf.io.gridEvery,
                          color='k',
                          dim=cf.io.plotSliceDim,
                          sliceIdx=cf.io.plotSlice,
                          isVF=False,
                          plotBase=False)
    #fig.patch.set_alpha(0)
    #fig.patch.set_visible(False)
    a = fig.gca()
    #a.set_frame_on(False)
    a.set_xticks([])
    a.set_yticks([])
    plt.axis('tight')
    plt.axis('image')
    plt.axis('off')
    plt.draw()
    fig.set_size_inches(4, 4)
    outfilename = cf.io.outputPrefix + '/frames/invdef' + str(image_idx).zfill(
        5) + '.png'
    plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

    fig = plt.figure(3, frameon=False)
    plt.clf()
    CAvmCommon.MyGridPlot(temp,
                          every=cf.io.gridEvery,
                          color='k',
                          dim=cf.io.plotSliceDim,
                          sliceIdx=cf.io.plotSlice,
                          isVF=False,
                          plotBase=False)
    #fig.patch.set_alpha(0)
    #fig.patch.set_visible(False)
    a = fig.gca()
    #a.set_frame_on(False)
    a.set_xticks([])
    a.set_yticks([])
    plt.axis('tight')
    plt.axis('image')
    plt.axis('off')
    plt.draw()
    fig.set_size_inches(4, 4)
    outfilename = cf.io.outputPrefix + '/frames/def' + str(image_idx).zfill(
        5) + '.png'
    plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

    fig = plt.figure(4, frameon=False)
    plt.clf()
    display.DispImage(I0,
                      '',
                      newFig=False,
                      cmap='gray',
                      dim=cf.io.plotSliceDim,
                      sliceIdx=cf.io.plotSlice)
    plt.hold('True')
    CAvmCommon.MyQuiver(m0,
                        dim=cf.io.plotSliceDim,
                        sliceIdx=cf.io.plotSlice,
                        every=cf.io.quiverEvery,
                        thresh=momentathresh,
                        scaleArrows=0.25,
                        arrowCol='r',
                        lineWidth=0.5,
                        width=0.005)
    plt.draw()

    plt.hold('False')

    outfilename = cf.io.outputPrefix + '/frames/m' + str(image_idx).zfill(
        5) + '.png'
    fig.set_size_inches(4, 4)
    plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

    for i in range(len(checkpointinds)):
        image_idx = image_idx + 1
        ca.ApplyH(It, I0, checkpointstates[i][1])
        fig = plt.figure(1, frameon=False)
        plt.clf()
        display.DispImage(It,
                          '',
                          newFig=False,
                          cmap='gray',
                          dim=cf.io.plotSliceDim,
                          sliceIdx=cf.io.plotSlice)
        plt.draw()
        outfilename = cf.io.outputPrefix + '/frames/I' + str(image_idx).zfill(
            5) + '.png'
        fig.set_size_inches(4, 4)
        plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

        fig = plt.figure(2, frameon=False)
        plt.clf()
        CAvmCommon.MyGridPlot(checkpointstates[i][1],
                              every=cf.io.gridEvery,
                              color='k',
                              dim=cf.io.plotSliceDim,
                              sliceIdx=cf.io.plotSlice,
                              isVF=False,
                              plotBase=False)
        #fig.patch.set_alpha(0)
        #fig.patch.set_visible(False)
        a = fig.gca()
        #a.set_frame_on(False)
        a.set_xticks([])
        a.set_yticks([])
        plt.axis('tight')
        plt.axis('image')
        plt.axis('off')
        plt.draw()
        outfilename = cf.io.outputPrefix + '/frames/invdef' + str(
            image_idx).zfill(5) + '.png'
        fig.set_size_inches(4, 4)
        plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

        fig = plt.figure(3, frameon=False)
        plt.clf()
        CAvmCommon.MyGridPlot(checkpointstates[i][0],
                              every=cf.io.gridEvery,
                              color='k',
                              dim=cf.io.plotSliceDim,
                              sliceIdx=cf.io.plotSlice,
                              isVF=False,
                              plotBase=False)
        #fig.patch.set_alpha(0)
        #fig.patch.set_visible(False)
        a = fig.gca()
        #a.set_frame_on(False)
        a.set_xticks([])
        a.set_yticks([])
        plt.axis('tight')
        plt.axis('image')
        plt.axis('off')
        plt.draw()
        outfilename = cf.io.outputPrefix + '/frames/def' + str(
            image_idx).zfill(5) + '.png'
        fig.set_size_inches(4, 4)
        plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

        ca.CoAd(mt, checkpointstates[i][1], m0)
        fig = plt.figure(4, frameon=False)
        plt.clf()
        display.DispImage(It,
                          '',
                          newFig=False,
                          cmap='gray',
                          dim=cf.io.plotSliceDim,
                          sliceIdx=cf.io.plotSlice)
        plt.hold('True')
        CAvmCommon.MyQuiver(mt,
                            dim=cf.io.plotSliceDim,
                            sliceIdx=cf.io.plotSlice,
                            every=cf.io.quiverEvery,
                            thresh=momentathresh,
                            scaleArrows=0.40,
                            arrowCol='r',
                            lineWidth=0.5,
                            width=0.005)
        plt.draw()
        plt.hold('False')
        outfilename = cf.io.outputPrefix + '/frames/m' + str(image_idx).zfill(
            5) + '.png'
        fig.set_size_inches(4, 4)
        plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)
Ejemplo n.º 15
0
def BuildGeoReg(cf):
    """Worker for running geodesic estimation on a subset of individuals
    """
    #common.DebugHere()
    localRank = Compute.GetMPIInfo()['local_rank']
    rank = Compute.GetMPIInfo()['rank']

    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # just one reporter process on each node
    isReporter = rank == 0

    # load filenames and times for all subjects
    (subjectsIds, subjectsImagePaths,
     subjectsTimes) = GeoRegLoadSubjectsDetails(cf.study.subjectFile)
    cf.study.numSubjects = len(subjectsIds)
    if isReporter:
        # Output loaded config
        if cf.io.outputPrefix is not None:
            cfstr = Config.ConfigToYAML(GeoRegConfigSpec, cf)
            with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
                f.write(cfstr)

    # if MPI check if processes are greater than number of subjects. it is okay if there are more subjects than processes
    if cf.compute.useMPI and (len(subjectsIds) < cf.compute.numProcesses):
        raise Exception("Please don't use more processes " +
                        "than total number of individuals")

    nodeSubjectsIds = subjectsIds[rank::cf.compute.numProcesses]
    nodeSubjectsImagePaths = subjectsImagePaths[rank::cf.compute.numProcesses]
    nodeSubjectsTimes = subjectsTimes[rank::cf.compute.numProcesses]

    numLocalSubjects = len(nodeSubjectsImagePaths)
    if cf.study.initializationsFile is not None:
        (subjectsInitialImages,
         subjectsInitialMomenta) = GeoRegLoadSubjectsInitializations(
             cf.study.initializationsFile)
        nodeSubjectsInitialImages = subjectsInitialImages[rank::cf.compute.
                                                          numProcesses]
        nodeSubjectsInitialMomenta = subjectsInitialMomenta[rank::cf.compute.
                                                            numProcesses]

    print 'rank:', rank, ', localRank:', localRank, ', numberSubjects/TotalSubjects:', len(
        nodeSubjectsImagePaths
    ), '/', cf.study.numSubjects, ', nodeSubjectsImagePaths:', nodeSubjectsImagePaths, ', nodeSubjectsTimes:', nodeSubjectsTimes

    # mem type is determined by whether or not we're using CUDA
    mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST

    # setting gpuid should be handled in gpu
    # if using GPU  set device based on local rank
    #if cf.compute.useCUDA:
    #    ca.SetCUDADevice(localRank)

    # get image size information
    dummyImToGetGridInfo = common.LoadITKImage(nodeSubjectsImagePaths[0][0],
                                               mType)
    imGrid = dummyImToGetGridInfo.grid()
    if cf.study.setUnitSpacing:
        imGrid.setSpacing(ca.Vec3Df(1.0, 1.0, 1.0))
    if cf.study.setZeroOrigin:
        imGrid.setOrigin(ca.Vec3Df(0, 0, 0))
    #del dummyImToGetGridInfo;

    # start up the memory manager for scratch variables
    ca.ThreadMemoryManager.init(imGrid, mType, 0)

    # allocate memory
    p = GeoRegVariables(imGrid,
                        mType,
                        cf.vectormomentum.diffOpParams[0],
                        cf.vectormomentum.diffOpParams[1],
                        cf.vectormomentum.diffOpParams[2],
                        cf.optim.NIterForInverse,
                        cf.vectormomentum.sigma,
                        cf.optim.stepSize,
                        integMethod=cf.optim.integMethod)
    # for each individual run geodesic regression for each subject
    for i in range(numLocalSubjects):

        # initializations for this subject
        if cf.study.initializationsFile is not None:
            # assuming the initializations are already preprocessed, in terms of intensities, origin and voxel scalings.
            p.I0 = common.LoadITKImage(nodeSubjectsInitialImages[i], mType)
            p.m0 = common.LoadITKField(nodeSubjectsInitialMomenta[i], mType)
        else:
            ca.SetMem(p.m0, 0.0)
            ca.SetMem(p.I0, 0.0)

        # allocate memory specific to this subject in steps a, b and c
        # a. create time array with checkpointing info for regression geodesic, allocate checkpoint memory
        (t, msmtinds, cpinds) = GeoRegSetUpTimeArray(cf.optim.nTimeSteps,
                                                     nodeSubjectsTimes[i],
                                                     0.001)
        cpstates = [(ca.Field3D(imGrid, mType), ca.Field3D(imGrid, mType))
                    for idx in cpinds]
        # b. allocate gradAtMeasurements of the length of msmtindex for storing residuals
        gradAtMsmts = [ca.Image3D(imGrid, mType) for idx in msmtinds]

        # c. load timepoint images for this subject
        Imsmts = [
            common.LoadITKImage(f, mType) if isinstance(f, str) else f
            for f in nodeSubjectsImagePaths[i]
        ]
        # reset stepsize if adaptive stepsize changed it inside
        p.stepSize = cf.optim.stepSize
        # preprocessimages
        GeoRegPreprocessInput(nodeSubjectsIds[i], cf, p, t, Imsmts, cpinds,
                              cpstates, msmtinds, gradAtMsmts)

        # run regression for this subject
        # REMEMBER
        # msmtinds index into cpinds
        # gradAtMsmts is parallel to msmtinds
        # cpinds index into t
        EnergyHistory = RunGeoReg(nodeSubjectsIds[i], cf, p, t, Imsmts, cpinds,
                                  cpstates, msmtinds, gradAtMsmts)

        # write output images and fields for this subject
        # TODO: BEWARE There are hardcoded numbers inside preprocessing code specific for ADNI/OASIS brain data.
        GeoRegWriteOuput(nodeSubjectsIds[i], cf, p, t, Imsmts, cpinds,
                         cpstates, msmtinds, gradAtMsmts, EnergyHistory)

        # clean up memory specific to this subject
        del t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts
Ejemplo n.º 16
0
    def __init__(self,
                 I0,
                 I1,
                 sigma,
                 t,
                 cpinds,
                 cpstates,
                 alpha,
                 beta,
                 gamma,
                 nIter,
                 stepSize,
                 maxPert,
                 nTimeSteps,
                 integMethod='RK4',
                 optMethod='FIXEDGD',
                 nInv=10,
                 plotEvery=None,
                 plotSlice=None,
                 quiverEvery=None,
                 outputPrefix='./'):
        """
        Initialize everything with the size and type given
        """
        print I0
        self.I0 = I0
        self.I1 = I1
        self.grid = I0.grid()
        print(self.grid)
        self.memtype = I0.memType()

        # matching param
        self.sigma = sigma

        # initial conditions
        self.m0 = ca.Field3D(self.grid, self.memtype)
        ca.SetMem(self.m0, 0.0)

        # state variables
        self.g = ca.Field3D(self.grid, self.memtype)
        self.ginv = ca.Field3D(self.grid, self.memtype)
        self.m = ca.Field3D(self.grid, self.memtype)
        self.I = ca.Image3D(self.grid, self.memtype)
        self.residualIm = ca.Image3D(self.grid, self.memtype)

        # adjoint variables
        self.madj = ca.Field3D(self.grid, self.memtype)
        self.Iadj = ca.Image3D(self.grid, self.memtype)
        self.madjtmp = ca.Field3D(self.grid, self.memtype)
        self.Iadjtmp = ca.Image3D(self.grid, self.memtype)

        # time array
        self.t = t

        # checkpointing variables
        self.checkpointinds = cpinds
        self.checkpointstates = cpstates

        # set up diffOp
        if self.memtype == ca.MEM_HOST:
            self.diffOp = ca.FluidKernelFFTCPU()
        else:
            self.diffOp = ca.FluidKernelFFTGPU()
        self.diffOp.setAlpha(alpha)
        self.diffOp.setBeta(beta)
        self.diffOp.setGamma(gamma)
        self.diffOp.setGrid(self.grid)

        # energy
        self.Energy = None

        # optimization stuff
        self.nIter = nIter
        self.stepSize = stepSize
        self.maxPert = maxPert
        self.nTimeSteps = nTimeSteps
        self.optMethod = optMethod
        self.integMethod = integMethod
        self.nInv = nInv  # for interative update to inverse deformation

        # plotting variables
        self.plotEvery = plotEvery
        self.plotSlice = plotSlice
        self.quiverEvery = quiverEvery
        self.outputPrefix = outputPrefix
        # scratch variables
        self.scratchV1 = ca.Field3D(self.grid, self.memtype)
        self.scratchV2 = ca.Field3D(self.grid, self.memtype)
        self.scratchV3 = ca.Field3D(self.grid, self.memtype)

        self.scratchV4 = ca.Field3D(self.grid, self.memtype)
        self.scratchV5 = ca.Field3D(self.grid, self.memtype)
        self.scratchV6 = ca.Field3D(self.grid, self.memtype)
        self.scratchV7 = ca.Field3D(self.grid, self.memtype)
        self.scratchV8 = ca.Field3D(self.grid, self.memtype)
        self.scratchV9 = ca.Field3D(self.grid, self.memtype)
        self.scratchV10 = ca.Field3D(self.grid, self.memtype)
        self.scratchV11 = ca.Field3D(self.grid, self.memtype)
def BuildAtlas(cf):
    """Worker for running Atlas construction on a subset of individuals.
    Runs Atlas on this subset sequentially. The variations retuned are
    summed up to get update for all individuals
    """

    localRank = Compute.GetMPIInfo()['local_rank']
    rank = Compute.GetMPIInfo()['rank']

    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # just one reporter process on each node
    isReporter = rank == 0
    cf.study.numSubjects = len(cf.study.subjectImages)

    if isReporter:
        # Output loaded config
        if cf.io.outputPrefix is not None:
            cfstr = Config.ConfigToYAML(AtlasConfigSpec, cf)
            with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
                f.write(cfstr)
    #common.DebugHere()

    # if MPI check if processes are greater than number of subjects. it is okay if there are more subjects than processes

    if cf.compute.useMPI and (cf.study.numSubjects < cf.compute.numProcesses):
        raise Exception("Please don't use more processes " +
                        "than total number of individuals")

    # subdivide data, create subsets for this thread to work on
    nodeSubjectIds = cf.study.subjectIds[rank::cf.compute.numProcesses]
    nodeImages = cf.study.subjectImages[rank::cf.compute.numProcesses]
    nodeWeights = cf.study.subjectWeights[rank::cf.compute.numProcesses]

    numLocalSubjects = len(nodeImages)
    print 'rank:', rank, ', localRank:', localRank, ', nodeImages:', nodeImages, ', nodeWeights:', nodeWeights

    # mem type is determined by whether or not we're using CUDA
    mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST

    # load data in memory
    # load intercepts
    J_array = [
        common.LoadITKImage(f, mType) if isinstance(f, str) else f
        for f in nodeImages
    ]

    # get imGrid from data
    imGrid = J_array[0].grid()

    # atlas image
    atlas = ca.Image3D(imGrid, mType)

    # allocate memory to store only the initial momenta for each individual in this thread
    m_array = [ca.Field3D(imGrid, mType) for i in range(numLocalSubjects)]

    # allocate only one copy of scratch memory to be reused for each local individual in this thread in loop
    p = WarpVariables(imGrid,
                      mType,
                      cf.vectormomentum.diffOpParams[0],
                      cf.vectormomentum.diffOpParams[1],
                      cf.vectormomentum.diffOpParams[2],
                      cf.optim.NIterForInverse,
                      cf.vectormomentum.sigma,
                      cf.optim.stepSize,
                      integMethod=cf.optim.integMethod)

    # memory to accumulate numerators and denominators for atlas from
    # local individuals which will be summed across MPI threads
    sumSplatI = ca.Image3D(imGrid, mType)
    sumJac = ca.Image3D(imGrid, mType)

    # start up the memory manager for scratch variables
    ca.ThreadMemoryManager.init(imGrid, mType, 0)

    # need some host memory in np array format for MPI reductions
    if cf.compute.useMPI:
        mpiImageBuff = None if mType == ca.MEM_HOST else ca.Image3D(
            imGrid, ca.MEM_HOST)

    t = [
        x * 1. / (cf.optim.nTimeSteps) for x in range(cf.optim.nTimeSteps + 1)
    ]
    cpinds = range(1, len(t))
    msmtinds = [
        len(t) - 2
    ]  # since t=0 is not in cpinds, thats just identity deformation so not checkpointed
    cpstates = [(ca.Field3D(imGrid, mType), ca.Field3D(imGrid, mType))
                for idx in cpinds]
    gradAtMsmts = [ca.Image3D(imGrid, mType) for idx in msmtinds]

    EnergyHistory = []

    # TODO: better initializations
    # initialize atlas image with zeros.
    ca.SetMem(atlas, 0.0)
    # initialize momenta with zeros

    for m0_individual in m_array:
        ca.SetMem(m0_individual, 0.0)
    '''
    # initial template image
    ca.SetMem(groupState.I0, 0.0)
    tmp = ca.ManagedImage3D(imGrid, mType)

    for tdisc in tdiscGroup:
        if tdisc.J is not None:
            ca.Copy(tmp, tdisc.J)
            groupState.I0 += tmp
    del tmp
    if cf.compute.useMPI:
        Compute.Reduce(groupState.I0, mpiImageBuff)
    
    # divide by total num subjects
    groupState.I0 /= cf.study.numSubjects
    '''

    # preprocessinput

    # assign atlas reference to p.I0. This reference will not change.
    p.I0 = atlas

    # run the loop
    for it in range(cf.optim.Niter):
        # run one iteration of warp for each individual and update
        # their own initial momenta and also accumulate SplatI and Jac
        ca.SetMem(sumSplatI, 0.0)
        ca.SetMem(sumJac, 0.0)
        TotalVEnergy = np.array([0.0])
        TotalIEnergy = np.array([0.0])

        for itsub in range(numLocalSubjects):
            # initializations for this subject, this only assigns
            # reference to image variables
            p.m0 = m_array[itsub]
            Imsmts = [J_array[itsub]]

            # run warp iteration
            VEnergy, IEnergy = RunWarpIteration(nodeSubjectIds[itsub], cf, p,
                                                t, Imsmts, cpinds, cpstates,
                                                msmtinds, gradAtMsmts, it)

            # gather relevant results
            ca.Add_I(sumSplatI, p.sumSplatI)
            ca.Add_I(sumJac, p.sumJac)
            TotalVEnergy[0] += VEnergy
            TotalIEnergy[0] += IEnergy

        # if there are multiple nodes we'll need to sum across processes now
        if cf.compute.useMPI:
            # do an MPI sum
            Compute.Reduce(sumSplatI, mpiImageBuff)
            Compute.Reduce(sumJac, mpiImageBuff)

            # also sum up energies of other nodes
            mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE,
                                            TotalVEnergy,
                                            op=mpi4py.MPI.SUM)
            mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE,
                                            TotalIEnergy,
                                            op=mpi4py.MPI.SUM)

        EnergyHistory.append([TotalVEnergy[0], TotalIEnergy[0]])

        # now divide to get the new atlas image
        ca.Div(atlas, sumSplatI, sumJac)

        # keep track of energy in this iteration
        if isReporter and cf.io.plotEvery > 0 and ((
            (it + 1) % cf.io.plotEvery == 0) or (it == cf.optim.Niter - 1)):
            # plots
            AtlasPlots(cf, p, atlas, m_array, EnergyHistory)

        if isReporter:
            # print out energy
            (VEnergy, IEnergy) = EnergyHistory[-1]
            print "Iter", it, "of", cf.optim.Niter, ":", VEnergy + IEnergy, '(Total) = ', VEnergy, '(Vector) + ', IEnergy, '(Image)'

    # write output images and fields
    AtlasWriteOutput(cf, atlas, m_array, nodeSubjectIds, isReporter)
def MatchingImageMomentaPlots(cf,
                              geodesicState,
                              tDiscGeodesic,
                              EnergyHistory,
                              m0,
                              J1,
                              n1,
                              writeOutput=True):
    """
    Do some summary plots for MatchingImageMomenta
    """

    #ENERGY
    fig = plt.figure(1)
    plt.clf()
    fig.patch.set_facecolor('white')

    TE = [row[0] for row in EnergyHistory]
    VE = [row[1] for row in EnergyHistory]
    IE = [row[2] for row in EnergyHistory]
    ME = [row[3] for row in EnergyHistory]
    plt.subplot(2, 2, 1)
    plt.plot(TE)
    plt.title('Total Energy')
    plt.hold(False)
    plt.subplot(2, 2, 2)
    plt.plot(VE)
    plt.title('Vector Energy')
    plt.hold(False)
    plt.subplot(2, 2, 3)
    plt.plot(IE)
    plt.title('Image Match Energy')
    plt.hold(False)
    plt.subplot(2, 2, 4)
    plt.plot(ME)
    plt.title('Momenta Match Energy')
    plt.hold(False)
    plt.draw()
    plt.show()
    if cf.io.outputPrefix != None and writeOutput:
        plt.savefig(cf.io.outputPrefix + 'energy.pdf')

    # GEODESIC INITIAL CONDITIONS and RHO and RHO inv
    CAvmHGMCommon.HGMIntegrateGeodesic(geodesicState.p0, geodesicState.s,
                                       geodesicState.diffOp, geodesicState.p,
                                       geodesicState.rho, geodesicState.rhoinv,
                                       tDiscGeodesic, geodesicState.Ninv,
                                       geodesicState.integMethod)

    fig = plt.figure(2)
    plt.clf()
    fig.patch.set_facecolor('white')

    plt.subplot(2, 2, 1)
    display.DispImage(geodesicState.J0,
                      'J0',
                      newFig=False,
                      sliceIdx=cf.io.plotSlice)
    plt.subplot(2, 2, 2)
    ca.ApplyH(geodesicState.J, geodesicState.J0, geodesicState.rhoinv)
    display.DispImage(geodesicState.J,
                      'J1',
                      newFig=False,
                      sliceIdx=cf.io.plotSlice)

    plt.subplot(2, 2, 3)
    display.GridPlot(geodesicState.rhoinv,
                     every=cf.io.quiverEvery,
                     color='k',
                     sliceIdx=cf.io.plotSlice,
                     isVF=False)
    plt.axis('equal')
    plt.axis('off')
    plt.title('rho^{-1}')
    plt.subplot(2, 2, 4)
    display.GridPlot(geodesicState.rho,
                     every=cf.io.quiverEvery,
                     color='k',
                     sliceIdx=cf.io.plotSlice,
                     isVF=False)
    plt.axis('equal')
    plt.axis('off')
    plt.title('rho')
    if cf.io.outputPrefix != None and writeOutput:
        plt.savefig(cf.io.outputPrefix + 'def.pdf')

    # MATCHING DIFFERENCE IMAGES
    grid = geodesicState.J0.grid()
    mType = geodesicState.J0.memType()
    imdiff = ca.ManagedImage3D(grid, mType)

    # Image matching
    ca.Copy(imdiff, geodesicState.J)
    ca.Sub_I(imdiff, J1)
    fig = plt.figure(3)
    plt.clf()
    fig.patch.set_facecolor('white')

    plt.subplot(1, 3, 1)
    display.DispImage(geodesicState.J0,
                      'Source J0',
                      newFig=False,
                      sliceIdx=cf.io.plotSlice)
    plt.colorbar()

    plt.subplot(1, 3, 2)
    display.DispImage(J1, 'Target J1', newFig=False, sliceIdx=cf.io.plotSlice)
    plt.colorbar()

    plt.subplot(1, 3, 3)
    display.DispImage(imdiff,
                      'rho.J0-J1',
                      newFig=False,
                      sliceIdx=cf.io.plotSlice)
    plt.colorbar()
    if cf.io.outputPrefix != None and writeOutput:
        plt.savefig(cf.io.outputPrefix + 'diffImage.pdf')

    # Momenta matching
    if mType == ca.MEM_DEVICE:
        scratchV1 = ca.Field3D(grid, mType)
        scratchV2 = ca.Field3D(grid, mType)
        scratchV3 = ca.Field3D(grid, mType)
    else:
        scratchV1 = ca.ManagedField3D(grid, mType)
        scratchV2 = ca.ManagedField3D(grid, mType)
        scratchV3 = ca.ManagedField3D(grid, mType)

    fig = plt.figure(4)
    plt.clf()
    fig.patch.set_facecolor('white')
    ca.Copy(scratchV1, m0)
    scratchV1.toType(ca.MEM_HOST)
    m0_x, m0_y, m0_z = scratchV1.asnp()
    plt.subplot(2, 3, 1)
    plt.imshow(np.squeeze(m0_x))
    plt.colorbar()
    plt.title('X: Source m0 ')
    plt.subplot(2, 3, 4)
    plt.imshow(np.squeeze(m0_y))
    plt.colorbar()
    plt.title('Y: Source m0')

    ca.Copy(scratchV2, n1)
    scratchV2.toType(ca.MEM_HOST)
    n1_x, n1_y, n1_z = scratchV2.asnp()
    plt.subplot(2, 3, 2)
    plt.imshow(np.squeeze(n1_x))
    plt.colorbar()
    plt.title('X: Target n1')
    plt.subplot(2, 3, 5)
    plt.imshow(np.squeeze(n1_y))
    plt.colorbar()
    plt.title('Y: Target n1')

    ca.CoAd(scratchV3, geodesicState.rhoinv, m0)
    ca.Sub_I(scratchV3, n1)
    scratchV3.toType(ca.MEM_HOST)
    diff_x, diff_y, diff_z = scratchV3.asnp()
    plt.subplot(2, 3, 3)
    plt.imshow(np.squeeze(diff_x))
    plt.colorbar()
    plt.title('X: rho.m0-n1')
    plt.subplot(2, 3, 6)
    plt.imshow(np.squeeze(diff_y))
    plt.colorbar()
    plt.title('Y: rho.m0-n1')

    if cf.io.outputPrefix != None and writeOutput:
        plt.savefig(cf.io.outputPrefix + 'diffMomenta.pdf')

    del scratchV1, scratchV2, scratchV3
    del imdiff
def DefReg(I_src, I_tar, config, memT, idConf):

    I_src.toType(memT)
    I_tar.toType(memT)

    # Convert to 2D spacing (because it really matters)
    sp2D = I_src.spacing().tolist()
    sp2D = ca.Vec3Df(sp2D[0], sp2D[1], 1)

    I_tar.setSpacing(sp2D)
    I_src.setSpacing(sp2D)
    gridReg = I_tar.grid()

    # Blur the images
    I_tar_blur = I_tar.copy()
    I_src_blur = I_src.copy()
    temp = ca.Image3D(I_tar.grid(), memT)
    gausFilt = ca.GaussianFilterGPU()

    scaleList = config.scale

    # Initiate the scale manager
    scaleManager = ca.MultiscaleManager(gridReg)
    for s in scaleList:
        scaleManager.addScaleLevel(s)
    if memT == ca.MEM_HOST:
        resampler = ca.MultiscaleResamplerGaussCPU(gridReg)
    else:
        resampler = ca.MultiscaleResamplerGaussGPU(gridReg)

    # Generate the scratch images
    scratchITar = ca.Image3D(gridReg, memT)
    scratchISrc = ca.Image3D(gridReg, memT)
    scratchI = ca.Image3D(gridReg, memT)
    scratchF = ca.Field3D(gridReg, memT)
    compF = ca.Field3D(gridReg, memT)

    def SetScale(scale):
        '''Scale Management for Multiscale'''
        scaleManager.set(scale)
        resampler.setScaleLevel(scaleManager)
        curGrid = scaleManager.getCurGrid()
        curGrid.spacing().z = 1  # Because only 2D

        print 'Inside setScale(). Current grid is ', curGrid

        if scaleManager.isLastScale():
            print 'Inside setScale(): **Last Scale**'
        if scaleManager.isFirstScale():
            print 'Inside setScale(): **First Scale**'

        scratchISrc.setGrid(curGrid)
        scratchITar.setGrid(curGrid)
        scratchI.setGrid(curGrid)
        compF.setGrid(curGrid)
        idConf.study.I0 = ca.Image3D(curGrid, memT)
        idConf.study.I1 = ca.Image3D(curGrid, memT)

        if scaleManager.isLastScale():
            s = config.sigBlur[scaleList.index(sc)]
            r = config.kerBlur[scaleList.index(sc)]
            gausFilt.updateParams(I_tar.size(), ca.Vec3Df(r, r, r),
                                  ca.Vec3Di(s, s, s))
            gausFilt.filter(scratchITar, I_tar, temp)
            gausFilt.filter(scratchI, I_src, temp)

# ca.Copy(scratchI, I_src)
# ca.Copy(scratchITar, I_tar)

        else:
            s = config.sigBlur[scaleList.index(sc)]
            r = config.kerBlur[scaleList.index(sc)]
            gausFilt.updateParams(I_tar.size(), ca.Vec3Df(r, r, r),
                                  ca.Vec3Di(s, s, s))
            gausFilt.filter(I_tar_blur, I_tar, temp)
            gausFilt.filter(I_src_blur, I_src, temp)
            resampler.downsampleImage(scratchI, I_src_blur)
            resampler.downsampleImage(scratchITar, I_tar_blur)

        if scaleManager.isFirstScale():
            scratchF.setGrid(curGrid)
            scratchITar.setGrid(curGrid)
            ca.SetToIdentity(scratchF)
            ca.ApplyH(scratchISrc, scratchI, scratchF)

        else:
            compF.setGrid(scratchF.grid())
            ca.ComposeHH(compF, scratchF, h)
            resampler.updateHField(scratchF)
            resampler.updateHField(compF)
            ca.Copy(scratchF, compF)
            ca.ApplyH(scratchISrc, scratchI, compF)

    for sc in scaleList:
        SetScale(scaleList.index(sc))

        #Set the optimize parameters in the IDiff configuration object
        idConf.optim.Niter = config.iters[scaleList.index(sc)]
        idConf.optim.stepSize = config.epsReg[scaleList.index(sc)]
        idConf.idiff.regWeight = config.sigReg[scaleList.index(sc)]
        ca.Copy(idConf.study.I0, scratchISrc)
        ca.Copy(idConf.study.I1, scratchITar)
        idConf.io.plotEvery = config.iters[scaleList.index(sc)]

        h = IDiff.Matching.Matching(idConf)
        tempScr = scratchISrc.copy()
        ca.ApplyH(tempScr, scratchISrc, h)

        #Plot the images to see the change
        cd.DispImage(scratchISrc - scratchITar,
                     rng=[-2, 2],
                     title='Orig Diff',
                     colorbar=True)
        cd.DispImage(tempScr - scratchITar,
                     rng=[-2, 2],
                     title='Reg Diff',
                     colorbar=True)

        # common.DebugHere()

        # I_src_def = idConf.study.I0.copy()

        # scratchITar = idConf.study.I1
        # eps = config.epsReg[scaleList.index(sc)]
        # sigma = config.sigReg[scaleList.index(sc)]
        # nIter = config.iters[scaleList.index(sc)]
        # # common.DebugHere()
        # [I_src_def, h, energy] = apps.IDiff(scratchISrc, scratchITar, eps, sigma, nIter, plot=True, verbose=1)
    ca.ComposeHH(scratchF, compF, h)
    I_src_def = idConf.study.I0.copy()

    return I_src_def, scratchF
def main():
    # Extract the Monkey number and section number from the command line
    global frgNum
    global secOb

    mkyNum = sys.argv[1]
    secNum = sys.argv[2]
    frgNum = int(sys.argv[3])
    write = True

    # if not os.path.exists(os.path.expanduser('~/korenbergNAS/3D_database/Working/configuration_files/SidescapeRelateBlockface/M{0}/section_{1}/include_configFile.yaml'.format(mkyNum,secNum))):
    #     cf = initial(secNum, mkyNum)

    try:
        secOb = Config.Load(
            secSpec,
            pth.expanduser(
                '~/korenbergNAS/3D_database/Working/configuration_files/SidescapeRelateBlockface/M{0}/section_{1}/include_configFile.yaml'
                .format(mkyNum, secNum)))
    except IOError as e:
        try:
            temp = Config.LoadYAMLDict(pth.expanduser(
                '~/korenbergNAS/3D_database/Working/configuration_files/SidescapeRelateBlockface/M{0}/section_{1}/include_configFile.yaml'
                .format(mkyNum, secNum)),
                                       include=False)
            secOb = Config.MkConfig(temp, secSpec)
        except IOError:
            print 'It appears there is no configuration file for this section. Please initialize one and restart.'
            sys.exit()
        if frgNum == int(secOb.yamlList[frgNum][-6]):
            Fragmenter()
            try:
                secOb = Config.Load(
                    secSpec,
                    pth.expanduser(
                        '~/korenbergNAS/3D_database/Working/configuration_files/SidescapeRelateBlockface/M{0}/section_{1}/include_configFile.yaml'
                        .format(mkyNum, secNum)))
            except IOError:
                print 'It appeas that the include yaml file list does not match your fragmentation number. Please check them and restart.'
                sys.exit()

    if not pth.exists(
            pth.expanduser(secOb.ssiOutPath + 'frag{0}'.format(frgNum))):
        common.Mkdir_p(
            pth.expanduser(secOb.ssiOutPath + 'frag{0}'.format(frgNum)))
    if not pth.exists(
            pth.expanduser(secOb.bfiOutPath + 'frag{0}'.format(frgNum))):
        common.Mkdir_p(
            pth.expanduser(secOb.bfiOutPath + 'frag{0}'.format(frgNum)))
    if not pth.exists(
            pth.expanduser(secOb.ssiSrcPath + 'frag{0}'.format(frgNum))):
        os.mkdir(pth.expanduser(secOb.ssiSrcPath + 'frag{0}'.format(frgNum)))
    if not pth.exists(
            pth.expanduser(secOb.bfiSrcPath + 'frag{0}'.format(frgNum))):
        os.mkdir(pth.expanduser(secOb.bfiSrcPath + 'frag{0}'.format(frgNum)))

    frgOb = Config.MkConfig(secOb.yamlList[frgNum], frgSpec)
    ssiSrc, bfiSrc, ssiMsk, bfiMsk = Loader(frgOb, ca.MEM_HOST)

    #Extract the saturation Image from the color iamge
    bfiHsv = common.FieldFromNPArr(
        matplotlib.colors.rgb_to_hsv(
            np.rollaxis(np.array(np.squeeze(bfiSrc.asnp())), 0, 3)),
        ca.MEM_HOST)
    bfiHsv.setGrid(bfiSrc.grid())
    bfiSat = ca.Image3D(bfiSrc.grid(), bfiHsv.memType())
    ca.Copy(bfiSat, bfiHsv, 1)
    #Histogram equalize, normalize and mask the blockface saturation image
    bfiSat = cb.HistogramEqualize(bfiSat, 256)
    bfiSat.setGrid(bfiSrc.grid())
    bfiSat *= -1
    bfiSat -= ca.Min(bfiSat)
    bfiSat /= ca.Max(bfiSat)
    bfiSat *= bfiMsk
    bfiSat.setGrid(bfiSrc.grid())

    #Write out the blockface region after adjusting the colors with a format that supports header information
    if write:
        common.SaveITKImage(
            bfiSat,
            pth.expanduser(secOb.bfiSrcPath +
                           'frag{0}/M{1}_01_bfi_section_{2}_frag{0}_sat.nrrd'.
                           format(frgNum, secOb.mkyNum, secOb.secNum)))

    #Set the sidescape grid relative to that of the blockface
    ssiSrc.setGrid(ConvertGrid(ssiSrc.grid(), bfiSat.grid()))
    ssiMsk.setGrid(ConvertGrid(ssiMsk.grid(), bfiSat.grid()))
    ssiSrc *= ssiMsk

    #Write out the sidescape masked image in a format that stores the header information
    if write:
        common.SaveITKImage(
            ssiSrc,
            pth.expanduser(secOb.ssiSrcPath +
                           'frag{0}/M{1}_01_ssi_section_{2}_frag{0}.nrrd'.
                           format(frgNum, secOb.mkyNum, secOb.secNum)))

    #Update the image parameters of the sidescape image for future use
    frgOb.imSize = ssiSrc.size().tolist()
    frgOb.imOrig = ssiSrc.origin().tolist()
    frgOb.imSpac = ssiSrc.spacing().tolist()
    updateFragOb(frgOb)

    #Find the affine transform between the two fragments
    bfiAff, ssiAff, aff = Affine(bfiSat, ssiSrc, frgOb)
    updateFragOb(frgOb)

    #Write out the affine transformed images in a format that stores header information
    if write:
        common.SaveITKImage(
            bfiAff,
            pth.expanduser(
                secOb.bfiOutPath +
                'frag{0}/M{1}_01_bfi_section_{2}_frag{0}_aff_ssi.nrrd'.format(
                    frgNum, secOb.mkyNum, secOb.secNum)))
        common.SaveITKImage(
            ssiAff,
            pth.expanduser(
                secOb.ssiOutPath +
                'frag{0}/M{1}_01_ssi_section_{2}_frag{0}_aff_bfi.nrrd'.format(
                    frgNum, secOb.mkyNum, secOb.secNum)))

    bfiVe = bfiAff.copy()
    ssiVe = ssiSrc.copy()
    cc.VarianceEqualize_I(bfiVe, sigma=frgOb.sigVarBfi, eps=frgOb.epsVar)
    cc.VarianceEqualize_I(ssiVe, sigma=frgOb.sigVarSsi, eps=frgOb.epsVar)

    #As of right now, the largest pre-computed FFT table is 2048, so resample onto that grid for registration
    regGrd = ConvertGrid(
        cc.MakeGrid(ca.Vec3Di(2048, 2048, 1), ca.Vec3Df(1, 1, 1),
                    ca.Vec3Df(0, 0, 0)), ssiSrc.grid())
    ssiReg = ca.Image3D(regGrd, ca.MEM_HOST)
    bfiReg = ca.Image3D(regGrd, ca.MEM_HOST)
    cc.ResampleWorld(ssiReg, ssiVe)
    cc.ResampleWorld(bfiReg, bfiVe)

    #Create the default configuration object for IDiff Matching and then set some parameters
    idCf = Config.SpecToConfig(IDiff.Matching.MatchingConfigSpec)
    idCf.compute.useCUDA = True
    idCf.io.outputPrefix = '/home/sci/blakez/IDtest/'

    #Run the registration
    ssiDef, phi = DefReg(ssiReg, bfiReg, frgOb, ca.MEM_DEVICE, idCf)

    #Turn the deformation into a displacement field so it can be applied to the large tif with C++ code
    affV = phi.copy()
    cc.ApplyAffineReal(affV, phi, np.linalg.inv(frgOb.affine))
    ca.HtoV_I(affV)

    #Apply the found deformation to the input ssi
    ssiSrc.toType(ca.MEM_DEVICE)
    cc.HtoReal(phi)
    affPhi = phi.copy()
    ssiBfi = ssiSrc.copy()
    upPhi = ca.Field3D(ssiSrc.grid(), phi.memType())

    cc.ApplyAffineReal(affPhi, phi, np.linalg.inv(frgOb.affine))
    cc.ResampleWorld(upPhi, affPhi, bg=2)
    cc.ApplyHReal(ssiBfi, ssiSrc, upPhi)

    # ssiPhi = ca.Image3D(ssiSrc.grid(), phi.memType())
    # upPhi = ca.Field3D(ssiSrc.grid(), phi.memType())
    # cc.ResampleWorld(upPhi, phi, bg=2)
    # cc.ApplyHReal(ssiPhi, ssiSrc, upPhi)
    # ssiBfi = ssiSrc.copy()
    # cc.ApplyAffineReal(ssiBfi, ssiPhi, np.linalg.inv(frgOb.affine))

    # #Apply affine to the deformation
    # affPhi = phi.copy()
    # cc.ApplyAffineReal(affPhi, phi, np.linalg.inv(frgOb.affine))

    if write:
        common.SaveITKImage(
            ssiBfi,
            pth.expanduser(
                secOb.ssiOutPath +
                'frag{0}/M{1}_01_ssi_section_{2}_frag{0}_def_bfi.nrrd'.format(
                    frgNum, secOb.mkyNum, secOb.secNum)))
        cc.WriteMHA(
            affPhi,
            pth.expanduser(
                secOb.ssiOutPath +
                'frag{0}/M{1}_01_ssi_section_{2}_frag{0}_to_bfi_real.mha'.
                format(frgNum, secOb.mkyNum, secOb.secNum)))
        cc.WriteMHA(
            affV,
            pth.expanduser(
                secOb.ssiOutPath +
                'frag{0}/M{1}_01_ssi_section_{2}_frag{0}_to_bfi_disp.mha'.
                format(frgNum, secOb.mkyNum, secOb.secNum)))

    #Create the list of names that the deformation should be applied to
    # nameList = ['M15_01_0956_SideLight_DimLED_10x_ORG.tif',
    #             'M15_01_0956_TyrosineHydroxylase_Ben_10x_Stitching_c1_ORG.tif',
    #             'M15_01_0956_TyrosineHydroxylase_Ben_10x_Stitching_c2_ORG.tif',
    #             'M15_01_0956_TyrosineHydroxylase_Ben_10x_Stitching_c3_ORG.tif']

    # appLarge(nameList, affPhi)

    common.DebugHere()
Ejemplo n.º 21
0
def LoadITKField(fname, mType=core.MEM_HOST):
    f = core.Field3D(mType)
    core._ITKFileIO.LoadField(f, fname)
    return f
Ejemplo n.º 22
0
def RigidReg(
    Is,
    It,
    theta_step=.0001,
    t_step=.01,
    a_step=0,
    maxIter=350,
    plot=True,
    origin=None,
    theta=0,  # only applies for 2D
    t=None,  # only applies for 2D
    Ain=np.matrix(np.identity(3))):

    Idef = ca.Image3D(It.grid(), It.memType())
    gradIdef = ca.Field3D(It.grid(), It.memType())
    h = ca.Field3D(It.grid(), It.memType())
    ca.SetToIdentity(h)
    x = ca.Image3D(It.grid(), It.memType())
    y = ca.Image3D(It.grid(), It.memType())
    DX = ca.Image3D(It.grid(), It.memType())
    DY = ca.Image3D(It.grid(), It.memType())
    diff = ca.Image3D(It.grid(), It.memType())
    scratchI = ca.Image3D(It.grid(), It.memType())

    ca.Copy(x, h, 0)
    ca.Copy(y, h, 1)
    if origin is None:
        origin = [(Is.grid().size().x + 1) / 2.0,
                  (Is.grid().size().y + 1) / 2.0,
                  (Is.grid().size().z + 1) / 2.0]
    x -= origin[0]
    y -= origin[1]

    numel = It.size().x * It.size().y * It.size().z
    immin, immax = ca.MinMax(It)
    imrng = max(immax - immin, .01)
    t_step /= numel * imrng
    theta_step /= numel * imrng
    a_step /= numel * imrng
    energy = []
    a = 1

    if cc.Is3D(Is):
        if theta:
            print "theta is not utilized in 3D registration"
        z = ca.Image3D(It.grid(), It.memType())
        DZ = ca.Image3D(It.grid(), It.memType())
        ca.Copy(z, h, 2)
        z -= origin[2]

        A = np.matrix(np.identity(4))
        cc.ApplyAffineReal(Idef, Is, A)
        #        cc.ApplyAffine(Idef, Is, A, origin)

        t = [0, 0, 0]
        for i in xrange(maxIter):
            ca.Sub(diff, Idef, It)
            ca.Gradient(gradIdef, Idef)
            ca.Copy(DX, gradIdef, 0)
            ca.Copy(DY, gradIdef, 1)
            ca.Copy(DZ, gradIdef, 2)

            # take gradient step for the translation
            ca.Mul(scratchI, DX, diff)
            t[0] += t_step * ca.Sum(scratchI)
            ca.Mul(scratchI, DY, diff)
            t[1] += t_step * ca.Sum(scratchI)
            ca.Mul(scratchI, DZ, diff)
            t[2] += t_step * ca.Sum(scratchI)

            A[0, 3] = t[0]
            A[1, 3] = t[1]
            A[2, 3] = t[2]
            if a_step > 0:
                DX *= x
                DY *= y
                DZ *= z
                DZ += DX
                DZ += DY
                DZ *= diff
                d_a = a_step * ca.Sum(DZ)
                a_prev = a
                a += d_a
                # multiplying by a/a_prev is equivalent to adding (a-aprev)
                A = A * np.matrix([[a / a_prev, 0, 0, 0], [
                    0, a / a_prev, 0, 0
                ], [0, 0, a / a_prev, 0], [0, 0, 0, 1]])

            # Z rotation
            ca.Copy(DX, gradIdef, 0)
            ca.Copy(DY, gradIdef, 1)
            DX *= y
            ca.Neg_I(DX)
            DY *= x
            ca.Add(scratchI, DX, DY)
            scratchI *= diff
            theta = -theta_step * ca.Sum(scratchI)
            # % Recalculate A
            A = A * np.matrix(
                [[np.cos(theta), np.sin(theta), 0, 0],
                 [-np.sin(theta), np.cos(theta), 0, 0], [0, 0, 1, 0],
                 [0, 0, 0, 1]])

            # Y rotation
            ca.Copy(DX, gradIdef, 0)
            ca.Copy(DZ, gradIdef, 2)
            DX *= z
            ca.Neg_I(DX)
            DZ *= x
            ca.Add(scratchI, DX, DZ)
            scratchI *= diff
            theta = -theta_step * ca.Sum(scratchI)
            # % Recalculate A
            A = A * np.matrix(
                [[np.cos(theta), 0, np.sin(theta), 0], [0, 1, 0, 0],
                 [-np.sin(theta), 0, np.cos(theta), 0], [0, 0, 0, 1]])

            # X rotation
            ca.Copy(DY, gradIdef, 1)
            ca.Copy(DZ, gradIdef, 2)
            DY *= z
            ca.Neg_I(DY)
            DZ *= y
            ca.Add(scratchI, DY, DZ)
            scratchI *= diff
            theta = -theta_step * ca.Sum(scratchI)
            # Recalculate A
            A = A * np.matrix(
                [[1, 0, 0, 0], [0, np.cos(theta),
                                np.sin(theta), 0],
                 [0, -np.sin(theta), np.cos(theta), 0], [0, 0, 0, 1]])

            cc.ApplyAffineReal(Idef, Is, A)
            #        cc.ApplyAffine(Idef, Is, A, origin)

            # % display Energy (and other figures) at the end
            energy.append(ca.Sum2(diff))
            if (i == maxIter -
                    1) or (i > 75 and abs(energy[-1] - energy[-50]) < immax):
                cd.DispImage(diff, title='Difference Image', colorbar=True)
                plt.figure()
                plt.plot(energy)
                cd.DispImage(Idef, title='Deformed Image')
                break

    elif cc.Is2D(Is):
        # theta = 0
        if t is None:
            t = [0, 0]

        # A = np.array([[a*np.cos(theta), np.sin(theta), t[0]],
        #               [-np.sin(theta), a*np.cos(theta), t[1]],
        #               [0, 0, 1]])

        A = np.copy(Ain)
        cc.ApplyAffineReal(Idef, Is, A)
        # ca.Copy(Idef, Is)
        for i in xrange(1, maxIter):
            # [FX,FY] = gradient(Idef)
            ca.Sub(diff, Idef, It)
            ca.Gradient(gradIdef, Idef)
            ca.Copy(DX, gradIdef, 0)
            ca.Copy(DY, gradIdef, 1)

            # take gradient step for the translation
            ca.Mul(scratchI, DX, diff)
            t[0] += t_step * ca.Sum(scratchI)
            ca.Mul(scratchI, DY, diff)
            t[1] += t_step * ca.Sum(scratchI)

            # take gradient step for the rotation theta
            if a_step > 0:
                # d/da
                DX *= x
                DY *= y
                DY += DX
                DY *= diff
                d_a = a_step * ca.Sum(DY)
                a += d_a
            # d/dtheta
            ca.Copy(DX, gradIdef, 0)
            ca.Copy(DY, gradIdef, 1)
            DX *= y
            ca.Neg_I(DX)
            DY *= x
            ca.Add(scratchI, DX, DY)
            scratchI *= diff
            d_theta = theta_step * ca.Sum(scratchI)
            theta -= d_theta

            # Recalculate A, Idef
            A = np.matrix([[a * np.cos(theta),
                            np.sin(theta), t[0]],
                           [-np.sin(theta), a * np.cos(theta), t[1]],
                           [0, 0, 1]])
            A = Ain * A

            cc.ApplyAffineReal(Idef, Is, A)
            #        cc.ApplyAffine(Idef, Is, A, origin)

            # % display Energy (and other figures) at the end
            energy.append(ca.Sum2(diff))
            if (i == maxIter -
                    1) or (i > 75 and abs(energy[-1] - energy[-50]) < immax):
                if i == maxIter - 1:
                    print "not converged in ", maxIter, " Iterations"
                if plot:
                    cd.DispImage(diff, title='Difference Image', colorbar=True)
                    plt.figure()
                    plt.plot(energy)
                    cd.DispImage(Idef, title='Deformed Image')
                break
    return A
Ejemplo n.º 23
0
                                      T2_VE,
                                      Ieps,
                                      Isigma,
                                      InIter,
                                      plot=False,
                                      verbose=1)

Idiff = T2_VE - live_VEfilt_ID
cd.Disp3Pane(Idiff, rng=[-3, 3], sliceIdx=dispslice)
cd.EnergyPlot(Ienergy, legend=['Reg', 'Data', 'Total'])
cd.DispHGrid(Iphi)
print ca.MinMax(Iphi)

# Compose the deformations and apply the total deformation to the initial live volume
# tempDef = ca.Field3D(phi.grid(), memT)
totDef = ca.Field3D(phi.grid(), memT)
ca.ComposeHH(totDef, phi, Iphi, ca.BACKGROUND_STRATEGY_CLAMP)
# ca.ComposeHH(totDef, h, tempDef, ca.BACKGROUND_STRATEGY_CLAMP)

#Apply the deformation to the TPS live volume and rotate to the original volume
live_T2reg_rot = T2.copy()
cc.HtoReal(totDef)
cc.ApplyHReal(live_T2reg_rot, liveDef, totDef)
cd.Disp3Pane(live_T2reg_rot)
# live_T2reg = T2.copy()
# cc.ApplyAffineReal(live_T2reg,live_T2reg_rot,np.linalg.inv(rotMat))
# cd.Disp3Pane(live_T2reg)

if write:
    cc.WriteMHA(live_T2reg_rot,
                SaveDir + 'M13_01_live_as_MRI_full_bw_256_roty-119_flipy.mha')
def MatchingImageMomenta(cf):
    """Runs matching for image momenta pair."""
    if cf.compute.useCUDA and cf.compute.gpuID is not None:
        ca.SetCUDADevice(cf.compute.gpuID)

    common.DebugHere()
    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # Output loaded config
    if cf.io.outputPrefix is not None:
        cfstr = Config.ConfigToYAML(MatchingImageMomentaConfigSpec, cf)
        with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
            f.write(cfstr)

    # mem type is determined by whether or not we're using CUDA
    mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST

    # load data in memory
    I0 = common.LoadITKImage(cf.study.I, mType)
    m0 = common.LoadITKField(cf.study.m, mType)
    J1 = common.LoadITKImage(cf.study.J, mType)
    n1 = common.LoadITKField(cf.study.n, mType)

    # get imGrid from data
    imGrid = I0.grid()

    # create time array with checkpointing info for this geodesic to be estimated
    (s, scratchInd,
     rCpinds) = CAvmHGM.HGMSetUpTimeArray(cf.optim.nTimeSteps, [1.0], 0.001)
    tDiscGeodesic = CAvmHGMCommon.HGMSetupTimeDiscretizationResidual(
        s, rCpinds, imGrid, mType)

    # create the state variable for geodesic that is going to hold all info
    p0 = ca.Field3D(imGrid, mType)
    geodesicState = CAvmHGMCommon.HGMResidualState(
        I0,
        p0,
        imGrid,
        mType,
        cf.vectormomentum.diffOpParams[0],
        cf.vectormomentum.diffOpParams[1],
        cf.vectormomentum.diffOpParams[2],
        s,
        cf.optim.NIterForInverse,
        1.0,
        cf.vectormomentum.sigmaM,
        cf.vectormomentum.sigmaI,
        cf.optim.stepSize,
        integMethod=cf.optim.integMethod)
    # initialize with zero
    ca.SetMem(geodesicState.p0, 0.0)
    # start up the memory manager for scratch variables
    ca.ThreadMemoryManager.init(imGrid, mType, 0)
    EnergyHistory = []
    # run the loop
    for it in range(cf.optim.Niter):
        # shoot the geodesic forward
        CAvmHGMCommon.HGMIntegrateGeodesic(geodesicState.p0, geodesicState.s,
                                           geodesicState.diffOp,
                                           geodesicState.p, geodesicState.rho,
                                           geodesicState.rhoinv, tDiscGeodesic,
                                           geodesicState.Ninv,
                                           geodesicState.integMethod)
        # integrate the geodesic backward
        CAvmHGMCommon.HGMIntegrateAdjointsResidual(geodesicState,
                                                   tDiscGeodesic, m0, J1, n1)

        # TODO: verify it should just be log map/simple image matching when sigmaM=\infty
        # gradient descent step for geodesic.p0
        CAvmHGMCommon.HGMTakeGradientStepResidual(geodesicState)

        # compute and print energy
        (VEnergy, IEnergy,
         MEnergy) = MatchingImageMomentaComputeEnergy(geodesicState, m0, J1,
                                                      n1)
        EnergyHistory.append(
            [VEnergy + IEnergy + MEnergy, VEnergy, IEnergy, MEnergy])
        print "Iter", it, "of", cf.optim.Niter, ":", VEnergy + IEnergy + MEnergy, '(Total) = ', VEnergy, '(Vector) + ', IEnergy, '(Image Match) + ', MEnergy, '(Momenta Match)'

        # plots
        if cf.io.plotEvery > 0 and (((it + 1) % cf.io.plotEvery == 0) or
                                    (it == cf.optim.Niter - 1)):
            MatchingImageMomentaPlots(cf,
                                      geodesicState,
                                      tDiscGeodesic,
                                      EnergyHistory,
                                      m0,
                                      J1,
                                      n1,
                                      writeOutput=True)

    # write output
    MatchingImageMomentaWriteOuput(cf, geodesicState)
Ejemplo n.º 25
0
def BuildHGM(cf):
    """Worker for running Hierarchical Geodesic Model (HGM) 
n    for group geodesic estimation on a subset of individuals. 
    Runs HGM on this subset sequentially. The variations retuned
    are summed up to get update for all individuals"""

    size = Compute.GetMPIInfo()['size']
    rank = Compute.GetMPIInfo()['rank']
    name = Compute.GetMPIInfo()['name']
    localRank = Compute.GetMPIInfo()['local_rank']
    nodename = socket.gethostname()

    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # just one reporter process on each node
    isReporter = rank == 0
    cf.study.numSubjects = len(cf.study.subjectIntercepts)
    if isReporter:
        # Output loaded config
        if cf.io.outputPrefix is not None:
            cfstr = Config.ConfigToYAML(HGMConfigSpec, cf)
            with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
                f.write(cfstr)
    #common.DebugHere()

    # if MPI check if processes are greater than number of subjects. it is okay if there are more subjects than processes

    if cf.compute.useMPI and (cf.study.numSubjects < cf.compute.numProcesses):
        raise Exception("Please don't use more processes " +
                        "than total number of individuals")

    # subdivide data, create subsets for this thread to work on
    nodeSubjectIds = cf.study.subjectIds[rank::cf.compute.numProcesses]
    nodeIntercepts = cf.study.subjectIntercepts[rank::cf.compute.numProcesses]
    nodeSlopes = cf.study.subjectSlopes[rank::cf.compute.numProcesses]
    nodeBaselineTimes = cf.study.subjectBaselineTimes[rank::cf.compute.
                                                      numProcesses]
    sys.stdout.write(
        "This is process %d of %d with name: %s on machinename: %s and local rank: %d.\nnodeIntercepts: %s\n nodeSlopes: %s\n nodeBaselineTimes: %s\n"
        % (rank, size, name, nodename, localRank, nodeIntercepts, nodeSlopes,
           nodeBaselineTimes))

    # mem type is determined by whether or not we're using CUDA
    mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST

    # load data in memory
    # load intercepts
    J = [
        common.LoadITKImage(f, mType) if isinstance(f, str) else f
        for f in nodeIntercepts
    ]

    # load slopes
    n = [
        common.LoadITKField(f, mType) if isinstance(f, str) else f
        for f in nodeSlopes
    ]

    # get imGrid from data
    imGrid = J[0].grid()

    # create time array with checkpointing info for group geodesic
    (t, Jind, gCpinds) = HGMSetUpTimeArray(cf.optim.nTimeStepsGroup,
                                           nodeBaselineTimes, 0.0000001)
    tdiscGroup = CAvmHGMCommon.HGMSetupTimeDiscretizationGroup(
        t, J, n, Jind, gCpinds, mType, nodeSubjectIds)

    # create time array with checkpointing info for residual geodesic
    (s, scratchInd, rCpinds) = HGMSetUpTimeArray(cf.optim.nTimeStepsResidual,
                                                 [1.0], 0.0000001)
    tdiscResidual = CAvmHGMCommon.HGMSetupTimeDiscretizationResidual(
        s, rCpinds, imGrid, mType)

    # create group state and residual state
    groupState = CAvmHGMCommon.HGMGroupState(
        imGrid,
        mType,
        cf.vectormomentum.diffOpParamsGroup[0],
        cf.vectormomentum.diffOpParamsGroup[1],
        cf.vectormomentum.diffOpParamsGroup[2],
        t,
        cf.optim.NIterForInverse,
        cf.vectormomentum.varIntercept,
        cf.vectormomentum.varSlope,
        cf.vectormomentum.varInterceptReg,
        cf.optim.stepSizeGroup,
        integMethod=cf.optim.integMethodGroup)

    #ca.Copy(groupState.I0, common.LoadITKImage('/usr/sci/projects/ADNI/nikhil/software/vectormomentumtest/TestData/FlowerData/Longitudinal/GroupGeodesic/I0.mhd', mType))

    # note that residual state is treated a scratch variable in this algorithm and reused for computing residual geodesics of multiple individual
    residualState = CAvmHGMCommon.HGMResidualState(
        None,
        None,
        imGrid,
        mType,
        cf.vectormomentum.diffOpParamsResidual[0],
        cf.vectormomentum.diffOpParamsResidual[1],
        cf.vectormomentum.diffOpParamsResidual[2],
        s,
        cf.optim.NIterForInverse,
        cf.vectormomentum.varIntercept,
        cf.vectormomentum.varSlope,
        cf.vectormomentum.varInterceptReg,
        cf.optim.stepSizeResidual,
        integMethod=cf.optim.integMethodResidual)

    # start up the memory manager for scratch variables
    ca.ThreadMemoryManager.init(imGrid, mType, 0)

    # need some host memory in np array format for MPI reductions
    if cf.compute.useMPI:
        mpiImageBuff = None if mType == ca.MEM_HOST else ca.Image3D(
            imGrid, ca.MEM_HOST)
        mpiFieldBuff = None if mType == ca.MEM_HOST else ca.Field3D(
            imGrid, ca.MEM_HOST)
    for i in range(len(groupState.t) - 1, -1, -1):
        if tdiscGroup[i].J is not None:
            indx_last_individual = i
            break
    '''
    # initial template image
    ca.SetMem(groupState.I0, 0.0)
    tmp = ca.ManagedImage3D(imGrid, mType)

    for tdisc in tdiscGroup:
        if tdisc.J is not None:
            ca.Copy(tmp, tdisc.J)
            groupState.I0 += tmp
    del tmp
    if cf.compute.useMPI:
        Compute.Reduce(groupState.I0, mpiImageBuff)
    
    # divide by total num subjects
    groupState.I0 /= cf.study.numSubjects
    '''

    # run the loop

    for it in range(cf.optim.Niter):
        # compute HGM variation for group
        HGMGroupVariation(groupState, tdiscGroup, residualState, tdiscResidual,
                          cf.io.outputPrefix, rank, it)
        common.CheckCUDAError("Error after HGM iteration")
        # compute gradient for momenta (m is used as scratch)
        # if there are multiple nodes we'll need to sum across processes now
        if cf.compute.useMPI:
            # do an MPI sum
            Compute.Reduce(groupState.sumSplatI, mpiImageBuff)
            Compute.Reduce(groupState.sumJac, mpiImageBuff)
            Compute.Reduce(groupState.madj, mpiFieldBuff)
            # also sum up energies of other nodes
            # intercept
            Eintercept = np.array([groupState.EnergyHistory[-1][1]])
            mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE,
                                            Eintercept,
                                            op=mpi4py.MPI.SUM)
            groupState.EnergyHistory[-1][1] = Eintercept[0]

            Eslope = np.array([groupState.EnergyHistory[-1][2]])
            mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE,
                                            Eslope,
                                            op=mpi4py.MPI.SUM)
            groupState.EnergyHistory[-1][2] = Eslope[0]

        ca.Copy(groupState.m, groupState.m0)
        groupState.diffOp.applyInverseOperator(groupState.m)
        ca.Sub_I(groupState.m, groupState.madj)
        #groupState.diffOp.applyOperator(groupState.m)
        # now take gradient step in momenta for group
        if cf.optim.method == 'FIXEDGD':
            # take fixed stepsize gradient step
            ca.Add_MulC_I(groupState.m0, groupState.m, -cf.optim.stepSizeGroup)
        else:
            raise Exception("Unknown optimization scheme: " + cf.optim.method)
        # end if

        # now divide to get the new base image for group
        ca.Div(groupState.I0, groupState.sumSplatI, groupState.sumJac)

        # keep track of energy in this iteration
        if isReporter and cf.io.plotEvery > 0 and ((
            (it + 1) % cf.io.plotEvery == 0) or (it == cf.optim.Niter - 1)):
            HGMPlots(cf,
                     groupState,
                     tdiscGroup,
                     residualState,
                     tdiscResidual,
                     indx_last_individual,
                     writeOutput=True)

        if isReporter:
            (VEnergy, IEnergy, SEnergy) = groupState.EnergyHistory[-1]
            print datetime.datetime.now().time(
            ), " Iter", it, "of", cf.optim.Niter, ":", VEnergy + IEnergy + SEnergy, '(Total) = ', VEnergy, '(Vector) + ', IEnergy, '(Intercept) + ', SEnergy, '(Slope)'

    # write output images and fields
    HGMWriteOutput(cf, groupState, tdiscGroup, isReporter)
    def __init__(self,
                 grid,
                 mType,
                 alpha,
                 beta,
                 gamma,
                 nInv,
                 sigma,
                 StepSize,
                 integMethod='EULER'):
        """
        Initialize everything with the size and type given
        """
        self.grid = grid
        self.memtype = mType

        # initial conditions
        self.I0 = None  # this is a reference that always points to the atlas image
        self.m0 = None  # this is a reference that gets assigned to momenta for an individual each time

        # state variables
        self.g = ca.Field3D(self.grid, self.memtype)
        self.ginv = ca.Field3D(self.grid, self.memtype)
        self.m = ca.Field3D(self.grid, self.memtype)
        self.I = ca.Image3D(self.grid, self.memtype)

        # adjoint variables
        self.madj = ca.Field3D(self.grid, self.memtype)
        self.Iadj = ca.Image3D(self.grid, self.memtype)
        self.madjtmp = ca.Field3D(self.grid, self.memtype)
        self.Iadjtmp = ca.Image3D(self.grid, self.memtype)

        # image variables for closed-form template update
        self.sumSplatI = ca.Image3D(self.grid, self.memtype)
        self.sumJac = ca.Image3D(self.grid, self.memtype)

        # set up diffOp
        if self.memtype == ca.MEM_HOST:
            self.diffOp = ca.FluidKernelFFTCPU()
        else:
            self.diffOp = ca.FluidKernelFFTGPU()
        self.diffOp.setAlpha(alpha)
        self.diffOp.setBeta(beta)
        self.diffOp.setGamma(gamma)
        self.diffOp.setGrid(self.grid)

        # some extras
        self.nInv = nInv  # for interative update to inverse deformation
        self.integMethod = integMethod
        self.sigma = sigma
        self.stepSize = StepSize

        # TODO: scratch variables to be changed to using managed memory
        self.scratchV1 = ca.Field3D(self.grid, self.memtype)
        self.scratchV2 = ca.Field3D(self.grid, self.memtype)
        self.scratchV3 = ca.Field3D(self.grid, self.memtype)
        self.scratchV4 = ca.Field3D(self.grid, self.memtype)
        self.scratchV5 = ca.Field3D(self.grid, self.memtype)
        self.scratchV6 = ca.Field3D(self.grid, self.memtype)
        self.scratchV7 = ca.Field3D(self.grid, self.memtype)
        self.scratchV8 = ca.Field3D(self.grid, self.memtype)
        self.scratchV9 = ca.Field3D(self.grid, self.memtype)
        self.scratchV10 = ca.Field3D(self.grid, self.memtype)
        self.scratchV11 = ca.Field3D(self.grid, self.memtype)
        self.scratchI1 = ca.Image3D(
            self.grid,
            self.memtype)  #only used  for geodesic regression with RK4
Ejemplo n.º 27
0
def ElastReg(I0Orig,
             I1Orig,
             scales=[1],
             nIters=[1000],
             maxPert=[0.2],
             fluidParams=[0.1, 0.1, 0.001],
             VFC=0.2,
             Mask=None,
             plotEvery=100):

    mType = I0Orig.memType()
    origGrid = I0Orig.grid()

    # allocate vars
    I0 = ca.Image3D(origGrid, mType)
    I1 = ca.Image3D(origGrid, mType)
    u = ca.Field3D(origGrid, mType)
    Idef = ca.Image3D(origGrid, mType)
    diff = ca.Image3D(origGrid, mType)
    gI = ca.Field3D(origGrid, mType)
    gU = ca.Field3D(origGrid, mType)
    scratchI = ca.Image3D(origGrid, mType)
    scratchV = ca.Field3D(origGrid, mType)

    # mask
    if Mask != None:
        MaskOrig = Mask.copy()

    # allocate diffOp
    if mType == ca.MEM_HOST:
        diffOp = ca.FluidKernelFFTCPU()
    else:
        diffOp = ca.FluidKernelFFTGPU()

    # initialize some vars
    nScales = len(scales)
    scaleManager = ca.MultiscaleManager(origGrid)
    for s in scales:
        scaleManager.addScaleLevel(s)

    # Initalize the thread memory manager (needed for resampler)
    # num pools is 2 (images) + 2*3 (fields)
    ca.ThreadMemoryManager.init(origGrid, mType, 8)

    if mType == ca.MEM_HOST:
        resampler = ca.MultiscaleResamplerGaussCPU(origGrid)
    else:
        resampler = ca.MultiscaleResamplerGaussGPU(origGrid)

    def setScale(scale):
        global curGrid

        scaleManager.set(scale)
        curGrid = scaleManager.getCurGrid()
        # since this is only 2D:
        curGrid.spacing().z = 1.0

        resampler.setScaleLevel(scaleManager)

        diffOp.setAlpha(fluidParams[0])
        diffOp.setBeta(fluidParams[1])
        diffOp.setGamma(fluidParams[2])
        diffOp.setGrid(curGrid)

        # downsample images
        I0.setGrid(curGrid)
        I1.setGrid(curGrid)
        if scaleManager.isLastScale():
            ca.Copy(I0, I0Orig)
            ca.Copy(I1, I1Orig)
        else:
            resampler.downsampleImage(I0, I0Orig)
            resampler.downsampleImage(I1, I1Orig)

        if Mask != None:
            if scaleManager.isLastScale():
                Mask.setGrid(curGrid)
                ca.Copy(Mask, MaskOrig)
            else:
                resampler.downsampleImage(Mask, MaskOrig)

        # initialize / upsample deformation
        if scaleManager.isFirstScale():
            u.setGrid(curGrid)
            ca.SetMem(u, 0.0)
        else:
            resampler.updateVField(u)

        # set grids
        gI.setGrid(curGrid)
        Idef.setGrid(curGrid)
        diff.setGrid(curGrid)
        gU.setGrid(curGrid)
        scratchI.setGrid(curGrid)
        scratchV.setGrid(curGrid)

    # end function

    energy = [[] for _ in xrange(3)]

    for scale in range(len(scales)):

        setScale(scale)
        ustep = None
        # update gradient
        ca.Gradient(gI, I0)

        for it in range(nIters[scale]):
            print 'iter %d' % it

            # compute deformed image
            ca.ApplyV(Idef, I0, u, 1.0)

            # update u
            ca.Sub(diff, I1, Idef)

            if Mask != None:
                ca.Mul_I(diff, Mask)

            ca.ApplyV(scratchV, gI, u, ca.BACKGROUND_STRATEGY_CLAMP)
            ca.Mul_I(scratchV, diff)

            diffOp.applyInverseOperator(gU, scratchV)

            vfcEn = VFC * ca.Dot(scratchV, gU)

            # why is this negative necessary?
            ca.MulC_I(gU, -1.0)

            # u =  u*(1-VFC*ustep) + (-2.0*ustep)*gU
            # MulC_Add_MulC_I(u, (1-VFC*ustep),
            #                        gU, 2.0*ustep)

            # u =  u - ustep*(VFC*u + 2.0*gU)
            ca.MulC_I(gU, 2.0)

            # subtract average if gamma is zero (result of nullspace
            # of L for K(L(u)))
            if fluidParams[2] == 0:
                av = ca.SumComp(u)
                av /= scratchI.nVox()
                ca.SubC(scratchV, u, av)
            # continue computing gradient
            ca.MulC(scratchV, u, VFC)
            ca.Add_I(gU, scratchV)

            ca.Magnitude(scratchI, gU)
            gradmax = ca.Max(scratchI)
            if ustep is None or ustep * gradmax > maxPert:
                ustep = maxPert[scale] / gradmax
                print 'step is %f' % ustep

            ca.MulC_I(gU, ustep)
            # apply gradient
            ca.Sub_I(u, gU)

            # compute energy
            energy[0].append(ca.Sum2(diff))
            diffOp.applyOperator(scratchV, u)
            energy[1].append(0.5 * VFC * ca.Dot(u, scratchV))
            energy[2].append(energy[0][-1]+\
                             energy[1][-1])

            if plotEvery > 0 and \
                   ((it+1) % plotEvery == 0 or \
                    (scale == nScales-1 and it == nIters[scale]-1)):
                print 'plotting'
                clrlist = ['r', 'g', 'b', 'm', 'c', 'y', 'k']
                plt.figure('energy')
                for i in range(len(energy)):
                    plt.plot(energy[i], clrlist[i])
                    if i == 0:
                        plt.hold(True)
                plt.hold(False)
                plt.draw()

                plt.figure('results')
                plt.clf()
                plt.subplot(3, 2, 1)
                display.DispImage(I0, 'I0', newFig=False)
                plt.subplot(3, 2, 2)
                display.DispImage(I1, 'I1', newFig=False)
                plt.subplot(3, 2, 3)
                display.DispImage(Idef, 'def', newFig=False)
                plt.subplot(3, 2, 4)
                display.DispImage(diff, 'diff', newFig=False)
                plt.colorbar()
                plt.subplot(3, 2, 5)
                display.GridPlot(u, every=4)
                plt.subplot(3, 2, 6)
                display.JacDetPlot(u)
                plt.colorbar()
                plt.draw()
                plt.show()

            # end plot
        # end iteration
    # end scale
    return (Idef, u, energy)
def main():
    secNum = sys.argv[1]
    mkyNum = sys.argv[2]
    region = str(sys.argv[3])
    # channel = sys.argv[3]
    ext = 'M{0}/section_{1}/{2}/'.format(mkyNum, secNum, region)
    ss_dir = '/home/sci/blakez/korenbergNAS/3D_database/Working/Microscopic/side_light_microscope/'
    conf_dir = '/home/sci/blakez/korenbergNAS/3D_database/Working/Microscopic/confocal/'
    memT = ca.MEM_DEVICE

    try:
        with open(
                ss_dir +
                'src_registration/M{0}/section_{1}/M{0}_01_section_{1}_regions.txt'
                .format(mkyNum, secNum), 'r') as f:
            region_dict = json.load(f)
            f.close()
    except IOError:
        region_dict = {}
        region_dict[region] = {}
        region_dict['size'] = map(
            int,
            raw_input("What is the size of the full resolution image x,y? ").
            split(','))
        region_dict[region]['bbx'] = map(
            int,
            raw_input(
                "What are the x indicies of the bounding box (Matlab Format x_start,x_stop? "
            ).split(','))
        region_dict[region]['bby'] = map(
            int,
            raw_input(
                "What are the y indicies of the bounding box (Matlab Format y_start,y_stop? "
            ).split(','))

    if region not in region_dict:
        region_dict[region] = {}
        region_dict[region]['bbx'] = map(
            int,
            raw_input(
                "What are the x indicies of the bounding box (Matlab Format x_start,x_stop? "
            ).split(','))
        region_dict[region]['bby'] = map(
            int,
            raw_input(
                "What are the y indicies of the bounding box (Matlab Format y_start,y_stop? "
            ).split(','))

    img_region = common.LoadITKImage(
        ss_dir +
        'src_registration/M{0}/section_{1}/M{0}_01_section_{1}_{2}.tiff'.
        format(mkyNum, secNum, region), ca.MEM_HOST)
    ssiSrc = common.LoadITKImage(
        ss_dir +
        'src_registration/M{0}/section_{1}/frag0/M{0}_01_ssi_section_{1}_frag0.nrrd'
        .format(mkyNum, secNum), ca.MEM_HOST)
    bfi_df = common.LoadITKField(
        ss_dir +
        'Blockface_registered/M{0}/section_{1}/frag0/M{0}_01_ssi_section_{1}_frag0_to_bfi_real.mha'
        .format(mkyNum, secNum), ca.MEM_DEVICE)

    # Figure out the same region in the low resolution image: There is a transpose from here to matlab so dimensions are flipped
    low_sz = ssiSrc.size().tolist()
    yrng_raw = [(low_sz[1] * region_dict[region]['bbx'][0]) /
                np.float(region_dict['size'][0]),
                (low_sz[1] * region_dict[region]['bbx'][1]) /
                np.float(region_dict['size'][0])]
    xrng_raw = [(low_sz[0] * region_dict[region]['bby'][0]) /
                np.float(region_dict['size'][1]),
                (low_sz[0] * region_dict[region]['bby'][1]) /
                np.float(region_dict['size'][1])]
    yrng = [np.int(np.floor(yrng_raw[0])), np.int(np.ceil(yrng_raw[1]))]
    xrng = [np.int(np.floor(xrng_raw[0])), np.int(np.ceil(xrng_raw[1]))]
    low_sub = cc.SubVol(ssiSrc, xrng, yrng)

    # Figure out the grid for the sub region in relation to the sidescape
    originout = [
        ssiSrc.origin().x + ssiSrc.spacing().x * xrng[0],
        ssiSrc.origin().y + ssiSrc.spacing().y * yrng[0], 0
    ]
    spacingout = [
        (low_sub.size().x * ssiSrc.spacing().x) / (img_region.size().x),
        (low_sub.size().y * ssiSrc.spacing().y) / (img_region.size().y), 1
    ]

    gridout = cc.MakeGrid(img_region.size().tolist(), spacingout, originout)
    img_region.setGrid(gridout)

    only_sub = np.zeros(ssiSrc.size().tolist()[0:2])
    only_sub[xrng[0]:xrng[1], yrng[0]:yrng[1]] = np.squeeze(low_sub.asnp())
    only_sub = common.ImFromNPArr(only_sub)
    only_sub.setGrid(ssiSrc.grid())

    # Deform the only sub region to
    only_sub.toType(ca.MEM_DEVICE)
    def_sub = ca.Image3D(bfi_df.grid(), bfi_df.memType())
    cc.ApplyHReal(def_sub, only_sub, bfi_df)
    def_sub.toType(ca.MEM_HOST)

    # Now have to find the bounding box in the deformation space (bfi space)
    if 'deformation_bbx' not in region_dict[region]:
        bb_def = np.squeeze(pp.LandmarkPicker([np.squeeze(def_sub.asnp())]))
        bb_def_y = [bb_def[0][0], bb_def[1][0]]
        bb_def_x = [bb_def[0][1], bb_def[1][1]]
        region_dict[region]['deformation_bbx'] = bb_def_x
        region_dict[region]['deformation_bby'] = bb_def_y

    with open(
            ss_dir +
            'src_registration/M{0}/section_{1}/M{0}_01_section_{1}_regions.txt'
            .format(mkyNum, secNum), 'w') as f:
        json.dump(region_dict, f)
        f.close()

    # Now need to extract the region and create a deformation and image that have the same resolution as the img_region
    deform_sub = cc.SubVol(bfi_df, region_dict[region]['deformation_bbx'],
                           region_dict[region]['deformation_bby'])

    common.DebugHere()
    sizeout = [
        int(
            np.ceil((deform_sub.size().x * deform_sub.spacing().x) /
                    img_region.spacing().x)),
        int(
            np.ceil((deform_sub.size().y * deform_sub.spacing().y) /
                    img_region.spacing().y)), 1
    ]

    region_grid = cc.MakeGrid(sizeout,
                              img_region.spacing().tolist(),
                              deform_sub.origin().tolist())

    def_im_region = ca.Image3D(region_grid, deform_sub.memType())
    up_deformation = ca.Field3D(region_grid, deform_sub.memType())

    img_region.toType(ca.MEM_DEVICE)
    cc.ResampleWorld(up_deformation, deform_sub,
                     ca.BACKGROUND_STRATEGY_PARTIAL_ZERO)
    cc.ApplyHReal(def_im_region, img_region, up_deformation)

    ss_out = ss_dir + 'Blockface_registered/M{0}/section_{1}/{2}/'.format(
        mkyNum, secNum, region)

    if not pth.exists(pth.expanduser(ss_out)):
        os.mkdir(pth.expanduser(ss_out))

    common.SaveITKImage(
        def_im_region,
        pth.expanduser(ss_out) +
        'M{0}_01_section_{1}_{2}_def_to_bfi.nrrd'.format(
            mkyNum, secNum, region))
    common.SaveITKImage(
        def_im_region,
        pth.expanduser(ss_out) +
        'M{0}_01_section_{1}_{2}_def_to_bfi.tiff'.format(
            mkyNum, secNum, region))
    del img_region, def_im_region, ssiSrc, deform_sub

    # Now apply the same deformation to the confocal images
    conf_grid = cc.LoadGrid(
        conf_dir +
        'sidelight_registered/M{0}/section_{1}/{2}/affine_registration_grid.txt'
        .format(mkyNum, secNum, region))
    cf_out = conf_dir + 'blockface_registered/M{0}/section_{1}/{2}/'.format(
        mkyNum, secNum, region)
    # confocal.toType(ca.MEM_DEVICE)
    # def_conf = ca.Image3D(region_grid, deform_sub.memType())
    # cc.ApplyHReal(def_conf, confocal, up_deformation)

    for channel in range(0, 4):
        z_stack = []
        num_slices = len(
            glob.glob(conf_dir +
                      'sidelight_registered/M{0}/section_{1}/{3}/Ch{2}/*.tiff'.
                      format(mkyNum, secNum, channel, region)))
        for z in range(0, num_slices):
            src_im = common.LoadITKImage(
                conf_dir +
                'sidelight_registered/M{0}/section_{1}/{3}/Ch{2}/M{0}_01_section_{1}_LGN_RHS_Ch{2}_conf_aff_sidelight_z{4}.tiff'
                .format(mkyNum, secNum, channel, region,
                        str(z).zfill(2)))
            src_im.setGrid(
                cc.MakeGrid(
                    ca.Vec3Di(conf_grid.size().x,
                              conf_grid.size().y, 1), conf_grid.spacing(),
                    conf_grid.origin()))
            src_im.toType(ca.MEM_DEVICE)
            def_im = ca.Image3D(region_grid, ca.MEM_DEVICE)
            cc.ApplyHReal(def_im, src_im, up_deformation)
            def_im.toType(ca.MEM_HOST)
            common.SaveITKImage(
                def_im, cf_out +
                'Ch{2}/M{0}_01_section_{1}_{3}_Ch{2}_conf_def_blockface_z{4}.tiff'
                .format(mkyNum, secNum, channel, region,
                        str(z).zfill(2)))
            if z == 0:
                common.SaveITKImage(
                    def_im, cf_out +
                    'Ch{2}/M{0}_01_section_{1}_{3}_Ch{2}_conf_def_blockface_z{4}.nrrd'
                    .format(mkyNum, secNum, channel, region,
                            str(z).zfill(2)))
            z_stack.append(def_im)
            print('==> Done with Ch {0}: {1}/{2}'.format(
                channel, z, num_slices - 1))
        stacked = cc.Imlist_to_Im(z_stack)
        stacked.setSpacing(
            ca.Vec3Df(region_grid.spacing().x,
                      region_grid.spacing().y,
                      conf_grid.spacing().z))
        common.SaveITKImage(
            stacked, cf_out +
            'Ch{2}/M{0}_01_section_{1}_{3}_Ch{2}_conf_def_blockface_stack.nrrd'
            .format(mkyNum, secNum, channel, region))
        if channel == 0:
            cc.WriteGrid(
                stacked.grid(),
                cf_out + 'deformed_registration_grid.txt'.format(
                    mkyNum, secNum, region))
Ejemplo n.º 29
0
    sigma_I = 0.09
    nIter_I = 200
    [def_ID, theta, energy] = apps.IDiff(def_TPS,
                                         M13_aff,
                                         eps,
                                         sigma_I,
                                         nIter_I,
                                         plot=True,
                                         verbose=1)

    if write:
        cc.WriteMHA(def_ID, M15dir + 'IDiff/M15_01_ID_to_M13.mha')
        cc.WriteMHA(theta, M15dir + 'IDiff/M15_01_ID_Field_to_M13.mha')

    h = cc.LoadMHA(M15dir + 'TPS/M15_01_TPS_Field_to_M13.mha', memT)
    compDef = ca.Field3D(M13_aff.grid(), memT)
    ca.ComposeHH(compDef, h, theta, bg=ca.BACKGROUND_STRATEGY_CLAMP)

    common.DebugHere()

    Final_aff = ca.Image3D(M13_aff.grid(), memT)
    cc.ApplyAffineReal(Final_aff, M15, np.linalg.inv(aff))

    Final = ca.Image3D(M13_aff.grid(), memT)
    cc.ApplyHReal(Final, Final_aff, compDef)

    if write:
        cc.WriteMHA(Final, M15dir + 'FullDef/M15_01_MRI_as_M13.mha')
        cc.WriteMHA(compDef, M15dir + 'FullDef/M15_01_Field_to_M13.mha')

if not M15_to_M13: