import sys, os, StringIO, ephem
import numpy as np
from fastcluster import *
from find_clusters import degToRad, Vincenty, getLevel, arcminToRad
from scipy.cluster.hierarchy import to_tree
from math import pi

def get_srcs(infns):
    """Gets all the sources from a list of files, returns an enormous array"""
    data = [ np.reshape(np.loadtxt(fn),(-1,2)) 
             for fn in infns if os.path.isfile(fn) 
                             and os.path.getsize(fn) > 0 ]
    if data == [] : return []
    else : return np.concatenate(data)

def srcDist(src1, src2):
    """Gets the distance between two srcs"""
    return Vincenty((degToRad(src1[0]),degToRad(src1[1])),
                    (degToRad(src2[0]),degToRad(src2[1])))

def arcsecToRad(arcsec): return pi/648000 * arcsec

if __name__ == "__main__":
	data = get_srcs(sys.argv[2:])
	if len(data) <= 1 : srcs = data
	elif len(data) > 1 :
    		tree = to_tree(linkage(data, metric=srcDist, method='complete'))
		idxs = map(lambda i : i[0], getLevel(tree,arcsecToRad(float(sys.argv[1]))))
		srcs = data[idxs]
	for (RA,DEC) in srcs : print RA,"\t",DEC
    fits = map(lambda x: pyfits.open(x), infns)
    data = np.concatenate(map(lambda f: np.array(f[1].data), fits))
    map(lambda f: f.close(), fits)
    return data

def srcDist(src1, src2):
    """Gets the distance between two srcs"""
    return Vincenty((degToRad(src1[0]),degToRad(src1[1])),
                    (degToRad(src2[0]),degToRad(src2[1])))

def arcsecToRad(arcsec): return pi/648000 * arcsec

if __name__ == "__main__":
    srcs = get_srcs(sys.argv[3:])
    positions = np.array(map(lambda x: (x[0],x[1]), srcs))
    tree = to_tree(linkage(positions, metric=srcDist, method='complete'))
    clustered_srcs = map(lambda x: np.array(map(lambda y: srcs[y], x),dtype=srcs.dtype), getLevel(tree,arcsecToRad(float(sys.argv[1]))))
    if not os.path.exists(sys.argv[2]):
       os.makedirs(sys.argv[2])
    for cls in clustered_srcs:
      g = ephem.Galactic(ephem.Equatorial(cls[0][0]/180 * pi, cls[0][1]/180 * pi))
      lat,lon = StringIO.StringIO(),StringIO.StringIO()
      print >>lat, g.lat ; print >>lon, g.lon
      outfn = '%s/%s-%s-R%g_srcs.fits' % (sys.argv[2],
		      lon.getvalue().strip().replace(":","_"), 
		      lat.getvalue().strip().replace(":","_"),
		      float(sys.argv[1]))
      update_srcs(sys.argv[3],srcs=cls,outfn=outfn) 
      print "Wrote:", outfn 
    #print get_srcs(sys.argv[2:])