Exemplo n.º 1
0
def Matching(cf):

    if cf.compute.useCUDA and cf.compute.gpuID is not None:
        ca.SetCUDADevice(cf.compute.gpuID)

    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # Output loaded config
    if cf.io.outputPrefix is not None:
        cfstr = Config.ConfigToYAML(MatchingConfigSpec, cf)
        with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
            f.write(cfstr)

    mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST


    I0 = common.LoadITKImage(cf.study.I0, mType)
    I1 = common.LoadITKImage(cf.study.I1, mType)
    #ca.DivC_I(I0,255.0)
    #ca.DivC_I(I1,255.0)
    grid = I0.grid()

    ca.ThreadMemoryManager.init(grid, mType, 1)
    
    #common.DebugHere()
    # TODO: need to work on these
    t = [x*1./cf.optim.nTimeSteps for x in range(cf.optim.nTimeSteps+1)]
    checkpointinds = range(1,len(t))
    checkpointstates =  [(ca.Field3D(grid,mType),ca.Field3D(grid,mType)) for idx in checkpointinds]

    p = MatchingVariables(I0,I1, cf.vectormomentum.sigma, t,checkpointinds, checkpointstates, cf.vectormomentum.diffOpParams[0], cf.vectormomentum.diffOpParams[1], cf.vectormomentum.diffOpParams[2], cf.optim.Niter, cf.optim.stepSize, cf.optim.maxPert, cf.optim.nTimeSteps, integMethod = cf.optim.integMethod, optMethod=cf.optim.method, nInv=cf.optim.NIterForInverse,plotEvery=cf.io.plotEvery, plotSlice = cf.io.plotSlice, quiverEvery = cf.io.quiverEvery, outputPrefix = cf.io.outputPrefix)

    RunMatching(p)

    # write output
    if cf.io.outputPrefix is not None: 
        # reset all variables by shooting once, may have been overwritten
        CAvmCommon.IntegrateGeodesic(p.m0,p.t,p.diffOp,\
                          p.m, p.g, p.ginv,\
                          p.scratchV1, p.scratchV2,p. scratchV3,\
                          p.checkpointstates, p.checkpointinds,\
                          Ninv=p.nInv, integMethod = p.integMethod)
        common.SaveITKField(p.m0, cf.io.outputPrefix+"m0.mhd")
        common.SaveITKField(p.ginv, cf.io.outputPrefix+"phiinv.mhd")
        common.SaveITKField(p.g, cf.io.outputPrefix+"phi.mhd")
def BuildAtlas(cf):
    """Worker for running Atlas construction on a subset of individuals.
    Runs Atlas on this subset sequentially. The variations retuned are
    summed up to get update for all individuals
    """

    localRank = Compute.GetMPIInfo()['local_rank']
    rank = Compute.GetMPIInfo()['rank']

    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # just one reporter process on each node
    isReporter = rank == 0
    cf.study.numSubjects = len(cf.study.subjectImages)

    if isReporter:
        # Output loaded config
        if cf.io.outputPrefix is not None:
            cfstr = Config.ConfigToYAML(AtlasConfigSpec, cf)
            with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
                f.write(cfstr)
    #common.DebugHere()

    # if MPI check if processes are greater than number of subjects. it is okay if there are more subjects than processes

    if cf.compute.useMPI and (cf.study.numSubjects < cf.compute.numProcesses):
        raise Exception("Please don't use more processes " +
                        "than total number of individuals")

    # subdivide data, create subsets for this thread to work on
    nodeSubjectIds = cf.study.subjectIds[rank::cf.compute.numProcesses]
    nodeImages = cf.study.subjectImages[rank::cf.compute.numProcesses]
    nodeWeights = cf.study.subjectWeights[rank::cf.compute.numProcesses]

    numLocalSubjects = len(nodeImages)
    print 'rank:', rank, ', localRank:', localRank, ', nodeImages:', nodeImages, ', nodeWeights:', nodeWeights

    # mem type is determined by whether or not we're using CUDA
    mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST

    # load data in memory
    # load intercepts
    J_array = [
        common.LoadITKImage(f, mType) if isinstance(f, str) else f
        for f in nodeImages
    ]

    # get imGrid from data
    imGrid = J_array[0].grid()

    # atlas image
    atlas = ca.Image3D(imGrid, mType)

    # allocate memory to store only the initial momenta for each individual in this thread
    m_array = [ca.Field3D(imGrid, mType) for i in range(numLocalSubjects)]

    # allocate only one copy of scratch memory to be reused for each local individual in this thread in loop
    p = WarpVariables(imGrid,
                      mType,
                      cf.vectormomentum.diffOpParams[0],
                      cf.vectormomentum.diffOpParams[1],
                      cf.vectormomentum.diffOpParams[2],
                      cf.optim.NIterForInverse,
                      cf.vectormomentum.sigma,
                      cf.optim.stepSize,
                      integMethod=cf.optim.integMethod)

    # memory to accumulate numerators and denominators for atlas from
    # local individuals which will be summed across MPI threads
    sumSplatI = ca.Image3D(imGrid, mType)
    sumJac = ca.Image3D(imGrid, mType)

    # start up the memory manager for scratch variables
    ca.ThreadMemoryManager.init(imGrid, mType, 0)

    # need some host memory in np array format for MPI reductions
    if cf.compute.useMPI:
        mpiImageBuff = None if mType == ca.MEM_HOST else ca.Image3D(
            imGrid, ca.MEM_HOST)

    t = [
        x * 1. / (cf.optim.nTimeSteps) for x in range(cf.optim.nTimeSteps + 1)
    ]
    cpinds = range(1, len(t))
    msmtinds = [
        len(t) - 2
    ]  # since t=0 is not in cpinds, thats just identity deformation so not checkpointed
    cpstates = [(ca.Field3D(imGrid, mType), ca.Field3D(imGrid, mType))
                for idx in cpinds]
    gradAtMsmts = [ca.Image3D(imGrid, mType) for idx in msmtinds]

    EnergyHistory = []

    # TODO: better initializations
    # initialize atlas image with zeros.
    ca.SetMem(atlas, 0.0)
    # initialize momenta with zeros

    for m0_individual in m_array:
        ca.SetMem(m0_individual, 0.0)
    '''
    # initial template image
    ca.SetMem(groupState.I0, 0.0)
    tmp = ca.ManagedImage3D(imGrid, mType)

    for tdisc in tdiscGroup:
        if tdisc.J is not None:
            ca.Copy(tmp, tdisc.J)
            groupState.I0 += tmp
    del tmp
    if cf.compute.useMPI:
        Compute.Reduce(groupState.I0, mpiImageBuff)
    
    # divide by total num subjects
    groupState.I0 /= cf.study.numSubjects
    '''

    # preprocessinput

    # assign atlas reference to p.I0. This reference will not change.
    p.I0 = atlas

    # run the loop
    for it in range(cf.optim.Niter):
        # run one iteration of warp for each individual and update
        # their own initial momenta and also accumulate SplatI and Jac
        ca.SetMem(sumSplatI, 0.0)
        ca.SetMem(sumJac, 0.0)
        TotalVEnergy = np.array([0.0])
        TotalIEnergy = np.array([0.0])

        for itsub in range(numLocalSubjects):
            # initializations for this subject, this only assigns
            # reference to image variables
            p.m0 = m_array[itsub]
            Imsmts = [J_array[itsub]]

            # run warp iteration
            VEnergy, IEnergy = RunWarpIteration(nodeSubjectIds[itsub], cf, p,
                                                t, Imsmts, cpinds, cpstates,
                                                msmtinds, gradAtMsmts, it)

            # gather relevant results
            ca.Add_I(sumSplatI, p.sumSplatI)
            ca.Add_I(sumJac, p.sumJac)
            TotalVEnergy[0] += VEnergy
            TotalIEnergy[0] += IEnergy

        # if there are multiple nodes we'll need to sum across processes now
        if cf.compute.useMPI:
            # do an MPI sum
            Compute.Reduce(sumSplatI, mpiImageBuff)
            Compute.Reduce(sumJac, mpiImageBuff)

            # also sum up energies of other nodes
            mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE,
                                            TotalVEnergy,
                                            op=mpi4py.MPI.SUM)
            mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE,
                                            TotalIEnergy,
                                            op=mpi4py.MPI.SUM)

        EnergyHistory.append([TotalVEnergy[0], TotalIEnergy[0]])

        # now divide to get the new atlas image
        ca.Div(atlas, sumSplatI, sumJac)

        # keep track of energy in this iteration
        if isReporter and cf.io.plotEvery > 0 and ((
            (it + 1) % cf.io.plotEvery == 0) or (it == cf.optim.Niter - 1)):
            # plots
            AtlasPlots(cf, p, atlas, m_array, EnergyHistory)

        if isReporter:
            # print out energy
            (VEnergy, IEnergy) = EnergyHistory[-1]
            print "Iter", it, "of", cf.optim.Niter, ":", VEnergy + IEnergy, '(Total) = ', VEnergy, '(Vector) + ', IEnergy, '(Image)'

    # write output images and fields
    AtlasWriteOutput(cf, atlas, m_array, nodeSubjectIds, isReporter)
def MatchingImageMomenta(cf):
    """Runs matching for image momenta pair."""
    if cf.compute.useCUDA and cf.compute.gpuID is not None:
        ca.SetCUDADevice(cf.compute.gpuID)

    common.DebugHere()
    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # Output loaded config
    if cf.io.outputPrefix is not None:
        cfstr = Config.ConfigToYAML(MatchingImageMomentaConfigSpec, cf)
        with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
            f.write(cfstr)

    # mem type is determined by whether or not we're using CUDA
    mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST

    # load data in memory
    I0 = common.LoadITKImage(cf.study.I, mType)
    m0 = common.LoadITKField(cf.study.m, mType)
    J1 = common.LoadITKImage(cf.study.J, mType)
    n1 = common.LoadITKField(cf.study.n, mType)

    # get imGrid from data
    imGrid = I0.grid()

    # create time array with checkpointing info for this geodesic to be estimated
    (s, scratchInd,
     rCpinds) = CAvmHGM.HGMSetUpTimeArray(cf.optim.nTimeSteps, [1.0], 0.001)
    tDiscGeodesic = CAvmHGMCommon.HGMSetupTimeDiscretizationResidual(
        s, rCpinds, imGrid, mType)

    # create the state variable for geodesic that is going to hold all info
    p0 = ca.Field3D(imGrid, mType)
    geodesicState = CAvmHGMCommon.HGMResidualState(
        I0,
        p0,
        imGrid,
        mType,
        cf.vectormomentum.diffOpParams[0],
        cf.vectormomentum.diffOpParams[1],
        cf.vectormomentum.diffOpParams[2],
        s,
        cf.optim.NIterForInverse,
        1.0,
        cf.vectormomentum.sigmaM,
        cf.vectormomentum.sigmaI,
        cf.optim.stepSize,
        integMethod=cf.optim.integMethod)
    # initialize with zero
    ca.SetMem(geodesicState.p0, 0.0)
    # start up the memory manager for scratch variables
    ca.ThreadMemoryManager.init(imGrid, mType, 0)
    EnergyHistory = []
    # run the loop
    for it in range(cf.optim.Niter):
        # shoot the geodesic forward
        CAvmHGMCommon.HGMIntegrateGeodesic(geodesicState.p0, geodesicState.s,
                                           geodesicState.diffOp,
                                           geodesicState.p, geodesicState.rho,
                                           geodesicState.rhoinv, tDiscGeodesic,
                                           geodesicState.Ninv,
                                           geodesicState.integMethod)
        # integrate the geodesic backward
        CAvmHGMCommon.HGMIntegrateAdjointsResidual(geodesicState,
                                                   tDiscGeodesic, m0, J1, n1)

        # TODO: verify it should just be log map/simple image matching when sigmaM=\infty
        # gradient descent step for geodesic.p0
        CAvmHGMCommon.HGMTakeGradientStepResidual(geodesicState)

        # compute and print energy
        (VEnergy, IEnergy,
         MEnergy) = MatchingImageMomentaComputeEnergy(geodesicState, m0, J1,
                                                      n1)
        EnergyHistory.append(
            [VEnergy + IEnergy + MEnergy, VEnergy, IEnergy, MEnergy])
        print "Iter", it, "of", cf.optim.Niter, ":", VEnergy + IEnergy + MEnergy, '(Total) = ', VEnergy, '(Vector) + ', IEnergy, '(Image Match) + ', MEnergy, '(Momenta Match)'

        # plots
        if cf.io.plotEvery > 0 and (((it + 1) % cf.io.plotEvery == 0) or
                                    (it == cf.optim.Niter - 1)):
            MatchingImageMomentaPlots(cf,
                                      geodesicState,
                                      tDiscGeodesic,
                                      EnergyHistory,
                                      m0,
                                      J1,
                                      n1,
                                      writeOutput=True)

    # write output
    MatchingImageMomentaWriteOuput(cf, geodesicState)
Exemplo n.º 4
0
def predict_image(args, moving_images, target_images, output_prefixes):
    if (args.use_CPU_for_shooting):
        mType = ca.MEM_HOST
    else:
        mType = ca.MEM_DEVICE

    # load the prediction network
    predict_network_config = torch.load(args.prediction_parameter)
    prediction_net = create_net(args, predict_network_config);

    batch_size = args.batch_size
    patch_size = predict_network_config['patch_size']
    input_batch = torch.zeros(batch_size, 2, patch_size, patch_size, patch_size).cuda()

    # use correction network if required
    if args.use_correction:
        correction_network_config = torch.load(args.correction_parameter);
        correction_net = create_net(args, correction_network_config);
    else:
        correction_net = None;

    # start prediction
    for i in range(0, len(moving_images)):

        common.Mkdir_p(os.path.dirname(output_prefixes[i]))
        if (args.affine_align):
            # Perform affine registration to both moving and target image to the ICBM152 atlas space.
            # Registration is done using Niftireg.
            call(["reg_aladin",
                  "-noSym", "-speeeeed", "-ref", args.atlas ,
                  "-flo", moving_images[i],
                  "-res", output_prefixes[i]+"moving_affine.nii",
                  "-aff", output_prefixes[i]+'moving_affine_transform.txt'])

            call(["reg_aladin",
                  "-noSym", "-speeeeed" ,"-ref", args.atlas ,
                  "-flo", target_images[i],
                  "-res", output_prefixes[i]+"target_affine.nii",
                  "-aff", output_prefixes[i]+'target_affine_transform.txt'])

            moving_image = common.LoadITKImage(output_prefixes[i]+"moving_affine.nii", mType)
            target_image = common.LoadITKImage(output_prefixes[i]+"target_affine.nii", mType)
        else:
            moving_image = common.LoadITKImage(moving_images[i], mType)
            target_image = common.LoadITKImage(target_images[i], mType)

        #preprocessing of the image
        moving_image_np = preprocess_image(moving_image, args.histeq);
        target_image_np = preprocess_image(target_image, args.histeq);

        grid = moving_image.grid()
        #moving_image = ca.Image3D(grid, mType)
        #target_image = ca.Image3D(grid, mType)
        moving_image_processed = common.ImFromNPArr(moving_image_np, mType)
        target_image_processed = common.ImFromNPArr(target_image_np, mType)
        moving_image.setGrid(grid)
        target_image.setGrid(grid)

        # Indicating whether we are using the old parameter files for the Neuroimage experiments (use .t7 files from matlab .h5 format)
        predict_transform_space = False
        if 'matlab_t7' in predict_network_config:
            predict_transform_space = True
        # run actual prediction
        prediction_result = util.predict_momentum(moving_image_np, target_image_np, input_batch, batch_size, patch_size, prediction_net, predict_transform_space);
        m0 = prediction_result['image_space']
        #convert to registration space and perform registration
        m0_reg = common.FieldFromNPArr(m0, mType);

        #perform correction
        if (args.use_correction):
            registration_result = registration_methods.geodesic_shooting(moving_image_processed, target_image_processed, m0_reg, args.shoot_steps, mType, predict_network_config)
            target_inv_np = common.AsNPCopy(registration_result['I1_inv'])

            correct_transform_space = False
            if 'matlab_t7' in correction_network_config:
                correct_transform_space = True
            correction_result = util.predict_momentum(moving_image_np, target_inv_np, input_batch, batch_size, patch_size, correction_net, correct_transform_space);
            m0_correct = correction_result['image_space']
            m0 += m0_correct;
            m0_reg = common.FieldFromNPArr(m0, mType);

        registration_result = registration_methods.geodesic_shooting(moving_image, target_image, m0_reg, args.shoot_steps, mType, predict_network_config)

        #endif

        write_result(registration_result, output_prefixes[i]);
def GeodesicShooting(cf):

    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # Output loaded config
    if cf.io.outputPrefix is not None:
        cfstr = Config.ConfigToYAML(GeodesicShootingConfigSpec, cf)
        with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
            f.write(cfstr)

    mType = ca.MEM_DEVICE if cf.useCUDA else ca.MEM_HOST
    #common.DebugHere()
    I0 = common.LoadITKImage(cf.study.I0, mType)
    m0 = common.LoadITKField(cf.study.m0, mType)
    grid = I0.grid()

    ca.ThreadMemoryManager.init(grid, mType, 1)
    # set up diffOp
    if mType == ca.MEM_HOST:
        diffOp = ca.FluidKernelFFTCPU()
    else:
        diffOp = ca.FluidKernelFFTGPU()
    diffOp.setAlpha(cf.diffOpParams[0])
    diffOp.setBeta(cf.diffOpParams[1])
    diffOp.setGamma(cf.diffOpParams[2])
    diffOp.setGrid(grid)

    g = ca.Field3D(grid, mType)
    ginv = ca.Field3D(grid, mType)
    mt = ca.Field3D(grid, mType)
    It = ca.Image3D(grid, mType)
    t = [
        x * 1. / cf.integration.nTimeSteps
        for x in range(cf.integration.nTimeSteps + 1)
    ]
    checkpointinds = range(1, len(t))
    checkpointstates = [(ca.Field3D(grid, mType), ca.Field3D(grid, mType))
                        for idx in checkpointinds]

    scratchV1 = ca.Field3D(grid, mType)
    scratchV2 = ca.Field3D(grid, mType)
    scratchV3 = ca.Field3D(grid, mType)
    # scale momenta to shoot
    cf.study.scaleMomenta = float(cf.study.scaleMomenta)
    if abs(cf.study.scaleMomenta) > 0.000000:
        ca.MulC_I(m0, float(cf.study.scaleMomenta))
        CAvmCommon.IntegrateGeodesic(m0,t,diffOp, mt, g, ginv,\
                                     scratchV1,scratchV2,scratchV3,\
                                     keepstates=checkpointstates,keepinds=checkpointinds,
                                     Ninv=cf.integration.NIterForInverse, integMethod = cf.integration.integMethod)
    else:
        ca.Copy(It, I0)
        ca.Copy(mt, m0)
        ca.SetToIdentity(ginv)
        ca.SetToIdentity(g)

    # write output
    if cf.io.outputPrefix is not None:
        # scale back shotmomenta before writing
        if abs(cf.study.scaleMomenta) > 0.000000:
            ca.ApplyH(It, I0, ginv)
            ca.CoAd(mt, ginv, m0)
            ca.DivC_I(mt, float(cf.study.scaleMomenta))

        common.SaveITKImage(It, cf.io.outputPrefix + "I1.mhd")
        common.SaveITKField(mt, cf.io.outputPrefix + "m1.mhd")
        common.SaveITKField(ginv, cf.io.outputPrefix + "phiinv.mhd")
        common.SaveITKField(g, cf.io.outputPrefix + "phi.mhd")
        GeodesicShootingPlots(g, ginv, I0, It, cf)
        if cf.io.saveFrames:
            SaveFrames(checkpointstates, checkpointinds, I0, It, m0, mt, cf)
def SaveFrames(checkpointstates, checkpointinds, I0, It, m0, mt, cf):
    momentathresh = 0.00002
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix) + '/frames/')
    image_idx = 0
    fig = plt.figure(1, frameon=False)
    plt.clf()
    display.DispImage(I0,
                      '',
                      newFig=False,
                      cmap='gray',
                      dim=cf.io.plotSliceDim,
                      sliceIdx=cf.io.plotSlice)
    plt.draw()
    outfilename = cf.io.outputPrefix + '/frames/I' + str(image_idx).zfill(
        5) + '.png'
    fig.set_size_inches(4, 4)
    plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

    fig = plt.figure(2, frameon=False)
    plt.clf()
    temp = ca.Field3D(I0.grid(), I0.memType())
    ca.SetToIdentity(temp)
    common.DebugHere()
    CAvmCommon.MyGridPlot(temp,
                          every=cf.io.gridEvery,
                          color='k',
                          dim=cf.io.plotSliceDim,
                          sliceIdx=cf.io.plotSlice,
                          isVF=False,
                          plotBase=False)
    #fig.patch.set_alpha(0)
    #fig.patch.set_visible(False)
    a = fig.gca()
    #a.set_frame_on(False)
    a.set_xticks([])
    a.set_yticks([])
    plt.axis('tight')
    plt.axis('image')
    plt.axis('off')
    plt.draw()
    fig.set_size_inches(4, 4)
    outfilename = cf.io.outputPrefix + '/frames/invdef' + str(image_idx).zfill(
        5) + '.png'
    plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

    fig = plt.figure(3, frameon=False)
    plt.clf()
    CAvmCommon.MyGridPlot(temp,
                          every=cf.io.gridEvery,
                          color='k',
                          dim=cf.io.plotSliceDim,
                          sliceIdx=cf.io.plotSlice,
                          isVF=False,
                          plotBase=False)
    #fig.patch.set_alpha(0)
    #fig.patch.set_visible(False)
    a = fig.gca()
    #a.set_frame_on(False)
    a.set_xticks([])
    a.set_yticks([])
    plt.axis('tight')
    plt.axis('image')
    plt.axis('off')
    plt.draw()
    fig.set_size_inches(4, 4)
    outfilename = cf.io.outputPrefix + '/frames/def' + str(image_idx).zfill(
        5) + '.png'
    plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

    fig = plt.figure(4, frameon=False)
    plt.clf()
    display.DispImage(I0,
                      '',
                      newFig=False,
                      cmap='gray',
                      dim=cf.io.plotSliceDim,
                      sliceIdx=cf.io.plotSlice)
    plt.hold('True')
    CAvmCommon.MyQuiver(m0,
                        dim=cf.io.plotSliceDim,
                        sliceIdx=cf.io.plotSlice,
                        every=cf.io.quiverEvery,
                        thresh=momentathresh,
                        scaleArrows=0.25,
                        arrowCol='r',
                        lineWidth=0.5,
                        width=0.005)
    plt.draw()

    plt.hold('False')

    outfilename = cf.io.outputPrefix + '/frames/m' + str(image_idx).zfill(
        5) + '.png'
    fig.set_size_inches(4, 4)
    plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

    for i in range(len(checkpointinds)):
        image_idx = image_idx + 1
        ca.ApplyH(It, I0, checkpointstates[i][1])
        fig = plt.figure(1, frameon=False)
        plt.clf()
        display.DispImage(It,
                          '',
                          newFig=False,
                          cmap='gray',
                          dim=cf.io.plotSliceDim,
                          sliceIdx=cf.io.plotSlice)
        plt.draw()
        outfilename = cf.io.outputPrefix + '/frames/I' + str(image_idx).zfill(
            5) + '.png'
        fig.set_size_inches(4, 4)
        plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

        fig = plt.figure(2, frameon=False)
        plt.clf()
        CAvmCommon.MyGridPlot(checkpointstates[i][1],
                              every=cf.io.gridEvery,
                              color='k',
                              dim=cf.io.plotSliceDim,
                              sliceIdx=cf.io.plotSlice,
                              isVF=False,
                              plotBase=False)
        #fig.patch.set_alpha(0)
        #fig.patch.set_visible(False)
        a = fig.gca()
        #a.set_frame_on(False)
        a.set_xticks([])
        a.set_yticks([])
        plt.axis('tight')
        plt.axis('image')
        plt.axis('off')
        plt.draw()
        outfilename = cf.io.outputPrefix + '/frames/invdef' + str(
            image_idx).zfill(5) + '.png'
        fig.set_size_inches(4, 4)
        plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

        fig = plt.figure(3, frameon=False)
        plt.clf()
        CAvmCommon.MyGridPlot(checkpointstates[i][0],
                              every=cf.io.gridEvery,
                              color='k',
                              dim=cf.io.plotSliceDim,
                              sliceIdx=cf.io.plotSlice,
                              isVF=False,
                              plotBase=False)
        #fig.patch.set_alpha(0)
        #fig.patch.set_visible(False)
        a = fig.gca()
        #a.set_frame_on(False)
        a.set_xticks([])
        a.set_yticks([])
        plt.axis('tight')
        plt.axis('image')
        plt.axis('off')
        plt.draw()
        outfilename = cf.io.outputPrefix + '/frames/def' + str(
            image_idx).zfill(5) + '.png'
        fig.set_size_inches(4, 4)
        plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)

        ca.CoAd(mt, checkpointstates[i][1], m0)
        fig = plt.figure(4, frameon=False)
        plt.clf()
        display.DispImage(It,
                          '',
                          newFig=False,
                          cmap='gray',
                          dim=cf.io.plotSliceDim,
                          sliceIdx=cf.io.plotSlice)
        plt.hold('True')
        CAvmCommon.MyQuiver(mt,
                            dim=cf.io.plotSliceDim,
                            sliceIdx=cf.io.plotSlice,
                            every=cf.io.quiverEvery,
                            thresh=momentathresh,
                            scaleArrows=0.40,
                            arrowCol='r',
                            lineWidth=0.5,
                            width=0.005)
        plt.draw()
        plt.hold('False')
        outfilename = cf.io.outputPrefix + '/frames/m' + str(image_idx).zfill(
            5) + '.png'
        fig.set_size_inches(4, 4)
        plt.savefig(outfilename, bbox_inches='tight', pad_inches=0, dpi=100)
Exemplo n.º 7
0
def BuildHGM(cf):
    """Worker for running Hierarchical Geodesic Model (HGM) 
n    for group geodesic estimation on a subset of individuals. 
    Runs HGM on this subset sequentially. The variations retuned
    are summed up to get update for all individuals"""

    size = Compute.GetMPIInfo()['size']
    rank = Compute.GetMPIInfo()['rank']
    name = Compute.GetMPIInfo()['name']
    localRank = Compute.GetMPIInfo()['local_rank']
    nodename = socket.gethostname()

    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # just one reporter process on each node
    isReporter = rank == 0
    cf.study.numSubjects = len(cf.study.subjectIntercepts)
    if isReporter:
        # Output loaded config
        if cf.io.outputPrefix is not None:
            cfstr = Config.ConfigToYAML(HGMConfigSpec, cf)
            with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
                f.write(cfstr)
    #common.DebugHere()

    # if MPI check if processes are greater than number of subjects. it is okay if there are more subjects than processes

    if cf.compute.useMPI and (cf.study.numSubjects < cf.compute.numProcesses):
        raise Exception("Please don't use more processes " +
                        "than total number of individuals")

    # subdivide data, create subsets for this thread to work on
    nodeSubjectIds = cf.study.subjectIds[rank::cf.compute.numProcesses]
    nodeIntercepts = cf.study.subjectIntercepts[rank::cf.compute.numProcesses]
    nodeSlopes = cf.study.subjectSlopes[rank::cf.compute.numProcesses]
    nodeBaselineTimes = cf.study.subjectBaselineTimes[rank::cf.compute.
                                                      numProcesses]
    sys.stdout.write(
        "This is process %d of %d with name: %s on machinename: %s and local rank: %d.\nnodeIntercepts: %s\n nodeSlopes: %s\n nodeBaselineTimes: %s\n"
        % (rank, size, name, nodename, localRank, nodeIntercepts, nodeSlopes,
           nodeBaselineTimes))

    # mem type is determined by whether or not we're using CUDA
    mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST

    # load data in memory
    # load intercepts
    J = [
        common.LoadITKImage(f, mType) if isinstance(f, str) else f
        for f in nodeIntercepts
    ]

    # load slopes
    n = [
        common.LoadITKField(f, mType) if isinstance(f, str) else f
        for f in nodeSlopes
    ]

    # get imGrid from data
    imGrid = J[0].grid()

    # create time array with checkpointing info for group geodesic
    (t, Jind, gCpinds) = HGMSetUpTimeArray(cf.optim.nTimeStepsGroup,
                                           nodeBaselineTimes, 0.0000001)
    tdiscGroup = CAvmHGMCommon.HGMSetupTimeDiscretizationGroup(
        t, J, n, Jind, gCpinds, mType, nodeSubjectIds)

    # create time array with checkpointing info for residual geodesic
    (s, scratchInd, rCpinds) = HGMSetUpTimeArray(cf.optim.nTimeStepsResidual,
                                                 [1.0], 0.0000001)
    tdiscResidual = CAvmHGMCommon.HGMSetupTimeDiscretizationResidual(
        s, rCpinds, imGrid, mType)

    # create group state and residual state
    groupState = CAvmHGMCommon.HGMGroupState(
        imGrid,
        mType,
        cf.vectormomentum.diffOpParamsGroup[0],
        cf.vectormomentum.diffOpParamsGroup[1],
        cf.vectormomentum.diffOpParamsGroup[2],
        t,
        cf.optim.NIterForInverse,
        cf.vectormomentum.varIntercept,
        cf.vectormomentum.varSlope,
        cf.vectormomentum.varInterceptReg,
        cf.optim.stepSizeGroup,
        integMethod=cf.optim.integMethodGroup)

    #ca.Copy(groupState.I0, common.LoadITKImage('/usr/sci/projects/ADNI/nikhil/software/vectormomentumtest/TestData/FlowerData/Longitudinal/GroupGeodesic/I0.mhd', mType))

    # note that residual state is treated a scratch variable in this algorithm and reused for computing residual geodesics of multiple individual
    residualState = CAvmHGMCommon.HGMResidualState(
        None,
        None,
        imGrid,
        mType,
        cf.vectormomentum.diffOpParamsResidual[0],
        cf.vectormomentum.diffOpParamsResidual[1],
        cf.vectormomentum.diffOpParamsResidual[2],
        s,
        cf.optim.NIterForInverse,
        cf.vectormomentum.varIntercept,
        cf.vectormomentum.varSlope,
        cf.vectormomentum.varInterceptReg,
        cf.optim.stepSizeResidual,
        integMethod=cf.optim.integMethodResidual)

    # start up the memory manager for scratch variables
    ca.ThreadMemoryManager.init(imGrid, mType, 0)

    # need some host memory in np array format for MPI reductions
    if cf.compute.useMPI:
        mpiImageBuff = None if mType == ca.MEM_HOST else ca.Image3D(
            imGrid, ca.MEM_HOST)
        mpiFieldBuff = None if mType == ca.MEM_HOST else ca.Field3D(
            imGrid, ca.MEM_HOST)
    for i in range(len(groupState.t) - 1, -1, -1):
        if tdiscGroup[i].J is not None:
            indx_last_individual = i
            break
    '''
    # initial template image
    ca.SetMem(groupState.I0, 0.0)
    tmp = ca.ManagedImage3D(imGrid, mType)

    for tdisc in tdiscGroup:
        if tdisc.J is not None:
            ca.Copy(tmp, tdisc.J)
            groupState.I0 += tmp
    del tmp
    if cf.compute.useMPI:
        Compute.Reduce(groupState.I0, mpiImageBuff)
    
    # divide by total num subjects
    groupState.I0 /= cf.study.numSubjects
    '''

    # run the loop

    for it in range(cf.optim.Niter):
        # compute HGM variation for group
        HGMGroupVariation(groupState, tdiscGroup, residualState, tdiscResidual,
                          cf.io.outputPrefix, rank, it)
        common.CheckCUDAError("Error after HGM iteration")
        # compute gradient for momenta (m is used as scratch)
        # if there are multiple nodes we'll need to sum across processes now
        if cf.compute.useMPI:
            # do an MPI sum
            Compute.Reduce(groupState.sumSplatI, mpiImageBuff)
            Compute.Reduce(groupState.sumJac, mpiImageBuff)
            Compute.Reduce(groupState.madj, mpiFieldBuff)
            # also sum up energies of other nodes
            # intercept
            Eintercept = np.array([groupState.EnergyHistory[-1][1]])
            mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE,
                                            Eintercept,
                                            op=mpi4py.MPI.SUM)
            groupState.EnergyHistory[-1][1] = Eintercept[0]

            Eslope = np.array([groupState.EnergyHistory[-1][2]])
            mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE,
                                            Eslope,
                                            op=mpi4py.MPI.SUM)
            groupState.EnergyHistory[-1][2] = Eslope[0]

        ca.Copy(groupState.m, groupState.m0)
        groupState.diffOp.applyInverseOperator(groupState.m)
        ca.Sub_I(groupState.m, groupState.madj)
        #groupState.diffOp.applyOperator(groupState.m)
        # now take gradient step in momenta for group
        if cf.optim.method == 'FIXEDGD':
            # take fixed stepsize gradient step
            ca.Add_MulC_I(groupState.m0, groupState.m, -cf.optim.stepSizeGroup)
        else:
            raise Exception("Unknown optimization scheme: " + cf.optim.method)
        # end if

        # now divide to get the new base image for group
        ca.Div(groupState.I0, groupState.sumSplatI, groupState.sumJac)

        # keep track of energy in this iteration
        if isReporter and cf.io.plotEvery > 0 and ((
            (it + 1) % cf.io.plotEvery == 0) or (it == cf.optim.Niter - 1)):
            HGMPlots(cf,
                     groupState,
                     tdiscGroup,
                     residualState,
                     tdiscResidual,
                     indx_last_individual,
                     writeOutput=True)

        if isReporter:
            (VEnergy, IEnergy, SEnergy) = groupState.EnergyHistory[-1]
            print datetime.datetime.now().time(
            ), " Iter", it, "of", cf.optim.Niter, ":", VEnergy + IEnergy + SEnergy, '(Total) = ', VEnergy, '(Vector) + ', IEnergy, '(Intercept) + ', SEnergy, '(Slope)'

    # write output images and fields
    HGMWriteOutput(cf, groupState, tdiscGroup, isReporter)
Exemplo n.º 8
0
def predict_image(args):
    if (args.use_CPU_for_shooting):
        mType = ca.MEM_HOST
    else:
        mType = ca.MEM_DEVICE

    # load the prediction network
    predict_network_config = torch.load(args.prediction_parameter)
    prediction_net = create_net(args, predict_network_config)

    batch_size = args.batch_size
    patch_size = predict_network_config['patch_size']
    input_batch = torch.zeros(batch_size, 2, patch_size, patch_size,
                              patch_size).cuda()

    # start prediction
    for i in range(0, len(args.moving_image)):
        common.Mkdir_p(os.path.dirname(args.output_prefix[i]))
        if (args.affine_align):
            # Perform affine registration to both moving and target image to the ICBM152 atlas space.
            # Registration is done using Niftireg.
            call([
                "reg_aladin", "-noSym", "-speeeeed", "-ref", args.atlas,
                "-flo", args.moving_image[i], "-res",
                args.output_prefix[i] + "moving_affine.nii", "-aff",
                args.output_prefix[i] + 'moving_affine_transform.txt'
            ])

            call([
                "reg_aladin", "-noSym", "-speeeeed", "-ref", args.atlas,
                "-flo", args.target_image[i], "-res",
                args.output_prefix[i] + "target_affine.nii", "-aff",
                args.output_prefix[i] + 'target_affine_transform.txt'
            ])

            moving_image = common.LoadITKImage(
                args.output_prefix[i] + "moving_affine.nii", mType)
            target_image = common.LoadITKImage(
                args.output_prefix[i] + "target_affine.nii", mType)
        else:
            moving_image = common.LoadITKImage(args.moving_image[i], mType)
            target_image = common.LoadITKImage(args.target_image[i], mType)

        #preprocessing of the image
        moving_image_np = preprocess_image(moving_image, args.histeq)
        target_image_np = preprocess_image(target_image, args.histeq)

        grid = moving_image.grid()
        moving_image_processed = common.ImFromNPArr(moving_image_np, mType)
        target_image_processed = common.ImFromNPArr(target_image_np, mType)
        moving_image.setGrid(grid)
        target_image.setGrid(grid)

        predict_transform_space = False
        if 'matlab_t7' in predict_network_config:
            predict_transform_space = True
        # run actual prediction
        prediction_result = util.predict_momentum(moving_image_np,
                                                  target_image_np, input_batch,
                                                  batch_size, patch_size,
                                                  prediction_net,
                                                  predict_transform_space)

        m0 = prediction_result['image_space']
        m0_reg = common.FieldFromNPArr(prediction_result['image_space'], mType)
        registration_result = registration_methods.geodesic_shooting(
            moving_image_processed, target_image_processed, m0_reg,
            args.shoot_steps, mType, predict_network_config)
        phi = common.AsNPCopy(registration_result['phiinv'])
        phi_square = np.power(phi, 2)

        for sample_iter in range(1, args.samples):
            print(sample_iter)
            prediction_result = util.predict_momentum(
                moving_image_np, target_image_np, input_batch, batch_size,
                patch_size, prediction_net, predict_transform_space)
            m0 += prediction_result['image_space']
            m0_reg = common.FieldFromNPArr(prediction_result['image_space'],
                                           mType)
            registration_result = registration_methods.geodesic_shooting(
                moving_image_processed, target_image_processed, m0_reg,
                args.shoot_steps, mType, predict_network_config)
            phi += common.AsNPCopy(registration_result['phiinv'])
            phi_square += np.power(
                common.AsNPCopy(registration_result['phiinv']), 2)

        m0_mean = np.divide(m0, args.samples)
        m0_reg = common.FieldFromNPArr(m0_mean, mType)
        registration_result = registration_methods.geodesic_shooting(
            moving_image_processed, target_image_processed, m0_reg,
            args.shoot_steps, mType, predict_network_config)
        phi_mean = registration_result['phiinv']
        phi_var = np.divide(phi_square, args.samples) - np.power(
            np.divide(phi, args.samples), 2)

        #save result
        common.SaveITKImage(registration_result['I1'],
                            args.output_prefix[i] + "I1.mhd")
        common.SaveITKField(phi_mean,
                            args.output_prefix[i] + "phiinv_mean.mhd")
        common.SaveITKField(common.FieldFromNPArr(phi_var, mType),
                            args.output_prefix[i] + "phiinv_var.mhd")
def main():
    # Extract the Monkey number and section number from the command line
    global frgNum
    global secOb

    mkyNum = sys.argv[1]
    secNum = sys.argv[2]
    frgNum = int(sys.argv[3])
    write = True

    # if not os.path.exists(os.path.expanduser('~/korenbergNAS/3D_database/Working/configuration_files/SidescapeRelateBlockface/M{0}/section_{1}/include_configFile.yaml'.format(mkyNum,secNum))):
    #     cf = initial(secNum, mkyNum)

    try:
        secOb = Config.Load(
            secSpec,
            pth.expanduser(
                '~/korenbergNAS/3D_database/Working/configuration_files/SidescapeRelateBlockface/M{0}/section_{1}/include_configFile.yaml'
                .format(mkyNum, secNum)))
    except IOError as e:
        try:
            temp = Config.LoadYAMLDict(pth.expanduser(
                '~/korenbergNAS/3D_database/Working/configuration_files/SidescapeRelateBlockface/M{0}/section_{1}/include_configFile.yaml'
                .format(mkyNum, secNum)),
                                       include=False)
            secOb = Config.MkConfig(temp, secSpec)
        except IOError:
            print 'It appears there is no configuration file for this section. Please initialize one and restart.'
            sys.exit()
        if frgNum == int(secOb.yamlList[frgNum][-6]):
            Fragmenter()
            try:
                secOb = Config.Load(
                    secSpec,
                    pth.expanduser(
                        '~/korenbergNAS/3D_database/Working/configuration_files/SidescapeRelateBlockface/M{0}/section_{1}/include_configFile.yaml'
                        .format(mkyNum, secNum)))
            except IOError:
                print 'It appeas that the include yaml file list does not match your fragmentation number. Please check them and restart.'
                sys.exit()

    if not pth.exists(
            pth.expanduser(secOb.ssiOutPath + 'frag{0}'.format(frgNum))):
        common.Mkdir_p(
            pth.expanduser(secOb.ssiOutPath + 'frag{0}'.format(frgNum)))
    if not pth.exists(
            pth.expanduser(secOb.bfiOutPath + 'frag{0}'.format(frgNum))):
        common.Mkdir_p(
            pth.expanduser(secOb.bfiOutPath + 'frag{0}'.format(frgNum)))
    if not pth.exists(
            pth.expanduser(secOb.ssiSrcPath + 'frag{0}'.format(frgNum))):
        os.mkdir(pth.expanduser(secOb.ssiSrcPath + 'frag{0}'.format(frgNum)))
    if not pth.exists(
            pth.expanduser(secOb.bfiSrcPath + 'frag{0}'.format(frgNum))):
        os.mkdir(pth.expanduser(secOb.bfiSrcPath + 'frag{0}'.format(frgNum)))

    frgOb = Config.MkConfig(secOb.yamlList[frgNum], frgSpec)
    ssiSrc, bfiSrc, ssiMsk, bfiMsk = Loader(frgOb, ca.MEM_HOST)

    #Extract the saturation Image from the color iamge
    bfiHsv = common.FieldFromNPArr(
        matplotlib.colors.rgb_to_hsv(
            np.rollaxis(np.array(np.squeeze(bfiSrc.asnp())), 0, 3)),
        ca.MEM_HOST)
    bfiHsv.setGrid(bfiSrc.grid())
    bfiSat = ca.Image3D(bfiSrc.grid(), bfiHsv.memType())
    ca.Copy(bfiSat, bfiHsv, 1)
    #Histogram equalize, normalize and mask the blockface saturation image
    bfiSat = cb.HistogramEqualize(bfiSat, 256)
    bfiSat.setGrid(bfiSrc.grid())
    bfiSat *= -1
    bfiSat -= ca.Min(bfiSat)
    bfiSat /= ca.Max(bfiSat)
    bfiSat *= bfiMsk
    bfiSat.setGrid(bfiSrc.grid())

    #Write out the blockface region after adjusting the colors with a format that supports header information
    if write:
        common.SaveITKImage(
            bfiSat,
            pth.expanduser(secOb.bfiSrcPath +
                           'frag{0}/M{1}_01_bfi_section_{2}_frag{0}_sat.nrrd'.
                           format(frgNum, secOb.mkyNum, secOb.secNum)))

    #Set the sidescape grid relative to that of the blockface
    ssiSrc.setGrid(ConvertGrid(ssiSrc.grid(), bfiSat.grid()))
    ssiMsk.setGrid(ConvertGrid(ssiMsk.grid(), bfiSat.grid()))
    ssiSrc *= ssiMsk

    #Write out the sidescape masked image in a format that stores the header information
    if write:
        common.SaveITKImage(
            ssiSrc,
            pth.expanduser(secOb.ssiSrcPath +
                           'frag{0}/M{1}_01_ssi_section_{2}_frag{0}.nrrd'.
                           format(frgNum, secOb.mkyNum, secOb.secNum)))

    #Update the image parameters of the sidescape image for future use
    frgOb.imSize = ssiSrc.size().tolist()
    frgOb.imOrig = ssiSrc.origin().tolist()
    frgOb.imSpac = ssiSrc.spacing().tolist()
    updateFragOb(frgOb)

    #Find the affine transform between the two fragments
    bfiAff, ssiAff, aff = Affine(bfiSat, ssiSrc, frgOb)
    updateFragOb(frgOb)

    #Write out the affine transformed images in a format that stores header information
    if write:
        common.SaveITKImage(
            bfiAff,
            pth.expanduser(
                secOb.bfiOutPath +
                'frag{0}/M{1}_01_bfi_section_{2}_frag{0}_aff_ssi.nrrd'.format(
                    frgNum, secOb.mkyNum, secOb.secNum)))
        common.SaveITKImage(
            ssiAff,
            pth.expanduser(
                secOb.ssiOutPath +
                'frag{0}/M{1}_01_ssi_section_{2}_frag{0}_aff_bfi.nrrd'.format(
                    frgNum, secOb.mkyNum, secOb.secNum)))

    bfiVe = bfiAff.copy()
    ssiVe = ssiSrc.copy()
    cc.VarianceEqualize_I(bfiVe, sigma=frgOb.sigVarBfi, eps=frgOb.epsVar)
    cc.VarianceEqualize_I(ssiVe, sigma=frgOb.sigVarSsi, eps=frgOb.epsVar)

    #As of right now, the largest pre-computed FFT table is 2048, so resample onto that grid for registration
    regGrd = ConvertGrid(
        cc.MakeGrid(ca.Vec3Di(2048, 2048, 1), ca.Vec3Df(1, 1, 1),
                    ca.Vec3Df(0, 0, 0)), ssiSrc.grid())
    ssiReg = ca.Image3D(regGrd, ca.MEM_HOST)
    bfiReg = ca.Image3D(regGrd, ca.MEM_HOST)
    cc.ResampleWorld(ssiReg, ssiVe)
    cc.ResampleWorld(bfiReg, bfiVe)

    #Create the default configuration object for IDiff Matching and then set some parameters
    idCf = Config.SpecToConfig(IDiff.Matching.MatchingConfigSpec)
    idCf.compute.useCUDA = True
    idCf.io.outputPrefix = '/home/sci/blakez/IDtest/'

    #Run the registration
    ssiDef, phi = DefReg(ssiReg, bfiReg, frgOb, ca.MEM_DEVICE, idCf)

    #Turn the deformation into a displacement field so it can be applied to the large tif with C++ code
    affV = phi.copy()
    cc.ApplyAffineReal(affV, phi, np.linalg.inv(frgOb.affine))
    ca.HtoV_I(affV)

    #Apply the found deformation to the input ssi
    ssiSrc.toType(ca.MEM_DEVICE)
    cc.HtoReal(phi)
    affPhi = phi.copy()
    ssiBfi = ssiSrc.copy()
    upPhi = ca.Field3D(ssiSrc.grid(), phi.memType())

    cc.ApplyAffineReal(affPhi, phi, np.linalg.inv(frgOb.affine))
    cc.ResampleWorld(upPhi, affPhi, bg=2)
    cc.ApplyHReal(ssiBfi, ssiSrc, upPhi)

    # ssiPhi = ca.Image3D(ssiSrc.grid(), phi.memType())
    # upPhi = ca.Field3D(ssiSrc.grid(), phi.memType())
    # cc.ResampleWorld(upPhi, phi, bg=2)
    # cc.ApplyHReal(ssiPhi, ssiSrc, upPhi)
    # ssiBfi = ssiSrc.copy()
    # cc.ApplyAffineReal(ssiBfi, ssiPhi, np.linalg.inv(frgOb.affine))

    # #Apply affine to the deformation
    # affPhi = phi.copy()
    # cc.ApplyAffineReal(affPhi, phi, np.linalg.inv(frgOb.affine))

    if write:
        common.SaveITKImage(
            ssiBfi,
            pth.expanduser(
                secOb.ssiOutPath +
                'frag{0}/M{1}_01_ssi_section_{2}_frag{0}_def_bfi.nrrd'.format(
                    frgNum, secOb.mkyNum, secOb.secNum)))
        cc.WriteMHA(
            affPhi,
            pth.expanduser(
                secOb.ssiOutPath +
                'frag{0}/M{1}_01_ssi_section_{2}_frag{0}_to_bfi_real.mha'.
                format(frgNum, secOb.mkyNum, secOb.secNum)))
        cc.WriteMHA(
            affV,
            pth.expanduser(
                secOb.ssiOutPath +
                'frag{0}/M{1}_01_ssi_section_{2}_frag{0}_to_bfi_disp.mha'.
                format(frgNum, secOb.mkyNum, secOb.secNum)))

    #Create the list of names that the deformation should be applied to
    # nameList = ['M15_01_0956_SideLight_DimLED_10x_ORG.tif',
    #             'M15_01_0956_TyrosineHydroxylase_Ben_10x_Stitching_c1_ORG.tif',
    #             'M15_01_0956_TyrosineHydroxylase_Ben_10x_Stitching_c2_ORG.tif',
    #             'M15_01_0956_TyrosineHydroxylase_Ben_10x_Stitching_c3_ORG.tif']

    # appLarge(nameList, affPhi)

    common.DebugHere()
Exemplo n.º 10
0
def BuildGeoReg(cf):
    """Worker for running geodesic estimation on a subset of individuals
    """
    #common.DebugHere()
    localRank = Compute.GetMPIInfo()['local_rank']
    rank = Compute.GetMPIInfo()['rank']

    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # just one reporter process on each node
    isReporter = rank == 0

    # load filenames and times for all subjects
    (subjectsIds, subjectsImagePaths,
     subjectsTimes) = GeoRegLoadSubjectsDetails(cf.study.subjectFile)
    cf.study.numSubjects = len(subjectsIds)
    if isReporter:
        # Output loaded config
        if cf.io.outputPrefix is not None:
            cfstr = Config.ConfigToYAML(GeoRegConfigSpec, cf)
            with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
                f.write(cfstr)

    # if MPI check if processes are greater than number of subjects. it is okay if there are more subjects than processes
    if cf.compute.useMPI and (len(subjectsIds) < cf.compute.numProcesses):
        raise Exception("Please don't use more processes " +
                        "than total number of individuals")

    nodeSubjectsIds = subjectsIds[rank::cf.compute.numProcesses]
    nodeSubjectsImagePaths = subjectsImagePaths[rank::cf.compute.numProcesses]
    nodeSubjectsTimes = subjectsTimes[rank::cf.compute.numProcesses]

    numLocalSubjects = len(nodeSubjectsImagePaths)
    if cf.study.initializationsFile is not None:
        (subjectsInitialImages,
         subjectsInitialMomenta) = GeoRegLoadSubjectsInitializations(
             cf.study.initializationsFile)
        nodeSubjectsInitialImages = subjectsInitialImages[rank::cf.compute.
                                                          numProcesses]
        nodeSubjectsInitialMomenta = subjectsInitialMomenta[rank::cf.compute.
                                                            numProcesses]

    print 'rank:', rank, ', localRank:', localRank, ', numberSubjects/TotalSubjects:', len(
        nodeSubjectsImagePaths
    ), '/', cf.study.numSubjects, ', nodeSubjectsImagePaths:', nodeSubjectsImagePaths, ', nodeSubjectsTimes:', nodeSubjectsTimes

    # mem type is determined by whether or not we're using CUDA
    mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST

    # setting gpuid should be handled in gpu
    # if using GPU  set device based on local rank
    #if cf.compute.useCUDA:
    #    ca.SetCUDADevice(localRank)

    # get image size information
    dummyImToGetGridInfo = common.LoadITKImage(nodeSubjectsImagePaths[0][0],
                                               mType)
    imGrid = dummyImToGetGridInfo.grid()
    if cf.study.setUnitSpacing:
        imGrid.setSpacing(ca.Vec3Df(1.0, 1.0, 1.0))
    if cf.study.setZeroOrigin:
        imGrid.setOrigin(ca.Vec3Df(0, 0, 0))
    #del dummyImToGetGridInfo;

    # start up the memory manager for scratch variables
    ca.ThreadMemoryManager.init(imGrid, mType, 0)

    # allocate memory
    p = GeoRegVariables(imGrid,
                        mType,
                        cf.vectormomentum.diffOpParams[0],
                        cf.vectormomentum.diffOpParams[1],
                        cf.vectormomentum.diffOpParams[2],
                        cf.optim.NIterForInverse,
                        cf.vectormomentum.sigma,
                        cf.optim.stepSize,
                        integMethod=cf.optim.integMethod)
    # for each individual run geodesic regression for each subject
    for i in range(numLocalSubjects):

        # initializations for this subject
        if cf.study.initializationsFile is not None:
            # assuming the initializations are already preprocessed, in terms of intensities, origin and voxel scalings.
            p.I0 = common.LoadITKImage(nodeSubjectsInitialImages[i], mType)
            p.m0 = common.LoadITKField(nodeSubjectsInitialMomenta[i], mType)
        else:
            ca.SetMem(p.m0, 0.0)
            ca.SetMem(p.I0, 0.0)

        # allocate memory specific to this subject in steps a, b and c
        # a. create time array with checkpointing info for regression geodesic, allocate checkpoint memory
        (t, msmtinds, cpinds) = GeoRegSetUpTimeArray(cf.optim.nTimeSteps,
                                                     nodeSubjectsTimes[i],
                                                     0.001)
        cpstates = [(ca.Field3D(imGrid, mType), ca.Field3D(imGrid, mType))
                    for idx in cpinds]
        # b. allocate gradAtMeasurements of the length of msmtindex for storing residuals
        gradAtMsmts = [ca.Image3D(imGrid, mType) for idx in msmtinds]

        # c. load timepoint images for this subject
        Imsmts = [
            common.LoadITKImage(f, mType) if isinstance(f, str) else f
            for f in nodeSubjectsImagePaths[i]
        ]
        # reset stepsize if adaptive stepsize changed it inside
        p.stepSize = cf.optim.stepSize
        # preprocessimages
        GeoRegPreprocessInput(nodeSubjectsIds[i], cf, p, t, Imsmts, cpinds,
                              cpstates, msmtinds, gradAtMsmts)

        # run regression for this subject
        # REMEMBER
        # msmtinds index into cpinds
        # gradAtMsmts is parallel to msmtinds
        # cpinds index into t
        EnergyHistory = RunGeoReg(nodeSubjectsIds[i], cf, p, t, Imsmts, cpinds,
                                  cpstates, msmtinds, gradAtMsmts)

        # write output images and fields for this subject
        # TODO: BEWARE There are hardcoded numbers inside preprocessing code specific for ADNI/OASIS brain data.
        GeoRegWriteOuput(nodeSubjectsIds[i], cf, p, t, Imsmts, cpinds,
                         cpstates, msmtinds, gradAtMsmts, EnergyHistory)

        # clean up memory specific to this subject
        del t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts