예제 #1
0
def on_press(event):
    """Handle keypresses"""
    from interface import raw_input_no_history
    import numpy

    global img_ch0, img_rms, img_mean, img_gaus_mod, img_shap_mod
    global pixels_per_beam, vmin, vmax, vmin_cur, vmax_cur, img_pi
    global ch0min, ch0max, low, fig, images, src_list, srcid_cur
    global markers
    if event.key == '0':
        print 'Resetting limits to defaults (%.4f -- %.4f Jy/beam)' \
            % (pow(10, vmin)-low,
               pow(10, vmax)-low)
        axes_list = fig.get_axes()
        for axindx, ax in enumerate(axes_list):
            if images[axindx] != 'wavelets' and images[axindx] != 'seds':
                im = ax.get_images()[0]
                im.set_clim(vmin, vmax)
        vmin_cur = vmin
        vmax_cur = vmax
        pl.draw()
    if event.key == 'm':
        # Modify scaling
        # First check that there are images to modify
        has_image = False
        for im in images:
            if isinstance(im, numpy.ndarray):
                has_image = True
        if not has_image:
            return
        minscl = 'a'
        while isinstance(minscl, str):
            try:
                if minscl == '':
                    minscl = pow(10, vmin_cur) - low
                    break
                minscl = float(minscl)
            except ValueError:
                prompt = "Enter min value (current = %.4f Jy/beam) : " % (
                    pow(10, vmin_cur) - low, )
                try:
                    minscl = raw_input_no_history(prompt)
                except RuntimeError:
                    print 'Sorry, unable to change scaling.'
                    return
        minscl = N.log10(minscl + low)
        maxscl = 'a'
        while isinstance(maxscl, str):
            try:
                if maxscl == '':
                    maxscl = pow(10, vmax_cur) - low
                    break
                maxscl = float(maxscl)
            except ValueError:
                prompt = "Enter max value (current = %.4f Jy/beam) : " % (
                    pow(10, vmax_cur) - low, )
                try:
                    maxscl = raw_input_no_history(prompt)
                except RuntimeError:
                    print 'Sorry, unable to change scaling.'
                    return
        maxscl = N.log10(maxscl + low)
        if maxscl <= minscl:
            print 'Max value must be greater than min value!'
            return
        axes_list = fig.get_axes()
        for axindx, ax in enumerate(axes_list):
            if images[axindx] != 'wavelets' and images[axindx] != 'seds':
                im = ax.get_images()[0]
                im.set_clim(minscl, maxscl)
        vmin_cur = minscl
        vmax_cur = maxscl
        pl.draw()
    if event.key == 'c':
        # Change source SED
        # First check that SEDs are being plotted
        has_sed = False
        if 'seds' in images:
            has_sed = True
        if not has_sed:
            return
        srcid = 'a'
        while isinstance(srcid, str):
            try:
                if srcid == '':
                    srcid = srcid_cur
                    break
                srcid = int(srcid)
            except ValueError:
                prompt = "Enter source ID (current = %i) : " % (srcid_cur, )
                try:
                    srcid = raw_input_no_history(prompt)
                except RuntimeError:
                    print 'Sorry, unable to change source.'
                    return
        ax_indx = images.index('seds')
        sed_src = get_src(src_list, srcid)
        if sed_src == None:
            print 'Source not found!'
            return
        srcid_cur = srcid
        axes_list = fig.get_axes()
        for axindx, ax in enumerate(axes_list):
            if images[axindx] == 'seds':
                plot_sed(sed_src, ax)
        pl.draw()
    if event.key == 'i':
        # Print info about visible region
        has_image = False
        axes_list = fig.get_axes()
        # Get limits of visible region
        for axindx, ax in enumerate(axes_list):
            if images[axindx] != 'wavelets' and images[axindx] != 'seds':
                xmin, xmax = ax.get_xlim()
                ymin, ymax = ax.get_ylim()
                has_image = True
                break
        if not has_image:
            return
        if xmin < 0:
            xmin = 0
        if xmax > img_ch0.shape[0]:
            xmax = img_ch0.shape[0]
        if ymin < 0:
            ymin = 0
        if ymax > img_ch0.shape[1]:
            ymax = img_ch0.shape[1]
        flux = N.nansum(img_ch0[xmin:xmax, ymin:ymax]) / pixels_per_beam
        mask = N.isnan(img_ch0[xmin:xmax, ymin:ymax])
        num_pix_unmasked = float(N.size(N.where(mask == False), 1))
        mean_rms = N.nansum(img_rms[xmin:xmax, ymin:ymax]) / num_pix_unmasked
        mean_map_flux = N.nansum(img_mean[xmin:xmax,
                                          ymin:ymax]) / pixels_per_beam
        if img_gaus_mod == None:
            gaus_mod_flux = 0.0
        else:
            gaus_mod_flux = N.nansum(img_gaus_mod[xmin:xmax,
                                                  ymin:ymax]) / pixels_per_beam
        print 'Visible region (%i:%i, %i:%i) :' % (xmin, xmax, ymin, ymax)
        print '  ch0 flux density from sum of pixels ... : %f Jy'\
            % (flux,)
        print '  Background mean map flux density ...... : %f Jy'\
            % (mean_map_flux,)
        print '  Gaussian model flux density ........... : %f Jy'\
            % (gaus_mod_flux,)
        if img_shap_mod != None:
            shap_mod_flux = N.nansum(img_shap_mod[xmin:xmax,
                                                  ymin:ymax]) / pixels_per_beam
            print '  Shapelet model flux density ........... : %f Jy'\
                % (shap_mod_flux,)
        print '  Mean rms (from rms map) ............... : %f Jy/beam'\
            % (mean_rms,)
    if event.key == 'n':
        # Show/Hide island numbers
        if markers:
            for marker in markers:
                marker.set_visible(not marker.get_visible())
            pl.draw()
예제 #2
0
    def __call__(self, img):

        mylog = mylogger.logging.getLogger("PyBDSM." + img.log + "Wavelet")

        if img.opts.atrous_do:
            if img.nisl == 0:
                mylog.warning(
                    "No islands found. Skipping wavelet decomposition.")
                img.completed_Ops.append('wavelet_atrous')
                return

            mylog.info(
                "Decomposing gaussian residual image into a-trous wavelets")
            bdir = img.basedir + '/wavelet/'
            if img.opts.output_all:
                if not os.path.isdir(bdir): os.makedirs(bdir)
                if not os.path.isdir(bdir + '/residual/'):
                    os.makedirs(bdir + '/residual/')
                if not os.path.isdir(bdir + '/model/'):
                    os.makedirs(bdir + '/model/')
            dobdsm = img.opts.atrous_bdsm_do
            filter = {
                'tr': {
                    'size': 3,
                    'vec': [1. / 4, 1. / 2, 1. / 4],
                    'name': 'Triangle'
                },
                'b3': {
                    'size': 5,
                    'vec': [1. / 16, 1. / 4, 3. / 8, 1. / 4, 1. / 16],
                    'name': 'B3 spline'
                }
            }

            if dobdsm: wchain, wopts = self.setpara_bdsm(img)

            n, m = img.ch0_arr.shape

            # Calculate residual image that results from normal (non-wavelet) Gaussian fitting
            Op_make_residimage()(img)
            resid = img.resid_gaus_arr

            lpf = img.opts.atrous_lpf
            if lpf not in ['b3', 'tr']: lpf = 'b3'
            jmax = img.opts.atrous_jmax
            l = len(filter[lpf]['vec']
                    )  # 1st 3 is arbit and 2nd 3 is whats expected for a-trous
            if jmax < 1 or jmax > 15:  # determine jmax
                # Check if largest island size is
                # smaller than 1/3 of image size. If so, use it to determine jmax.
                min_size = min(resid.shape)
                max_isl_shape = (0, 0)
                for isl in img.islands:
                    if isl.image.shape[0] * isl.image.shape[1] > max_isl_shape[
                            0] * max_isl_shape[1]:
                        max_isl_shape = isl.image.shape
                if max_isl_shape != (
                        0, 0) and min(max_isl_shape) < min(resid.shape) / 3.0:
                    min_size = min(max_isl_shape) * 4.0
                else:
                    min_size = min(resid.shape)
                jmax = int(
                    floor(
                        log((min_size / 3.0 * 3.0 - l) /
                            (l - 1) + 1) / log(2.0) + 1.0)) + 1
                if min_size * 0.55 <= (l + (l - 1) * (2**(jmax) - 1)):
                    jmax = jmax - 1
            img.wavelet_lpf = lpf
            img.wavelet_jmax = jmax
            mylog.info("Using " + filter[lpf]['name'] +
                       ' filter with J_max = ' + str(jmax))

            img.atrous_islands = []
            img.atrous_gaussians = []
            img.atrous_sources = []
            img.atrous_opts = []
            img.resid_wavelets_arr = cp(img.resid_gaus_arr)

            im_old = img.resid_wavelets_arr
            total_flux = 0.0
            ntot_wvgaus = 0
            stop_wav = False
            pix_masked = N.where(N.isnan(resid) == True)
            jmin = 1
            if img.opts.ncores is None:
                numcores = 1
            else:
                numcores = img.opts.ncores
            for j in range(jmin, jmax +
                           1):  # extra +1 is so we can do bdsm on cJ as well
                mylogger.userinfo(mylog, "\nWavelet scale #" + str(j))
                im_new = self.atrous(im_old,
                                     filter[lpf]['vec'],
                                     lpf,
                                     j,
                                     numcores=numcores,
                                     use_scipy_fft=img.opts.use_scipy_fft)
                im_new[
                    pix_masked] = N.nan  # since fftconvolve wont work with blanked pixels
                if img.opts.atrous_sum:
                    w = im_new
                else:
                    w = im_old - im_new
                im_old = im_new
                suffix = 'w' + ` j `
                filename = img.imagename + '.atrous.' + suffix + '.fits'
                if img.opts.output_all:
                    func.write_image_to_file('fits', filename, w, img, bdir)
                    mylog.info('%s %s' % ('Wrote ', img.imagename +
                                          '.atrous.' + suffix + '.fits'))

                # now do bdsm on each wavelet image.
                if dobdsm:
                    wopts['filename'] = filename
                    wopts['basedir'] = bdir
                    box = img.rms_box[0]
                    y1 = (l + (l - 1) * (2**(j - 1) - 1))
                    bs = max(5 * y1, box)  # changed from 10 to 5
                    if bs > min(n, m) / 2:
                        wopts['rms_map'] = False
                        wopts['mean_map'] = 'const'
                        wopts['rms_box'] = None
                    else:
                        wopts['rms_box'] = (bs, bs / 3)
                        if hasattr(img, '_adapt_rms_isl_pos'):
                            bs_bright = max(5 * y1, img.rms_box_bright[0])
                            if bs_bright < bs / 1.5:
                                wopts['adaptive_rms_box'] = True
                                wopts['rms_box_bright'] = (bs_bright,
                                                           bs_bright / 3)
                            else:
                                wopts['adaptive_rms_box'] = False
                    if j <= 3:
                        wopts['ini_gausfit'] = 'default'
                    else:
                        wopts['ini_gausfit'] = 'nobeam'
                    wid = (l + (l - 1) * (2**(j - 1) - 1))  # / 3.0
                    b1, b2 = img.pixel_beam()[0:2]
                    b1 = b1 * fwsig
                    b2 = b2 * fwsig
                    cdelt = img.wcs_obj.acdelt[:2]

                    wimg = Image(wopts)
                    wimg.beam = (sqrt(wid * wid + b1 * b1) * cdelt[0] * 2.0,
                                 sqrt(wid * wid + b2 * b2) * cdelt[1] * 2.0,
                                 0.0)
                    wimg.orig_beam = img.beam
                    wimg.pixel_beam = img.pixel_beam
                    wimg.pixel_beamarea = img.pixel_beamarea
                    wimg.log = 'Wavelet.'
                    wimg.basedir = img.basedir
                    wimg.extraparams['bbsprefix'] = suffix
                    wimg.extraparams['bbsname'] = img.imagename + '.wavelet'
                    wimg.extraparams['bbsappend'] = True
                    wimg.bbspatchnum = img.bbspatchnum
                    wimg.waveletimage = True
                    wimg.j = j
                    if hasattr(img, '_adapt_rms_isl_pos'):
                        wimg._adapt_rms_isl_pos = img._adapt_rms_isl_pos

                    self.init_image_simple(wimg, img, w, '.atrous.' + suffix)
                    for op in wchain:
                        op(wimg)
                        gc.collect()
                        if isinstance(op,
                                      Op_islands) and img.opts.atrous_orig_isl:
                            if wimg.nisl > 0:

                                # Find islands that do not share any pixels with
                                # islands in original ch0 image.
                                good_isl = []

                                # Make original rank image boolean; rank counts from 0, with -1 being
                                # outside any island
                                orig_rankim_bool = N.array(img.pyrank + 1,
                                                           dtype=bool)

                                # Multiply rank images
                                old_islands = orig_rankim_bool * (wimg.pyrank +
                                                                  1) - 1

                                # Exclude islands that don't overlap with a ch0 island.
                                valid_ids = set(old_islands.flatten())
                                for idx, wvisl in enumerate(wimg.islands):
                                    if idx in valid_ids:
                                        wvisl.valid = True
                                        good_isl.append(wvisl)
                                    else:
                                        wvisl.valid = False

                                wimg.islands = good_isl
                                wimg.nisl = len(good_isl)
                                mylogger.userinfo(mylog,
                                                  "Number of islands found",
                                                  '%i' % wimg.nisl)

                                # Renumber islands:
                                for wvindx, wvisl in enumerate(wimg.islands):
                                    wvisl.island_id = wvindx

                        if isinstance(op, Op_gausfit):
                            # If opts.atrous_orig_isl then exclude Gaussians outside of
                            # the original ch0 islands
                            nwvgaus = 0
                            if img.opts.atrous_orig_isl:
                                gaul = wimg.gaussians
                                tot_flux = 0.0

                                if img.ngaus == 0:
                                    gaus_id = -1
                                else:
                                    gaus_id = img.gaussians[-1].gaus_num
                                wvgaul = []
                                for g in gaul:
                                    if not hasattr(g, 'valid'):
                                        g.valid = False
                                    if not g.valid:
                                        try:
                                            isl_id = img.pyrank[
                                                int(g.centre_pix[0] + 1),
                                                int(g.centre_pix[1] + 1)]
                                        except IndexError:
                                            isl_id = -1
                                        if isl_id >= 0:
                                            isl = img.islands[isl_id]
                                            gcenter = (g.centre_pix[0] -
                                                       isl.origin[0],
                                                       g.centre_pix[1] -
                                                       isl.origin[1])
                                            if not isl.mask_active[gcenter]:
                                                gaus_id += 1
                                                gcp = Gaussian(
                                                    img, g.parameters[:],
                                                    isl.island_id, gaus_id)
                                                gcp.gaus_num = gaus_id
                                                gcp.wisland_id = g.island_id
                                                gcp.jlevel = j
                                                g.valid = True
                                                isl.gaul.append(gcp)
                                                isl.ngaus += 1
                                                img.gaussians.append(gcp)
                                                nwvgaus += 1
                                                tot_flux += gcp.total_flux
                                            else:
                                                g.valid = False
                                                g.jlevel = 0
                                        else:
                                            g.valid = False
                                            g.jlevel = 0
                                vg = []
                                for g in wimg.gaussians:
                                    if g.valid:
                                        vg.append(g)
                                wimg.gaussians = vg
                                mylogger.userinfo(
                                    mylog, "Number of valid wavelet Gaussians",
                                    str(nwvgaus))
                            else:
                                # Keep all Gaussians and merge islands that overlap
                                tot_flux = check_islands_for_overlap(img, wimg)

                                # Now renumber the islands and adjust the rank image before going to next wavelet image
                                renumber_islands(img)

                    total_flux += tot_flux
                    if img.opts.interactive and has_pl:
                        dc = '\033[34;1m'
                        nc = '\033[0m'
                        print dc + '--> Displaying islands and rms image...' + nc
                        if max(wimg.ch0_arr.shape) > 4096:
                            print dc + '--> Image is large. Showing islands only.' + nc
                            wimg.show_fit(rms_image=False,
                                          mean_image=False,
                                          ch0_image=False,
                                          ch0_islands=True,
                                          gresid_image=False,
                                          sresid_image=False,
                                          gmodel_image=False,
                                          smodel_image=False,
                                          pyramid_srcs=False)
                        else:
                            wimg.show_fit()
                        prompt = dc + "Press enter to continue or 'q' stop fitting wavelet images : " + nc
                        answ = raw_input_no_history(prompt)
                        while answ != '':
                            if answ == 'q':
                                img.wavelet_jmax = j
                                stop_wav = True
                                break
                            answ = raw_input_no_history(prompt)
                    if len(wimg.gaussians) > 0:
                        img.resid_wavelets_arr = self.subtract_wvgaus(
                            img.opts, img.resid_wavelets_arr, wimg.gaussians,
                            wimg.islands)
                        if img.opts.atrous_sum:
                            im_old = self.subtract_wvgaus(
                                img.opts, im_old, wimg.gaussians, wimg.islands)
                    if stop_wav == True:
                        break

            pyrank = N.zeros(img.pyrank.shape, dtype=N.int32)
            for i, isl in enumerate(img.islands):
                isl.island_id = i
                for g in isl.gaul:
                    g.island_id = i
                for dg in isl.dgaul:
                    dg.island_id = i
                pyrank[isl.bbox] += N.invert(isl.mask_active) * (i + 1)
            pyrank -= 1  # align pyrank values with island ids and set regions outside of islands to -1
            img.pyrank = pyrank

            pdir = img.basedir + '/misc/'
            img.ngaus += ntot_wvgaus
            img.total_flux_gaus += total_flux
            mylogger.userinfo(mylog,
                              "Total flux density in model on all scales",
                              '%.3f Jy' % img.total_flux_gaus)
            if img.opts.output_all:
                func.write_image_to_file('fits',
                                         img.imagename + '.atrous.cJ.fits',
                                         im_new, img, bdir)
                mylog.info('%s %s' %
                           ('Wrote ', img.imagename + '.atrous.cJ.fits'))
                func.write_image_to_file(
                    'fits', img.imagename + '.resid_wavelets.fits',
                    (img.ch0_arr - img.resid_gaus_arr +
                     img.resid_wavelets_arr), img, bdir + '/residual/')
                mylog.info('%s %s' %
                           ('Wrote ', img.imagename + '.resid_wavelets.fits'))
                func.write_image_to_file(
                    'fits', img.imagename + '.model_wavelets.fits',
                    (img.resid_gaus_arr - img.resid_wavelets_arr), img,
                    bdir + '/model/')
                mylog.info('%s %s' %
                           ('Wrote ', img.imagename + '.model_wavelets.fits'))
            img.completed_Ops.append('wavelet_atrous')
예제 #3
0
def _run_op_list(img, chain):
    """Runs an Image object through chain of op's.

    This is separate from execute() to allow other modules (such as
    interface.py) to use it as well.
    """
    from time import time
    from types import ClassType, TypeType
    from interface import raw_input_no_history
    from gausfit import Op_gausfit
    import mylogger
    import gc

    ops = []
    stopat = img.opts.stop_at
    # Make sure all op's are instances
    for op in chain:
        if isinstance(op, (ClassType, TypeType)):
            ops.append(op())
        else:
            ops.append(op)
        if stopat == 'read' and isinstance(op, Op_readimage): break
        if stopat == 'isl' and isinstance(op, Op_islands): break

    # Log all non-default parameters
    mylog = mylogger.logging.getLogger("PyBDSF.Init")
    mylog.info("PyBDSF version %s"
                             % (__version__, ))
    par_msg = "Non-default input parameters:\n"
    user_opts = img.opts.to_list()
    for user_opt in user_opts:
        k, v = user_opt
        val = img.opts.__getattribute__(k)
        if val != v._default and v.group() != 'hidden':
            par_msg += '    %-20s = %s\n' % (k, repr(val))
    mylog.info(par_msg[:-1]) # -1 is to trim final newline

    # Run all op's
    dc = '\033[34;1m'
    nc = '\033[0m'
    for op in ops:
        if isinstance(op, Op_gausfit) and img.opts.interactive:
            print dc + '--> Displaying islands and rms image...' + nc
            if max(img.ch0_arr.shape) > 4096:
                print dc + '--> Image is large. Showing islands only.' + nc
                img.show_fit(rms_image=False, mean_image=False, ch0_image=False,
                    ch0_islands=True, gresid_image=False, sresid_image=False,
                    gmodel_image=False, smodel_image=False, pyramid_srcs=False)
            else:
                img.show_fit(rms_image=True, mean_image=True,
                    ch0_islands=True, gresid_image=False, sresid_image=False,
                    gmodel_image=False, smodel_image=False, pyramid_srcs=False)
            prompt = dc + "Press enter to continue or 'q' to quit .. : " + nc
            answ = raw_input_no_history(prompt)
            while answ != '':
                if answ == 'q':
                    return False
                answ = raw_input_no_history(prompt)
        op.__start_time = time()
        op(img)
        op.__stop_time = time()
        gc.collect()

    if img.opts.interactive and not img._pi:
        print dc + 'Fitting complete. Displaying results...' + nc
        if img.opts.shapelet_do:
            show_smod = True
            show_sres = True
        else:
            show_smod = False
            show_sres = False
        if img.opts.spectralindex_do:
            show_spec = True
        else:
            show_spec = False
        if max(img.ch0_arr.shape) > 4096:
            print dc + '--> Image is large. Showing Gaussian residual image only.' + nc
            img.show_fit(rms_image=False, mean_image=False, ch0_image=False,
                ch0_islands=False, gresid_image=True, sresid_image=False,
                gmodel_image=False, smodel_image=False, pyramid_srcs=False,
                source_seds=show_spec)
        else:
            img.show_fit(smodel_image=show_smod, sresid_image=show_sres,
                     source_seds=show_spec)

    if img.opts.print_timing:
        print "="*36
        print "%18s : %10s" % ("Module", "Time (sec)")
        print "-"*36
        for i, op in enumerate(chain):
            if hasattr(op, '__start_time'):
                print "%18s : %f" % (op.__class__.__name__,
                                 (op.__stop_time - op.__start_time))
                indx_stop = i
        print "="*36
        print "%18s : %f" % ("Total",
                             (chain[indx_stop].__stop_time - chain[0].__start_time))

    # Log all internally derived parameters
    mylog = mylogger.logging.getLogger("PyBDSF.Final")
    par_msg = "Internally derived parameters:\n"
    import inspect
    import types

    for attr in inspect.getmembers(img.opts):
        if attr[0][0] != '_':
            if isinstance(attr[1], (int, str, bool, float, types.NoneType, tuple, list)):
                if hasattr(img, attr[0]):
                    used = img.__getattribute__(attr[0])
                    if used != attr[1] and isinstance(used, (int, str, bool, float,
                                                             types.NoneType, tuple,
                                                             list)):

                        par_msg += '    %-20s : %s\n' % (attr[0], repr(used))
    mylog.info(par_msg[:-1]) # -1 is to trim final newline

    return True
예제 #4
0
파일: plotresults.py 프로젝트: jjdmol/LOFAR
def on_press(event):
    """Handle keypresses"""
    from interface import raw_input_no_history
    import numpy

    global img_ch0, img_rms, img_mean, img_gaus_mod, img_shap_mod
    global pixels_per_beam, vmin, vmax, vmin_cur, vmax_cur, img_pi
    global ch0min, ch0max, low, fig, images, src_list, srcid_cur
    global markers
    if event.key == '0':
        print 'Resetting limits to defaults (%.4f -- %.4f Jy/beam)' \
            % (pow(10, vmin)-low,
               pow(10, vmax)-low)
        axes_list = fig.get_axes()
        for axindx, ax in enumerate(axes_list):
            if images[axindx] != 'wavelets' and images[axindx] != 'seds':
                im = ax.get_images()[0]
                im.set_clim(vmin, vmax)
        vmin_cur = vmin
        vmax_cur = vmax
        pl.draw()
    if event.key == 'm':
        # Modify scaling
        # First check that there are images to modify
        has_image = False
        for im in images:
            if isinstance(im, numpy.ndarray):
                has_image = True
        if not has_image:
            return
        minscl = 'a'
        while isinstance(minscl, str):
            try:
                if minscl == '':
                    minscl = pow(10, vmin_cur) - low
                    break
                minscl = float(minscl)
            except ValueError:
                prompt = "Enter min value (current = %.4f Jy/beam) : " % (pow(10, vmin_cur)-low,)
                try:
                    minscl = raw_input_no_history(prompt)
                except RuntimeError:
                    print 'Sorry, unable to change scaling.'
                    return
        minscl = N.log10(minscl + low)
        maxscl = 'a'
        while isinstance(maxscl, str):
            try:
                if maxscl == '':
                    maxscl = pow(10, vmax_cur) - low
                    break
                maxscl = float(maxscl)
            except ValueError:
                prompt = "Enter max value (current = %.4f Jy/beam) : " % (pow(10, vmax_cur)-low,)
                try:
                    maxscl = raw_input_no_history(prompt)
                except RuntimeError:
                    print 'Sorry, unable to change scaling.'
                    return
        maxscl = N.log10(maxscl + low)
        if maxscl <= minscl:
            print 'Max value must be greater than min value!'
            return
        axes_list = fig.get_axes()
        for axindx, ax in enumerate(axes_list):
            if images[axindx] != 'wavelets' and images[axindx] != 'seds':
                im = ax.get_images()[0]
                im.set_clim(minscl, maxscl)
        vmin_cur = minscl
        vmax_cur = maxscl
        pl.draw()
    if event.key == 'c':
        # Change source SED
        # First check that SEDs are being plotted
        has_sed = False
        if 'seds' in images:
            has_sed = True
        if not has_sed:
            return
        srcid = 'a'
        while isinstance(srcid, str):
            try:
                if srcid == '':
                    srcid = srcid_cur
                    break
                srcid = int(srcid)
            except ValueError:
                prompt = "Enter source ID (current = %i) : " % (srcid_cur,)
                try:
                    srcid = raw_input_no_history(prompt)
                except RuntimeError:
                    print 'Sorry, unable to change source.'
                    return
        ax_indx = images.index('seds')
        sed_src = get_src(src_list, srcid)
        if sed_src is None:
            print 'Source not found!'
            return
        srcid_cur = srcid
        axes_list = fig.get_axes()
        for axindx, ax in enumerate(axes_list):
            if images[axindx] == 'seds':
                plot_sed(sed_src, ax)
        pl.draw()
    if event.key == 'i':
        # Print info about visible region
        has_image = False
        axes_list = fig.get_axes()
        # Get limits of visible region
        for axindx, ax in enumerate(axes_list):
            if images[axindx] != 'wavelets' and images[axindx] != 'seds':
                xmin, xmax = ax.get_xlim()
                ymin, ymax = ax.get_ylim()
                has_image = True
                break
        if not has_image:
            return
        if xmin < 0:
            xmin = 0
        if xmax > img_ch0.shape[0]:
            xmax = img_ch0.shape[0]
        if ymin < 0:
            ymin = 0
        if ymax > img_ch0.shape[1]:
            ymax = img_ch0.shape[1]
        flux = N.nansum(img_ch0[xmin:xmax, ymin:ymax])/pixels_per_beam
        mask = N.isnan(img_ch0[xmin:xmax, ymin:ymax])
        num_pix_unmasked = float(N.size(N.where(mask == False), 1))
        mean_rms = N.nansum(img_rms[xmin:xmax, ymin:ymax])/num_pix_unmasked
        mean_map_flux = N.nansum(img_mean[xmin:xmax, ymin:ymax])/pixels_per_beam
        if img_gaus_mod is None:
            gaus_mod_flux = 0.0
        else:
            gaus_mod_flux = N.nansum(img_gaus_mod[xmin:xmax, ymin:ymax])/pixels_per_beam
        print 'Visible region (%i:%i, %i:%i) :' % (xmin, xmax, ymin, ymax)
        print '  ch0 flux density from sum of pixels ... : %f Jy'\
            % (flux,)
        print '  Background mean map flux density ...... : %f Jy'\
            % (mean_map_flux,)
        print '  Gaussian model flux density ........... : %f Jy'\
            % (gaus_mod_flux,)
        if img_shap_mod is not None:
            shap_mod_flux = N.nansum(img_shap_mod[xmin:xmax, ymin:ymax])/pixels_per_beam
            print '  Shapelet model flux density ........... : %f Jy'\
                % (shap_mod_flux,)
        print '  Mean rms (from rms map) ............... : %f Jy/beam'\
            % (mean_rms,)
    if event.key == 'n':
        # Show/Hide island numbers
        if markers:
            for marker in markers:
                marker.set_visible(not marker.get_visible())
            pl.draw()
예제 #5
0
def _run_op_list(img, chain):
    """Runs an Image object through chain of op's.

    This is separate from execute() to allow other modules (such as
    interface.py) to use it as well.
    """
    from time import time
    from types import ClassType, TypeType
    from interface import raw_input_no_history
    from gausfit import Op_gausfit
    import mylogger
    import gc

    ops = []
    stopat = img.opts.stop_at
    # Make sure all op's are instances
    for op in chain:
        if isinstance(op, (ClassType, TypeType)):
            ops.append(op())
        else:
            ops.append(op)
        if stopat == 'read' and isinstance(op, Op_readimage): break
        if stopat == 'isl' and isinstance(op, Op_islands): break

    # Log all non-default parameters
    mylog = mylogger.logging.getLogger("PyBDSM.Init")
    mylog.info("PyBDSM version %s (LUS revision %s)"
                             % (__version__, __revision__))
    par_msg = "Non-default input parameters:\n"
    user_opts = img.opts.to_list()
    for user_opt in user_opts:
        k, v = user_opt
        val = img.opts.__getattribute__(k)
        if val != v._default and v.group() != 'hidden':
            par_msg += '    %-20s = %s\n' % (k, repr(val))
    mylog.info(par_msg[:-1]) # -1 is to trim final newline

    # Run all op's
    dc = '\033[34;1m'
    nc = '\033[0m'
    for op in ops:
        if isinstance(op, Op_gausfit) and img.opts.interactive:
            print dc + '--> Displaying islands and rms image...' + nc
            if max(img.ch0_arr.shape) > 4096:
                print dc + '--> Image is large. Showing islands only.' + nc
                img.show_fit(rms_image=False, mean_image=False, ch0_image=False,
                    ch0_islands=True, gresid_image=False, sresid_image=False,
                    gmodel_image=False, smodel_image=False, pyramid_srcs=False)
            else:
                img.show_fit(rms_image=True, mean_image=True,
                    ch0_islands=True, gresid_image=False, sresid_image=False,
                    gmodel_image=False, smodel_image=False, pyramid_srcs=False)
            prompt = dc + "Press enter to continue or 'q' to quit .. : " + nc
            answ = raw_input_no_history(prompt)
            while answ != '':
                if answ == 'q':
                    return False
                answ = raw_input_no_history(prompt)
        op.__start_time = time()
        op(img)
        op.__stop_time = time()
        gc.collect()

    if img.opts.interactive and not img._pi:
        print dc + 'Fitting complete. Displaying results...' + nc
        if img.opts.shapelet_do:
            show_smod = True
            show_sres = True
        else:
            show_smod = False
            show_sres = False
        if img.opts.spectralindex_do:
            show_spec = True
        else:
            show_spec = False
        if max(img.ch0_arr.shape) > 4096:
            print dc + '--> Image is large. Showing Gaussian residual image only.' + nc
            img.show_fit(rms_image=False, mean_image=False, ch0_image=False,
                ch0_islands=False, gresid_image=True, sresid_image=False,
                gmodel_image=False, smodel_image=False, pyramid_srcs=False,
                source_seds=show_spec)
        else:
            img.show_fit(smodel_image=show_smod, sresid_image=show_sres,
                     source_seds=show_spec)

    if img.opts.print_timing:
        print "="*36
        print "%18s : %10s" % ("Module", "Time (sec)")
        print "-"*36
        for i, op in enumerate(chain):
            if hasattr(op, '__start_time'):
                print "%18s : %f" % (op.__class__.__name__,
                                 (op.__stop_time - op.__start_time))
                indx_stop = i
        print "="*36
        print "%18s : %f" % ("Total",
                             (chain[indx_stop].__stop_time - chain[0].__start_time))

    # Log all internally derived parameters
    mylog = mylogger.logging.getLogger("PyBDSM.Final")
    par_msg = "Internally derived parameters:\n"
    import inspect
    import types

    for attr in inspect.getmembers(img.opts):
        if attr[0][0] != '_':
            if isinstance(attr[1], (int, str, bool, float, types.NoneType, tuple, list)):
                if hasattr(img, attr[0]):
                    used = img.__getattribute__(attr[0])
                    if used != attr[1] and isinstance(used, (int, str, bool, float,
                                                             types.NoneType, tuple,
                                                             list)):

                        par_msg += '    %-20s : %s\n' % (attr[0], repr(used))
    mylog.info(par_msg[:-1]) # -1 is to trim final newline

    return True
예제 #6
0
    def __call__(self, img):

        mylog = mylogger.logging.getLogger("PyBDSM." + img.log + "Wavelet")

        if img.opts.atrous_do:
          if img.nisl == 0:
            mylog.warning("No islands found. Skipping wavelet decomposition.")
            img.completed_Ops.append('wavelet_atrous')
            return

          mylog.info("Decomposing gaussian residual image into a-trous wavelets")
          bdir = img.basedir + '/wavelet/'
          if img.opts.output_all:
              if not os.path.isdir(bdir): os.makedirs(bdir)
              if not os.path.isdir(bdir + '/residual/'): os.makedirs(bdir + '/residual/')
              if not os.path.isdir(bdir + '/model/'): os.makedirs(bdir + '/model/')
          dobdsm = img.opts.atrous_bdsm_do
          filter = {'tr':{'size':3, 'vec':[1. / 4, 1. / 2, 1. / 4], 'name':'Triangle'},
                    'b3':{'size':5, 'vec':[1. / 16, 1. / 4, 3. / 8, 1. / 4, 1. / 16], 'name':'B3 spline'}}

          if dobdsm: wchain, wopts = self.setpara_bdsm(img)

          n, m = img.ch0_arr.shape

          # Calculate residual image that results from normal (non-wavelet) Gaussian fitting
          Op_make_residimage()(img)
          resid = img.resid_gaus_arr

          lpf = img.opts.atrous_lpf
          if lpf not in ['b3', 'tr']: lpf = 'b3'
          jmax = img.opts.atrous_jmax
          l = len(filter[lpf]['vec'])             # 1st 3 is arbit and 2nd 3 is whats expected for a-trous
          if jmax < 1 or jmax > 15:                   # determine jmax
            # Check if largest island size is
            # smaller than 1/3 of image size. If so, use it to determine jmax.
            min_size = min(resid.shape)
            max_isl_shape = (0, 0)
            for isl in img.islands:
                if isl.image.shape[0] * isl.image.shape[1] > max_isl_shape[0] * max_isl_shape[1]:
                    max_isl_shape = isl.image.shape
            if max_isl_shape != (0, 0) and min(max_isl_shape) < min(resid.shape) / 3.0:
                min_size = min(max_isl_shape) * 4.0
            else:
                min_size = min(resid.shape)
            jmax = int(floor(log((min_size / 3.0 * 3.0 - l) / (l - 1) + 1) / log(2.0) + 1.0)) + 1
            if min_size * 0.55 <= (l + (l - 1) * (2 ** (jmax) - 1)): jmax = jmax - 1
          img.wavelet_lpf = lpf
          img.wavelet_jmax = jmax
          mylog.info("Using " + filter[lpf]['name'] + ' filter with J_max = ' + str(jmax))

          img.atrous_islands = []
          img.atrous_gaussians = []
          img.atrous_sources = []
          img.atrous_opts = []
          img.resid_wavelets_arr = cp(img.resid_gaus_arr)

          im_old = img.resid_wavelets_arr
          total_flux = 0.0
          ntot_wvgaus = 0
          stop_wav = False
          pix_masked = N.where(N.isnan(resid) == True)
          jmin = 1
          if img.opts.ncores is None:
              numcores = 1
          else:
              numcores = img.opts.ncores
          for j in range(jmin, jmax + 1):  # extra +1 is so we can do bdsm on cJ as well
            mylogger.userinfo(mylog, "\nWavelet scale #" + str(j))
            im_new = self.atrous(im_old, filter[lpf]['vec'], lpf, j, numcores=numcores, use_scipy_fft=img.opts.use_scipy_fft)
            im_new[pix_masked] = N.nan  # since fftconvolve wont work with blanked pixels
            if img.opts.atrous_sum:
                w = im_new
            else:
                w = im_old - im_new
            im_old = im_new
            suffix = 'w' + `j`
            filename = img.imagename + '.atrous.' + suffix + '.fits'
            if img.opts.output_all:
                func.write_image_to_file('fits', filename, w, img, bdir)
                mylog.info('%s %s' % ('Wrote ', img.imagename + '.atrous.' + suffix + '.fits'))

            # now do bdsm on each wavelet image.
            if dobdsm:
              wopts['filename'] = filename
              wopts['basedir'] = bdir
              box = img.rms_box[0]
              y1 = (l + (l - 1) * (2 ** (j - 1) - 1))
              bs = max(5 * y1, box)  # changed from 10 to 5
              if bs > min(n, m) / 2:
                wopts['rms_map'] = False
                wopts['mean_map'] = 'const'
                wopts['rms_box'] = None
              else:
                wopts['rms_box'] = (bs, bs/3)
                if hasattr(img, '_adapt_rms_isl_pos'):
                    bs_bright = max(5 * y1, img.rms_box_bright[0])
                    if bs_bright < bs/1.5:
                        wopts['adaptive_rms_box'] = True
                        wopts['rms_box_bright'] = (bs_bright, bs_bright/3)
                    else:
                        wopts['adaptive_rms_box'] = False
              if j <= 3:
                wopts['ini_gausfit'] = 'default'
              else:
                wopts['ini_gausfit'] = 'nobeam'
              wid = (l + (l - 1) * (2 ** (j - 1) - 1))# / 3.0
              b1, b2 = img.pixel_beam()[0:2]
              b1 = b1 * fwsig
              b2 = b2 * fwsig
              cdelt = img.wcs_obj.acdelt[:2]

              wimg = Image(wopts)
              wimg.beam = (sqrt(wid * wid + b1 * b1) * cdelt[0] * 2.0, sqrt(wid * wid + b2 * b2) * cdelt[1] * 2.0, 0.0)
              wimg.orig_beam = img.beam
              wimg.pixel_beam = img.pixel_beam
              wimg.pixel_beamarea = img.pixel_beamarea
              wimg.log = 'Wavelet.'
              wimg.basedir = img.basedir
              wimg.extraparams['bbsprefix'] = suffix
              wimg.extraparams['bbsname'] = img.imagename + '.wavelet'
              wimg.extraparams['bbsappend'] = True
              wimg.bbspatchnum = img.bbspatchnum
              wimg.waveletimage = True
              wimg.j = j
              if hasattr(img, '_adapt_rms_isl_pos'):
                  wimg._adapt_rms_isl_pos = img._adapt_rms_isl_pos


              self.init_image_simple(wimg, img, w, '.atrous.' + suffix)
              for op in wchain:
                op(wimg)
                gc.collect()
                if isinstance(op, Op_islands) and img.opts.atrous_orig_isl:
                    if wimg.nisl > 0:

                        # Find islands that do not share any pixels with
                        # islands in original ch0 image.
                        good_isl = []

                        # Make original rank image boolean; rank counts from 0, with -1 being
                        # outside any island
                        orig_rankim_bool = N.array(img.pyrank + 1, dtype = bool)

                        # Multiply rank images
                        old_islands = orig_rankim_bool * (wimg.pyrank + 1) - 1

                        # Exclude islands that don't overlap with a ch0 island.
                        valid_ids = set(old_islands.flatten())
                        for idx, wvisl in enumerate(wimg.islands):
                            if idx in valid_ids:
                                wvisl.valid = True
                                good_isl.append(wvisl)
                            else:
                                wvisl.valid = False

                        wimg.islands = good_isl
                        wimg.nisl = len(good_isl)
                        mylogger.userinfo(mylog, "Number of islands found", '%i' %
                                  wimg.nisl)

                        # Renumber islands:
                        for wvindx, wvisl in enumerate(wimg.islands):
                            wvisl.island_id = wvindx

                if isinstance(op, Op_gausfit):
                  # If opts.atrous_orig_isl then exclude Gaussians outside of
                  # the original ch0 islands
                  nwvgaus = 0
                  if img.opts.atrous_orig_isl:
                      gaul = wimg.gaussians
                      tot_flux = 0.0

                      if img.ngaus == 0:
                          gaus_id = -1
                      else:
                          gaus_id = img.gaussians[-1].gaus_num
                      wvgaul = []
                      for g in gaul:
                          if not hasattr(g, 'valid'):
                              g.valid = False
                          if not g.valid:
                              try:
                                  isl_id = img.pyrank[int(g.centre_pix[0] + 1), int(g.centre_pix[1] + 1)]
                              except IndexError:
                                  isl_id = -1
                              if isl_id >= 0:
                                  isl = img.islands[isl_id]
                                  gcenter = (g.centre_pix[0] - isl.origin[0],
                                             g.centre_pix[1] - isl.origin[1])
                                  if not isl.mask_active[gcenter]:
                                      gaus_id += 1
                                      gcp = Gaussian(img, g.parameters[:], isl.island_id, gaus_id)
                                      gcp.gaus_num = gaus_id
                                      gcp.wisland_id = g.island_id
                                      gcp.jlevel = j
                                      g.valid = True
                                      isl.gaul.append(gcp)
                                      isl.ngaus += 1
                                      img.gaussians.append(gcp)
                                      nwvgaus += 1
                                      tot_flux += gcp.total_flux
                                  else:
                                      g.valid = False
                                      g.jlevel = 0
                              else:
                                  g.valid = False
                                  g.jlevel = 0
                      vg = []
                      for g in wimg.gaussians:
                          if g.valid:
                              vg.append(g)
                      wimg.gaussians = vg
                      mylogger.userinfo(mylog, "Number of valid wavelet Gaussians", str(nwvgaus))
                  else:
                      # Keep all Gaussians and merge islands that overlap
                      tot_flux = check_islands_for_overlap(img, wimg)

                      # Now renumber the islands and adjust the rank image before going to next wavelet image
                      renumber_islands(img)

              total_flux += tot_flux
              if img.opts.interactive and has_pl:
                  dc = '\033[34;1m'
                  nc = '\033[0m'
                  print dc + '--> Displaying islands and rms image...' + nc
                  if max(wimg.ch0_arr.shape) > 4096:
                      print dc + '--> Image is large. Showing islands only.' + nc
                      wimg.show_fit(rms_image=False, mean_image=False, ch0_image=False,
                        ch0_islands=True, gresid_image=False, sresid_image=False,
                        gmodel_image=False, smodel_image=False, pyramid_srcs=False)
                  else:
                      wimg.show_fit()
                  prompt = dc + "Press enter to continue or 'q' stop fitting wavelet images : " + nc
                  answ = raw_input_no_history(prompt)
                  while answ != '':
                      if answ == 'q':
                          img.wavelet_jmax = j
                          stop_wav = True
                          break
                      answ = raw_input_no_history(prompt)
              if len(wimg.gaussians) > 0:
                img.resid_wavelets_arr = self.subtract_wvgaus(img.opts, img.resid_wavelets_arr, wimg.gaussians, wimg.islands)
                if img.opts.atrous_sum:
                    im_old = self.subtract_wvgaus(img.opts, im_old, wimg.gaussians, wimg.islands)
              if stop_wav == True:
                  break

          pyrank = N.zeros(img.pyrank.shape, dtype=N.int32)
          for i, isl in enumerate(img.islands):
              isl.island_id = i
              for g in isl.gaul:
                  g.island_id = i
              for dg in isl.dgaul:
                  dg.island_id = i
              pyrank[isl.bbox] += N.invert(isl.mask_active) * (i + 1)
          pyrank -= 1 # align pyrank values with island ids and set regions outside of islands to -1
          img.pyrank = pyrank

          pdir = img.basedir + '/misc/'
          img.ngaus += ntot_wvgaus
          img.total_flux_gaus += total_flux
          mylogger.userinfo(mylog, "Total flux density in model on all scales" , '%.3f Jy' % img.total_flux_gaus)
          if img.opts.output_all:
              func.write_image_to_file('fits', img.imagename + '.atrous.cJ.fits',
                                       im_new, img, bdir)
              mylog.info('%s %s' % ('Wrote ', img.imagename + '.atrous.cJ.fits'))
              func.write_image_to_file('fits', img.imagename + '.resid_wavelets.fits',
                                       (img.ch0_arr - img.resid_gaus_arr + img.resid_wavelets_arr), img, bdir + '/residual/')
              mylog.info('%s %s' % ('Wrote ', img.imagename + '.resid_wavelets.fits'))
              func.write_image_to_file('fits', img.imagename + '.model_wavelets.fits',
                                       (img.resid_gaus_arr - img.resid_wavelets_arr), img, bdir + '/model/')
              mylog.info('%s %s' % ('Wrote ', img.imagename + '.model_wavelets.fits'))
          img.completed_Ops.append('wavelet_atrous')