import numpy as np
import scattering
import dsd
from disser.units import to_dB, to_dBz, angle, exp_to_dB
from disser.scatter import bulk_scatter
import quantities as pq

d = (np.linspace(0.01, 8, 100).reshape(-1, 1, 1).astype(np.float32) 
        * pq.mm)
qr = (np.linspace(0, 15.0, 40).reshape(1, -1, 1).astype(np.float32)
        * pq.g / pq.m**3)
n = (np.logspace(-2, 6, 110).reshape(1, 1,-1).astype(np.float32)
        / pq.m**3)
dist = dsd.constrained_gamma_from_moments(n, qr, d)

qr_calc = dsd.lwc(d, dist).rescale('g/m**3')
print np.abs(qr - qr_calc).max()
rr = dsd.rainrate(d, dist, dsd.rain_fallspeed(d)).rescale('mm/hr')
print rr.min(), rr.max()

wavelength = 0.053 * pq.m
temp_colors = {0:'b', 10:'g', 20:'y', 30:'r'}
temps = np.array(temp_colors.keys())

d_plot = d.squeeze()

z = np.empty((temps.size,) + dist.shape[1:], dtype=np.float32)
zdr = np.empty_like(z)
atten = np.empty_like(z)
diff_atten = np.empty_like(z)
kdp = np.empty_like(z)
import numpy as np
import quantities as pq
from disser import io
import dsd

num = 100
ind = 1453433

data = io.ModelData('/home/rmay/radar_sim_git/data/commas_wz_3600.nc')
mask = (data.qr > 1e-6) & (data.nr > 1e-3)
qr = data.qr[mask]
nr = data.nr[mask]
d = np.linspace(0.01 * pq.mm, 20 * pq.mm, 300).reshape(-1,1)
dist = dsd.constrained_gamma_from_moments(nr[ind:ind+num], qr[ind:ind+num], d)

test_qr = dsd.lwc(d, dist).simplified
test_nr = np.trapz(dist, axis=0, x=d).simplified
print qr[ind:ind+num], test_qr, np.allclose(qr[ind:ind+num], test_qr)
print nr[ind:ind+num], test_nr, np.allclose(nr[ind:ind+num], test_nr)
d = np.linspace(0.01, 8, 100) * pq.mm
# d0_lut = (np.linspace(0.01, .5, 40).reshape(1, -1, 1).astype(np.float32)
#        * consts.centi)
# nr_lut = np.logspace(-2, 6, 110).reshape(1, 1,-1).astype(np.float32)

trials = [
    (5 / pq.mm, 1e4 / pq.m ** 3, "k"),
    (10 / pq.mm, 1e4 / pq.m ** 3, "r"),
    (15 / pq.mm, 1e4 / pq.m ** 3, "b"),
    (20 / pq.mm, 1e4 / pq.m ** 3, "g"),
    (5 / pq.mm, 1e5 / pq.m ** 3, "k--"),
    (10 / pq.mm, 1e5 / pq.m ** 3, "r--"),
    (15 / pq.mm, 1e5 / pq.m ** 3, "b--"),
    (20 / pq.mm, 1e5 / pq.m ** 3, "g--"),
]
for lam, Nr, plot in trials:
    # dist = dsd.gamma(d, d0_lut, nr_lut, nu=-0.8)
    nu = dsd.constrained_gamma_shape(lam)
    dist = dsd.modified_gamma(d, lam, Nr, nu)
    plt.semilogy(d, dist, plot, label=r"$\Lambda:%.1f\, N_r:%.0e\, \nu:%.1f$" % (lam, Nr, nu))
    print dsd.lwc(d, dist).simplified, dsd.constrained_gamma_shape(lam)

plt.legend(loc="lower right")
plt.title("DSD Comparison")
plt.xlabel("Diameter (%s)" % d.dimensionality.latex)
plt.ylabel("Number (%s)" % dist.dimensionality.latex)
plt.ylim(1e1, None)
plt.grid()
plt.show()