def reconstruct_size_distribution(r,a,t,sig_g,sig_d,alpha,rho_s,T,M_star,v_f,a_0=1e-4,fix_pd=None,ir_0=2,return_a=False):
    Reconstructs the approximate size distribution based on the recipe of Birnstiel et al. 2015, ApJ.
    r : array
    :    radial grid [cm]
    a : array
    :    grain size grid, should have plenty of range and bins to work [cm]
    t : float
    :    time of the snapshot [s]
    sig_g : array
    :    gas surface densities on grid r [g cm^-2]
    sig_d : array
    :    dust surface densities on grid r [g cm^-2]
    alpha : array
    :    alpha parameter on grid r [-]
    rho_s : float
    :    material (bulk) density of the dust grains [g cm^-3]
    T : array
    :    temperature on grid r [K]
    M_star : float
    :    stellar mass [g]
    v_f : float
    :    fragmentation velocity [cm s^-1]
    a_0 : float
    :    initial particle size [cm]
    fix_pd : None | float
    :    float: set the inward diffusion slope to this values
         None:  calculate it
    return_a : bool
        If True, then in addition to the default output, also the 
        three size limits: drift, fragmentation, growth timescale
    sig_sol : array
    :    2D grain size distribution on grid r x a
    a_max : array
    :    maximum particle size on grid r [cm]
    r_f : float
    :    critical fragmentation radius (see paper) [cm]
    sig_1, sig_2, sig_3 : array
    :    grain size distributions corresponding to the regions discussed in the paper
    import warnings
    import numpy as np
    from aux_functions     import dlydlx, k_b, mu, m_p, Grav, pi, sig_h2, trace_line_though_grid
    from scipy.integrate   import cumtrapz
    from scipy.interpolate import interp1d
    floor = 1e-100
    if fix_pd is not None: print('WARNING: fixing the inward diffusion slope')    
    # calculate derived quantities
    alpha  = alpha*np.ones(len(r))
    cs     = np.sqrt(k_b*T/mu/m_p)
    om     = np.sqrt(Grav*M_star/r**3)
    vk     = r*om
    gamma  = dlydlx(r,sig_g*np.sqrt(T)*om)
    p      = -dlydlx(r,sig_g)
    q      = -dlydlx(r,T)
    # fragmentation size
    b      = 3.*alpha*cs**2/v_f**2
    a_fr   = sig_g/(pi*rho_s)*(b-np.sqrt(b**2-4.))
    a_fr[np.isnan(a_fr)] = np.inf
    # drift size
    a_dr   = 0.55*2/pi*sig_d/rho_s*r**2.*(Grav*M_star/r**3)/(abs(gamma)*cs**2)
    # time dependent growth 
    t_grow = sig_g/(om*sig_d)
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', r'overflow encountered in exp')
        a_grow = a_0*np.exp(t/t_grow)
    # the minimum of all of those
    a_max  = np.minimum(np.minimum(a_fr,a_dr),a_grow)
    if a_max.max()>a[-1]: raise ValueError('Maximum grain size larger than size grid. Increase upper end of size grid.')
    # transition to turblent velocities
    Re = alpha*sig_g*sig_h2/(2*mu*m_p)
    a_bt = (8*sig_g/(pi*rho_s)*Re**-0.25*np.sqrt(mu*m_p/(3*pi*alpha))*(4*pi/3*rho_s)**-0.5)**0.4
    # apply the reconstruction recipes
    n_r       = len(r)
    n_a       = len(a)
    sig_1     = floor*np.ones([len(a),n_r]) # the fragment distribution
    sig_2     = floor*np.ones(sig_1.shape)  # outward mixed fragments
    sig_3     = floor*np.ones(sig_1.shape)  # drift distribution diffused
    prev_frag = None
    # Radial loop: to fill in all fragmentation regions
    frag_mask = a_max==a_fr
    frag_idx  = np.where(frag_mask)[0]
    # select which cells to avoid
    if ir_0 is None:
        ir_0 = np.where(a_max[:-1]>a_max[1:])[0]
        if len(ir_0)==0:
            ir_0 = 0
            ir_0 = min(10,ir_0[0])
    # calculate fragmentation effects
    for ir in range(ir_0,n_r):
        ia_max = abs(a-a_max[ir]).argmin()
        if frag_mask[ir]:
            # fragmentation case
            i_bt = abs(a-a_bt[ir]).argmin()
            dist = (a/a[0])**1.5
            dist[i_bt:] = dist[i_bt]*(a[i_bt:]/a[i_bt])**0.375 # can be 1/4, or 1/2, or average 0.375, OO13 used 1/4
            dist[ia_max+1:] = floor
            sig_1[:,ir] = dist
            prev_frag = ir
        elif sig_g[ir]>1e-5:
            # outward diffusion of fragments from index prev_frag
            if prev_frag is not None:
                sig_2[:,ir] = sig_1[:,prev_frag]*sig_g[ir]/sig_g[prev_frag]*(r[ir]/r[prev_frag])**-1.5
                # limit outward diffusion by drift
                St    = a*rho_s/sig_g[ir]*pi/2.
                vd    = 1./(St+1./St)*cs[ir]**2/vk[ir]*gamma[ir]
                t_dri = r[ir]/abs(vd)
                t_dif = (r[ir]-r[prev_frag])**2/(alpha[ir]*cs[ir]**2/om[ir]/(1.+St**2))
                # the factor of 3 below changes how far out we take the
                # diffusion, but doesn't affect the results too much
                sig_2[t_dif>3*t_dri,ir] = floor
                # we also need to make sure we don't diffuse over
                # the drift limit (too much, at least)
                sig_2[a>a_max[ir],ir] = floor
    # ---------------------
    # add up all the the radial approximations from all cells crossed by the drift limit
    # ---------------------
    # find all intersected grid cells
    # we will assume that the interfaces are in the middle of the grid center
    # usually it's the other way around, but this doesn't really matter here
    ri  = 0.5*(r[1:]+r[:-1])
    ai  = 0.5*(a[1:]+a[:-1])
    ri  = np.hstack((r[0]-(ri[0]-r[0]),ri,r[-1]+(r[-1]-ri[-1])))
    ai  = np.hstack((a[0]-(ai[0]-a[0]),ai,a[-1]+(a[-1]-ai[-1])))
    f   = interp1d(np.log10(np.hstack((0.5*ri[0],r))),np.log10(np.hstack((a_max[0],a_max))),bounds_error=False,fill_value=np.log10(a_max[-1]))
    res = trace_line_though_grid(np.log10(ri),np.log10(ai),f)
    # loop through every cell
    for cell in res:
        ir,ia = cell
        # if fragmentation limited: skip this cell
        if frag_mask[ir]: continue
        # skip inner possibly problematic cells
        if ir <= ir_0: continue
        # find ra, our starting point, and dust and gas densities there
        mask  = np.arange(max(ir-2,0),min(ir+2,len(r)-1)+1)
        mask  = mask[a_max[mask].argsort()] # needs to be sorted to work for interpolation?
        ra    = 10.**np.interp(np.log10(a[ia]),np.log10(a_max[mask]),np.log10(r[mask]))
        _sigd = 10.**np.interp(np.log10(ra),np.log10(r),np.log10(sig_d+1e-100))
        _sigg = 10.**np.interp(np.log10(ra),np.log10(r),np.log10(sig_g+1e-100))
        # find previous and next fragmentation index
        prev_frag = frag_idx[frag_idx<ir]
        if len(prev_frag)==0:
            prev_frag = ir_0
            prev_frag = prev_frag[-1]
        next_frag = frag_idx[frag_idx>ir]
        if len(next_frag)==0:
            next_frag = n_r-1
            next_frag = next_frag[0]
        # get semi-analytical estimate of the dust power-law (assuming everything is a power-law)
        # v is approximated as v0 * (r/ra)**d
        # old way: d = (p-q+0.5)
        mask = np.zeros(n_r,dtype=bool)
        St     = a[ia]*rho_s*pi/(2*sig_g)
        v0     = -1./(St+1./St)*cs**2/vk*(p+(q+3.)/2.)
        d      = dlydlx(r,v0)
        d[np.isnan(d)] = (p-q+0.5)[np.isnan(d)]
        v      = v0[ir]*(r/ra)**d[ir]
        pd_est = 1./(2*alpha)*(v/cs*vk/cs-2*p*alpha+vk/cs*np.sqrt(   (v/cs)**2+4*(1+d-p)*v/vk*alpha+4*alpha*sig_d/sig_g  ))
        if fix_pd is not None: pd_est = fix_pd*np.ones(n_r)
        # now decide where to apply the solution: inward or outward
        if v0[ir]<=0:
            # inward drift
            mask[prev_frag+1:min(ir+1,n_r)] = True
            # outward drift
            mask[ir+1:next_frag+1] = True
        # mask away NaNs
        mask &= np.invert(np.isnan(pd_est))
        if len(pd_est[mask])!=0:
            inte = cumtrapz(pd_est[mask],x=np.log10(r[mask]),initial=0)
            inte /= np.abs(inte).max()
            sol = np.exp(inte)
            sol = sol/np.interp(ra,r[mask],sol)*_sigd
            sig_3[ia,mask] = np.maximum(sol,sig_3[ia,mask])
            # in case there is no place to apply a solution
            # at least fill the initial cell
            sig_3[ia,ir] = np.maximum(sig_d[ir],sig_3[ia,ir])
        if v0[ir]<=0:
            # outward diffusion
            mask = np.zeros(n_r,dtype=bool)
            mask[ir+1:next_frag+1] = True
            A = a[ia]*rho_s*pi*gamma[ir]/(2*alpha[ir]*_sigg*p[ir])
            sol2 = np.maximum(sig_3[ia,mask] , _sigd*np.exp(A*((r[mask]/ra)**p[mask]-1.)))
            # make sure this diffusion is not increasing
            mask_idx = np.where(sol2[:-1]<sol2[1:])[0]
            if len(mask_idx)==0:
                mask_idx = 0
                mask_idx = mask_idx[0]
            mask &= (np.arange(n_r)<=np.arange(n_r)[mask][mask_idx])
            sig_3[ia,mask] = sol2[:mask_idx+1]
    sig_3[np.isnan(sig_3)] = floor
    # extrapolate empty small-grain bins in the drift limit
    for ir in range(n_r):
        if a_max[ir] == a_fr[ir]: continue
        # if all bins are empty, fill in smallest sizes
        if sum(sig_3[:,ir])<=10*n_a*floor:
            mask = a<=a_0
            sig_3[mask,ir] = a[mask]**0.5
            sig_3[mask,ir] = sig_3[mask,ir]/sum(sig_3[mask,ir])*sig_d[ir]
            sig_3[np.invert(mask),ir] = floor/(1e2*n_a)
        # if there are some empty bins at the bottom, find largest empty bin
        if sig_3[0,ir]<=1e-70:
            i_full = np.where(sig_3[:,ir]>1e-70)[0] 
            if len(i_full)==0:
                i_full = i_full[0]
        # now get the slope and continue it downward
        imax = sig_3[:,ir].argmax()
        pwl  = np.log(sig_3[imax,ir]/sig_3[i_full,ir])/np.log(a[imax]/a[i_full])
        if imax == i_full:
            imax = np.where(sig_3[:,ir]>floor)[0][-1]
            pwl  = np.log(sig_3[imax,ir]/sig_3[i_full,ir])/np.log(a[imax]/a[i_full])
        sig_3[:i_full,ir] = sig_3[i_full,ir]*(a[:i_full]/a[i_full])**pwl
    # Now the problem is how to stitch the 3 distributions together,
    # particularly sig_2 and sig_3. sig_2 near the fragmentation zone
    # is pretty much normalized, but further away, the drift zone might
    # dominate. So for now, we assume sig_2 to be normalized and we
    # adapt the normalization of sig_3 to account for that. It can happen
    # that the sufrace density in sig_2 is actually larger than the total
    # dust surface density. Fort this reason, we make sure sig_3 is not
    # normalized to a negative value and we renormalize the total
    # distribution then once more.
    sig_3_sum = sig_3.sum(0)
    mask      = sig_3_sum>1e-50
    sig_3[:,mask] = sig_3[:,mask]/sig_3_sum[mask]*np.maximum(1e-100,sig_d[mask]-sig_2.sum(0)[mask])
    # add up all of them and normalize
    sig_dr = sig_1+sig_2+sig_3
    sig_dr = sig_dr/sig_dr.sum(0)*sig_d
    # as some aspects of diffusion are not included it might be useful to do a smoothing of the resulting
    # distribution. 
    sig_s = np.zeros(sig_dr.shape)
    for ir in np.arange(n_r):
        St    = a*rho_s/sig_g[ir]*pi/2.
        vd    = 1./(St+1./St)*cs[ir]**2/vk[ir]*gamma[ir]
        # radial width of the kernel: how far can particles diffuse in X orbits
        X     = 20.
        sig_r = np.sqrt(2*pi*X/om[ir] * (alpha[ir]*cs[ir]**2/om[ir]/(1.+St**2)))
        # radial width of the kernel: how far can it diffuse in X drift time scales
        #X     = 1.
        #sig_r = np.sqrt(  X*r[ir]/abs(vd) * (alpha[ir]*cs[ir]**2/om[ir]/(1.+St**2))  )
        # now define a radial stencil over which to smooth (2*sigma left and 2*sigma
        # right), but at at least 1 grid left and 1 grid right, apart from the boundaries
        dr_max = 2*sig_r.max()
        ir0 = max(0,min(ir-1,np.searchsorted(r,r[ir]-dr_max)))
        ir1 = min(n_r-1,max(ir+1,np.searchsorted(r,r[ir]+dr_max)))
        for ia in np.arange(n_a):
            # define the size range, take a factor of 2 in mass, and go
            # from half-size to double-size
            dmf = 2. # log-normal width factor (from m/dmf to m*dmf)
            ia0 = max(0,min(ia-1,np.searchsorted(a,a[ia]/2)))
            ia1 = min(n_a-1,max(ia+1,np.searchsorted(a,a[ia]*2)))
            # construct the kernel
            kernel_a = np.exp(-np.log(a[ia0:ia1+1]/a[ia])**2/(2*np.log(dmf**0.33)**2))
            kernel_r = np.exp(-(r[ir0:ir1+1]-r[ir])**2/(2*sig_r[ia]**2))
            kernel   = np.outer(kernel_a,kernel_r)
            kernel   = kernel/np.trapz(2*pi*r[ir0:ir1+1]*kernel.sum(0),x=r[ir0:ir1+1])
            # apply the smoothing
            sig_s[ia,ir] = np.trapz(2*pi*r[ir0:ir1+1]*(kernel*sig_dr[ia0:ia1+1,ir0:ir1+1]).sum(0),x=r[ir0:ir1+1])
    if len(frag_idx)==0:frag_idx=[0]    
    if return_a:
        return sig_s,a_max,r[frag_idx[-1]],sig_1,sig_2,sig_3,a_dr,a_fr,a_grow
        return sig_s,a_max,r[frag_idx[-1]],sig_1,sig_2,sig_3   
def test_reconstruction():
    import numpy as np
    from aux_functions     import trace_line_though_grid, M_sun, AU, year, pi
    from scipy.interpolate import interp1d
    # ================
    # set up the model
    # ================
    nr = 100
    na = 200
    ri = np.logspace(-1,3,nr+1)*AU
    ai = np.logspace(-4,2,na+1)
    r = 0.5*(ri[1:]+ri[:-1])
    a = 0.5*(ai[1:]+ai[:-1])
    M_star = M_sun
    M_disk = 0.05 * M_star
    rc     = 60 * AU
    eps    = 0.01
    rho_s  = 1.2
    v_f    = 1e3
    alpha  = 1e-3
    sig_g  = r**-1*np.exp(-r/rc)
    sig_g  = M_disk*sig_g/np.trapz(2*pi*r*sig_g,x=r)
    sig_d  = sig_g*eps
    T      = 200*(r/AU)**-0.5
    # call the reconstruction routine
    sig_dr,a_max,_,_,_,_ = reconstruct_size_distribution(r,a,1e6*year,
    # ========
    # ========
    _,ax = plt.subplots()
    # plot the grid
    for _ai in ai: ax.plot(ri/AU,_ai*np.ones(len(ri)),'k')
    for _ri in ri: ax.plot(_ri/AU*np.ones(len(ai)),ai,'k')
    for _a in a: ax.plot(r/AU,_a*np.ones(len(r)),'kx')
    # find all intersected grid cells
    f   = interp1d(np.log10(r),np.log10(a_max),bounds_error=False,fill_value=np.log10(a_max[-1]))
    res = trace_line_though_grid(np.log10(ri),np.log10(ai),f)
    # draw all intersected cells
    for j,i in res:
    # plot the distribution
    mx = np.ceil(np.log10(sig_dr.max()))
    # plot maximum particle size
