    def get_shapelet_params(self, image, mask, basis, beam_pix, fixed, ori, mode, beta=None, cen=None, nmax=None):
         """ This takes as input an image, its mask (false=valid), basis="cartesian"/"polar",
	     fixed=(i,j,k) where i,j,k =0/1 to calculate or take as fixed for (beta, centre, nmax),
	     beam_pix has the beam in (pix_fwhm, pix_fwhm, deg),
	     beta (the scale), cen (centre of basis expansion), nmax (max order). The output
	     is an updated set of values of (beta, centre, nmax). If fixed is 1 and the value is not
	     specified as an argument, then fixed is taken as 0."""
	 from math import sqrt, log, floor
         import functions as func
         import numpy as N

	 if fixed[0]==1 and beta==None: fixed[0]=0
	 if fixed[1]==1 and cen==None: fixed[1]=0
	 if fixed[2]==1 and nmax==None: fixed[2]=0

         if fixed[0]*fixed[1]==0:
             (m1, m2, m3)=func.moment(image, mask)

         if fixed[0]==0:
             if beta == 0.0:
                beta = 0.5
	 if fixed[1]==0:
         if fixed[2]==0:
             (n, m)=image.shape
             nmax=min(max(nmax*2+2,10),10)                      # totally ad hoc
             npix = N.product(image.shape)-N.sum(mask)
             if nmax*nmax >= n*m : nmax = int(floor(sqrt(npix-1)))  # -1 is for when n*m is a perfect square
             if mode == 'fit':   # make sure npara <= npix
               nmax_max = int(round(0.5*(-3+sqrt(1+8*npix))))
               nmax=min(nmax, nmax_max)

         betarange=[0.5,sqrt(beta*max(n,m))]  # min, max

	 if fixed[1]==0:
             cen=shape_findcen(image, mask, basis, beta, nmax, beam_pix) # + check_cen_shapelet
         #print 'First Centre = ',cen,N.array(cen)+ori

         from time import time
         t1 = time()
         if fixed[0]==0:
             beta, err=shape_varybeta(image, mask, basis, beta, cen, nmax, betarange, plot=False)
         t2 = time()
         #print 'TIME ',t2-t1, '\n'
         #print 'Final Beta = ',beta, err

	 if fixed[1]==0 and fixed[0]==0:
             cen=shape_findcen(image, mask, basis, beta, nmax, beam_pix) # + check_cen_shapelet
         #print 'Final Cen = ',N.array(cen)+ori

         return beta, cen, nmax
    def __call__(self, img):

      if img.opts.psf_vary_do:
        mylog = mylogger.logging.getLogger("PyBDSM."+img.log+"Psf_Vary")
        mylogger.userinfo(mylog, '\nEstimating PSF variations')
        opts = img.opts
        dir = img.basedir + '/misc/'
        plot = False # debug figures
        image = img.ch0_arr

            from astropy.io import fits as pyfits
            old_pyfits = False
        except ImportError, err:
            from distutils.version import StrictVersion
            import pyfits
            if StrictVersion(pyfits.__version__) < StrictVersion('2.2'):
                old_pyfits = True
                old_pyfits = False

        if old_pyfits:
            mylog.warning('PyFITS version is too old: psf_vary module skipped')

        over = 2
        generators = opts.psf_generators; nsig = opts.psf_nsig; kappa2 = opts.psf_kappa2
        snrtop = opts.psf_snrtop; snrbot = opts.psf_snrbot; snrcutstack = opts.psf_snrcutstack
        gencode = opts.psf_gencode; primarygen = opts.psf_primarygen; itess_method = opts.psf_itess_method
        tess_sc = opts.psf_tess_sc; tess_fuzzy= opts.psf_tess_fuzzy
        bright_snr_cut = opts.psf_high_snr
        s_only = opts.psf_stype_only
        if opts.psf_snrcut < 5.0:
            mylogger.userinfo(mylog, "Value of psf_snrcut too low; increasing to 5")
            snrcut = 5.0
            snrcut = opts.psf_snrcut
        img.psf_snrcut = snrcut
        if opts.psf_high_snr != None:
            if opts.psf_high_snr < 10.0:
                mylogger.userinfo(mylog, "Value of psf_high_snr too low; increasing to 10")
                high_snrcut = 10.0
                high_snrcut = opts.psf_high_snr
            high_snrcut = opts.psf_high_snr
        img.psf_high_snr = high_snrcut

        wtfns=['unity', 'roundness', 'log10', 'sqrtlog10']
        if 0 <= itess_method < 4: tess_method=wtfns[itess_method]
        else: tess_method='unity'

        ### now put all relevant gaussian parameters into a list
        ngaus = img.ngaus
        nsrc = img.nsrc
        num = N.zeros(nsrc, dtype=N.int32)
        peak = N.zeros(nsrc)
        xc = N.zeros(nsrc)
        yc = N.zeros(nsrc)
        bmaj = N.zeros(nsrc)
        bmin = N.zeros(nsrc)
        bpa = N.zeros(nsrc)
        code = N.array(['']*nsrc);
        rms = N.zeros(nsrc)
        src_id_list = []
        for i, src in enumerate(img.sources):
            src_max = 0.0
            for gmax in src.gaussians:
                # Take only brightest Gaussian per source
                if gmax.peak_flux > src_max:
                    src_max = gmax.peak_flux
                    g = gmax
            num[i] = i
            peak[i] = g.peak_flux
            xc[i] = g.centre_pix[0]
            yc[i] = g.centre_pix[1]
            bmaj[i] = g.size_pix[0]
            bmin[i] = g.size_pix[1]
            bpa[i] = g.size_pix[2]
            code[i] = img.sources[g.source_id].code
            rms[i] = img.islands[g.island_id].rms
        gauls = (num, peak, xc, yc, bmaj, bmin, bpa, code, rms)
        tr_gauls = self.trans_gaul(gauls)

        # takes gaussians with code=S and snr > snrcut.
        if s_only:
            tr = [n for n in tr_gauls if n[1]/n[8]>snrcut and n[7] == 'S']
            tr = [n for n in tr_gauls if n[1]/n[8]>snrcut]
        g_gauls = self.trans_gaul(tr)

        # computes statistics of fitted sizes. Same as psfvary_fullstat.f in fBDSM.
        bmaj_a, bmaj_r, bmaj_ca, bmaj_cr, ni = _cbdsm.bstat(bmaj, None, nsig)
        bmin_a, bmin_r, bmin_ca, bmin_cr, ni = _cbdsm.bstat(bmin, None, nsig)
        bpa_a, bpa_r, bpa_ca, bpa_cr, ni = _cbdsm.bstat(bpa, None, nsig)

        # get subset of sources deemed to be unresolved. Same as size_ksclip_wenss.f in fBDSM.
        flag_unresolved = self.get_unresolved(g_gauls, img.beam, nsig, kappa2, over, img.psf_high_snr, plot)

        # see how much the SNR-weighted sizes of unresolved sources differ from the synthesized beam.
        wtsize_beam_snr = self.av_psf(g_gauls, img.beam, flag_unresolved)

        # filter out resolved sources
        tr_gaul = self.trans_gaul(g_gauls)
        tr = [n for i, n in enumerate(tr_gaul) if flag_unresolved[i]]
        g_gauls = self.trans_gaul(tr)
        mylogger.userinfo(mylog, 'Number of unresolved sources', str(len(g_gauls[0])))

        # get a list of voronoi generators. vorogenS has values (and not None) if generators='field'.
        vorogenP, vorogenS = self.get_voronoi_generators(g_gauls, generators, gencode, snrcut, snrtop, snrbot, snrcutstack)
        mylogger.userinfo(mylog, 'Number of generators for PSF variation', str(len(vorogenP[0])))
        if len(vorogenP[0]) < 3:
            mylog.warning('Insufficient number of generators')

        mylogger.userinfo(mylog, 'Tesselating image')
        # group generators into tiles
        tile_prop = self.edit_vorogenlist(vorogenP, frac=0.9)

        # tesselate the image
        volrank, vorowts = self.tesselate(vorogenP, vorogenS, tile_prop, tess_method, tess_sc, tess_fuzzy, \
                  generators, gencode, image.shape)
        if opts.output_all:
            func.write_image_to_file(img.use_io, img.imagename + '.volrank.fits', volrank, img, dir)

        tile_list, tile_coord, tile_snr = tile_prop
        ntile = len(tile_list)
        bar = statusbar.StatusBar('Determining PSF variation ............... : ', 0, ntile)
        mylogger.userinfo(mylog, 'Number of tiles for PSF variation', str(ntile))

        # For each tile, calculate the weighted averaged psf image. Also for all the sources in the image.
        cdelt = list(img.wcs_obj.acdelt[0:2])
        psfimages, psfcoords, totpsfimage, psfratio, psfratio_aper = self.psf_in_tile(image, img.beam, g_gauls, \
                   cdelt, factor, snrcutstack, volrank, tile_prop, plot, img)
        npsf = len(psfimages)

        if opts.psf_use_shap:
            # use totpsfimage to get beta, centre and nmax for shapelet decomposition. Use nmax=5 or 6
            mask=N.zeros(totpsfimage.shape, dtype=bool)
            (m1, m2, m3)=func.moment(totpsfimage, mask)
            betainit=sqrt(m3[0]*m3[1])*2.0  * 1.4
            tshape = totpsfimage.shape
            cen = N.array(N.unravel_index(N.argmax(totpsfimage), tshape))+[1,1]
            cen = tuple(cen)
            nmax = 12
            basis = 'cartesian'
            betarange = [0.5,sqrt(betainit*max(tshape))]
            beta, error  = sh.shape_varybeta(totpsfimage, mask, basis, betainit, cen, nmax, betarange, plot)
            if error == 1: print '  Unable to find minimum in beta'

            # decompose all the psf images using the beta from above
            nmax=12; psf_cf=[]
            for i in range(npsf):
                psfim = psfimages[i]
                cf = sh.decompose_shapelets(psfim, mask, basis, beta, cen, nmax, mode='')
                if img.opts.quiet == False:

            # transpose the psf image list
            xt, yt = N.transpose(tile_coord)
            tr_psf_cf = N.transpose(N.array(psf_cf))

            # interpolate the coefficients across the image. Ok, interpolate in scipy for
            # irregular grids is crap. doesnt even pass through some of the points.
            # for now, fit polynomial.
            compress = 100.0
            x, y = N.transpose(psfcoords)
            if len(x) < 3:
                mylog.warning('Insufficient number of tiles to do interpolation of PSF variation')

            psf_coeff_interp, xgrid, ygrid = self.interp_shapcoefs(nmax, tr_psf_cf, psfcoords, image.shape, \
                     compress, plot)

            psfshape = psfimages[0].shape
            skip = 5
            aa = self.create_psf_grid(psf_coeff_interp, image.shape, xgrid, ygrid, skip, nmax, psfshape, \
                 basis, beta, cen, totpsfimage, plot)
            img.psf_images = aa
            if ntile < 4:
                mylog.warning('Insufficient number of tiles to do interpolation of PSF variation')
                # Fit stacked PSFs with Gaussians and measure aperture fluxes
                bm_pix = N.array([img.pixel_beam()[0]*fwsig, img.pixel_beam()[1]*fwsig, img.pixel_beam()[2]])
                psf_maj = N.zeros(npsf)
                psf_min = N.zeros(npsf)
                psf_pa = N.zeros(npsf)
                if img.opts.quiet == False:
                for i in range(ntile):
                    psfim = psfimages[i]
                    mask = N.zeros(psfim.shape, dtype=bool)
                    x_ax, y_ax = N.indices(psfim.shape)
                    maxv = N.max(psfim)
                    p_ini = [maxv, (psfim.shape[0]-1)/2.0*1.1, (psfim.shape[1]-1)/2.0*1.1, bm_pix[0]/fwsig*1.3,
                             bm_pix[1]/fwsig*1.1, bm_pix[2]*2]
                    para, ierr = func.fit_gaus2d(psfim, p_ini, x_ax, y_ax, mask)
                    ### first extent is major
                    if para[3] < para[4]:
                        para[3:5] = para[4:2:-1]
                        para[5] += 90
                    ### clip position angle
                    para[5] = divmod(para[5], 180)[1]

                    psf_maj[i] = para[3]
                    psf_min[i] = para[4]
                    posang = para[5]
                    while posang >= 180.0:
                        posang -= 180.0
                    psf_pa[i] = posang

                    if img.opts.quiet == False:

                # Interpolate Gaussian parameters
                if img.aperture == None:
                    psf_maps = [psf_maj, psf_min, psf_pa, psfratio]
                    psf_maps = [psf_maj, psf_min, psf_pa, psfratio, psfratio_aper]
                nimgs = len(psf_maps)
                bar = statusbar.StatusBar('Interpolating PSF images ................ : ', 0, nimgs)
                if img.opts.quiet == False:
                map_list = mp.parallel_map(func.eval_func_tuple,
                    psf_maps, itertools.repeat(psfcoords),
                    itertools.repeat(image.shape)), numcores=opts.ncores,
                if img.aperture == None:
                    psf_maj_int, psf_min_int, psf_pa_int, psf_ratio_int = map_list
                    psf_maj_int, psf_min_int, psf_pa_int, psf_ratio_int, psf_ratio_aper_int = map_list

                # Smooth if desired
                if img.opts.psf_smooth != None:
                    sm_scale = img.opts.psf_smooth / img.pix2beam([1.0, 1.0, 0.0])[0] / 3600.0 # pixels
                    if img.opts.aperture == None:
                        psf_maps = [psf_maj_int, psf_min_int, psf_pa_int, psf_ratio_int]
                        psf_maps = [psf_maj_int, psf_min_int, psf_pa_int, psf_ratio_int, psf_ratio_aper_int]
                    nimgs = len(psf_maps)
                    bar = statusbar.StatusBar('Smoothing PSF images .................... : ', 0, nimgs)
                    if img.opts.quiet == False:
                    map_list = mp.parallel_map(func.eval_func_tuple,
                        psf_maps, itertools.repeat(sm_scale)), numcores=opts.ncores,
                    if img.aperture == None:
                        psf_maj_int, psf_min_int, psf_pa_int, psf_ratio_int = map_list
                        psf_maj_int, psf_min_int, psf_pa_int, psf_ratio_int, psf_ratio_aper_int = map_list

                # Make sure all smoothed, interpolated images are ndarrays
                psf_maj_int = N.array(psf_maj_int)
                psf_min_int = N.array(psf_min_int)
                psf_pa_int = N.array(psf_pa_int)
                psf_ratio_int = N.array(psf_ratio_int)
                if img.aperture == None:
                    psf_ratio_aper_int = N.zeros(psf_maj_int.shape, dtype=N.float32)
                    psf_ratio_aper_int = N.array(psf_ratio_aper_int, dtype=N.float32)

                # Blank with NaNs if needed
                mask = img.mask_arr
                if isinstance(mask, N.ndarray):
                    pix_masked = N.where(mask == True)
                    psf_maj_int[pix_masked] = N.nan
                    psf_min_int[pix_masked] = N.nan
                    psf_pa_int[pix_masked] = N.nan
                    psf_ratio_int[pix_masked] = N.nan
                    psf_ratio_aper_int[pix_masked] = N.nan

                # Store interpolated images. The major and minor axis images are
                # the sigma in units of arcsec, the PA image in units of degrees east of
                # north, the ratio images in units of 1/beam.
                img.psf_vary_maj_arr = psf_maj_int * img.pix2beam([1.0, 1.0, 0.0])[0] * 3600.0 # sigma in arcsec
                img.psf_vary_min_arr = psf_min_int * img.pix2beam([1.0, 1.0, 0.0])[0] * 3600.0 # sigma in arcsec
                img.psf_vary_pa_arr = psf_pa_int
                img.psf_vary_ratio_arr = psf_ratio_int # in 1/beam
                img.psf_vary_ratio_aper_arr = psf_ratio_aper_int # in 1/beam

                if opts.output_all:
                    func.write_image_to_file(img.use_io, img.imagename + '.psf_vary_maj.fits', img.psf_vary_maj_arr*fwsig, img, dir)
                    func.write_image_to_file(img.use_io, img.imagename + '.psf_vary_min.fits', img.psf_vary_min_arr*fwsig, img, dir)
                    func.write_image_to_file(img.use_io, img.imagename + '.psf_vary_pa.fits', img.psf_vary_pa_arr, img, dir)
                    func.write_image_to_file(img.use_io, img.imagename + '.psf_vary_ratio.fits', img.psf_vary_ratio_arr, img, dir)
                    func.write_image_to_file(img.use_io, img.imagename + '.psf_vary_ratio_aper.fits', img.psf_vary_ratio_aper_arr, img, dir)

                # Loop through source and Gaussian lists and deconvolve the sizes using appropriate beam
                bar2 = statusbar.StatusBar('Correcting deconvolved source sizes ..... : ', 0, img.nsrc)
                if img.opts.quiet == False:
                for src in img.sources:
                    src_pos = img.sky2pix(src.posn_sky_centroid)
                    src_pos_int = (int(src_pos[0]), int(src_pos[1]))
                    gaus_c = img.gaus2pix(src.size_sky, src.posn_sky_centroid)
                    gaus_bm = [psf_maj_int[src_pos_int]*fwsig, psf_min_int[src_pos_int]*fwsig, psf_pa_int[src_pos_int]*fwsig]
                    gaus_dc, err = func.deconv2(gaus_bm, gaus_c)
                    src.deconv_size_sky = img.pix2gaus(gaus_dc, src_pos)
                    src.deconv_size_skyE = [0.0, 0.0, 0.0]
                    for g in src.gaussians:
                        gaus_c = img.gaus2pix(g.size_sky, src.posn_sky_centroid)
                        gaus_dc, err = func.deconv2(gaus_bm, gaus_c)
                        g.deconv_size_sky = img.pix2gaus(gaus_dc, g.centre_pix)
                        g.deconv_size_skyE = [0.0, 0.0, 0.0]
                        if img.opts.quiet == False:
                    if img.opts.quiet == False:
def shape_findcen(image, mask, basis, beta, nmax, beam_pix):  # + check_cen_shapelet
    """ Finds the optimal centre for shapelet decomposition. Minimising various
    combinations of c12 and c21, as in literature doesnt work for all cases.
    Hence, for the c1 image, we find the zero crossing for every vertical line
    and for the c2 image, the zero crossing for every horizontal line, and then
    we find intersection point of these two. This seems to work even for highly
    non-gaussian cases. """
    import functions as func
    import sys

    hc = []
    hc = shapelet_coeff(nmax, basis)

    msk = N.zeros(mask.shape, dtype=bool)
    for i, v in N.ndenumerate(mask):
        msk[i] = not v

    n, m = image.shape
    cf12 = N.zeros(image.shape, dtype=N.float32)
    cf21 = N.zeros(image.shape, dtype=N.float32)
    index = [(i, j) for i in range(n) for j in range(m)]
    for coord in index:
        if msk[coord]:
            B12 = shapelet_image(basis, beta, coord, hc, 0, 1, image.shape)
            cf12[coord] = N.sum(image * B12 * msk)

            if coord == (27, 51):
                dumpy = B12

            B21 = shapelet_image(basis, beta, coord, hc, 1, 0, image.shape)
            cf21[coord] = N.sum(image * B21 * msk)
            cf12[coord] = None
            cf21[coord] = None

    (xmax, ymax) = N.unravel_index(image.argmax(), image.shape)  #  FIX  with mask
    if xmax in [1, n] or ymax in [1, m]:
        (m1, m2, m3) = func.moment(mask)
        xmax, ymax = N.round(m2)

    # in high snr area, get zero crossings for each horizontal and vertical line for c1, c2 resp
    tr_mask = mask.transpose()
    tr_cf21 = cf21.transpose()
        (x1, y1) = getzeroes_matrix(mask, cf12, ymax, xmax)  # y1 is array of zero crossings
        (y2, x2) = getzeroes_matrix(tr_mask, tr_cf21, xmax, ymax)  # x2 is array of zero crossings

        # find nominal intersection pt as integers
        xind = N.where(x1 == xmax)
        yind = N.where(y2 == ymax)
        xind = xind[0][0]
        yind = yind[0][0]

        # now take 2 before and 2 after, fit straight lines, get proper intersection
        ninter = 5
        if xind < 3 or yind < 3 or xind > n - 2 or yind > m - 2:
            ninter = 3
        xft1 = x1[xind - (ninter - 1) / 2 : xind + (ninter - 1) / 2 + 1]
        yft1 = y1[xind - (ninter - 1) / 2 : xind + (ninter - 1) / 2 + 1]
        xft2 = x2[yind - (ninter - 1) / 2 : yind + (ninter - 1) / 2 + 1]
        yft2 = y2[yind - (ninter - 1) / 2 : yind + (ninter - 1) / 2 + 1]
        sig = N.ones(ninter, dtype=float)
        smask1 = N.array([r == 0 for r in yft1])
        smask2 = N.array([r == 0 for r in xft2])
        cen = [0.0] * 2
        if sum(smask1) < len(yft1) and sum(smask2) < len(xft2):
            [c1, m1], errors = func.fit_mask_1d(xft1, yft1, sig, smask1, func.poly, do_err=False, order=1)
            [c2, m2], errors = func.fit_mask_1d(xft2, yft2, sig, smask2, func.poly, do_err=False, order=1)
            if m2 - m1 == 0:
                cen[0] = cen[1] = 0.0
                cen[0] = (c1 - c2) / (m2 - m1)
                cen[1] = c1 + m1 * cen[0]
            cen[0] = cen[1] = 0.0

        # check if estimated centre makes sense
        error = shapelet_check_centre(image, mask, cen, beam_pix)
        error = 1
    if error > 0:
        # print 'Error '+str(error)+' in finding centre, will take 1st moment instead.'
        (m1, m2, m3) = func.moment(image, mask)
        cen = m2

    return cen
