Ejemplo n.º 1
0
def gather_file(args):
    if args.momentum:
        file = common.AsNPCopy(common.LoadITKField(args.files[0], ca.MEM_HOST))
    else:
        #file = common.AsNPCopy(common.LoadITKImage(args.files[0], ca.MEM_HOST))
        file = nib.load(args.files[0]).get_data()

    all_size = (len(args.files), ) + (15, 15, 15, 1, 3)  #file.shape;

    data = torch.zeros(all_size)
    for i in range(0, len(args.files)):
        if args.momentum:
            cur_slice = torch.from_numpy(
                common.AsNPCopy(common.LoadITKField(args.files[i],
                                                    ca.MEM_HOST)))
        else:
            cur_slice = nib.load(args.files[i]).get_data()
            if cur_slice.size == 3375:
                cur_slice = np.zeros([15, 15, 15, 1, 3])
            cur_slice = torch.from_numpy(cur_slice)
        data[i] = cur_slice

    if args.momentum:
        # transpose the dataset to fit the training format
        data = data.numpy()
        data = np.transpose(data, [0, 4, 1, 2, 3])
        data = torch.from_numpy(data)

    torch.save(data, args.output)
def MatchingImageMomentaWriteOuput(cf, geodesicState, EnergyHistory, m0, n1):
    grid = geodesicState.J0.grid()
    mType = geodesicState.J0.memType()

    # save momenta for the gedoesic
    common.SaveITKField(geodesicState.p0, cf.io.outputPrefix + "p0.mhd")

    # save matched momenta for the geodesic
    if cf.vectormomentum.matchImOnly:
        m0 = common.LoadITKField(cf.study.m, mType)

    ca.CoAd(geodesicState.p, geodesicState.rhoinv, m0)
    common.SaveITKField(geodesicState.p, cf.io.outputPrefix + "m1.mhd")

    # momenta match energy
    if cf.vectormomentum.matchImOnly:
        vecdiff = ca.ManagedField3D(grid, mType)
        ca.Sub_I(geodesicState.p, n1)
        ca.Copy(vecdiff, geodesicState.p)
        geodesicState.diffOp.applyInverseOperator(geodesicState.p)
        momentaMatchEnergy = ca.Dot(vecdiff, geodesicState.p) / (
            float(geodesicState.p0.nVox()) * geodesicState.SigmaSlope *
            geodesicState.SigmaSlope)
        # save energy
        energyFilename = cf.io.outputPrefix + "testMomentaMatchEnergy.csv"
        with open(energyFilename, 'w') as f:
            print >> f, momentaMatchEnergy

    # save matched image for the geodesic
    tempim = ca.ManagedImage3D(grid, mType)
    ca.ApplyH(tempim, geodesicState.J0, geodesicState.rhoinv)
    common.SaveITKImage(tempim, cf.io.outputPrefix + "I1.mhd")

    # save energy
    energyFilename = cf.io.outputPrefix + "energy.csv"
    MatchingImageMomentaWriteEnergyHistoryToFile(EnergyHistory, energyFilename)
def main():
    secNum = sys.argv[1]
    mkyNum = sys.argv[2]
    region = str(sys.argv[3])
    # channel = sys.argv[3]
    ext = 'M{0}/section_{1}/{2}/'.format(mkyNum, secNum, region)
    ss_dir = '/home/sci/blakez/korenbergNAS/3D_database/Working/Microscopic/side_light_microscope/'
    conf_dir = '/home/sci/blakez/korenbergNAS/3D_database/Working/Microscopic/confocal/'
    memT = ca.MEM_DEVICE

    try:
        with open(
                ss_dir +
                'src_registration/M{0}/section_{1}/M{0}_01_section_{1}_regions.txt'
                .format(mkyNum, secNum), 'r') as f:
            region_dict = json.load(f)
            f.close()
    except IOError:
        region_dict = {}
        region_dict[region] = {}
        region_dict['size'] = map(
            int,
            raw_input("What is the size of the full resolution image x,y? ").
            split(','))
        region_dict[region]['bbx'] = map(
            int,
            raw_input(
                "What are the x indicies of the bounding box (Matlab Format x_start,x_stop? "
            ).split(','))
        region_dict[region]['bby'] = map(
            int,
            raw_input(
                "What are the y indicies of the bounding box (Matlab Format y_start,y_stop? "
            ).split(','))

    if region not in region_dict:
        region_dict[region] = {}
        region_dict[region]['bbx'] = map(
            int,
            raw_input(
                "What are the x indicies of the bounding box (Matlab Format x_start,x_stop? "
            ).split(','))
        region_dict[region]['bby'] = map(
            int,
            raw_input(
                "What are the y indicies of the bounding box (Matlab Format y_start,y_stop? "
            ).split(','))

    img_region = common.LoadITKImage(
        ss_dir +
        'src_registration/M{0}/section_{1}/M{0}_01_section_{1}_{2}.tiff'.
        format(mkyNum, secNum, region), ca.MEM_HOST)
    ssiSrc = common.LoadITKImage(
        ss_dir +
        'src_registration/M{0}/section_{1}/frag0/M{0}_01_ssi_section_{1}_frag0.nrrd'
        .format(mkyNum, secNum), ca.MEM_HOST)
    bfi_df = common.LoadITKField(
        ss_dir +
        'Blockface_registered/M{0}/section_{1}/frag0/M{0}_01_ssi_section_{1}_frag0_to_bfi_real.mha'
        .format(mkyNum, secNum), ca.MEM_DEVICE)

    # Figure out the same region in the low resolution image: There is a transpose from here to matlab so dimensions are flipped
    low_sz = ssiSrc.size().tolist()
    yrng_raw = [(low_sz[1] * region_dict[region]['bbx'][0]) /
                np.float(region_dict['size'][0]),
                (low_sz[1] * region_dict[region]['bbx'][1]) /
                np.float(region_dict['size'][0])]
    xrng_raw = [(low_sz[0] * region_dict[region]['bby'][0]) /
                np.float(region_dict['size'][1]),
                (low_sz[0] * region_dict[region]['bby'][1]) /
                np.float(region_dict['size'][1])]
    yrng = [np.int(np.floor(yrng_raw[0])), np.int(np.ceil(yrng_raw[1]))]
    xrng = [np.int(np.floor(xrng_raw[0])), np.int(np.ceil(xrng_raw[1]))]
    low_sub = cc.SubVol(ssiSrc, xrng, yrng)

    # Figure out the grid for the sub region in relation to the sidescape
    originout = [
        ssiSrc.origin().x + ssiSrc.spacing().x * xrng[0],
        ssiSrc.origin().y + ssiSrc.spacing().y * yrng[0], 0
    ]
    spacingout = [
        (low_sub.size().x * ssiSrc.spacing().x) / (img_region.size().x),
        (low_sub.size().y * ssiSrc.spacing().y) / (img_region.size().y), 1
    ]

    gridout = cc.MakeGrid(img_region.size().tolist(), spacingout, originout)
    img_region.setGrid(gridout)

    only_sub = np.zeros(ssiSrc.size().tolist()[0:2])
    only_sub[xrng[0]:xrng[1], yrng[0]:yrng[1]] = np.squeeze(low_sub.asnp())
    only_sub = common.ImFromNPArr(only_sub)
    only_sub.setGrid(ssiSrc.grid())

    # Deform the only sub region to
    only_sub.toType(ca.MEM_DEVICE)
    def_sub = ca.Image3D(bfi_df.grid(), bfi_df.memType())
    cc.ApplyHReal(def_sub, only_sub, bfi_df)
    def_sub.toType(ca.MEM_HOST)

    # Now have to find the bounding box in the deformation space (bfi space)
    if 'deformation_bbx' not in region_dict[region]:
        bb_def = np.squeeze(pp.LandmarkPicker([np.squeeze(def_sub.asnp())]))
        bb_def_y = [bb_def[0][0], bb_def[1][0]]
        bb_def_x = [bb_def[0][1], bb_def[1][1]]
        region_dict[region]['deformation_bbx'] = bb_def_x
        region_dict[region]['deformation_bby'] = bb_def_y

    with open(
            ss_dir +
            'src_registration/M{0}/section_{1}/M{0}_01_section_{1}_regions.txt'
            .format(mkyNum, secNum), 'w') as f:
        json.dump(region_dict, f)
        f.close()

    # Now need to extract the region and create a deformation and image that have the same resolution as the img_region
    deform_sub = cc.SubVol(bfi_df, region_dict[region]['deformation_bbx'],
                           region_dict[region]['deformation_bby'])

    common.DebugHere()
    sizeout = [
        int(
            np.ceil((deform_sub.size().x * deform_sub.spacing().x) /
                    img_region.spacing().x)),
        int(
            np.ceil((deform_sub.size().y * deform_sub.spacing().y) /
                    img_region.spacing().y)), 1
    ]

    region_grid = cc.MakeGrid(sizeout,
                              img_region.spacing().tolist(),
                              deform_sub.origin().tolist())

    def_im_region = ca.Image3D(region_grid, deform_sub.memType())
    up_deformation = ca.Field3D(region_grid, deform_sub.memType())

    img_region.toType(ca.MEM_DEVICE)
    cc.ResampleWorld(up_deformation, deform_sub,
                     ca.BACKGROUND_STRATEGY_PARTIAL_ZERO)
    cc.ApplyHReal(def_im_region, img_region, up_deformation)

    ss_out = ss_dir + 'Blockface_registered/M{0}/section_{1}/{2}/'.format(
        mkyNum, secNum, region)

    if not pth.exists(pth.expanduser(ss_out)):
        os.mkdir(pth.expanduser(ss_out))

    common.SaveITKImage(
        def_im_region,
        pth.expanduser(ss_out) +
        'M{0}_01_section_{1}_{2}_def_to_bfi.nrrd'.format(
            mkyNum, secNum, region))
    common.SaveITKImage(
        def_im_region,
        pth.expanduser(ss_out) +
        'M{0}_01_section_{1}_{2}_def_to_bfi.tiff'.format(
            mkyNum, secNum, region))
    del img_region, def_im_region, ssiSrc, deform_sub

    # Now apply the same deformation to the confocal images
    conf_grid = cc.LoadGrid(
        conf_dir +
        'sidelight_registered/M{0}/section_{1}/{2}/affine_registration_grid.txt'
        .format(mkyNum, secNum, region))
    cf_out = conf_dir + 'blockface_registered/M{0}/section_{1}/{2}/'.format(
        mkyNum, secNum, region)
    # confocal.toType(ca.MEM_DEVICE)
    # def_conf = ca.Image3D(region_grid, deform_sub.memType())
    # cc.ApplyHReal(def_conf, confocal, up_deformation)

    for channel in range(0, 4):
        z_stack = []
        num_slices = len(
            glob.glob(conf_dir +
                      'sidelight_registered/M{0}/section_{1}/{3}/Ch{2}/*.tiff'.
                      format(mkyNum, secNum, channel, region)))
        for z in range(0, num_slices):
            src_im = common.LoadITKImage(
                conf_dir +
                'sidelight_registered/M{0}/section_{1}/{3}/Ch{2}/M{0}_01_section_{1}_LGN_RHS_Ch{2}_conf_aff_sidelight_z{4}.tiff'
                .format(mkyNum, secNum, channel, region,
                        str(z).zfill(2)))
            src_im.setGrid(
                cc.MakeGrid(
                    ca.Vec3Di(conf_grid.size().x,
                              conf_grid.size().y, 1), conf_grid.spacing(),
                    conf_grid.origin()))
            src_im.toType(ca.MEM_DEVICE)
            def_im = ca.Image3D(region_grid, ca.MEM_DEVICE)
            cc.ApplyHReal(def_im, src_im, up_deformation)
            def_im.toType(ca.MEM_HOST)
            common.SaveITKImage(
                def_im, cf_out +
                'Ch{2}/M{0}_01_section_{1}_{3}_Ch{2}_conf_def_blockface_z{4}.tiff'
                .format(mkyNum, secNum, channel, region,
                        str(z).zfill(2)))
            if z == 0:
                common.SaveITKImage(
                    def_im, cf_out +
                    'Ch{2}/M{0}_01_section_{1}_{3}_Ch{2}_conf_def_blockface_z{4}.nrrd'
                    .format(mkyNum, secNum, channel, region,
                            str(z).zfill(2)))
            z_stack.append(def_im)
            print('==> Done with Ch {0}: {1}/{2}'.format(
                channel, z, num_slices - 1))
        stacked = cc.Imlist_to_Im(z_stack)
        stacked.setSpacing(
            ca.Vec3Df(region_grid.spacing().x,
                      region_grid.spacing().y,
                      conf_grid.spacing().z))
        common.SaveITKImage(
            stacked, cf_out +
            'Ch{2}/M{0}_01_section_{1}_{3}_Ch{2}_conf_def_blockface_stack.nrrd'
            .format(mkyNum, secNum, channel, region))
        if channel == 0:
            cc.WriteGrid(
                stacked.grid(),
                cf_out + 'deformed_registration_grid.txt'.format(
                    mkyNum, secNum, region))
def MatchingImageMomenta(cf):
    """Runs matching for image momenta pair."""
    if cf.compute.useCUDA and cf.compute.gpuID is not None:
        ca.SetCUDADevice(cf.compute.gpuID)

    common.DebugHere()
    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # Output loaded config
    if cf.io.outputPrefix is not None:
        cfstr = Config.ConfigToYAML(MatchingImageMomentaConfigSpec, cf)
        with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
            f.write(cfstr)

    # mem type is determined by whether or not we're using CUDA
    mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST

    # load data in memory
    I0 = common.LoadITKImage(cf.study.I, mType)
    m0 = common.LoadITKField(cf.study.m, mType)
    J1 = common.LoadITKImage(cf.study.J, mType)
    n1 = common.LoadITKField(cf.study.n, mType)

    # get imGrid from data
    imGrid = I0.grid()

    # create time array with checkpointing info for this geodesic to be estimated
    (s, scratchInd,
     rCpinds) = CAvmHGM.HGMSetUpTimeArray(cf.optim.nTimeSteps, [1.0], 0.001)
    tDiscGeodesic = CAvmHGMCommon.HGMSetupTimeDiscretizationResidual(
        s, rCpinds, imGrid, mType)

    # create the state variable for geodesic that is going to hold all info
    p0 = ca.Field3D(imGrid, mType)
    geodesicState = CAvmHGMCommon.HGMResidualState(
        I0,
        p0,
        imGrid,
        mType,
        cf.vectormomentum.diffOpParams[0],
        cf.vectormomentum.diffOpParams[1],
        cf.vectormomentum.diffOpParams[2],
        s,
        cf.optim.NIterForInverse,
        1.0,
        cf.vectormomentum.sigmaM,
        cf.vectormomentum.sigmaI,
        cf.optim.stepSize,
        integMethod=cf.optim.integMethod)
    # initialize with zero
    ca.SetMem(geodesicState.p0, 0.0)
    # start up the memory manager for scratch variables
    ca.ThreadMemoryManager.init(imGrid, mType, 0)
    EnergyHistory = []
    # run the loop
    for it in range(cf.optim.Niter):
        # shoot the geodesic forward
        CAvmHGMCommon.HGMIntegrateGeodesic(geodesicState.p0, geodesicState.s,
                                           geodesicState.diffOp,
                                           geodesicState.p, geodesicState.rho,
                                           geodesicState.rhoinv, tDiscGeodesic,
                                           geodesicState.Ninv,
                                           geodesicState.integMethod)
        # integrate the geodesic backward
        CAvmHGMCommon.HGMIntegrateAdjointsResidual(geodesicState,
                                                   tDiscGeodesic, m0, J1, n1)

        # TODO: verify it should just be log map/simple image matching when sigmaM=\infty
        # gradient descent step for geodesic.p0
        CAvmHGMCommon.HGMTakeGradientStepResidual(geodesicState)

        # compute and print energy
        (VEnergy, IEnergy,
         MEnergy) = MatchingImageMomentaComputeEnergy(geodesicState, m0, J1,
                                                      n1)
        EnergyHistory.append(
            [VEnergy + IEnergy + MEnergy, VEnergy, IEnergy, MEnergy])
        print "Iter", it, "of", cf.optim.Niter, ":", VEnergy + IEnergy + MEnergy, '(Total) = ', VEnergy, '(Vector) + ', IEnergy, '(Image Match) + ', MEnergy, '(Momenta Match)'

        # plots
        if cf.io.plotEvery > 0 and (((it + 1) % cf.io.plotEvery == 0) or
                                    (it == cf.optim.Niter - 1)):
            MatchingImageMomentaPlots(cf,
                                      geodesicState,
                                      tDiscGeodesic,
                                      EnergyHistory,
                                      m0,
                                      J1,
                                      n1,
                                      writeOutput=True)

    # write output
    MatchingImageMomentaWriteOuput(cf, geodesicState)
def GeodesicShooting(cf):

    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # Output loaded config
    if cf.io.outputPrefix is not None:
        cfstr = Config.ConfigToYAML(GeodesicShootingConfigSpec, cf)
        with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
            f.write(cfstr)

    mType = ca.MEM_DEVICE if cf.useCUDA else ca.MEM_HOST
    #common.DebugHere()
    I0 = common.LoadITKImage(cf.study.I0, mType)
    m0 = common.LoadITKField(cf.study.m0, mType)
    grid = I0.grid()

    ca.ThreadMemoryManager.init(grid, mType, 1)
    # set up diffOp
    if mType == ca.MEM_HOST:
        diffOp = ca.FluidKernelFFTCPU()
    else:
        diffOp = ca.FluidKernelFFTGPU()
    diffOp.setAlpha(cf.diffOpParams[0])
    diffOp.setBeta(cf.diffOpParams[1])
    diffOp.setGamma(cf.diffOpParams[2])
    diffOp.setGrid(grid)

    g = ca.Field3D(grid, mType)
    ginv = ca.Field3D(grid, mType)
    mt = ca.Field3D(grid, mType)
    It = ca.Image3D(grid, mType)
    t = [
        x * 1. / cf.integration.nTimeSteps
        for x in range(cf.integration.nTimeSteps + 1)
    ]
    checkpointinds = range(1, len(t))
    checkpointstates = [(ca.Field3D(grid, mType), ca.Field3D(grid, mType))
                        for idx in checkpointinds]

    scratchV1 = ca.Field3D(grid, mType)
    scratchV2 = ca.Field3D(grid, mType)
    scratchV3 = ca.Field3D(grid, mType)
    # scale momenta to shoot
    cf.study.scaleMomenta = float(cf.study.scaleMomenta)
    if abs(cf.study.scaleMomenta) > 0.000000:
        ca.MulC_I(m0, float(cf.study.scaleMomenta))
        CAvmCommon.IntegrateGeodesic(m0,t,diffOp, mt, g, ginv,\
                                     scratchV1,scratchV2,scratchV3,\
                                     keepstates=checkpointstates,keepinds=checkpointinds,
                                     Ninv=cf.integration.NIterForInverse, integMethod = cf.integration.integMethod)
    else:
        ca.Copy(It, I0)
        ca.Copy(mt, m0)
        ca.SetToIdentity(ginv)
        ca.SetToIdentity(g)

    # write output
    if cf.io.outputPrefix is not None:
        # scale back shotmomenta before writing
        if abs(cf.study.scaleMomenta) > 0.000000:
            ca.ApplyH(It, I0, ginv)
            ca.CoAd(mt, ginv, m0)
            ca.DivC_I(mt, float(cf.study.scaleMomenta))

        common.SaveITKImage(It, cf.io.outputPrefix + "I1.mhd")
        common.SaveITKField(mt, cf.io.outputPrefix + "m1.mhd")
        common.SaveITKField(ginv, cf.io.outputPrefix + "phiinv.mhd")
        common.SaveITKField(g, cf.io.outputPrefix + "phi.mhd")
        GeodesicShootingPlots(g, ginv, I0, It, cf)
        if cf.io.saveFrames:
            SaveFrames(checkpointstates, checkpointinds, I0, It, m0, mt, cf)
Ejemplo n.º 6
0
def BuildHGM(cf):
    """Worker for running Hierarchical Geodesic Model (HGM) 
n    for group geodesic estimation on a subset of individuals. 
    Runs HGM on this subset sequentially. The variations retuned
    are summed up to get update for all individuals"""

    size = Compute.GetMPIInfo()['size']
    rank = Compute.GetMPIInfo()['rank']
    name = Compute.GetMPIInfo()['name']
    localRank = Compute.GetMPIInfo()['local_rank']
    nodename = socket.gethostname()

    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # just one reporter process on each node
    isReporter = rank == 0
    cf.study.numSubjects = len(cf.study.subjectIntercepts)
    if isReporter:
        # Output loaded config
        if cf.io.outputPrefix is not None:
            cfstr = Config.ConfigToYAML(HGMConfigSpec, cf)
            with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
                f.write(cfstr)
    #common.DebugHere()

    # if MPI check if processes are greater than number of subjects. it is okay if there are more subjects than processes

    if cf.compute.useMPI and (cf.study.numSubjects < cf.compute.numProcesses):
        raise Exception("Please don't use more processes " +
                        "than total number of individuals")

    # subdivide data, create subsets for this thread to work on
    nodeSubjectIds = cf.study.subjectIds[rank::cf.compute.numProcesses]
    nodeIntercepts = cf.study.subjectIntercepts[rank::cf.compute.numProcesses]
    nodeSlopes = cf.study.subjectSlopes[rank::cf.compute.numProcesses]
    nodeBaselineTimes = cf.study.subjectBaselineTimes[rank::cf.compute.
                                                      numProcesses]
    sys.stdout.write(
        "This is process %d of %d with name: %s on machinename: %s and local rank: %d.\nnodeIntercepts: %s\n nodeSlopes: %s\n nodeBaselineTimes: %s\n"
        % (rank, size, name, nodename, localRank, nodeIntercepts, nodeSlopes,
           nodeBaselineTimes))

    # mem type is determined by whether or not we're using CUDA
    mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST

    # load data in memory
    # load intercepts
    J = [
        common.LoadITKImage(f, mType) if isinstance(f, str) else f
        for f in nodeIntercepts
    ]

    # load slopes
    n = [
        common.LoadITKField(f, mType) if isinstance(f, str) else f
        for f in nodeSlopes
    ]

    # get imGrid from data
    imGrid = J[0].grid()

    # create time array with checkpointing info for group geodesic
    (t, Jind, gCpinds) = HGMSetUpTimeArray(cf.optim.nTimeStepsGroup,
                                           nodeBaselineTimes, 0.0000001)
    tdiscGroup = CAvmHGMCommon.HGMSetupTimeDiscretizationGroup(
        t, J, n, Jind, gCpinds, mType, nodeSubjectIds)

    # create time array with checkpointing info for residual geodesic
    (s, scratchInd, rCpinds) = HGMSetUpTimeArray(cf.optim.nTimeStepsResidual,
                                                 [1.0], 0.0000001)
    tdiscResidual = CAvmHGMCommon.HGMSetupTimeDiscretizationResidual(
        s, rCpinds, imGrid, mType)

    # create group state and residual state
    groupState = CAvmHGMCommon.HGMGroupState(
        imGrid,
        mType,
        cf.vectormomentum.diffOpParamsGroup[0],
        cf.vectormomentum.diffOpParamsGroup[1],
        cf.vectormomentum.diffOpParamsGroup[2],
        t,
        cf.optim.NIterForInverse,
        cf.vectormomentum.varIntercept,
        cf.vectormomentum.varSlope,
        cf.vectormomentum.varInterceptReg,
        cf.optim.stepSizeGroup,
        integMethod=cf.optim.integMethodGroup)

    #ca.Copy(groupState.I0, common.LoadITKImage('/usr/sci/projects/ADNI/nikhil/software/vectormomentumtest/TestData/FlowerData/Longitudinal/GroupGeodesic/I0.mhd', mType))

    # note that residual state is treated a scratch variable in this algorithm and reused for computing residual geodesics of multiple individual
    residualState = CAvmHGMCommon.HGMResidualState(
        None,
        None,
        imGrid,
        mType,
        cf.vectormomentum.diffOpParamsResidual[0],
        cf.vectormomentum.diffOpParamsResidual[1],
        cf.vectormomentum.diffOpParamsResidual[2],
        s,
        cf.optim.NIterForInverse,
        cf.vectormomentum.varIntercept,
        cf.vectormomentum.varSlope,
        cf.vectormomentum.varInterceptReg,
        cf.optim.stepSizeResidual,
        integMethod=cf.optim.integMethodResidual)

    # start up the memory manager for scratch variables
    ca.ThreadMemoryManager.init(imGrid, mType, 0)

    # need some host memory in np array format for MPI reductions
    if cf.compute.useMPI:
        mpiImageBuff = None if mType == ca.MEM_HOST else ca.Image3D(
            imGrid, ca.MEM_HOST)
        mpiFieldBuff = None if mType == ca.MEM_HOST else ca.Field3D(
            imGrid, ca.MEM_HOST)
    for i in range(len(groupState.t) - 1, -1, -1):
        if tdiscGroup[i].J is not None:
            indx_last_individual = i
            break
    '''
    # initial template image
    ca.SetMem(groupState.I0, 0.0)
    tmp = ca.ManagedImage3D(imGrid, mType)

    for tdisc in tdiscGroup:
        if tdisc.J is not None:
            ca.Copy(tmp, tdisc.J)
            groupState.I0 += tmp
    del tmp
    if cf.compute.useMPI:
        Compute.Reduce(groupState.I0, mpiImageBuff)
    
    # divide by total num subjects
    groupState.I0 /= cf.study.numSubjects
    '''

    # run the loop

    for it in range(cf.optim.Niter):
        # compute HGM variation for group
        HGMGroupVariation(groupState, tdiscGroup, residualState, tdiscResidual,
                          cf.io.outputPrefix, rank, it)
        common.CheckCUDAError("Error after HGM iteration")
        # compute gradient for momenta (m is used as scratch)
        # if there are multiple nodes we'll need to sum across processes now
        if cf.compute.useMPI:
            # do an MPI sum
            Compute.Reduce(groupState.sumSplatI, mpiImageBuff)
            Compute.Reduce(groupState.sumJac, mpiImageBuff)
            Compute.Reduce(groupState.madj, mpiFieldBuff)
            # also sum up energies of other nodes
            # intercept
            Eintercept = np.array([groupState.EnergyHistory[-1][1]])
            mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE,
                                            Eintercept,
                                            op=mpi4py.MPI.SUM)
            groupState.EnergyHistory[-1][1] = Eintercept[0]

            Eslope = np.array([groupState.EnergyHistory[-1][2]])
            mpi4py.MPI.COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE,
                                            Eslope,
                                            op=mpi4py.MPI.SUM)
            groupState.EnergyHistory[-1][2] = Eslope[0]

        ca.Copy(groupState.m, groupState.m0)
        groupState.diffOp.applyInverseOperator(groupState.m)
        ca.Sub_I(groupState.m, groupState.madj)
        #groupState.diffOp.applyOperator(groupState.m)
        # now take gradient step in momenta for group
        if cf.optim.method == 'FIXEDGD':
            # take fixed stepsize gradient step
            ca.Add_MulC_I(groupState.m0, groupState.m, -cf.optim.stepSizeGroup)
        else:
            raise Exception("Unknown optimization scheme: " + cf.optim.method)
        # end if

        # now divide to get the new base image for group
        ca.Div(groupState.I0, groupState.sumSplatI, groupState.sumJac)

        # keep track of energy in this iteration
        if isReporter and cf.io.plotEvery > 0 and ((
            (it + 1) % cf.io.plotEvery == 0) or (it == cf.optim.Niter - 1)):
            HGMPlots(cf,
                     groupState,
                     tdiscGroup,
                     residualState,
                     tdiscResidual,
                     indx_last_individual,
                     writeOutput=True)

        if isReporter:
            (VEnergy, IEnergy, SEnergy) = groupState.EnergyHistory[-1]
            print datetime.datetime.now().time(
            ), " Iter", it, "of", cf.optim.Niter, ":", VEnergy + IEnergy + SEnergy, '(Total) = ', VEnergy, '(Vector) + ', IEnergy, '(Intercept) + ', SEnergy, '(Slope)'

    # write output images and fields
    HGMWriteOutput(cf, groupState, tdiscGroup, isReporter)
Ejemplo n.º 7
0
def BuildGeoReg(cf):
    """Worker for running geodesic estimation on a subset of individuals
    """
    #common.DebugHere()
    localRank = Compute.GetMPIInfo()['local_rank']
    rank = Compute.GetMPIInfo()['rank']

    # prepare output directory
    common.Mkdir_p(os.path.dirname(cf.io.outputPrefix))

    # just one reporter process on each node
    isReporter = rank == 0

    # load filenames and times for all subjects
    (subjectsIds, subjectsImagePaths,
     subjectsTimes) = GeoRegLoadSubjectsDetails(cf.study.subjectFile)
    cf.study.numSubjects = len(subjectsIds)
    if isReporter:
        # Output loaded config
        if cf.io.outputPrefix is not None:
            cfstr = Config.ConfigToYAML(GeoRegConfigSpec, cf)
            with open(cf.io.outputPrefix + "parsedconfig.yaml", "w") as f:
                f.write(cfstr)

    # if MPI check if processes are greater than number of subjects. it is okay if there are more subjects than processes
    if cf.compute.useMPI and (len(subjectsIds) < cf.compute.numProcesses):
        raise Exception("Please don't use more processes " +
                        "than total number of individuals")

    nodeSubjectsIds = subjectsIds[rank::cf.compute.numProcesses]
    nodeSubjectsImagePaths = subjectsImagePaths[rank::cf.compute.numProcesses]
    nodeSubjectsTimes = subjectsTimes[rank::cf.compute.numProcesses]

    numLocalSubjects = len(nodeSubjectsImagePaths)
    if cf.study.initializationsFile is not None:
        (subjectsInitialImages,
         subjectsInitialMomenta) = GeoRegLoadSubjectsInitializations(
             cf.study.initializationsFile)
        nodeSubjectsInitialImages = subjectsInitialImages[rank::cf.compute.
                                                          numProcesses]
        nodeSubjectsInitialMomenta = subjectsInitialMomenta[rank::cf.compute.
                                                            numProcesses]

    print 'rank:', rank, ', localRank:', localRank, ', numberSubjects/TotalSubjects:', len(
        nodeSubjectsImagePaths
    ), '/', cf.study.numSubjects, ', nodeSubjectsImagePaths:', nodeSubjectsImagePaths, ', nodeSubjectsTimes:', nodeSubjectsTimes

    # mem type is determined by whether or not we're using CUDA
    mType = ca.MEM_DEVICE if cf.compute.useCUDA else ca.MEM_HOST

    # setting gpuid should be handled in gpu
    # if using GPU  set device based on local rank
    #if cf.compute.useCUDA:
    #    ca.SetCUDADevice(localRank)

    # get image size information
    dummyImToGetGridInfo = common.LoadITKImage(nodeSubjectsImagePaths[0][0],
                                               mType)
    imGrid = dummyImToGetGridInfo.grid()
    if cf.study.setUnitSpacing:
        imGrid.setSpacing(ca.Vec3Df(1.0, 1.0, 1.0))
    if cf.study.setZeroOrigin:
        imGrid.setOrigin(ca.Vec3Df(0, 0, 0))
    #del dummyImToGetGridInfo;

    # start up the memory manager for scratch variables
    ca.ThreadMemoryManager.init(imGrid, mType, 0)

    # allocate memory
    p = GeoRegVariables(imGrid,
                        mType,
                        cf.vectormomentum.diffOpParams[0],
                        cf.vectormomentum.diffOpParams[1],
                        cf.vectormomentum.diffOpParams[2],
                        cf.optim.NIterForInverse,
                        cf.vectormomentum.sigma,
                        cf.optim.stepSize,
                        integMethod=cf.optim.integMethod)
    # for each individual run geodesic regression for each subject
    for i in range(numLocalSubjects):

        # initializations for this subject
        if cf.study.initializationsFile is not None:
            # assuming the initializations are already preprocessed, in terms of intensities, origin and voxel scalings.
            p.I0 = common.LoadITKImage(nodeSubjectsInitialImages[i], mType)
            p.m0 = common.LoadITKField(nodeSubjectsInitialMomenta[i], mType)
        else:
            ca.SetMem(p.m0, 0.0)
            ca.SetMem(p.I0, 0.0)

        # allocate memory specific to this subject in steps a, b and c
        # a. create time array with checkpointing info for regression geodesic, allocate checkpoint memory
        (t, msmtinds, cpinds) = GeoRegSetUpTimeArray(cf.optim.nTimeSteps,
                                                     nodeSubjectsTimes[i],
                                                     0.001)
        cpstates = [(ca.Field3D(imGrid, mType), ca.Field3D(imGrid, mType))
                    for idx in cpinds]
        # b. allocate gradAtMeasurements of the length of msmtindex for storing residuals
        gradAtMsmts = [ca.Image3D(imGrid, mType) for idx in msmtinds]

        # c. load timepoint images for this subject
        Imsmts = [
            common.LoadITKImage(f, mType) if isinstance(f, str) else f
            for f in nodeSubjectsImagePaths[i]
        ]
        # reset stepsize if adaptive stepsize changed it inside
        p.stepSize = cf.optim.stepSize
        # preprocessimages
        GeoRegPreprocessInput(nodeSubjectsIds[i], cf, p, t, Imsmts, cpinds,
                              cpstates, msmtinds, gradAtMsmts)

        # run regression for this subject
        # REMEMBER
        # msmtinds index into cpinds
        # gradAtMsmts is parallel to msmtinds
        # cpinds index into t
        EnergyHistory = RunGeoReg(nodeSubjectsIds[i], cf, p, t, Imsmts, cpinds,
                                  cpstates, msmtinds, gradAtMsmts)

        # write output images and fields for this subject
        # TODO: BEWARE There are hardcoded numbers inside preprocessing code specific for ADNI/OASIS brain data.
        GeoRegWriteOuput(nodeSubjectsIds[i], cf, p, t, Imsmts, cpinds,
                         cpstates, msmtinds, gradAtMsmts, EnergyHistory)

        # clean up memory specific to this subject
        del t, Imsmts, cpinds, cpstates, msmtinds, gradAtMsmts