def on_press(event): """Handle keypresses""" from interface import raw_input_no_history import numpy global img_ch0, img_rms, img_mean, img_gaus_mod, img_shap_mod global pixels_per_beam, vmin, vmax, vmin_cur, vmax_cur, img_pi global ch0min, ch0max, low, fig, images, src_list, srcid_cur global markers if event.key == '0': print 'Resetting limits to defaults (%.4f -- %.4f Jy/beam)' \ % (pow(10, vmin)-low, pow(10, vmax)-low) axes_list = fig.get_axes() for axindx, ax in enumerate(axes_list): if images[axindx] != 'wavelets' and images[axindx] != 'seds': im = ax.get_images()[0] im.set_clim(vmin, vmax) vmin_cur = vmin vmax_cur = vmax pl.draw() if event.key == 'm': # Modify scaling # First check that there are images to modify has_image = False for im in images: if isinstance(im, numpy.ndarray): has_image = True if not has_image: return minscl = 'a' while isinstance(minscl, str): try: if minscl == '': minscl = pow(10, vmin_cur) - low break minscl = float(minscl) except ValueError: prompt = "Enter min value (current = %.4f Jy/beam) : " % ( pow(10, vmin_cur) - low, ) try: minscl = raw_input_no_history(prompt) except RuntimeError: print 'Sorry, unable to change scaling.' return minscl = N.log10(minscl + low) maxscl = 'a' while isinstance(maxscl, str): try: if maxscl == '': maxscl = pow(10, vmax_cur) - low break maxscl = float(maxscl) except ValueError: prompt = "Enter max value (current = %.4f Jy/beam) : " % ( pow(10, vmax_cur) - low, ) try: maxscl = raw_input_no_history(prompt) except RuntimeError: print 'Sorry, unable to change scaling.' return maxscl = N.log10(maxscl + low) if maxscl <= minscl: print 'Max value must be greater than min value!' return axes_list = fig.get_axes() for axindx, ax in enumerate(axes_list): if images[axindx] != 'wavelets' and images[axindx] != 'seds': im = ax.get_images()[0] im.set_clim(minscl, maxscl) vmin_cur = minscl vmax_cur = maxscl pl.draw() if event.key == 'c': # Change source SED # First check that SEDs are being plotted has_sed = False if 'seds' in images: has_sed = True if not has_sed: return srcid = 'a' while isinstance(srcid, str): try: if srcid == '': srcid = srcid_cur break srcid = int(srcid) except ValueError: prompt = "Enter source ID (current = %i) : " % (srcid_cur, ) try: srcid = raw_input_no_history(prompt) except RuntimeError: print 'Sorry, unable to change source.' return ax_indx = images.index('seds') sed_src = get_src(src_list, srcid) if sed_src == None: print 'Source not found!' return srcid_cur = srcid axes_list = fig.get_axes() for axindx, ax in enumerate(axes_list): if images[axindx] == 'seds': plot_sed(sed_src, ax) pl.draw() if event.key == 'i': # Print info about visible region has_image = False axes_list = fig.get_axes() # Get limits of visible region for axindx, ax in enumerate(axes_list): if images[axindx] != 'wavelets' and images[axindx] != 'seds': xmin, xmax = ax.get_xlim() ymin, ymax = ax.get_ylim() has_image = True break if not has_image: return if xmin < 0: xmin = 0 if xmax > img_ch0.shape[0]: xmax = img_ch0.shape[0] if ymin < 0: ymin = 0 if ymax > img_ch0.shape[1]: ymax = img_ch0.shape[1] flux = N.nansum(img_ch0[xmin:xmax, ymin:ymax]) / pixels_per_beam mask = N.isnan(img_ch0[xmin:xmax, ymin:ymax]) num_pix_unmasked = float(N.size(N.where(mask == False), 1)) mean_rms = N.nansum(img_rms[xmin:xmax, ymin:ymax]) / num_pix_unmasked mean_map_flux = N.nansum(img_mean[xmin:xmax, ymin:ymax]) / pixels_per_beam if img_gaus_mod == None: gaus_mod_flux = 0.0 else: gaus_mod_flux = N.nansum(img_gaus_mod[xmin:xmax, ymin:ymax]) / pixels_per_beam print 'Visible region (%i:%i, %i:%i) :' % (xmin, xmax, ymin, ymax) print ' ch0 flux density from sum of pixels ... : %f Jy'\ % (flux,) print ' Background mean map flux density ...... : %f Jy'\ % (mean_map_flux,) print ' Gaussian model flux density ........... : %f Jy'\ % (gaus_mod_flux,) if img_shap_mod != None: shap_mod_flux = N.nansum(img_shap_mod[xmin:xmax, ymin:ymax]) / pixels_per_beam print ' Shapelet model flux density ........... : %f Jy'\ % (shap_mod_flux,) print ' Mean rms (from rms map) ............... : %f Jy/beam'\ % (mean_rms,) if event.key == 'n': # Show/Hide island numbers if markers: for marker in markers: marker.set_visible(not marker.get_visible()) pl.draw()
def __call__(self, img): mylog = mylogger.logging.getLogger("PyBDSM." + img.log + "Wavelet") if img.opts.atrous_do: if img.nisl == 0: mylog.warning( "No islands found. Skipping wavelet decomposition.") img.completed_Ops.append('wavelet_atrous') return mylog.info( "Decomposing gaussian residual image into a-trous wavelets") bdir = img.basedir + '/wavelet/' if img.opts.output_all: if not os.path.isdir(bdir): os.makedirs(bdir) if not os.path.isdir(bdir + '/residual/'): os.makedirs(bdir + '/residual/') if not os.path.isdir(bdir + '/model/'): os.makedirs(bdir + '/model/') dobdsm = img.opts.atrous_bdsm_do filter = { 'tr': { 'size': 3, 'vec': [1. / 4, 1. / 2, 1. / 4], 'name': 'Triangle' }, 'b3': { 'size': 5, 'vec': [1. / 16, 1. / 4, 3. / 8, 1. / 4, 1. / 16], 'name': 'B3 spline' } } if dobdsm: wchain, wopts = self.setpara_bdsm(img) n, m = img.ch0_arr.shape # Calculate residual image that results from normal (non-wavelet) Gaussian fitting Op_make_residimage()(img) resid = img.resid_gaus_arr lpf = img.opts.atrous_lpf if lpf not in ['b3', 'tr']: lpf = 'b3' jmax = img.opts.atrous_jmax l = len(filter[lpf]['vec'] ) # 1st 3 is arbit and 2nd 3 is whats expected for a-trous if jmax < 1 or jmax > 15: # determine jmax # Check if largest island size is # smaller than 1/3 of image size. If so, use it to determine jmax. min_size = min(resid.shape) max_isl_shape = (0, 0) for isl in img.islands: if isl.image.shape[0] * isl.image.shape[1] > max_isl_shape[ 0] * max_isl_shape[1]: max_isl_shape = isl.image.shape if max_isl_shape != ( 0, 0) and min(max_isl_shape) < min(resid.shape) / 3.0: min_size = min(max_isl_shape) * 4.0 else: min_size = min(resid.shape) jmax = int( floor( log((min_size / 3.0 * 3.0 - l) / (l - 1) + 1) / log(2.0) + 1.0)) + 1 if min_size * 0.55 <= (l + (l - 1) * (2**(jmax) - 1)): jmax = jmax - 1 img.wavelet_lpf = lpf img.wavelet_jmax = jmax mylog.info("Using " + filter[lpf]['name'] + ' filter with J_max = ' + str(jmax)) img.atrous_islands = [] img.atrous_gaussians = [] img.atrous_sources = [] img.atrous_opts = [] img.resid_wavelets_arr = cp(img.resid_gaus_arr) im_old = img.resid_wavelets_arr total_flux = 0.0 ntot_wvgaus = 0 stop_wav = False pix_masked = N.where(N.isnan(resid) == True) jmin = 1 if img.opts.ncores is None: numcores = 1 else: numcores = img.opts.ncores for j in range(jmin, jmax + 1): # extra +1 is so we can do bdsm on cJ as well mylogger.userinfo(mylog, "\nWavelet scale #" + str(j)) im_new = self.atrous(im_old, filter[lpf]['vec'], lpf, j, numcores=numcores, use_scipy_fft=img.opts.use_scipy_fft) im_new[ pix_masked] = N.nan # since fftconvolve wont work with blanked pixels if img.opts.atrous_sum: w = im_new else: w = im_old - im_new im_old = im_new suffix = 'w' + ` j ` filename = img.imagename + '.atrous.' + suffix + '.fits' if img.opts.output_all: func.write_image_to_file('fits', filename, w, img, bdir) mylog.info('%s %s' % ('Wrote ', img.imagename + '.atrous.' + suffix + '.fits')) # now do bdsm on each wavelet image. if dobdsm: wopts['filename'] = filename wopts['basedir'] = bdir box = img.rms_box[0] y1 = (l + (l - 1) * (2**(j - 1) - 1)) bs = max(5 * y1, box) # changed from 10 to 5 if bs > min(n, m) / 2: wopts['rms_map'] = False wopts['mean_map'] = 'const' wopts['rms_box'] = None else: wopts['rms_box'] = (bs, bs / 3) if hasattr(img, '_adapt_rms_isl_pos'): bs_bright = max(5 * y1, img.rms_box_bright[0]) if bs_bright < bs / 1.5: wopts['adaptive_rms_box'] = True wopts['rms_box_bright'] = (bs_bright, bs_bright / 3) else: wopts['adaptive_rms_box'] = False if j <= 3: wopts['ini_gausfit'] = 'default' else: wopts['ini_gausfit'] = 'nobeam' wid = (l + (l - 1) * (2**(j - 1) - 1)) # / 3.0 b1, b2 = img.pixel_beam()[0:2] b1 = b1 * fwsig b2 = b2 * fwsig cdelt = img.wcs_obj.acdelt[:2] wimg = Image(wopts) wimg.beam = (sqrt(wid * wid + b1 * b1) * cdelt[0] * 2.0, sqrt(wid * wid + b2 * b2) * cdelt[1] * 2.0, 0.0) wimg.orig_beam = img.beam wimg.pixel_beam = img.pixel_beam wimg.pixel_beamarea = img.pixel_beamarea wimg.log = 'Wavelet.' wimg.basedir = img.basedir wimg.extraparams['bbsprefix'] = suffix wimg.extraparams['bbsname'] = img.imagename + '.wavelet' wimg.extraparams['bbsappend'] = True wimg.bbspatchnum = img.bbspatchnum wimg.waveletimage = True wimg.j = j if hasattr(img, '_adapt_rms_isl_pos'): wimg._adapt_rms_isl_pos = img._adapt_rms_isl_pos self.init_image_simple(wimg, img, w, '.atrous.' + suffix) for op in wchain: op(wimg) gc.collect() if isinstance(op, Op_islands) and img.opts.atrous_orig_isl: if wimg.nisl > 0: # Find islands that do not share any pixels with # islands in original ch0 image. good_isl = [] # Make original rank image boolean; rank counts from 0, with -1 being # outside any island orig_rankim_bool = N.array(img.pyrank + 1, dtype=bool) # Multiply rank images old_islands = orig_rankim_bool * (wimg.pyrank + 1) - 1 # Exclude islands that don't overlap with a ch0 island. valid_ids = set(old_islands.flatten()) for idx, wvisl in enumerate(wimg.islands): if idx in valid_ids: wvisl.valid = True good_isl.append(wvisl) else: wvisl.valid = False wimg.islands = good_isl wimg.nisl = len(good_isl) mylogger.userinfo(mylog, "Number of islands found", '%i' % wimg.nisl) # Renumber islands: for wvindx, wvisl in enumerate(wimg.islands): wvisl.island_id = wvindx if isinstance(op, Op_gausfit): # If opts.atrous_orig_isl then exclude Gaussians outside of # the original ch0 islands nwvgaus = 0 if img.opts.atrous_orig_isl: gaul = wimg.gaussians tot_flux = 0.0 if img.ngaus == 0: gaus_id = -1 else: gaus_id = img.gaussians[-1].gaus_num wvgaul = [] for g in gaul: if not hasattr(g, 'valid'): g.valid = False if not g.valid: try: isl_id = img.pyrank[ int(g.centre_pix[0] + 1), int(g.centre_pix[1] + 1)] except IndexError: isl_id = -1 if isl_id >= 0: isl = img.islands[isl_id] gcenter = (g.centre_pix[0] - isl.origin[0], g.centre_pix[1] - isl.origin[1]) if not isl.mask_active[gcenter]: gaus_id += 1 gcp = Gaussian( img, g.parameters[:], isl.island_id, gaus_id) gcp.gaus_num = gaus_id gcp.wisland_id = g.island_id gcp.jlevel = j g.valid = True isl.gaul.append(gcp) isl.ngaus += 1 img.gaussians.append(gcp) nwvgaus += 1 tot_flux += gcp.total_flux else: g.valid = False g.jlevel = 0 else: g.valid = False g.jlevel = 0 vg = [] for g in wimg.gaussians: if g.valid: vg.append(g) wimg.gaussians = vg mylogger.userinfo( mylog, "Number of valid wavelet Gaussians", str(nwvgaus)) else: # Keep all Gaussians and merge islands that overlap tot_flux = check_islands_for_overlap(img, wimg) # Now renumber the islands and adjust the rank image before going to next wavelet image renumber_islands(img) total_flux += tot_flux if img.opts.interactive and has_pl: dc = '\033[34;1m' nc = '\033[0m' print dc + '--> Displaying islands and rms image...' + nc if max(wimg.ch0_arr.shape) > 4096: print dc + '--> Image is large. Showing islands only.' + nc wimg.show_fit(rms_image=False, mean_image=False, ch0_image=False, ch0_islands=True, gresid_image=False, sresid_image=False, gmodel_image=False, smodel_image=False, pyramid_srcs=False) else: wimg.show_fit() prompt = dc + "Press enter to continue or 'q' stop fitting wavelet images : " + nc answ = raw_input_no_history(prompt) while answ != '': if answ == 'q': img.wavelet_jmax = j stop_wav = True break answ = raw_input_no_history(prompt) if len(wimg.gaussians) > 0: img.resid_wavelets_arr = self.subtract_wvgaus( img.opts, img.resid_wavelets_arr, wimg.gaussians, wimg.islands) if img.opts.atrous_sum: im_old = self.subtract_wvgaus( img.opts, im_old, wimg.gaussians, wimg.islands) if stop_wav == True: break pyrank = N.zeros(img.pyrank.shape, dtype=N.int32) for i, isl in enumerate(img.islands): isl.island_id = i for g in isl.gaul: g.island_id = i for dg in isl.dgaul: dg.island_id = i pyrank[isl.bbox] += N.invert(isl.mask_active) * (i + 1) pyrank -= 1 # align pyrank values with island ids and set regions outside of islands to -1 img.pyrank = pyrank pdir = img.basedir + '/misc/' img.ngaus += ntot_wvgaus img.total_flux_gaus += total_flux mylogger.userinfo(mylog, "Total flux density in model on all scales", '%.3f Jy' % img.total_flux_gaus) if img.opts.output_all: func.write_image_to_file('fits', img.imagename + '.atrous.cJ.fits', im_new, img, bdir) mylog.info('%s %s' % ('Wrote ', img.imagename + '.atrous.cJ.fits')) func.write_image_to_file( 'fits', img.imagename + '.resid_wavelets.fits', (img.ch0_arr - img.resid_gaus_arr + img.resid_wavelets_arr), img, bdir + '/residual/') mylog.info('%s %s' % ('Wrote ', img.imagename + '.resid_wavelets.fits')) func.write_image_to_file( 'fits', img.imagename + '.model_wavelets.fits', (img.resid_gaus_arr - img.resid_wavelets_arr), img, bdir + '/model/') mylog.info('%s %s' % ('Wrote ', img.imagename + '.model_wavelets.fits')) img.completed_Ops.append('wavelet_atrous')
def _run_op_list(img, chain): """Runs an Image object through chain of op's. This is separate from execute() to allow other modules (such as interface.py) to use it as well. """ from time import time from types import ClassType, TypeType from interface import raw_input_no_history from gausfit import Op_gausfit import mylogger import gc ops = [] stopat = img.opts.stop_at # Make sure all op's are instances for op in chain: if isinstance(op, (ClassType, TypeType)): ops.append(op()) else: ops.append(op) if stopat == 'read' and isinstance(op, Op_readimage): break if stopat == 'isl' and isinstance(op, Op_islands): break # Log all non-default parameters mylog = mylogger.logging.getLogger("PyBDSF.Init") mylog.info("PyBDSF version %s" % (__version__, )) par_msg = "Non-default input parameters:\n" user_opts = img.opts.to_list() for user_opt in user_opts: k, v = user_opt val = img.opts.__getattribute__(k) if val != v._default and v.group() != 'hidden': par_msg += ' %-20s = %s\n' % (k, repr(val)) mylog.info(par_msg[:-1]) # -1 is to trim final newline # Run all op's dc = '\033[34;1m' nc = '\033[0m' for op in ops: if isinstance(op, Op_gausfit) and img.opts.interactive: print dc + '--> Displaying islands and rms image...' + nc if max(img.ch0_arr.shape) > 4096: print dc + '--> Image is large. Showing islands only.' + nc img.show_fit(rms_image=False, mean_image=False, ch0_image=False, ch0_islands=True, gresid_image=False, sresid_image=False, gmodel_image=False, smodel_image=False, pyramid_srcs=False) else: img.show_fit(rms_image=True, mean_image=True, ch0_islands=True, gresid_image=False, sresid_image=False, gmodel_image=False, smodel_image=False, pyramid_srcs=False) prompt = dc + "Press enter to continue or 'q' to quit .. : " + nc answ = raw_input_no_history(prompt) while answ != '': if answ == 'q': return False answ = raw_input_no_history(prompt) op.__start_time = time() op(img) op.__stop_time = time() gc.collect() if img.opts.interactive and not img._pi: print dc + 'Fitting complete. Displaying results...' + nc if img.opts.shapelet_do: show_smod = True show_sres = True else: show_smod = False show_sres = False if img.opts.spectralindex_do: show_spec = True else: show_spec = False if max(img.ch0_arr.shape) > 4096: print dc + '--> Image is large. Showing Gaussian residual image only.' + nc img.show_fit(rms_image=False, mean_image=False, ch0_image=False, ch0_islands=False, gresid_image=True, sresid_image=False, gmodel_image=False, smodel_image=False, pyramid_srcs=False, source_seds=show_spec) else: img.show_fit(smodel_image=show_smod, sresid_image=show_sres, source_seds=show_spec) if img.opts.print_timing: print "="*36 print "%18s : %10s" % ("Module", "Time (sec)") print "-"*36 for i, op in enumerate(chain): if hasattr(op, '__start_time'): print "%18s : %f" % (op.__class__.__name__, (op.__stop_time - op.__start_time)) indx_stop = i print "="*36 print "%18s : %f" % ("Total", (chain[indx_stop].__stop_time - chain[0].__start_time)) # Log all internally derived parameters mylog = mylogger.logging.getLogger("PyBDSF.Final") par_msg = "Internally derived parameters:\n" import inspect import types for attr in inspect.getmembers(img.opts): if attr[0][0] != '_': if isinstance(attr[1], (int, str, bool, float, types.NoneType, tuple, list)): if hasattr(img, attr[0]): used = img.__getattribute__(attr[0]) if used != attr[1] and isinstance(used, (int, str, bool, float, types.NoneType, tuple, list)): par_msg += ' %-20s : %s\n' % (attr[0], repr(used)) mylog.info(par_msg[:-1]) # -1 is to trim final newline return True
def on_press(event): """Handle keypresses""" from interface import raw_input_no_history import numpy global img_ch0, img_rms, img_mean, img_gaus_mod, img_shap_mod global pixels_per_beam, vmin, vmax, vmin_cur, vmax_cur, img_pi global ch0min, ch0max, low, fig, images, src_list, srcid_cur global markers if event.key == '0': print 'Resetting limits to defaults (%.4f -- %.4f Jy/beam)' \ % (pow(10, vmin)-low, pow(10, vmax)-low) axes_list = fig.get_axes() for axindx, ax in enumerate(axes_list): if images[axindx] != 'wavelets' and images[axindx] != 'seds': im = ax.get_images()[0] im.set_clim(vmin, vmax) vmin_cur = vmin vmax_cur = vmax pl.draw() if event.key == 'm': # Modify scaling # First check that there are images to modify has_image = False for im in images: if isinstance(im, numpy.ndarray): has_image = True if not has_image: return minscl = 'a' while isinstance(minscl, str): try: if minscl == '': minscl = pow(10, vmin_cur) - low break minscl = float(minscl) except ValueError: prompt = "Enter min value (current = %.4f Jy/beam) : " % (pow(10, vmin_cur)-low,) try: minscl = raw_input_no_history(prompt) except RuntimeError: print 'Sorry, unable to change scaling.' return minscl = N.log10(minscl + low) maxscl = 'a' while isinstance(maxscl, str): try: if maxscl == '': maxscl = pow(10, vmax_cur) - low break maxscl = float(maxscl) except ValueError: prompt = "Enter max value (current = %.4f Jy/beam) : " % (pow(10, vmax_cur)-low,) try: maxscl = raw_input_no_history(prompt) except RuntimeError: print 'Sorry, unable to change scaling.' return maxscl = N.log10(maxscl + low) if maxscl <= minscl: print 'Max value must be greater than min value!' return axes_list = fig.get_axes() for axindx, ax in enumerate(axes_list): if images[axindx] != 'wavelets' and images[axindx] != 'seds': im = ax.get_images()[0] im.set_clim(minscl, maxscl) vmin_cur = minscl vmax_cur = maxscl pl.draw() if event.key == 'c': # Change source SED # First check that SEDs are being plotted has_sed = False if 'seds' in images: has_sed = True if not has_sed: return srcid = 'a' while isinstance(srcid, str): try: if srcid == '': srcid = srcid_cur break srcid = int(srcid) except ValueError: prompt = "Enter source ID (current = %i) : " % (srcid_cur,) try: srcid = raw_input_no_history(prompt) except RuntimeError: print 'Sorry, unable to change source.' return ax_indx = images.index('seds') sed_src = get_src(src_list, srcid) if sed_src is None: print 'Source not found!' return srcid_cur = srcid axes_list = fig.get_axes() for axindx, ax in enumerate(axes_list): if images[axindx] == 'seds': plot_sed(sed_src, ax) pl.draw() if event.key == 'i': # Print info about visible region has_image = False axes_list = fig.get_axes() # Get limits of visible region for axindx, ax in enumerate(axes_list): if images[axindx] != 'wavelets' and images[axindx] != 'seds': xmin, xmax = ax.get_xlim() ymin, ymax = ax.get_ylim() has_image = True break if not has_image: return if xmin < 0: xmin = 0 if xmax > img_ch0.shape[0]: xmax = img_ch0.shape[0] if ymin < 0: ymin = 0 if ymax > img_ch0.shape[1]: ymax = img_ch0.shape[1] flux = N.nansum(img_ch0[xmin:xmax, ymin:ymax])/pixels_per_beam mask = N.isnan(img_ch0[xmin:xmax, ymin:ymax]) num_pix_unmasked = float(N.size(N.where(mask == False), 1)) mean_rms = N.nansum(img_rms[xmin:xmax, ymin:ymax])/num_pix_unmasked mean_map_flux = N.nansum(img_mean[xmin:xmax, ymin:ymax])/pixels_per_beam if img_gaus_mod is None: gaus_mod_flux = 0.0 else: gaus_mod_flux = N.nansum(img_gaus_mod[xmin:xmax, ymin:ymax])/pixels_per_beam print 'Visible region (%i:%i, %i:%i) :' % (xmin, xmax, ymin, ymax) print ' ch0 flux density from sum of pixels ... : %f Jy'\ % (flux,) print ' Background mean map flux density ...... : %f Jy'\ % (mean_map_flux,) print ' Gaussian model flux density ........... : %f Jy'\ % (gaus_mod_flux,) if img_shap_mod is not None: shap_mod_flux = N.nansum(img_shap_mod[xmin:xmax, ymin:ymax])/pixels_per_beam print ' Shapelet model flux density ........... : %f Jy'\ % (shap_mod_flux,) print ' Mean rms (from rms map) ............... : %f Jy/beam'\ % (mean_rms,) if event.key == 'n': # Show/Hide island numbers if markers: for marker in markers: marker.set_visible(not marker.get_visible()) pl.draw()
def _run_op_list(img, chain): """Runs an Image object through chain of op's. This is separate from execute() to allow other modules (such as interface.py) to use it as well. """ from time import time from types import ClassType, TypeType from interface import raw_input_no_history from gausfit import Op_gausfit import mylogger import gc ops = [] stopat = img.opts.stop_at # Make sure all op's are instances for op in chain: if isinstance(op, (ClassType, TypeType)): ops.append(op()) else: ops.append(op) if stopat == 'read' and isinstance(op, Op_readimage): break if stopat == 'isl' and isinstance(op, Op_islands): break # Log all non-default parameters mylog = mylogger.logging.getLogger("PyBDSM.Init") mylog.info("PyBDSM version %s (LUS revision %s)" % (__version__, __revision__)) par_msg = "Non-default input parameters:\n" user_opts = img.opts.to_list() for user_opt in user_opts: k, v = user_opt val = img.opts.__getattribute__(k) if val != v._default and v.group() != 'hidden': par_msg += ' %-20s = %s\n' % (k, repr(val)) mylog.info(par_msg[:-1]) # -1 is to trim final newline # Run all op's dc = '\033[34;1m' nc = '\033[0m' for op in ops: if isinstance(op, Op_gausfit) and img.opts.interactive: print dc + '--> Displaying islands and rms image...' + nc if max(img.ch0_arr.shape) > 4096: print dc + '--> Image is large. Showing islands only.' + nc img.show_fit(rms_image=False, mean_image=False, ch0_image=False, ch0_islands=True, gresid_image=False, sresid_image=False, gmodel_image=False, smodel_image=False, pyramid_srcs=False) else: img.show_fit(rms_image=True, mean_image=True, ch0_islands=True, gresid_image=False, sresid_image=False, gmodel_image=False, smodel_image=False, pyramid_srcs=False) prompt = dc + "Press enter to continue or 'q' to quit .. : " + nc answ = raw_input_no_history(prompt) while answ != '': if answ == 'q': return False answ = raw_input_no_history(prompt) op.__start_time = time() op(img) op.__stop_time = time() gc.collect() if img.opts.interactive and not img._pi: print dc + 'Fitting complete. Displaying results...' + nc if img.opts.shapelet_do: show_smod = True show_sres = True else: show_smod = False show_sres = False if img.opts.spectralindex_do: show_spec = True else: show_spec = False if max(img.ch0_arr.shape) > 4096: print dc + '--> Image is large. Showing Gaussian residual image only.' + nc img.show_fit(rms_image=False, mean_image=False, ch0_image=False, ch0_islands=False, gresid_image=True, sresid_image=False, gmodel_image=False, smodel_image=False, pyramid_srcs=False, source_seds=show_spec) else: img.show_fit(smodel_image=show_smod, sresid_image=show_sres, source_seds=show_spec) if img.opts.print_timing: print "="*36 print "%18s : %10s" % ("Module", "Time (sec)") print "-"*36 for i, op in enumerate(chain): if hasattr(op, '__start_time'): print "%18s : %f" % (op.__class__.__name__, (op.__stop_time - op.__start_time)) indx_stop = i print "="*36 print "%18s : %f" % ("Total", (chain[indx_stop].__stop_time - chain[0].__start_time)) # Log all internally derived parameters mylog = mylogger.logging.getLogger("PyBDSM.Final") par_msg = "Internally derived parameters:\n" import inspect import types for attr in inspect.getmembers(img.opts): if attr[0][0] != '_': if isinstance(attr[1], (int, str, bool, float, types.NoneType, tuple, list)): if hasattr(img, attr[0]): used = img.__getattribute__(attr[0]) if used != attr[1] and isinstance(used, (int, str, bool, float, types.NoneType, tuple, list)): par_msg += ' %-20s : %s\n' % (attr[0], repr(used)) mylog.info(par_msg[:-1]) # -1 is to trim final newline return True
def __call__(self, img): mylog = mylogger.logging.getLogger("PyBDSM." + img.log + "Wavelet") if img.opts.atrous_do: if img.nisl == 0: mylog.warning("No islands found. Skipping wavelet decomposition.") img.completed_Ops.append('wavelet_atrous') return mylog.info("Decomposing gaussian residual image into a-trous wavelets") bdir = img.basedir + '/wavelet/' if img.opts.output_all: if not os.path.isdir(bdir): os.makedirs(bdir) if not os.path.isdir(bdir + '/residual/'): os.makedirs(bdir + '/residual/') if not os.path.isdir(bdir + '/model/'): os.makedirs(bdir + '/model/') dobdsm = img.opts.atrous_bdsm_do filter = {'tr':{'size':3, 'vec':[1. / 4, 1. / 2, 1. / 4], 'name':'Triangle'}, 'b3':{'size':5, 'vec':[1. / 16, 1. / 4, 3. / 8, 1. / 4, 1. / 16], 'name':'B3 spline'}} if dobdsm: wchain, wopts = self.setpara_bdsm(img) n, m = img.ch0_arr.shape # Calculate residual image that results from normal (non-wavelet) Gaussian fitting Op_make_residimage()(img) resid = img.resid_gaus_arr lpf = img.opts.atrous_lpf if lpf not in ['b3', 'tr']: lpf = 'b3' jmax = img.opts.atrous_jmax l = len(filter[lpf]['vec']) # 1st 3 is arbit and 2nd 3 is whats expected for a-trous if jmax < 1 or jmax > 15: # determine jmax # Check if largest island size is # smaller than 1/3 of image size. If so, use it to determine jmax. min_size = min(resid.shape) max_isl_shape = (0, 0) for isl in img.islands: if isl.image.shape[0] * isl.image.shape[1] > max_isl_shape[0] * max_isl_shape[1]: max_isl_shape = isl.image.shape if max_isl_shape != (0, 0) and min(max_isl_shape) < min(resid.shape) / 3.0: min_size = min(max_isl_shape) * 4.0 else: min_size = min(resid.shape) jmax = int(floor(log((min_size / 3.0 * 3.0 - l) / (l - 1) + 1) / log(2.0) + 1.0)) + 1 if min_size * 0.55 <= (l + (l - 1) * (2 ** (jmax) - 1)): jmax = jmax - 1 img.wavelet_lpf = lpf img.wavelet_jmax = jmax mylog.info("Using " + filter[lpf]['name'] + ' filter with J_max = ' + str(jmax)) img.atrous_islands = [] img.atrous_gaussians = [] img.atrous_sources = [] img.atrous_opts = [] img.resid_wavelets_arr = cp(img.resid_gaus_arr) im_old = img.resid_wavelets_arr total_flux = 0.0 ntot_wvgaus = 0 stop_wav = False pix_masked = N.where(N.isnan(resid) == True) jmin = 1 if img.opts.ncores is None: numcores = 1 else: numcores = img.opts.ncores for j in range(jmin, jmax + 1): # extra +1 is so we can do bdsm on cJ as well mylogger.userinfo(mylog, "\nWavelet scale #" + str(j)) im_new = self.atrous(im_old, filter[lpf]['vec'], lpf, j, numcores=numcores, use_scipy_fft=img.opts.use_scipy_fft) im_new[pix_masked] = N.nan # since fftconvolve wont work with blanked pixels if img.opts.atrous_sum: w = im_new else: w = im_old - im_new im_old = im_new suffix = 'w' + `j` filename = img.imagename + '.atrous.' + suffix + '.fits' if img.opts.output_all: func.write_image_to_file('fits', filename, w, img, bdir) mylog.info('%s %s' % ('Wrote ', img.imagename + '.atrous.' + suffix + '.fits')) # now do bdsm on each wavelet image. if dobdsm: wopts['filename'] = filename wopts['basedir'] = bdir box = img.rms_box[0] y1 = (l + (l - 1) * (2 ** (j - 1) - 1)) bs = max(5 * y1, box) # changed from 10 to 5 if bs > min(n, m) / 2: wopts['rms_map'] = False wopts['mean_map'] = 'const' wopts['rms_box'] = None else: wopts['rms_box'] = (bs, bs/3) if hasattr(img, '_adapt_rms_isl_pos'): bs_bright = max(5 * y1, img.rms_box_bright[0]) if bs_bright < bs/1.5: wopts['adaptive_rms_box'] = True wopts['rms_box_bright'] = (bs_bright, bs_bright/3) else: wopts['adaptive_rms_box'] = False if j <= 3: wopts['ini_gausfit'] = 'default' else: wopts['ini_gausfit'] = 'nobeam' wid = (l + (l - 1) * (2 ** (j - 1) - 1))# / 3.0 b1, b2 = img.pixel_beam()[0:2] b1 = b1 * fwsig b2 = b2 * fwsig cdelt = img.wcs_obj.acdelt[:2] wimg = Image(wopts) wimg.beam = (sqrt(wid * wid + b1 * b1) * cdelt[0] * 2.0, sqrt(wid * wid + b2 * b2) * cdelt[1] * 2.0, 0.0) wimg.orig_beam = img.beam wimg.pixel_beam = img.pixel_beam wimg.pixel_beamarea = img.pixel_beamarea wimg.log = 'Wavelet.' wimg.basedir = img.basedir wimg.extraparams['bbsprefix'] = suffix wimg.extraparams['bbsname'] = img.imagename + '.wavelet' wimg.extraparams['bbsappend'] = True wimg.bbspatchnum = img.bbspatchnum wimg.waveletimage = True wimg.j = j if hasattr(img, '_adapt_rms_isl_pos'): wimg._adapt_rms_isl_pos = img._adapt_rms_isl_pos self.init_image_simple(wimg, img, w, '.atrous.' + suffix) for op in wchain: op(wimg) gc.collect() if isinstance(op, Op_islands) and img.opts.atrous_orig_isl: if wimg.nisl > 0: # Find islands that do not share any pixels with # islands in original ch0 image. good_isl = [] # Make original rank image boolean; rank counts from 0, with -1 being # outside any island orig_rankim_bool = N.array(img.pyrank + 1, dtype = bool) # Multiply rank images old_islands = orig_rankim_bool * (wimg.pyrank + 1) - 1 # Exclude islands that don't overlap with a ch0 island. valid_ids = set(old_islands.flatten()) for idx, wvisl in enumerate(wimg.islands): if idx in valid_ids: wvisl.valid = True good_isl.append(wvisl) else: wvisl.valid = False wimg.islands = good_isl wimg.nisl = len(good_isl) mylogger.userinfo(mylog, "Number of islands found", '%i' % wimg.nisl) # Renumber islands: for wvindx, wvisl in enumerate(wimg.islands): wvisl.island_id = wvindx if isinstance(op, Op_gausfit): # If opts.atrous_orig_isl then exclude Gaussians outside of # the original ch0 islands nwvgaus = 0 if img.opts.atrous_orig_isl: gaul = wimg.gaussians tot_flux = 0.0 if img.ngaus == 0: gaus_id = -1 else: gaus_id = img.gaussians[-1].gaus_num wvgaul = [] for g in gaul: if not hasattr(g, 'valid'): g.valid = False if not g.valid: try: isl_id = img.pyrank[int(g.centre_pix[0] + 1), int(g.centre_pix[1] + 1)] except IndexError: isl_id = -1 if isl_id >= 0: isl = img.islands[isl_id] gcenter = (g.centre_pix[0] - isl.origin[0], g.centre_pix[1] - isl.origin[1]) if not isl.mask_active[gcenter]: gaus_id += 1 gcp = Gaussian(img, g.parameters[:], isl.island_id, gaus_id) gcp.gaus_num = gaus_id gcp.wisland_id = g.island_id gcp.jlevel = j g.valid = True isl.gaul.append(gcp) isl.ngaus += 1 img.gaussians.append(gcp) nwvgaus += 1 tot_flux += gcp.total_flux else: g.valid = False g.jlevel = 0 else: g.valid = False g.jlevel = 0 vg = [] for g in wimg.gaussians: if g.valid: vg.append(g) wimg.gaussians = vg mylogger.userinfo(mylog, "Number of valid wavelet Gaussians", str(nwvgaus)) else: # Keep all Gaussians and merge islands that overlap tot_flux = check_islands_for_overlap(img, wimg) # Now renumber the islands and adjust the rank image before going to next wavelet image renumber_islands(img) total_flux += tot_flux if img.opts.interactive and has_pl: dc = '\033[34;1m' nc = '\033[0m' print dc + '--> Displaying islands and rms image...' + nc if max(wimg.ch0_arr.shape) > 4096: print dc + '--> Image is large. Showing islands only.' + nc wimg.show_fit(rms_image=False, mean_image=False, ch0_image=False, ch0_islands=True, gresid_image=False, sresid_image=False, gmodel_image=False, smodel_image=False, pyramid_srcs=False) else: wimg.show_fit() prompt = dc + "Press enter to continue or 'q' stop fitting wavelet images : " + nc answ = raw_input_no_history(prompt) while answ != '': if answ == 'q': img.wavelet_jmax = j stop_wav = True break answ = raw_input_no_history(prompt) if len(wimg.gaussians) > 0: img.resid_wavelets_arr = self.subtract_wvgaus(img.opts, img.resid_wavelets_arr, wimg.gaussians, wimg.islands) if img.opts.atrous_sum: im_old = self.subtract_wvgaus(img.opts, im_old, wimg.gaussians, wimg.islands) if stop_wav == True: break pyrank = N.zeros(img.pyrank.shape, dtype=N.int32) for i, isl in enumerate(img.islands): isl.island_id = i for g in isl.gaul: g.island_id = i for dg in isl.dgaul: dg.island_id = i pyrank[isl.bbox] += N.invert(isl.mask_active) * (i + 1) pyrank -= 1 # align pyrank values with island ids and set regions outside of islands to -1 img.pyrank = pyrank pdir = img.basedir + '/misc/' img.ngaus += ntot_wvgaus img.total_flux_gaus += total_flux mylogger.userinfo(mylog, "Total flux density in model on all scales" , '%.3f Jy' % img.total_flux_gaus) if img.opts.output_all: func.write_image_to_file('fits', img.imagename + '.atrous.cJ.fits', im_new, img, bdir) mylog.info('%s %s' % ('Wrote ', img.imagename + '.atrous.cJ.fits')) func.write_image_to_file('fits', img.imagename + '.resid_wavelets.fits', (img.ch0_arr - img.resid_gaus_arr + img.resid_wavelets_arr), img, bdir + '/residual/') mylog.info('%s %s' % ('Wrote ', img.imagename + '.resid_wavelets.fits')) func.write_image_to_file('fits', img.imagename + '.model_wavelets.fits', (img.resid_gaus_arr - img.resid_wavelets_arr), img, bdir + '/model/') mylog.info('%s %s' % ('Wrote ', img.imagename + '.model_wavelets.fits')) img.completed_Ops.append('wavelet_atrous')