from lande.utilities import pubplot
from lande.fermi.pipeline.pwncat2.interp.bigfile import PulsarCatalogLoader

pubplot.set_latex_defaults()

bw = pubplot.get_bw()


fig = P.figure(None,(4,4))
axes = fig.add_subplot(111)

axes.set_xscale("log")
axes.set_yscale("log")

cat=PulsarCatalogLoader(
    bigfile_filename='$lat2pc/BigFile/Pulsars_BigFile_v20121002103223.fits',
    off_peak_auxiliary_filename='$lat2pc/OffPeak/tables/off_peak_auxiliary_table.fits')
bigfile_fits = cat.bigfile_fits


psrlist = cat.get_off_peak_psrlist()


Edot=np.empty_like(psrlist,dtype=float)
age=np.empty_like(psrlist,dtype=float)
classification=np.empty_like(psrlist,dtype=object)
has_distance=np.empty_like(psrlist,dtype=bool)
distance=np.empty_like(psrlist,dtype=float)


for i,psr in enumerate(psrlist):
import pylab as P

import numpy as np

from lande.utilities.plotting import plot_points
from lande.utilities import pubplot
from lande.fermi.pipeline.pwncat2.interp.bigfile import PulsarCatalogLoader

pubplot.set_latex_defaults()

bw = pubplot.get_bw()

cat=PulsarCatalogLoader(
    bigfile_filename='$lat2pc/BigFile/Pulsars_BigFile_v20130214170325.fits',
    off_peak_auxiliary_filename='$lat2pc/OffPeak/auxiliary/off_peak_auxiliary_table.fits')

psrlist = cat.get_off_peak_psrlist()


fig = P.figure(None,(6,6))
axes = fig.add_subplot(111)

axes.set_xscale("log")
axes.set_yscale("log")


classification=np.empty_like(psrlist,dtype=object)
Edot=np.empty_like(psrlist,dtype=float)
luminosity=np.empty_like(psrlist,dtype=float)
luminosity_error_statistical=np.empty_like(psrlist,dtype=float)
luminosity_lower_error_systematic=np.empty_like(psrlist,dtype=float)
def spatial_spectral_table(pwndata, 
                           phase_shift, 
                           fitdir, savedir, pwn_classification, filebase, table_type,
                           bigfile_filename):
    assert table_type == 'latex'

    format=PWNFormatter(table_type=table_type, precision=2)

    loader = PWNResultsLoader(
        pwndata=pwndata,
        fitdir=fitdir,
        phase_shift=phase_shift
        )

    classifier = PWNManualClassifier(loader=loader, pwn_classification=pwn_classification)

    table = OrderedDefaultDict(list)

    psr_name='PSR'
    classification_name = 'Type'
    ts_point_name=r'$\tspoint$'
    ts_ext_name=r'$\tsext$'
    ts_cutoff_name = r'$\tscutoff$'
    ts_altdiff_name = r'$\tsaltdiff$'
    eflux_name = r'Energy Flux'
    index_name = r'$\Gamma$'
    cutoff_name = r'$\Ecutoff$'

    pwnlist = loader.get_pwnlist()
    #pwnlist = pwnlist[10:20]

    pcl = PulsarCatalogLoader(bigfile_filename=bigfile_filename)

    young = [ psr for psr in pwnlist if 'm' not in pcl.get_pulsar_classification(psr.replace('PSRJ','J'))]
    msps = [ psr for psr in pwnlist if 'm' in pcl.get_pulsar_classification(psr.replace('PSRJ','J'))]
    print 'young',young
    print 'msps',msps

    sorted_pwnlist = young + msps

    first_msp_index = None

    for pwn in sorted_pwnlist:
        print pwn

        try:
            r = classifier.get_results(pwn)

            if r['source_class'] == 'Upper_Limit': 
                continue

            if first_msp_index is None and pwn in msps:
                first_msp_index = len(table[psr_name])

            table[psr_name].append(format.pwn(pwn))
            table[classification_name].append(r['abbreviated_source_class'])

            def david_format_ts(x):
                if x >= 100:
                    return format.value(x,precision=0) + '.'
                else:
                    return format.value(x,precision=1)

            table[ts_point_name].append(david_format_ts(r['ts_point']))
            table[ts_ext_name].append(david_format_ts(r['ts_ext']))
            table[ts_cutoff_name].append(david_format_ts(r['ts_cutoff']))
            table[ts_altdiff_name].append(david_format_ts(r['ts_altdiff']) if r['ts_altdiff'] is not None else format.nodata)

            def david_format_flux(x, y):
                if x >= 10:
                    return format.error(x,y, precision=1)
                else:
                    return  format.error(x,y, precision=2)

            table[eflux_name].append(david_format_flux(r['energy_flux']/1e-11,r['energy_flux_err']/1e-11))
            if r['spectral_model'] in ['PowerLaw','PLSuperExpCutoff']:
                table[index_name].append(format.error(r['index'],r['index_err']))
            elif pwn == 'PSRJ0534+2200':
                table[index_name].append(r'\tablenotemark{a}')
            elif pwn == 'PSRJ0835-4510':
                table[index_name].append(r'\tablenotemark{b}')
            else:
                table[index_name].append(format.nodata)

            if r['spectral_model'] == 'PLSuperExpCutoff':
                table[cutoff_name].append(format.error(r['cutoff']/1e3,r['cutoff_err']/1e3, precision=2))
            else:
                table[cutoff_name].append(format.nodata)

        except PWNClassifierException, ex:
            print 'Skipping %s: %s' % (pwn,ex)
            table[psr_name].append(format.pwn(pwn))
            table[classification_name].append('None')
            table[ts_point_name].append('None')
            table[ts_ext_name].append('None')
            table[ts_cutoff_name].append('None')
            table[eflux_name].append('None')
            table[index_name].append('None')
            table[cutoff_name].append('None')