Esempio n. 1
0
def test_increment():
    # `example_mod` is defined in ../cuvec/src/example_mod/
    from cuvec.example_mod import increment2d_f
    a = cu.zeros((1337, 42), 'f')
    assert (a == 0).all()
    res = cu.asarray(increment2d_f(a.cuvec, a.cuvec))
    assert (a == 1).all()
    assert (res == 1).all()

    a[:] = 0
    assert (a == 0).all()
    assert (res == 0).all()

    res = cu.asarray(increment2d_f(a))
    assert (res == 1).all()
Esempio n. 2
0
def test_increment_return():
    from cuvec.example_mod import increment2d_f
    a = cu.zeros((1337, 42), 'f')
    assert (a == 0).all()
    res = cu.asarray(increment2d_f(a, a))
    assert (a == 1).all()
    del a
    assert (res == 1).all()
Esempio n. 3
0
def test_np_types():
    from cuvec.example_mod import increment2d_f
    f = cu.zeros((1337, 42), 'f')
    d = cu.zeros((1337, 42), 'd')
    cu.asarray(increment2d_f(f))
    cu.asarray(increment2d_f(f, f))
    with raises(TypeError):
        cu.asarray(increment2d_f(d))
    with raises(SystemError):
        # the TypeError is suppressed since a new output is generated
        cu.asarray(increment2d_f(f, d))
Esempio n. 4
0
def test_asarray():
    v = cu.asarray(np.random.random(shape))
    w = cu.CuVec(v)
    assert w.cuvec == v.cuvec
    assert (w == v).all()
    assert np.asarray(w.cuvec).data == np.asarray(v.cuvec).data
    x = cu.asarray(w.cuvec)
    assert x.cuvec == v.cuvec
    assert (x == v).all()
    assert np.asarray(x.cuvec).data == np.asarray(v.cuvec).data
    y = cu.asarray(x.tolist())
    assert y.cuvec != v.cuvec
    assert (y == v).all()
    assert np.asarray(y.cuvec).data == np.asarray(v.cuvec).data
    z = cu.asarray(v[:])
    assert z.cuvec != v.cuvec
    assert (z == v[:]).all()
    assert np.asarray(z.cuvec).data == np.asarray(v.cuvec).data
    s = cu.asarray(v[1:])
    assert s.cuvec != v.cuvec
    assert (s == v[1:]).all()
    assert np.asarray(s.cuvec).data != np.asarray(v.cuvec).data
Esempio n. 5
0
def test_cuda_array_interface():
    cupy = importorskip("cupy")
    v = cu.asarray(np.random.random(shape))
    assert hasattr(v, '__cuda_array_interface__')

    c = cupy.asarray(v)
    assert (c == v).all()
    c[0, 0, 0] = 1
    cu.dev_sync()
    assert c[0, 0, 0] == v[0, 0, 0]
    c[0, 0, 0] = 0
    cu.dev_sync()
    assert c[0, 0, 0] == v[0, 0, 0]

    ndarr = v + 1
    assert ndarr.shape == v.shape
    assert ndarr.dtype == v.dtype
    with raises(AttributeError):
        ndarr.__cuda_array_interface__
Esempio n. 6
0
def osemone(datain,
            mumaps,
            hst,
            scanner_params,
            recmod=3,
            itr=4,
            fwhm=0.,
            psf=None,
            mask_radius=29.,
            decay_ref_time=None,
            attnsino=None,
            sctsino=None,
            randsino=None,
            normcomp=None,
            emmskS=False,
            frmno='',
            fcomment='',
            outpath=None,
            fout=None,
            store_img=False,
            store_itr=None,
            ret_sinos=False):
    '''
    OSEM image reconstruction with several modes
    (with/without scatter and/or attenuation correction)

    Args:
      psf: Reconstruction with PSF, passed to `psf_config`
    '''

    # > Get particular scanner parameters: Constants, transaxial and axial LUTs
    Cnt = scanner_params['Cnt']
    txLUT = scanner_params['txLUT']
    axLUT = scanner_params['axLUT']

    # ---------- sort out OUTPUT ------------
    # -output file name for the reconstructed image
    if outpath is None:
        opth = os.path.join(datain['corepath'], 'reconstructed')
    else:
        opth = outpath

    # > file output name (the path is ignored if given)
    if fout is not None:
        # > get rid of folders
        fout = os.path.basename(fout)
        # > get rid of extension
        fout = fout.split('.')[0]

    if store_img is True or store_itr is not None:
        mmraux.create_dir(opth)

    return_ssrb, return_mask = ret_sinos, ret_sinos

    # ----------

    log.info('reconstruction in mode: %d', recmod)

    # get object and hardware mu-maps
    muh, muo = mumaps

    # get the GPU version of the image dims
    mus = mmrimg.convert2dev(muo + muh, Cnt)

    # remove gaps from the prompt sino
    psng = mmraux.remgaps(hst['psino'], txLUT, Cnt)

    # ========================================================================
    # GET NORM
    # -------------------------------------------------------------------------
    if normcomp is None:
        ncmp, _ = mmrnorm.get_components(datain, Cnt)
    else:
        ncmp = normcomp
        log.warning('using user-defined normalisation components')
    nsng = mmrnorm.get_norm_sino(datain,
                                 scanner_params,
                                 hst,
                                 normcomp=ncmp,
                                 gpu_dim=True)
    # ========================================================================

    # ========================================================================
    # ATTENUATION FACTORS FOR COMBINED OBJECT AND BED MU-MAP
    # -------------------------------------------------------------------------
    # > combine attenuation and norm together depending on reconstruction mode
    if recmod == 0:
        asng = np.ones(psng.shape, dtype=np.float32)
    else:
        # > check if the attenuation sino is given as an array
        if isinstance(attnsino, np.ndarray) \
                and attnsino.shape==(Cnt['NSN11'], Cnt['NSANGLES'], Cnt['NSBINS']):
            asng = mmraux.remgaps(attnsino, txLUT, Cnt)
            log.info('using provided attenuation factor sinogram')
        elif isinstance(attnsino, np.ndarray) \
                and attnsino.shape==(Cnt['Naw'], Cnt['NSN11']):
            asng = attnsino
            log.info('using provided attenuation factor sinogram')
        else:
            asng = cu.zeros(psng.shape, dtype=np.float32)
            petprj.fprj(asng.cuvec,
                        cu.asarray(mus).cuvec, txLUT, axLUT,
                        np.array([-1], dtype=np.int32), Cnt, 1)
    # > combine attenuation and normalisation
    ansng = asng * nsng
    # ========================================================================

    # ========================================================================
    # Randoms
    # -------------------------------------------------------------------------
    if isinstance(randsino, np.ndarray) \
            and randsino.shape==(Cnt['NSN11'], Cnt['NSANGLES'], Cnt['NSBINS']):
        rsino = randsino
        rsng = mmraux.remgaps(randsino, txLUT, Cnt)
    else:
        rsino, snglmap = randoms(hst, scanner_params)
        rsng = mmraux.remgaps(rsino, txLUT, Cnt)
    # ========================================================================

    # ========================================================================
    # SCAT
    # -------------------------------------------------------------------------
    if recmod == 2:
        if sctsino is not None:
            ssng = mmraux.remgaps(sctsino, txLUT, Cnt)
        elif sctsino is None and os.path.isfile(datain['em_crr']):
            emd = nimpa.getnii(datain['em_crr'])
            ssn = vsm(
                datain,
                mumaps,
                emd['im'],
                scanner_params,
                histo=hst,
                rsino=rsino,
                prcnt_scl=0.1,
                emmsk=False,
            )
            ssng = mmraux.remgaps(ssn, txLUT, Cnt)
        else:
            raise ValueError(
                "No emission image available for scatter estimation! " +
                " Check if it's present or the path is correct.")
    else:
        ssng = np.zeros(rsng.shape, dtype=rsng.dtype)
    # ========================================================================

    log.info('------ OSEM (%d) -------', itr)
    # ------------------------------------
    Sn = 14  # number of subsets

    # -get one subset to get number of projection bins in a subset
    Sprj, s = get_subsets14(0, scanner_params)
    Nprj = len(Sprj)
    # -init subset array and sensitivity image for a given subset
    sinoTIdx = np.zeros((Sn, Nprj + 1), dtype=np.int32)
    # -init sensitivity images for each subset
    imgsens = np.zeros((Sn, Cnt['SZ_IMY'], Cnt['SZ_IMX'], Cnt['SZ_IMZ']),
                       dtype=np.float32)
    tmpsens = cu.zeros((Cnt['SZ_IMY'], Cnt['SZ_IMX'], Cnt['SZ_IMZ']),
                       dtype=np.float32)
    for n in range(Sn):
        # first number of projection for the given subset
        sinoTIdx[n, 0] = Nprj
        sinoTIdx[n, 1:], s = get_subsets14(n, scanner_params)
        # sensitivity image
        petprj.bprj(tmpsens.cuvec,
                    cu.asarray(ansng[sinoTIdx[n, 1:], :]).cuvec, txLUT, axLUT,
                    sinoTIdx[n, 1:], Cnt)
        imgsens[n] = tmpsens
    del tmpsens
    # -------------------------------------

    # -mask for reconstructed image.  anything outside it is set to zero
    msk = mmrimg.get_cylinder(
        Cnt, rad=mask_radius, xo=0, yo=0, unival=1, gpu_dim=True) > 0.9

    # -init image
    img = np.ones((Cnt['SZ_IMY'], Cnt['SZ_IMX'], Cnt['SZ_IMZ']),
                  dtype=np.float32)

    # -decay correction
    lmbd = np.log(2) / resources.riLUT[Cnt['ISOTOPE']]['thalf']
    if Cnt['DCYCRR'] and 't0' in hst and 'dur' in hst:
        # > decay correct to the reference time (e.g., injection time) if provided
        # > otherwise correct in reference to the scan start time (using the time
        # > past from the start to the start time frame)
        if decay_ref_time is not None:
            tref = decay_ref_time
        else:
            tref = hst['t0']

        dcycrr = np.exp(
            lmbd * tref) * lmbd * hst['dur'] / (1 - np.exp(-lmbd * hst['dur']))
        # apply quantitative correction to the image
        qf = ncmp['qf'] / resources.riLUT[Cnt['ISOTOPE']]['BF'] / float(
            hst['dur'])
        qf_loc = ncmp['qf_loc']

    elif not Cnt['DCYCRR'] and 't0' in hst and 'dur' in hst:
        dcycrr = 1.
        # apply quantitative correction to the image
        qf = ncmp['qf'] / resources.riLUT[Cnt['ISOTOPE']]['BF'] / float(
            hst['dur'])
        qf_loc = ncmp['qf_loc']

    else:
        dcycrr = 1.
        qf = 1.
        qf_loc = 1.

    # -affine matrix for the reconstructed images
    B = mmrimg.image_affine(datain, Cnt)

    # resolution modelling
    psfkernel = psf_config(psf, Cnt)

    # -time it
    stime = time.time()

    # import pdb; pdb.set_trace()

    # ========================================================================
    # OSEM RECONSTRUCTION
    # -------------------------------------------------------------------------
    with trange(itr,
                desc="OSEM",
                disable=log.getEffectiveLevel() > logging.INFO,
                leave=log.getEffectiveLevel() <= logging.INFO) as pbar:

        for k in pbar:

            petprj.osem(img, psng, rsng, ssng, nsng, asng, sinoTIdx, imgsens,
                        msk, psfkernel, txLUT, axLUT, Cnt)

            if np.nansum(img) < 0.1:
                log.warning(
                    'it seems there is not enough true data to render reasonable image'
                )
                # img[:]=0
                itr = k
                break
            if recmod >= 3 and k < itr - 1 and itr > 1:
                sct_time = time.time()
                sct = vsm(datain,
                          mumaps,
                          mmrimg.convert2e7(img, Cnt),
                          scanner_params,
                          histo=hst,
                          rsino=rsino,
                          emmsk=emmskS,
                          return_ssrb=return_ssrb,
                          return_mask=return_mask)

                if isinstance(sct, dict):
                    ssn = sct['sino']
                else:
                    ssn = sct

                ssng = mmraux.remgaps(ssn, txLUT, Cnt)
                pbar.set_postfix(scatter="%.3gs" % (time.time() - sct_time))
            # save images during reconstruction if requested
            if store_itr and (k + 1) in store_itr:
                im = mmrimg.convert2e7(img * (dcycrr * qf * qf_loc), Cnt)

                if fout is None:
                    fpet = os.path.join(opth, (
                        os.path.basename(datain['lm_bf'])[:16].replace(
                            '.', '-') +
                        f"{frmno}_t{hst['t0']}-{hst['t1']}sec_itr{k+1}{fcomment}_inrecon.nii.gz"
                    ))
                else:
                    fpet = os.path.join(
                        opth, fout + f'_itr{k+1}{fcomment}_inrecon.nii.gz')

                nimpa.array2nii(im[::-1, ::-1, :], B, fpet)

    log.info('recon time: %.3g', time.time() - stime)
    # ========================================================================

    log.info('applying decay correction of: %r', dcycrr)
    log.info('applying quantification factor: %r to the whole image', qf)
    log.info('for the frame duration of: %r', hst['dur'])

    # additional factor for making it quantitative in absolute terms (derived from measurements)
    img *= dcycrr * qf * qf_loc

    # ---- save images -----
    # -first convert to standard mMR image size
    im = mmrimg.convert2e7(img, Cnt)

    # -description text to NIfTI
    # -attenuation number: if only bed present then it is 0.5
    attnum = (1 * (np.sum(muh) > 0.5) + 1 * (np.sum(muo) > 0.5)) / 2.
    descrip = (f"alg=osem"
               f";sub=14"
               f";att={attnum*(recmod>0)}"
               f";sct={1*(recmod>1)}"
               f";spn={Cnt['SPN']}"
               f";itr={itr}"
               f";fwhm=0"
               f";t0={hst['t0']}"
               f";t1={hst['t1']}"
               f";dur={hst['dur']}"
               f";qf={qf}")

    # > file name of the output reconstructed image
    # > (maybe used later even if not stored now)
    if fout is None:
        fpet = os.path.join(
            opth,
            (os.path.basename(datain['lm_bf']).split('.')[0] +
             f"{frmno}_t{hst['t0']}-{hst['t1']}sec_itr{itr}{fcomment}.nii.gz"))
    else:
        fpet = os.path.join(opth, fout + f'_itr{itr}{fcomment}.nii.gz')

    if store_img:
        log.info('saving image to: %s', fpet)
        nimpa.array2nii(im[::-1, ::-1, :], B, fpet, descrip=descrip)

    im_smo = None
    fsmo = None
    if fwhm > 0:
        im_smo = ndi.filters.gaussian_filter(im,
                                             fwhm2sig(fwhm,
                                                      voxsize=Cnt['SZ_VOXY'] *
                                                      10),
                                             mode='mirror')

        if store_img:
            fsmo = fpet.split('.nii.gz')[0] + '_smo-' + str(fwhm).replace(
                '.', '-') + 'mm.nii.gz'
            log.info('saving smoothed image to: ' + fsmo)
            descrip.replace(';fwhm=0', ';fwhm=str(fwhm)')
            nimpa.array2nii(im_smo[::-1, ::-1, :], B, fsmo, descrip=descrip)

    # returning:
    # (0) E7 image [can be smoothed];
    # (1) file name of saved E7 image
    # (2) [optional] scatter sino
    # (3) [optional] single slice rebinned scatter
    # (4) [optional] mask for scatter scaling based on attenuation data
    # (5) [optional] random sino
    # if ret_sinos and recmod>=3:
    #     recout = namedtuple('recout', 'im, fpet, ssn, sssr, amsk, rsn')
    #     recout.im   = im
    #     recout.fpet = fout
    #     recout.ssn  = ssn
    #     recout.sssr = sssr
    #     recout.amsk = amsk
    #     recout.rsn  = rsino
    # else:
    #     recout = namedtuple('recout', 'im, fpet')
    #     recout.im   = im
    #     recout.fpet = fout

    if ret_sinos and recmod >= 3 and itr > 1:
        RecOut = namedtuple(
            'RecOut', 'im, fpet, imsmo, fsmo, affine, ssn, sssr, amsk, rsn')
        recout = RecOut(im, fpet, im_smo, fsmo, B, ssn, sct['ssrb'],
                        sct['mask'], rsino)
    else:
        RecOut = namedtuple('RecOut', 'im, fpet, imsmo, fsmo, affine')
        recout = RecOut(im, fpet, im_smo, fsmo, B)

    return recout
Esempio n. 7
0
def simulate_recon(
    measured_sino,
    ctim,
    scanner_params,
    simulate_3d=False,
    nitr=60,
    fwhm_rm=0.,
    slice_idx=-1,
    randoms=None,
    scatter=None,
    mu_input=False,
    msk_radius=29.,
    psf=None,
):
    '''
    Reconstruct PET image from simulated input data
    using the EM-ML (2D) or OSEM (3D) algorithm.

    measured_sino  : simulated emission data with photon attenuation
    ctim  : either a 2D CT image or a 3D CT image from which a 2D slice
        is chosen (slice_idx) for estimation of the attenuation factors
    slice_idx  : index to extract one 2D slice for this simulation
        if input image is 3D
    nitr  : number of iterations used for the EM-ML reconstruction algorithm
    scanner_params  : scanner parameters containing scanner constants and
        axial and transaxial look up tables (LUTs)
    randoms  : randoms and scatter events (optional)
    '''
    # > decompose the scanner constants and LUTs for easier access
    Cnt = scanner_params['Cnt']
    txLUT = scanner_params['txLUT']
    axLUT = scanner_params['axLUT']
    psfkernel = mmrrec.psf_config(psf, Cnt)

    if simulate_3d:
        if ctim.ndim!=3 \
                or ctim.shape!=(Cnt['SO_IMZ'], Cnt['SO_IMY'], Cnt['SO_IMX']):
            raise ValueError(
                'The CT/mu-map image does not match the scanner image shape.')
    else:
        # > 2D case with reduced rings
        if len(ctim.shape) == 3:
            # make sure that the shape of the input image matches the image size of the scanner
            if ctim.shape[1:] != (Cnt['SO_IMY'], Cnt['SO_IMX']):
                raise ValueError(
                    'The input image shape for x and y does not match the scanner image size.'
                )
            # pick the right slice index (slice_idx) if not given or mistaken
            if slice_idx < 0:
                log.warning(
                    'the axial index <slice_idx> is chosen to be in the middle of axial FOV.'
                )
                slice_idx = ctim.shape[0] / 2
            if slice_idx >= ctim.shape[0]:
                raise ValueError(
                    'The axial index for 2D slice selection is outside the image.'
                )
        elif len(ctim.shape) == 2:
            # make sure that the shape of the input image matches the image size of the scanner
            if ctim.shape != (Cnt['SO_IMY'], Cnt['SO_IMX']):
                raise ValueError(
                    'The input image shape for x and y does not match the scanner image size.'
                )
            ctim.shape = (1, ) + ctim.shape
            slice_idx = 0

        if 'rSZ_IMZ' not in Cnt:
            raise ValueError('Missing reduced axial FOV parameters.')

    # --------------------
    if mu_input:
        mui = ctim
    else:
        # > get the mu-map [1/cm] from CT [HU]
        mui = nimpa.ct2mu(ctim)

    # > get rid of negative values
    mui[mui < 0] = 0
    # --------------------

    if simulate_3d:
        rmu = mui
        # > number of axial sinograms
        nsinos = Cnt['NSN11']
    else:
        # --------------------
        # > create a number of slides of the same chosen image slice
        # for reduced (fast) 3D simulation
        rmu = mui[slice_idx, :, :]
        rmu.shape = (1, ) + rmu.shape
        rmu = np.repeat(rmu, Cnt['rSZ_IMZ'], axis=0)
        # --------------------
        # > number of axial sinograms
        nsinos = Cnt['rNSN1']

    # import pdb; pdb.set_trace()

    # > attenuation factor sinogram
    attsino = mmrprj.frwd_prj(rmu,
                              scanner_params,
                              attenuation=True,
                              dev_out=True)

    nrmsino = np.ones(attsino.shape, dtype=np.float32)

    # > randoms and scatter put together
    if isinstance(randoms,
                  np.ndarray) and measured_sino.shape == randoms.shape:
        rsng = mmraux.remgaps(randoms, txLUT, Cnt)
    else:
        rsng = 1e-5 * np.ones((Cnt['Naw'], nsinos), dtype=np.float32)

    if isinstance(scatter,
                  np.ndarray) and measured_sino.shape == scatter.shape:
        ssng = mmraux.remgaps(scatter, txLUT, Cnt)
    else:
        ssng = 1e-5 * np.ones((Cnt['Naw'], nsinos), dtype=np.float32)

    # resolution modelling
    Cnt['SIGMA_RM'] = mmrrec.fwhm2sig(fwhm_rm, voxsize=Cnt['SZ_VOXZ'] *
                                      10) if fwhm_rm else 0

    if simulate_3d:
        log.debug('------ OSEM (%d) -------', nitr)

        # measured sinogram in GPU-enabled shape
        psng = mmraux.remgaps(measured_sino.astype(np.uint16), txLUT, Cnt)

        # > mask for reconstructed image.  anything outside it is set to zero
        msk = mmrimg.get_cylinder(
            Cnt, rad=msk_radius, xo=0, yo=0, unival=1, gpu_dim=True) > 0.9

        # > init image
        eimg = np.ones((Cnt['SZ_IMY'], Cnt['SZ_IMX'], Cnt['SZ_IMZ']),
                       dtype=np.float32)

        # ------------------------------------
        Sn = 14  # number of subsets

        # -get one subset to get number of projection bins in a subset
        Sprj, s = mmrrec.get_subsets14(0, scanner_params)
        Nprj = len(Sprj)

        # > init subset array and sensitivity image for a given subset
        sinoTIdx = np.zeros((Sn, Nprj + 1), dtype=np.int32)

        # > init sensitivity images for each subset
        sim = np.zeros((Sn, Cnt['SZ_IMY'], Cnt['SZ_IMX'], Cnt['SZ_IMZ']),
                       dtype=np.float32)
        tmpsim = cu.zeros((Cnt['SZ_IMY'], Cnt['SZ_IMX'], Cnt['SZ_IMZ']),
                          dtype=np.float32)

        for n in trange(Sn,
                        desc="sensitivity",
                        leave=log.getEffectiveLevel() < logging.INFO):
            # first number of projection for the given subset
            sinoTIdx[n, 0] = Nprj
            sinoTIdx[n, 1:], s = mmrrec.get_subsets14(n, scanner_params)

            # > sensitivity image
            petprj.bprj(tmpsim.cuvec,
                        cu.asarray(attsino[sinoTIdx[n, 1:], :]).cuvec, txLUT,
                        axLUT, sinoTIdx[n, 1:], Cnt)
            sim[n] = tmpsim
        del tmpsim
        # -------------------------------------

        for _ in trange(nitr,
                        desc="OSEM",
                        disable=log.getEffectiveLevel() > logging.INFO,
                        leave=log.getEffectiveLevel() < logging.INFO):
            petprj.osem(eimg, psng, rsng, ssng, nrmsino, attsino, sinoTIdx,
                        sim, msk, psfkernel, txLUT, axLUT, Cnt)
        eim = mmrimg.convert2e7(eimg, Cnt)

    else:

        def psf(x, output=None):
            if Cnt['SIGMA_RM']:
                x = ndi.gaussian_filter(x,
                                        sigma=Cnt['SIGMA_RM'],
                                        mode='constant',
                                        output=None)
            return x

        # > estimated image, initialised to ones
        eim = np.ones(rmu.shape, dtype=np.float32)

        msk = mmrimg.get_cylinder(
            Cnt, rad=msk_radius, xo=0, yo=0, unival=1, gpu_dim=False) > 0.9

        # > sensitivity image for the EM-ML reconstruction
        sim = mmrprj.back_prj(attsino, scanner_params)
        sim_inv = 1 / psf(sim)
        sim_inv[~msk] = 0

        rndsct = rsng + ssng
        for _ in trange(nitr,
                        desc="MLEM",
                        disable=log.getEffectiveLevel() > logging.INFO,
                        leave=log.getEffectiveLevel() < logging.INFO):
            # > remove gaps from the measured sinogram
            # > then forward project the estimated image
            # > after which divide the measured sinogram
            # by the estimated sinogram (forward projected)
            crrsino = (
                mmraux.remgaps(measured_sino, txLUT, Cnt) /
                (mmrprj.frwd_prj(psf(eim), scanner_params, dev_out=True) +
                 rndsct))

            # > back project the correction factors sinogram
            bim = mmrprj.back_prj(crrsino, scanner_params)
            bim = psf(bim, output=bim)

            # > divide the back-projected image by the sensitivity image
            # > update the estimated image and remove NaNs
            eim *= bim * sim_inv
            eim[np.isnan(eim)] = 0

    return eim
Esempio n. 8
0
def vsm(
    datain,
    mumaps,
    em,
    scanner_params,
    histo=None,
    rsino=None,
    prcnt_scl=0.1,
    fwhm_input=0.42,
    mask_threshlod=0.999,
    snmsk=None,
    emmsk=False,
    interpolate=True,
    return_uninterp=False,
    return_ssrb=False,
    return_mask=False,
    return_scaling=False,
    scaling=True,
    self_scaling=False,
    save_sax=False,
):
    '''
    Voxel-driven scatter modelling (VSM).
    Obtain a scatter sinogram using the mu-maps (hardware and object mu-maps)
    an estimate of emission image, the prompt measured sinogram, an
    estimate of the randoms sinogram and a normalisation sinogram.
    Input:
        - datain:       Contains the data used for scatter-specific detector
                        normalisation.  May also include the non-corrected
                        emission image used for masking, when requested.
        - mumaps:       A tuple of hardware and object mu-maps (in this order).
        - em:           An estimate of the emission image.
        - histo:          Dictionary containing the histogrammed measured data into
                        sinograms.
        - rsino:       Randoms sinogram (3D).  Needed for proper scaling of
                        scatter to the prompt data.
        - scanner_params: Scanner specific parameters.
        - prcnt_scl:    Ratio of the maximum scatter intensities below which the
                        scatter is not used for fitting it to the tails of prompt
                        data.  Default is 10%.
        - emmsk:        When 'True' it will use uncorrected emission image for
                        masking the sources (voxels) of photons to be used in the
                        scatter modelling.
        - scaling:      performs scaling to the data (sinogram)
        - self_scaling: Scaling is performed on span-1 without the help of SSR
                        scaling and using the sax factors (scatter axial factors).
                        If False (default), the sax factors have to be provided.
        - sax:          Scatter axial factors used for scaling with SSR sinograms.

    '''

    # > decompose constants, transaxial and axial LUTs are extracted
    Cnt = scanner_params['Cnt']
    txLUT = scanner_params['txLUT']
    axLUT = scanner_params['axLUT']

    if self_scaling:
        scaling = True

    # > decompose mu-maps
    muh, muo = mumaps

    if emmsk and not os.path.isfile(datain['em_nocrr']):
        log.info(
            'reconstructing emission data without scatter and attenuation corrections'
            ' for mask generation...')
        recnac = mmrrec.osemone(datain,
                                mumaps,
                                histo,
                                scanner_params,
                                recmod=0,
                                itr=3,
                                fwhm=2.0,
                                store_img=True)
        datain['em_nocrr'] = recnac.fpet

    # if rsino is None and not histo is None and 'rsino' in histo:
    #     rsino = histo['rsino']

    # > if histogram data or randoms sinogram not given, then no scaling or normalisation
    if (histo is None) or (rsino is None):
        scaling = False

    # -get the normalisation components
    nrmcmp, nhdr = mmrnorm.get_components(datain, Cnt)

    # -smooth for defining the sino scatter only regions
    if fwhm_input > 0.:
        mu_sctonly = ndi.filters.gaussian_filter(mmrimg.convert2dev(muo, Cnt),
                                                 fwhm2sig(fwhm_input, Cnt),
                                                 mode='mirror')
    else:
        mu_sctonly = muo

    if Cnt['SPN'] == 1:
        snno = Cnt['NSN1']
        snno_ = Cnt['NSN64']
        ssrlut = axLUT['sn1_ssrb']
        saxnrm = nrmcmp['sax_f1']
    elif Cnt['SPN'] == 11:
        snno = Cnt['NSN11']
        snno_ = snno
        ssrlut = axLUT['sn11_ssrb']
        saxnrm = nrmcmp['sax_f11']

    # LUTs for scatter
    sctLUT = get_sctLUT(scanner_params)

    # > smooth before scaling/down-sampling the mu-map and emission images
    if fwhm_input > 0.:
        muim = ndi.filters.gaussian_filter(muo + muh,
                                           fwhm2sig(fwhm_input, Cnt),
                                           mode='mirror')
        emim = ndi.filters.gaussian_filter(em,
                                           fwhm2sig(fwhm_input, Cnt),
                                           mode='mirror')
    else:
        muim = muo + muh
        emim = em

    muim = ndi.interpolation.zoom(muim, Cnt['SCTSCLMU'],
                                  order=3)  # (0.499, 0.5, 0.5)
    emim = ndi.interpolation.zoom(emim, Cnt['SCTSCLEM'],
                                  order=3)  # (0.34, 0.33, 0.33)

    # -smooth the mu-map for mask creation.
    # the mask contains voxels for which attenuation ray LUT is found.
    if fwhm_input > 0.:
        smomu = ndi.filters.gaussian_filter(muim,
                                            fwhm2sig(fwhm_input, Cnt),
                                            mode='mirror')
        mumsk = np.int8(smomu > 0.003)
    else:
        mumsk = np.int8(muim > 0.001)

    # CORE SCATTER ESTIMATION
    NSCRS, NSRNG = sctLUT['NSCRS'], sctLUT['NSRNG']
    sctout = {
        'sct_3d':
        np.zeros((Cnt['TOFBINN'], snno_, NSCRS, NSCRS), dtype=np.float32),
        'sct_val':
        np.zeros((Cnt['TOFBINN'], NSRNG, NSCRS, NSRNG, NSCRS),
                 dtype=np.float32)
    }

    # <<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>>
    nifty_scatter.vsm(sctout, muim, mumsk, emim, sctLUT, axLUT, Cnt)
    # <<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>>

    sct3d = sctout['sct_3d']
    sctind = sctLUT['sct2aw']

    log.debug('total scatter sum: {}'.format(np.sum(sct3d)))

    # -------------------------------------------------------------------
    # > initialise output dictionary
    out = {}

    if return_uninterp:
        out['uninterp'] = sct3d
        out['indexes'] = sctind
    # -------------------------------------------------------------------

    if np.sum(sct3d) < 1e-04:
        log.warning('total scatter below threshold: {}'.format(np.sum(sct3d)))
        sss = np.zeros((snno, Cnt['NSANGLES'], Cnt['NSBINS']),
                       dtype=np.float32)
        asnmsk = np.zeros((snno, Cnt['NSANGLES'], Cnt['NSBINS']),
                          dtype=np.float32)
        sssr = np.zeros((Cnt['NSEG0'], Cnt['NSANGLES'], Cnt['NSBINS']),
                        dtype=np.float32)
        return sss, sssr, asnmsk

    # import pdb; pdb.set_trace()

    # -------------------------------------------------------------------
    if interpolate:
        # > interpolate basic scatter distributions into full size and
        # > transfer them to sinograms

        log.debug('transaxial scatter interpolation...')
        start = time.time()
        ssn, sssr = intrp_bsct(sct3d, Cnt, sctLUT, ssrlut)
        stop = time.time()
        log.debug('scatter interpolation done in {} sec.'.format(stop - start))

        if not scaling:
            out['ssrb'] = sssr
            out['sino'] = ssn
            return out
    else:
        return out
    # -------------------------------------------------------------------

    # -------------------------------------------------------------------
    # import pdb; pdb.set_trace()
    '''
    debugging scatter:
    import matplotlib.pyplot as plt
    ss = np.squeeze(sct3d)
    ss = np.sum(ss, axis=0)
    plt.matshow(ss)
    plt.matshow(sct3d[0,41,...])
    plt.matshow(np.sum(sct3d[0,0:72,...],axis=0))

    plt.plot(np.sum(sct3d, axis=(0,2,3)))

    rslt = sctout['sct_val']
    rslt.shape
    plt.matshow(rslt[0,4,:,4,:])

    debugging scatter:
    plt.matshow(np.sum(sssr, axis=(0,1)))
    plt.matshow(np.sum(ssn, axis=(0,1)))
    plt.matshow(sssr[0,70,...])
    plt.matshow(sssr[0,50,...])
    '''
    # -------------------------------------------------------------------

    # > get SSR for randoms from span-1 or span-11
    rssr = np.zeros((Cnt['NSEG0'], Cnt['NSANGLES'], Cnt['NSBINS']),
                    dtype=np.float32)
    if scaling:
        for i in range(snno):
            rssr[ssrlut[i], :, :] += rsino[i, :, :]

    # ATTENUATION FRACTIONS for scatter only regions, and NORMALISATION for all SCATTER
    # <<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>>
    currentspan = Cnt['SPN']
    Cnt['SPN'] = 1
    atto = cu.zeros((txLUT['Naw'], Cnt['NSN1']), dtype=np.float32)
    petprj.fprj(atto.cuvec,
                cu.asarray(mu_sctonly).cuvec, txLUT, axLUT,
                np.array([-1], dtype=np.int32), Cnt, 1)
    atto = mmraux.putgaps(atto, txLUT, Cnt)
    # --------------------------------------------------------------
    # > get norm components setting the geometry and axial to ones
    # as they are accounted for differently
    nrmcmp['geo'][:] = 1
    nrmcmp['axe1'][:] = 1
    # get sino with no gaps
    nrmg = np.zeros((txLUT['Naw'], Cnt['NSN1']), dtype=np.float32)
    mmr_auxe.norm(nrmg, nrmcmp, histo['buckets'], axLUT, txLUT['aw2ali'], Cnt)
    nrm = mmraux.putgaps(nrmg, txLUT, Cnt)
    # --------------------------------------------------------------

    # > get attenuation + norm in (span-11) and SSR
    attossr = np.zeros((Cnt['NSEG0'], Cnt['NSANGLES'], Cnt['NSBINS']),
                       dtype=np.float32)
    nrmsssr = np.zeros((Cnt['NSEG0'], Cnt['NSANGLES'], Cnt['NSBINS']),
                       dtype=np.float32)

    for i in range(Cnt['NSN1']):
        si = axLUT['sn1_ssrb'][i]
        attossr[si, :, :] += atto[i, :, :] / float(axLUT['sn1_ssrno'][si])
        nrmsssr[si, :, :] += nrm[i, :, :] / float(axLUT['sn1_ssrno'][si])
    if currentspan == 11:
        Cnt['SPN'] = 11
        nrmg = np.zeros((txLUT['Naw'], snno), dtype=np.float32)
        mmr_auxe.norm(nrmg, nrmcmp, histo['buckets'], axLUT, txLUT['aw2ali'],
                      Cnt)
        nrm = mmraux.putgaps(nrmg, txLUT, Cnt)
    # --------------------------------------------------------------

    # <<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>>

    # get the mask for the object from uncorrected emission image
    if emmsk and os.path.isfile(datain['em_nocrr']):
        nim = nib.load(datain['em_nocrr'])
        eim = nim.get_fdata(dtype=np.float32)
        eim = eim[:, ::-1, ::-1]
        eim = np.transpose(eim, (2, 1, 0))

        em_sctonly = ndi.filters.gaussian_filter(eim,
                                                 fwhm2sig(.6, Cnt),
                                                 mode='mirror')
        msk = np.float32(em_sctonly > 0.07 * np.max(em_sctonly))
        msk = ndi.filters.gaussian_filter(msk,
                                          fwhm2sig(.6, Cnt),
                                          mode='mirror')
        msk = np.float32(msk > 0.01)
        msksn = mmrprj.frwd_prj(msk, txLUT, axLUT, Cnt)

        mssr = mmraux.sino2ssr(msksn, axLUT, Cnt)
        mssr = mssr > 0
    else:
        mssr = np.zeros((Cnt['NSEG0'], Cnt['NSANGLES'], Cnt['NSBINS']),
                        dtype=bool)

    # <<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>><<+>>

    # ======= SCALING ========
    # > scale scatter using non-TOF SSRB sinograms

    # > gap mask
    rmsk = (txLUT['msino'] > 0).T
    rmsk.shape = (1, Cnt['NSANGLES'], Cnt['NSBINS'])
    rmsk = np.repeat(rmsk, Cnt['NSEG0'], axis=0)

    # > include attenuating object into the mask (and the emission if selected)
    amsksn = np.logical_and(attossr >= mask_threshlod, rmsk) * ~mssr

    # > scaling factors for SSRB scatter
    scl_ssr = np.zeros((Cnt['NSEG0']), dtype=np.float32)

    for sni in range(Cnt['NSEG0']):
        # > region for scaling defined by the percentage of lowest
        # > but usable/significant scatter
        thrshld = prcnt_scl * np.max(sssr[sni, :, :])
        amsksn[sni, :, :] *= (sssr[sni, :, :] > thrshld)
        amsk = amsksn[sni, :, :]

        # > normalised estimated scatter
        mssn = sssr[sni, :, :] * nrmsssr[sni, :, :]
        vpsn = histo['pssr'][sni, amsk] - rssr[sni, amsk]
        scl_ssr[sni] = np.sum(vpsn) / np.sum(mssn[amsk])

        # > scatter SSRB sinogram output
        sssr[sni, :, :] *= nrmsssr[sni, :, :] * scl_ssr[sni]

    # === scale scatter for the full-size sinogram ===
    sss = np.zeros((snno, Cnt['NSANGLES'], Cnt['NSBINS']), dtype=np.float32)
    for i in range(snno):
        sss[i, :, :] = ssn[i, :, :] * scl_ssr[ssrlut[i]] * saxnrm[i] * nrm[
            i, :, :]
    '''
    # > debug
    si = 60
    ai = 60
    matshow(sssr[si,...])

    figure()
    plot(histo['pssr'][si,ai,:])
    plot(rssr[si,ai,:]+sssr[si,ai,:])

    plot(np.sum(histo['pssr'],axis=(0,1)))
    plot(np.sum(rssr+sssr,axis=(0,1)))
    '''

    # === OUTPUT ===
    if return_uninterp:
        out['uninterp'] = sct3d
        out['indexes'] = sctind

    if return_ssrb:
        out['ssrb'] = sssr
        out['rssr'] = rssr

    if return_mask:
        out['mask'] = amsksn

    if return_scaling:
        out['scaling'] = scl_ssr

    # if self_scaling:
    #     out['scl_sn1'] = scl_ssn

    if not out:
        return sss
    else:
        out['sino'] = sss
        return out