示例#1
0
文件: mmrsim.py 项目: devhliu/NIPET
def simulate_recon(measured_sino,
                   ctim,
                   scanner_params,
                   simulate_3d=False,
                   nitr=60,
                   slice_idx=-1,
                   randoms=None,
                   scatter=None,
                   mu_input=False,
                   msk_radius=29.):
    ''' Reconstruct PET image from simulated input data using the EM-ML algorithm.
        Arguments:
        measured_sino -- simulated emission data with photon attenuation
        ctim -- either a 2D CT image or a 3D CT image from which a 2D slice is chosen (slice_idx) for estimation
            of the attenuation factors
        slice_idx -- index to extract one 2D slice for this simulation if input image is 3D
        nitr -- number of iterations used for the EM-ML reconstruction algorithm
        scanner_params -- scanner parameters containing scanner constants and
            axial and transaxial look up tables (LUTs)
        randoms[=None] -- possibility of using randoms and scatter events in the simulation  
    '''

    #> decompose the scanner constants and LUTs for easier access
    Cnt = scanner_params['Cnt']
    txLUT = scanner_params['txLUT']
    axLUT = scanner_params['axLUT']

    if simulate_3d:

        if ctim.ndim!=3 \
                or ctim.shape!=(Cnt['SO_IMZ'], Cnt['SO_IMY'], Cnt['SO_IMX']):
            raise ValueError(
                'The CT/mu-map image does not match the scanner image shape.')
    else:

        #> 2D case with reduced rings
        if len(ctim.shape) == 3:

            # make sure that the shape of the input image matches the image size of the scanner
            if ctim.shape[1:] != (Cnt['SO_IMY'], Cnt['SO_IMX']):
                raise ValueError(
                    'The input image shape for x and y does not match the scanner image size.'
                )

            # pick the right slice index (slice_idx) if not given or mistaken
            if slice_idx < 0:
                print 'w> the axial index <slice_idx> is chosen to be in the middle of axial FOV.'
                slice_idx = ctim.shape[0] / 2
            if slice_idx >= ctim.shape[0]:
                raise ValueError(
                    'The axial index for 2D slice selection is outside the image.'
                )

        elif len(ctim.shape) == 2:

            # make sure that the shape of the input image matches the image size of the scanner
            if ctim.shape != (Cnt['SO_IMY'], Cnt['SO_IMX']):
                raise ValueError(
                    'The input image shape for x and y does not match the scanner image size.'
                )

            ctim.shape = (1, ) + ctim.shape
            slice_idx = 0

        if not 'rSZ_IMZ' in Cnt:
            raise ValueError('Missing reduced axial FOV parameters.')

    #--------------------
    if mu_input:
        mui = ctim
    else:
        #> get the mu-map [1/cm] from CT [HU]
        mui = nimpa.ct2mu(ctim)

    #> get rid of negative values
    mui[mui < 0] = 0
    #--------------------

    if simulate_3d:

        rmu = mui

        #> number of axial sinograms
        nsinos = Cnt['NSN11']

    else:
        #--------------------
        #> create a number of slides of the same chosen image slice for reduced (fast) 3D simulation
        rmu = mui[slice_idx, :, :]
        rmu.shape = (1, ) + rmu.shape
        rmu = np.repeat(rmu, Cnt['rSZ_IMZ'], axis=0)
        #--------------------

        #> number of axial sinograms
        nsinos = Cnt['rNSN1']

    # import pdb; pdb.set_trace()

    #> attenuation factor sinogram
    attsino = mmrprj.frwd_prj(rmu,
                              scanner_params,
                              attenuation=True,
                              dev_out=True)

    nrmsino = np.ones(attsino.shape, dtype=np.float32)

    #> randoms and scatter put together
    if isinstance(randoms,
                  np.ndarray) and measured_sino.shape == randoms.shape:
        rsng = mmraux.remgaps(randoms, txLUT, Cnt)
    else:
        rsng = 1e-5 * np.ones((Cnt['Naw'], nsinos), dtype=np.float32)

    if isinstance(scatter,
                  np.ndarray) and measured_sino.shape == scatter.shape:
        ssng = mmraux.remgaps(scatter, txLUT, Cnt)
    else:
        ssng = 1e-5 * np.ones((Cnt['Naw'], nsinos), dtype=np.float32)

    if simulate_3d:

        if Cnt['VERBOSE']:
            print '\n>------ OSEM (', nitr, ') -------\n'

        # measured sinogram in GPU-enabled shape
        psng = mmraux.remgaps(measured_sino.astype(np.uint16), txLUT, Cnt)

        #> mask for reconstructed image.  anything outside it is set to zero
        msk = mmrimg.get_cylinder(
            Cnt, rad=msk_radius, xo=0, yo=0, unival=1, gpu_dim=True) > 0.9

        #> init image
        eimg = np.ones((Cnt['SZ_IMY'], Cnt['SZ_IMX'], Cnt['SZ_IMZ']),
                       dtype=np.float32)

        #------------------------------------
        Sn = 14  # number of subsets
        #-get one subset to get number of projection bins in a subset
        Sprj, s = mmrrec.get_subsets14(0, scanner_params)
        Nprj = len(Sprj)

        #> init subset array and sensitivity image for a given subset
        sinoTIdx = np.zeros((Sn, Nprj + 1), dtype=np.int32)

        #> init sensitivity images for each subset
        sim = np.zeros((Sn, Cnt['SZ_IMY'], Cnt['SZ_IMX'], Cnt['SZ_IMZ']),
                       dtype=np.float32)

        for n in range(Sn):
            sinoTIdx[
                n, 0] = Nprj  #first number of projection for the given subset
            sinoTIdx[n, 1:], s = mmrrec.get_subsets14(n, scanner_params)
            #> sensitivity image
            petprj.bprj(sim[n, :, :, :], attsino[sinoTIdx[n, 1:], :], txLUT,
                        axLUT, sinoTIdx[n, 1:], Cnt)
        #-------------------------------------

        for k in trange(nitr, disable=not Cnt['VERBOSE'], desc="OSEM"):
            petprj.osem(eimg, msk, psng, rsng, ssng, nrmsino, attsino, sim,
                        txLUT, axLUT, sinoTIdx, Cnt)

        eim = mmrimg.convert2e7(eimg, Cnt)

    else:

        #> estimated image, initialised to ones
        eim = np.ones(rmu.shape, dtype=np.float32)

        msk = mmrimg.get_cylinder(
            Cnt, rad=msk_radius, xo=0, yo=0, unival=1, gpu_dim=False) > 0.9

        #> sensitivity image for the EM-ML reconstruction
        sim = mmrprj.back_prj(attsino, scanner_params)

        for i in range(nitr):
            if Cnt['VERBOSE']: print '>---- EM iteration:', i
            #> remove gaps from the measured sinogram
            #> then forward project the estimated image
            #> after which divide the measured sinogram by the estimated sinogram (forward projected)
            crrsino = mmraux.remgaps(measured_sino, txLUT, Cnt) / \
                        (mmrprj.frwd_prj(eim, scanner_params, dev_out=True) + rndsct)

            #> back project the correction factors sinogram
            bim = mmrprj.back_prj(crrsino, scanner_params)

            #> divide the back-projected image by the sensitivity image
            bim[msk] /= sim[msk]
            bim[~msk] = 0

            #> update the estimated image and remove NaNs
            eim *= msk * bim
            eim[np.isnan(eim)] = 0

    return eim
示例#2
0
def back_prj(sino, scanner_params, isub=np.array([-1], dtype=np.int32)):
    ''' Calculate forward projection for the provided input image.
        Arguments:
        sino -- input emission sinogram to be back projected to the image space.
        scanner_params -- dictionary of all scanner parameters, containing scanner constants,
            transaxial and axial look up tables (LUT).
        isub -- array of transaxial indices of all sinograms (angles x bins) used for subsets;
            when the first element is negative, all transaxial bins are used (as in pure EM-ML).
    '''

    # Get particular scanner parameters: Constants, transaxial and axial LUTs
    Cnt = scanner_params['Cnt']
    txLUT = scanner_params['txLUT']
    axLUT = scanner_params['axLUT']

    if Cnt['SPN'] == 1:
        # number of rings calculated for the given ring range (optionally we can use only part of the axial FOV)
        NRNG_c = Cnt['RNG_END'] - Cnt['RNG_STRT']
        # number of sinos in span-1
        nsinos = NRNG_c**2
        # correct for the max. ring difference in the full axial extent (don't use ring range (1,63) as for this case no correction)
        if NRNG_c == 64:
            nsinos -= 12
    elif Cnt['SPN'] == 11:
        nsinos = Cnt['NSN11']
    elif Cnt['SPN'] == 0:
        nsinos = Cnt['NSEG0']

    #> check first the Siemens default sinogram;
    #> for this default shape only full sinograms are expected--no subsets.
    if len(sino.shape) == 3:
        if sino.shape[0] != nsinos or sino.shape[1] != Cnt[
                'NSANGLES'] or sino.shape[2] != Cnt['NSBINS']:
            raise ValueError(
                'Unexpected sinogram array dimensions/shape for Siemens defaults.'
            )
        sinog = mmraux.remgaps(sino, txLUT, Cnt)

    elif len(sino.shape) == 2:
        if isub[0] < 0 and sino.shape[0] != txLUT["Naw"]:
            raise ValueError(
                'Unexpected number of transaxial elements in the full sinogram.'
            )
        elif isub[0] >= 0 and sino.shape[0] != len(isub):
            raise ValueError(
                'Unexpected number of transaxial elements in the subset sinogram.'
            )
        #> check if the number of sinograms is correct
        if sino.shape[1] != nsinos:
            raise ValueError('Inconsistent number of sinograms in the array.')
        #> when found the dimensions/shape are fine:
        sinog = sino
    else:
        raise ValueError('Unexpected shape of the input sinogram.')

    #predefine the output image depending on the number of rings used
    if Cnt['SPN'] == 1 and 'rSZ_IMZ' in Cnt:
        nvz = Cnt['rSZ_IMZ']
    else:
        nvz = Cnt['SZ_IMZ']
    bimg = np.zeros((Cnt['SZ_IMX'], Cnt['SZ_IMY'], nvz), dtype=np.float32)

    #> run back-projection
    petprj.bprj(bimg, sinog, txLUT, axLUT, isub, Cnt)

    #> change from GPU optimised image dimensions to the standard Siemens shape
    bimg = mmrimg.convert2e7(bimg, Cnt)

    return bimg
示例#3
0
def osemone(datain,
            mumaps,
            hst,
            scanner_params,
            recmod=3,
            itr=4,
            fwhm=0.,
            mask_radius=29.,
            sctsino=np.array([]),
            outpath='',
            store_img=False,
            frmno='',
            fcomment='',
            store_itr=[],
            emmskS=False,
            ret_sinos=False,
            attnsino=None,
            randsino=None,
            normcomp=None):

    #---------- sort out OUTPUT ------------
    #-output file name for the reconstructed image, initially assume n/a
    fout = 'n/a'
    if store_img or store_itr:
        if outpath == '':
            opth = os.path.join(datain['corepath'], 'reconstructed')
        else:
            opth = outpath
        mmraux.create_dir(opth)

    if ret_sinos:
        return_ssrb = True
        return_mask = True
    else:
        return_ssrb = False
        return_mask = False

    #----------

    # Get particular scanner parameters: Constants, transaxial and axial LUTs
    Cnt = scanner_params['Cnt']
    txLUT = scanner_params['txLUT']
    axLUT = scanner_params['axLUT']

    import time
    from niftypet import nipet
    # from niftypet.nipet.sct import mmrsct
    # from niftypet.nipet.prj import mmrhist

    if Cnt['VERBOSE']: print 'i> reconstruction in mode', recmod

    # get object and hardware mu-maps
    muh, muo = mumaps

    # get the GPU version of the image dims
    mus = mmrimg.convert2dev(muo + muh, Cnt)

    if Cnt['SPN'] == 1:
        snno = Cnt['NSN1']
    elif Cnt['SPN'] == 11:
        snno = Cnt['NSN11']

    # remove gaps from the prompt sino
    psng = mmraux.remgaps(hst['psino'], txLUT, Cnt)

    #=========================================================================
    # GET NORM
    #-------------------------------------------------------------------------
    if normcomp == None:
        ncmp, _ = mmrnorm.get_components(datain, Cnt)
    else:
        ncmp = normcomp
        print 'w> using user-defined normalisation components'
    nsng = mmrnorm.get_sinog(datain, hst, axLUT, txLUT, Cnt, normcomp=ncmp)
    #=========================================================================

    #=========================================================================
    # ATTENUATION FACTORS FOR COMBINED OBJECT AND BED MU-MAP
    #-------------------------------------------------------------------------
    #> combine attenuation and norm together depending on reconstruction mode
    if recmod == 0:
        asng = np.ones(psng.shape, dtype=np.float32)
    else:
        #> check if the attenuation sino is given as an array
        if isinstance(attnsino, np.ndarray) \
                and attnsino.shape==(Cnt['NSN11'], Cnt['NSANGLES'], Cnt['NSBINS']):
            asng = mmraux.remgaps(attnsino, txLUT, Cnt)
            print 'i> using provided attenuation factor sinogram'
        elif isinstance(attnsino, np.ndarray) \
                and attnsino.shape==(Cnt['Naw'], Cnt['NSN11']):
            asng = attnsino
            print 'i> using provided attenuation factor sinogram'
        else:
            asng = np.zeros(psng.shape, dtype=np.float32)
            petprj.fprj(asng, mus, txLUT, axLUT,
                        np.array([-1], dtype=np.int32), Cnt, 1)
    #> combine attenuation and normalisation
    ansng = asng * nsng
    #=========================================================================

    #=========================================================================
    # Randoms
    #-------------------------------------------------------------------------
    if isinstance(randsino, np.ndarray):
        rsino = randsino
        rsng = mmraux.remgaps(randsino, txLUT, Cnt)
    else:
        rsino, snglmap = nipet.randoms(hst, scanner_params)
        rsng = mmraux.remgaps(rsino, txLUT, Cnt)
    #=========================================================================

    #=========================================================================
    # SCAT
    #-------------------------------------------------------------------------
    if recmod == 2:
        if sctsino.size > 0:
            ssng = mmraux.remgaps(sctsino, txLUT, Cnt)
        elif sctsino.size == 0 and os.path.isfile(datain['em_crr']):
            emd = nimpa.getnii(datain['em_crr'])
            ssn = nipet.vsm(datain,
                            mumaps,
                            emd['im'],
                            hst,
                            rsino,
                            scanner_params,
                            prcnt_scl=0.1,
                            emmsk=False)
            ssng = mmraux.remgaps(ssn, txLUT, Cnt)
        else:
            print 'e> no emission image available for scatter estimation!  check if it' 's present or the path is correct.'
            sys.exit()
    else:
        ssng = np.zeros(rsng.shape, dtype=rsng.dtype)
    #=========================================================================

    if Cnt['VERBOSE']:
        print '\n>------ OSEM (', itr, ') -------\n'
    #------------------------------------
    Sn = 14  # number of subsets
    #-get one subset to get number of projection bins in a subset
    Sprj, s = get_subsets14(0, scanner_params)
    Nprj = len(Sprj)
    #-init subset array and sensitivity image for a given subset
    sinoTIdx = np.zeros((Sn, Nprj + 1), dtype=np.int32)
    #-init sensitivity images for each subset
    imgsens = np.zeros((Sn, Cnt['SZ_IMY'], Cnt['SZ_IMX'], Cnt['SZ_IMZ']),
                       dtype=np.float32)
    for n in range(Sn):
        sinoTIdx[n, 0] = Nprj  #first number of projection for the given subset
        sinoTIdx[n, 1:], s = get_subsets14(n, scanner_params)
        # sensitivity image
        petprj.bprj(imgsens[n, :, :, :], ansng[sinoTIdx[n, 1:], :], txLUT,
                    axLUT, sinoTIdx[n, 1:], Cnt)
    #-------------------------------------

    #-mask for reconstructed image.  anything outside it is set to zero
    msk = mmrimg.get_cylinder(
        Cnt, rad=mask_radius, xo=0, yo=0, unival=1, gpu_dim=True) > 0.9

    #-init image
    img = np.ones((Cnt['SZ_IMY'], Cnt['SZ_IMX'], Cnt['SZ_IMZ']),
                  dtype=np.float32)

    #-decay correction
    lmbd = np.log(2) / resources.riLUT[Cnt['ISOTOPE']]['thalf']
    if Cnt['DCYCRR'] and 't0' in hst and 'dur' in hst:
        dcycrr = np.exp(lmbd * hst['t0']) * lmbd * hst['dur'] / (
            1 - np.exp(-lmbd * hst['dur']))
        # apply quantitative correction to the image
        qf = ncmp['qf'] / resources.riLUT[Cnt['ISOTOPE']]['BF'] / float(
            hst['dur'])
        qf_loc = ncmp['qf_loc']
    elif not Cnt['DCYCRR'] and 't0' in hst and 'dur' in hst:
        dcycrr = 1.
        # apply quantitative correction to the image
        qf = ncmp['qf'] / resources.riLUT[Cnt['ISOTOPE']]['BF'] / float(
            hst['dur'])
        qf_loc = ncmp['qf_loc']
    else:
        dcycrr = 1.
        qf = 1.
        qf_loc = 1.

    #-affine matrix for the reconstructed images
    B = mmrimg.image_affine(datain, Cnt)

    #-time it
    stime = time.time()

    # import pdb; pdb.set_trace()

    #=========================================================================
    # OSEM RECONSTRUCTION
    #-------------------------------------------------------------------------
    for k in trange(itr, disable=not Cnt['VERBOSE'], desc="OSEM"):
        petprj.osem(img, msk, psng, rsng, ssng, nsng, asng, imgsens, txLUT,
                    axLUT, sinoTIdx, Cnt)
        if np.nansum(img) < 0.1:
            print '---------------------------------------------------------------------'
            print 'w> it seems there is not enough true data to render reasonable image.'
            print '---------------------------------------------------------------------'
            #img[:]=0
            itr = k
            break
        if recmod >= 3 and (((k < itr - 1) and (itr > 1))):  # or (itr==1)
            sct_time = time.time()

            sct = nipet.vsm(datain,
                            mumaps,
                            mmrimg.convert2e7(img, Cnt),
                            hst,
                            rsino,
                            scanner_params,
                            emmsk=emmskS,
                            return_ssrb=return_ssrb,
                            return_mask=return_mask)

            if isinstance(sct, dict):
                ssn = sct['sino']
            else:
                ssn = sct

            ssng = mmraux.remgaps(ssn, txLUT, Cnt)

            if Cnt['VERBOSE']:
                print 'i> scatter time:', (time.time() - sct_time)

        # save images during reconstruction if requested
        if store_itr and k in store_itr:
            im = mmrimg.convert2e7(img * (dcycrr * qf * qf_loc), Cnt)
            fout =  os.path.join(opth, os.path.basename(datain['lm_bf'])[:8] \
                + frmno +'_t'+str(hst['t0'])+'-'+str(hst['t1'])+'sec' \
                +'_itr'+str(k)+fcomment+'_inrecon.nii.gz')
            nimpa.array2nii(im[::-1, ::-1, :], B, fout)

    if Cnt['VERBOSE']: print 'i> recon time:', (time.time() - stime)
    #=========================================================================

    if Cnt['VERBOSE']:
        print 'i> applying decay correction of', dcycrr
        print 'i> applying quantification factor', qf, 'to the whole image for the frame duration of :', hst[
            'dur']

    img *= dcycrr * qf * qf_loc  #additional factor for making it quantitative in absolute terms (derived from measurements)

    #---- save images -----
    #-first convert to standard mMR image size
    im = mmrimg.convert2e7(img, Cnt)

    #-description text to NIfTI
    #-attenuation number: if only bed present then it is 0.5
    attnum = (1 * (np.sum(muh) > 0.5) + 1 * (np.sum(muo) > 0.5)) / 2.
    descrip =   'alg=osem'+ \
                ';sub=14'+ \
                ';att='+str(attnum*(recmod>0))+ \
                ';sct='+str(1*(recmod>1))+ \
                ';spn='+str(Cnt['SPN'])+ \
                ';itr='+str(itr) +\
                ';fwhm='+str(fwhm) +\
                ';t0='+str(hst['t0']) +\
                ';t1='+str(hst['t1']) +\
                ';dur='+str(hst['dur']) +\
                ';qf='+str(qf)

    if fwhm > 0:
        im = ndi.filters.gaussian_filter(im,
                                         fwhm2sig(fwhm, Cnt),
                                         mode='mirror')
    if store_img:
        fout =  os.path.join(opth, os.path.basename(datain['lm_bf'])[:8] \
                + frmno +'_t'+str(hst['t0'])+'-'+str(hst['t1'])+'sec' \
                +'_itr'+str(itr)+fcomment+'.nii.gz')
        if Cnt['VERBOSE']: print 'i> saving image to: ', fout
        nimpa.array2nii(im[::-1, ::-1, :], B, fout, descrip=descrip)

    # returning:
    # (0) E7 image [can be smoothed];
    # (1) file name of saved E7 image
    # (2) [optional] scatter sino
    # (3) [optional] single slice rebinned scatter
    # (4) [optional] mask for scatter scaling based on attenuation data
    # (5) [optional] random sino
    # if ret_sinos and recmod>=3:
    #     recout = namedtuple('recout', 'im, fpet, ssn, sssr, amsk, rsn')
    #     recout.im   = im
    #     recout.fpet = fout
    #     recout.ssn  = ssn
    #     recout.sssr = sssr
    #     recout.amsk = amsk
    #     recout.rsn  = rsino
    # else:
    #     recout = namedtuple('recout', 'im, fpet')
    #     recout.im   = im
    #     recout.fpet = fout
    if ret_sinos and recmod >= 3 and itr > 1:
        RecOut = namedtuple('RecOut', 'im, fpet, affine, ssn, sssr, amsk, rsn')
        recout = RecOut(im, fout, B, ssn, sct['ssrb'], sct['mask'], rsino)
    else:
        RecOut = namedtuple('RecOut', 'im, fpet, affine')
        recout = RecOut(im, fout, B)

    return recout