from funkyyak import grad, numpy_wrapper as np
from redshift_utils import load_data_clean_split, project_to_bands
from slicesample import slicesample
import matplotlib.pyplot as plt
import seaborn as sns
import sys, os
sns.set_style("white")
current_palette = sns.color_palette()
npr.seed(42)

## save figure output files
out_dir = "/Users/acm/Dropbox/Proj/astro/DESIMCMC/tex/quasar_z/figs/"

## load a handful of quasar spectra
lam_obs, qtrain, qtest = \
    load_data_clean_split(spec_fits_file = 'quasar_data.fits', Ntrain = 400)

## load in basis
th = np.load("cache/basis_th.npy")
lls = np.load("cache/lls.npy")
lam0 = np.load("cache/lam0.npy")
N = th.shape[1] - lam0.shape[0]
omegas = th[:, :N]
betas = th[:, N:]
W = np.exp(omegas)
B = np.exp(betas)
B = B / B.sum(axis=1, keepdims=True)

## compute all marginal expected z's and compare
z_pred = np.zeros(qtest['Z'].shape)
z_lo = np.zeros(qtest['Z'].shape)
    chain_idx    = int(sys.argv[1]) if len(sys.argv) > 1 else 0
    Nsamps       = int(sys.argv[2]) if len(sys.argv) > 2 else 100
    length_scale = float(sys.argv[3]) if len(sys.argv) > 3 else 40.
    init_iter    = int(sys.argv[4]) if len(sys.argv) > 4 else 100
    K            = 4
    print "==== SAMPLING CHAIN ID = %d ============== "%chain_idx
    print "    Nsamps          = %d "%Nsamps
    print "    length_scale    = %2.2f"%length_scale
    print "    num init_iters  = %d   "%init_iter
    print "    K               = %d   "%K

    ##################################################################
    ## load a handful of quasar spectra and resample
    ##################################################################
    lam_obs, qtrain, qtest = \
        load_data_clean_split(spec_fits_file = 'quasar_data.fits',
                              Ntrain = 400)
    N = qtrain['spectra'].shape[0]

    ## resample to lam0 => rest frame basis 
    lam0, lam0_delta = get_lam0(lam_subsample=10)
    print "    resampling de-redshifted data"
    spectra_resampled, spectra_ivar_resampled, lam_mat = \
        resample_rest_frame(qtrain['spectra'], 
                            qtrain['spectra_ivar'],
                            qtrain['Z'], 
                            lam_obs, 
                            lam0)
    # clean nans
    X                  = spectra_resampled
    X[np.isnan(X)]     = 0
    Lam                = spectra_ivar_resampled