Ejemplo n.º 1
0
def common_processing(caselist):
    imgs, masks = read_caselist(caselist)

    res = []
    pool = multiprocessing.Pool(N_proc)
    for imgPath, maskPath in zip(imgs, masks):
        res.append(
            pool.apply_async(func=preprocessing, args=(imgPath, maskPath)))

    attributes = [r.get() for r in res]

    pool.close()
    pool.join()

    f = open(caselist + '.modified', 'w')
    for i in range(len(imgs)):
        imgs[i] = attributes[i][0]
        masks[i] = attributes[i][1]
        f.write(f'{imgs[i]},{masks[i]}\n')
    f.close()

    # the following imgs, masks is for diagnosing MemoryError i.e. computing rish w/o preprocessing
    # to diagnose, comment all the above and uncomment the following
    # imgs, masks = read_caselist(caselist+'.modified')

    # experimentally found ncpu=4 to be memroy optimal
    pool = multiprocessing.Pool(4)
    for imgPath, maskPath in zip(imgs, masks):
        pool.apply_async(func=dti_harm, args=(imgPath, maskPath))

    pool.close()
    pool.join()

    return (imgs, masks)
Ejemplo n.º 2
0
def sub2tmp2mni(templatePath,
                siteName,
                caselist,
                ref=False,
                tar_unproc=False,
                tar_harm=False):

    # obtain the transform
    moving = pjoin(templatePath, f'Mean_{siteName}_FA_b{bshell_b}.nii.gz')

    outPrefix = pjoin(templatePath, f'TemplateToMNI_{siteName}')
    warp2mni = outPrefix + '1Warp.nii.gz'
    trans2mni = outPrefix + '0GenericAffine.mat'

    # check existence of transforms created with _b{bmax}
    if not exists(warp2mni):
        antsReg(mniTmp, None, moving, outPrefix)

    imgs, _ = read_caselist(caselist)

    pool = multiprocessing.Pool(N_proc)
    for imgPath in imgs:

        if ref:
            pool.apply_async(func=register_reference,
                             args=(
                                 imgPath,
                                 warp2mni,
                                 trans2mni,
                                 templatePath,
                             ))
        elif tar_unproc:
            pool.apply_async(func=register_target,
                             args=(
                                 imgPath,
                                 templatePath,
                             ))
        elif tar_harm:
            pool.apply_async(func=register_harmonized,
                             args=(
                                 imgPath,
                                 warp2mni,
                                 trans2mni,
                                 templatePath,
                                 siteName,
                             ))

    pool.close()
    pool.join()
Ejemplo n.º 3
0
def common_processing(caselist):

    imgs, masks = read_caselist(caselist)

    # compute dti_harm of unprocessed data
    pool = multiprocessing.Pool(N_proc)
    for imgPath, maskPath in zip(imgs, masks):
        pool.apply_async(func=dti_harm, args=(imgPath, maskPath))
    pool.close()
    pool.join()

    try:
        copyfile(caselist, caselist + '.modified')
    except SameFileError:
        pass

    # data is not manipulated in multi-shell-dMRIharmonization i.e. bvalMapped, resampled, nor denoised
    # this block may be uncommented in a future design
    # preprocess data
    # res=[]
    # pool = multiprocessing.Pool(N_proc)
    # for imgPath,maskPath in zip(imgs,masks):
    #     res.append(pool.apply_async(func= preprocessing, args= (imgPath,maskPath)))
    #
    # attributes= [r.get() for r in res]
    #
    # pool.close()
    # pool.join()
    #
    # f = open(caselist + '.modified', 'w')
    # for i in range(len(imgs)):
    #     imgs[i] = attributes[i][0]
    #     masks[i] = attributes[i][1]
    # f.close()
    #
    #
    # # compute dti_harm of preprocessed data
    # pool = multiprocessing.Pool(N_proc)
    # for imgPath,maskPath in zip(imgs,masks):
    #     pool.apply_async(func= dti_harm, args= (imgPath,maskPath))
    # pool.close()
    # pool.join()
    #
    #
    # if debug:
    #     #TODO compute dti_harm for all intermediate data _denoised, _denoised_bmapped, _bmapped
    #     pass

    return (imgs, masks)
    def harmonizeData(self):

        from reconstSignal import reconst

        # check the templatePath
        if not exists(self.templatePath):
            raise NotADirectoryError(f'{self.templatePath} does not exist')
        else:
            if not os.listdir(self.templatePath):
                raise ValueError(f'{self.templatePath} is empty')

        # go through each file listed in csv, check their existence, create dti and harm directories
        check_csv(self.target_csv, self.force)

        # target data is not manipulated in multi-shell-dMRIharmonization i.e. bvalMapped, resampled, nor denoised
        # this block may be uncommented in a future design
        # from preprocess import dti_harm
        # if self.debug:
        #     # calcuate diffusion measures of target site before any processing so we are able to compare
        #     # with the ones after harmonization
        #     imgs, masks= read_caselist(self.tar_unproc_csv)
        #     pool = multiprocessing.Pool(self.N_proc)
        #     for imgPath, maskPath in zip(imgs, masks):
        #         imgPath= convertedPath(imgPath)
        #         maskPath= convertedPath(maskPath)
        #         pool.apply_async(func= dti_harm, args= ((imgPath, maskPath, )))
        #
        #     pool.close()
        #     pool.join()

        # reconstSignal steps ------------------------------------------------------------------------------------------

        # read target image list
        moving = pjoin(self.templatePath,
                       f'Mean_{self.target}_FA_b{self.bshell_b}.nii.gz')
        imgs, masks = read_caselist(self.target_csv)

        preFlag = 1  # omit preprocessing of target data again
        if self.target_csv.endswith('.modified'):
            preFlag = 0
        else:
            # this file will be used later for debugging
            self.target_csv += '.modified'
            fm = open(self.target_csv, 'w')

        self.harm_csv = self.target_csv + '.harmonized'
        fh = open(self.harm_csv, 'w')
        pool = multiprocessing.Pool(self.N_proc)
        res = []
        for imgPath, maskPath in zip(imgs, masks):
            res.append(
                pool.apply_async(func=reconst,
                                 args=(
                                     imgPath,
                                     maskPath,
                                     moving,
                                     self.templatePath,
                                     preFlag,
                                 )))

        for r in res:
            imgPath, maskPath, harmImg, harmMask = r.get()

            if preFlag:
                fm.write(imgPath + ',' + maskPath + '\n')
            fh.write(harmImg + ',' + harmMask + '\n')

        pool.close()
        pool.join()

        # loop for debugging
        # res= []
        # for imgPath, maskPath in zip(imgs, masks):
        #     res.append(reconst(imgPath, maskPath, moving, self.templatePath, preFlag))
        #
        # for r in res:
        #     imgPath, maskPath, harmImg, harmMask= r
        #
        #     if preFlag:
        #         fm.write(imgPath + ',' + maskPath + '\n')
        #     fh.write(harmImg + ',' + harmMask + '\n')

        if preFlag:
            fm.close()
        fh.close()
        print('\n\nHarmonization completed\n\n')
    def main(self):

        self.templatePath= abspath(self.templatePath)
        self.N_shm= int(self.N_shm)
        self.N_proc= int(self.N_proc)
        if self.N_proc==-1:
            self.N_proc= N_CPU

    
        if self.ref_csv:
            self.ref_unproc_csv= self.ref_csv.strip('.modified')
        self.tar_unproc_csv= self.target_csv.strip('.modified')


        # check appropriateness of N_shm
        if self.N_shm!=-1 and (self.N_shm<2 or self.N_shm>8):
            raise ValueError('2<= --nshm <=8')



        # determine N_shm in default mode during template creation
        if self.N_shm==-1 and self.create:
            if self.ref_csv:
                ref_nshm_img = read_caselist(self.ref_csv)[0][0]
            elif self.target_csv:
                ref_nshm_img = read_caselist(self.target_csv)[0][0]

            directory= dirname(ref_nshm_img)
            prefix= basename(ref_nshm_img).split('.nii')[0]
            bvalFile= pjoin(directory, prefix+'.bval')
            self.N_shm, _= determineNshm(bvalFile)


        # automatic determination of N_shm during data harmonization is limited by N_shm used during template creation
        # Scale_L{i}.nii.gz of <= {N_shm during template creation} are present only
        elif self.N_shm==-1 and self.process:
            for i in range(0,8,2):
                if isfile(pjoin(self.templatePath, f'Scale_L{i}_b{self.bshell_b}.nii.gz')):
                    self.N_shm= i
                else:
                    break


        # verify validity of provided/determined N_shm for all subjects
        if self.ref_csv:
            verifyNshmForAll(self.ref_csv, self.N_shm)
        if self.target_csv:
            verifyNshmForAll(self.target_csv, self.N_shm)


        # write config file to temporary directory
        configFile= f'/tmp/harm_config_{getpid()}.ini'
        with open(configFile,'w') as f:
            f.write('[DEFAULT]\n')
            f.write(f'N_shm = {self.N_shm}\n')
            f.write(f'N_proc = {self.N_proc}\n')
            f.write(f'N_zero = {self.N_zero}\n')
            f.write(f'resample = {self.resample if self.resample else 0}\n')
            f.write(f'bvalMap = {self.bvalMap if self.bvalMap else 0}\n')
            f.write(f'bshell_b = {self.bshell_b}\n')
            f.write(f'denoise = {1 if self.denoise else 0}\n')
            f.write(f'travelHeads = {1 if self.travelHeads else 0}\n')
            f.write(f'debug = {1 if self.debug else 0}\n')
            f.write(f'force = {1 if self.force else 0}\n')
            f.write(f'verbose = {1 if self.verbose else 0}\n')
            f.write('diffusionMeasures = {}\n'.format((',').join(self.diffusionMeasures)))


        self.sanityCheck()

        if self.create:
            self.createTemplate()
            import fileinput
            for line in fileinput.input(configFile, inplace=True):
                if 'force' in line:
                    print('force = 0')
                else:
                    print(line)
            self.force= False

        if self.process:
            self.harmonizeData()

        if self.create and self.process and self.debug:
            self.post_debug()


        remove(configFile)
    def harmonizeData(self):

        from reconstSignal import reconst, approx
        from preprocess import dti_harm, preprocessing, common_processing

        # check the templatePath
        if not exists(self.templatePath):
            raise NotADirectoryError(f'{self.templatePath} does not exist')
        else:
            if not listdir(self.templatePath):
                raise ValueError(f'{self.templatePath} is empty')

        
        # fit spherical harmonics on reference site
        if self.debug and self.ref_csv:
            check_csv(self.ref_unproc_csv, self.force)
            refImgs, refMasks= read_caselist(self.ref_unproc_csv)

            # reference data is not manipulated in multi-shell-dMRIharmonization i.e. bvalMapped, resampled, nor denoised
            # this block may be uncommented in a future design
            # res= []
            # pool = multiprocessing.Pool(self.N_proc)
            # for imgPath, maskPath in zip(refImgs, refMasks):
            #     res.append(pool.apply_async(func=preprocessing, args=(imgPath, maskPath)))
            #
            # attributes = [r.get() for r in res]
            #
            # pool.close()
            # pool.join()
            #
            # for i in range(len(refImgs)):
            #     refImgs[i] = attributes[i][0]
            #     refMasks[i] = attributes[i][1]

            pool = multiprocessing.Pool(self.N_proc)
            for imgPath, maskPath in zip(refImgs, refMasks):
                pool.apply_async(func= approx, args=(imgPath,maskPath,))

            pool.close()
            pool.join()



        # go through each file listed in csv, check their existence, create dti and harm directories
        check_csv(self.target_csv, self.force)
        targetImgs, targetMasks= common_processing(self.tar_unproc_csv)

        # reconstSignal steps ------------------------------------------------------------------------------------------

        # read target image list
        moving= pjoin(self.templatePath, f'Mean_{self.target}_FA_b{self.bshell_b}.nii.gz')

        if not self.target_csv.endswith('.modified'):
            self.target_csv += '.modified'


        self.harm_csv= self.target_csv+'.harmonized'
        fh= open(self.harm_csv, 'w')
        pool = multiprocessing.Pool(self.N_proc)
        res= []
        for imgPath, maskPath in zip(targetImgs, targetMasks):
            res.append(pool.apply_async(func= reconst, args= (imgPath, maskPath, moving, self.templatePath,)))

        for r in res:
            harmImg, harmMask= r.get()
            fh.write(harmImg + ',' + harmMask + '\n')


        pool.close()
        pool.join()

       
        # loop for debugging
        # res= []
        # for imgPath, maskPath in zip(imgs, masks):
        #     res.append(reconst(imgPath, maskPath, moving, self.templatePath))
        #
        # for r in res:
        #     harmImg, harmMask= r
        #     fh.write(harmImg + ',' + harmMask + '\n')


        fh.close()

        
        if self.debug:
            harmImgs, harmMasks= read_caselist(self.harm_csv)
            pool = multiprocessing.Pool(self.N_proc)
            for imgPath,maskPath in zip(harmImgs,harmMasks):
                pool.apply_async(func= dti_harm, args= (imgPath,maskPath,))
            pool.close()
            pool.join()


        print('\n\nHarmonization completed\n\n')
        pjoin(templatePath,
              basename(maskPath).split('.nii')[0] + 'Warped.nii.gz'))
    '''
    # warping the rish features
    for i in range(0, N_shm+1, 2):
        applyXform(pjoin(directory, 'harm', f'{prefix}_L{i}.nii.gz'),
           pjoin(templatePath, 'template0.nii.gz'),
           warp, trans,
           pjoin(templatePath, f'{prefix}_WarpedL{i}.nii.gz'))


    # warping the diffusion measures
    for dm in diffusionMeasures:
        applyXform(pjoin(directory, 'dti', f'{prefix}_{dm}.nii.gz'),
                   pjoin(templatePath, 'template0.nii.gz'),
                   warp, trans,
                   pjoin(templatePath, f'{prefix}_Warped{dm}.nii.gz'))
    
    '''


if __name__ == '__main__':

    templatePath = '/data/pnl/HarmonizationProject/abcd/site21/site21_cluster/retest_multi/template_April4'

    img_list = '/data/pnl/HarmonizationProject/abcd/site21/site21_cluster/retest_multi/target_b1000.csv.modified'
    imgs, masks = read_caselist(img_list)

    for imgPath, maskPath in zip(imgs, masks):
        warp_bands(imgPath, maskPath, templatePath)