예제 #1
0
    def setpara_bdsm(self, img):
        from types import ClassType, TypeType

        chain = [
            Op_preprocess,
            Op_rmsimage(),
            Op_threshold(),
            Op_islands(),
            Op_gausfit(),
            Op_gaul2srl(),
            Op_make_residimage()
        ]

        opts = img.opts.to_dict()
        if img.opts.pi_thresh_isl != None:
            opts['thresh_isl'] = img.opts.pi_thresh_isl
        if img.opts.pi_thresh_pix != None:
            opts['thresh_pix'] = img.opts.pi_thresh_pix
        opts['thresh'] = 'hard'
        opts['polarisation_do'] = False
        opts['filename'] = ''
        opts['detection_image'] = ''

        ops = []
        for op in chain:
            if isinstance(op, (ClassType, TypeType)):
                ops.append(op())
            else:
                ops.append(op)

        return ops, opts
예제 #2
0
    def setpara_bdsm(self, img):
        from types import ClassType, TypeType

        chain = [
            Op_preprocess,
            Op_rmsimage(),
            Op_threshold(),
            Op_islands(),
            Op_gausfit(),
            Op_gaul2srl(),
            Op_make_residimage()
        ]

        opts = {'thresh': 'hard'}
        opts['thresh_pix'] = img.thresh_pix
        opts['kappa_clip'] = 3.0
        opts['rms_map'] = img.opts.rms_map
        opts['mean_map'] = img.opts.mean_map
        opts['thresh_isl'] = img.opts.thresh_isl
        opts['minpix_isl'] = 6
        opts['savefits_rmsim'] = False
        opts['savefits_meanim'] = False
        opts['savefits_rankim'] = False
        opts['savefits_normim'] = False
        opts['polarisation_do'] = False
        opts['aperture'] = None
        opts['group_by_isl'] = img.opts.group_by_isl
        opts['quiet'] = img.opts.quiet
        opts['ncores'] = img.opts.ncores

        opts['flag_smallsrc'] = False
        opts['flag_minsnr'] = 0.2
        opts['flag_maxsnr'] = 1.2
        opts['flag_maxsize_isl'] = 2.5
        opts['flag_bordersize'] = 0
        opts['flag_maxsize_bm'] = 50.0
        opts['flag_minsize_bm'] = 0.2
        opts['flag_maxsize_fwhm'] = 0.5
        opts['bbs_patches'] = img.opts.bbs_patches
        opts['filename'] = ''
        opts['output_all'] = img.opts.output_all
        opts['verbose_fitting'] = img.opts.verbose_fitting
        opts['split_isl'] = False
        opts['peak_fit'] = True
        opts['peak_maxsize'] = 30.0
        opts['detection_image'] = ''
        opts['verbose_fitting'] = img.opts.verbose_fitting

        ops = []
        for op in chain:
            if isinstance(op, (ClassType, TypeType)):
                ops.append(op())
            else:
                ops.append(op)

        return ops, opts
예제 #3
0
    def subtract_wvgaus(self, opts, residim, gaussians, islands):
        import functions as func
        from make_residimage import Op_make_residimage as opp

        dummy = opp()
        shape = residim.shape
        thresh = opts.fittedimage_clip

        for g in gaussians:
            if g.valid:
                C1, C2 = g.centre_pix
                if hasattr(g, 'wisland_id'):
                    isl = islands[g.wisland_id]
                else:
                    isl = islands[g.island_id]
                b = opp.find_bbox(dummy, thresh * isl.rms, g)
                bbox = N.s_[max(0, int(C1 - b)):min(shape[0], int(C1 + b + 1)),
                            max(0, int(C2 - b)):min(shape[1], int(C2 + b + 1))]
                x_ax, y_ax = N.mgrid[bbox]
                ffimg = func.gaussian_fcn(g, x_ax, y_ax)
                residim[bbox] = residim[bbox] - ffimg

        return residim
예제 #4
0
    def subtract_wvgaus(self, opts, residim, gaussians, islands):
        import functions as func
        from make_residimage import Op_make_residimage as opp

        dummy = opp()
        shape = residim.shape
        thresh = opts.fittedimage_clip

        for g in gaussians:
          if g.valid:
              C1, C2 = g.centre_pix
              if hasattr(g, 'wisland_id'):
                  isl = islands[g.wisland_id]
              else:
                  isl = islands[g.island_id]
              b = opp.find_bbox(dummy, thresh * isl.rms, g)
              bbox = N.s_[max(0, int(C1 - b)):min(shape[0], int(C1 + b + 1)),
                          max(0, int(C2 - b)):min(shape[1], int(C2 + b + 1))]
              x_ax, y_ax = N.mgrid[bbox]
              ffimg = func.gaussian_fcn(g, x_ax, y_ax)
              residim[bbox] = residim[bbox] - ffimg

        return residim
예제 #5
0
    def __call__(self, img):

        mylog = mylogger.logging.getLogger("PyBDSM." + img.log + "Wavelet")

        if img.opts.atrous_do:
            if img.nisl == 0:
                mylog.warning(
                    "No islands found. Skipping wavelet decomposition.")
                img.completed_Ops.append('wavelet_atrous')
                return

            mylog.info(
                "Decomposing gaussian residual image into a-trous wavelets")
            bdir = img.basedir + '/wavelet/'
            if img.opts.output_all:
                if not os.path.isdir(bdir): os.makedirs(bdir)
                if not os.path.isdir(bdir + '/residual/'):
                    os.makedirs(bdir + '/residual/')
                if not os.path.isdir(bdir + '/model/'):
                    os.makedirs(bdir + '/model/')
            dobdsm = img.opts.atrous_bdsm_do
            filter = {
                'tr': {
                    'size': 3,
                    'vec': [1. / 4, 1. / 2, 1. / 4],
                    'name': 'Triangle'
                },
                'b3': {
                    'size': 5,
                    'vec': [1. / 16, 1. / 4, 3. / 8, 1. / 4, 1. / 16],
                    'name': 'B3 spline'
                }
            }

            if dobdsm: wchain, wopts = self.setpara_bdsm(img)

            n, m = img.ch0_arr.shape

            # Calculate residual image that results from normal (non-wavelet) Gaussian fitting
            Op_make_residimage()(img)
            resid = img.resid_gaus_arr

            lpf = img.opts.atrous_lpf
            if lpf not in ['b3', 'tr']: lpf = 'b3'
            jmax = img.opts.atrous_jmax
            l = len(filter[lpf]['vec']
                    )  # 1st 3 is arbit and 2nd 3 is whats expected for a-trous
            if jmax < 1 or jmax > 15:  # determine jmax
                # Check if largest island size is
                # smaller than 1/3 of image size. If so, use it to determine jmax.
                min_size = min(resid.shape)
                max_isl_shape = (0, 0)
                for isl in img.islands:
                    if isl.image.shape[0] * isl.image.shape[1] > max_isl_shape[
                            0] * max_isl_shape[1]:
                        max_isl_shape = isl.image.shape
                if max_isl_shape != (
                        0, 0) and min(max_isl_shape) < min(resid.shape) / 3.0:
                    min_size = min(max_isl_shape) * 4.0
                else:
                    min_size = min(resid.shape)
                jmax = int(
                    floor(
                        log((min_size / 3.0 * 3.0 - l) /
                            (l - 1) + 1) / log(2.0) + 1.0)) + 1
                if min_size * 0.55 <= (l + (l - 1) * (2**(jmax) - 1)):
                    jmax = jmax - 1
            img.wavelet_lpf = lpf
            img.wavelet_jmax = jmax
            mylog.info("Using " + filter[lpf]['name'] +
                       ' filter with J_max = ' + str(jmax))

            img.atrous_islands = []
            img.atrous_gaussians = []
            img.atrous_sources = []
            img.atrous_opts = []
            img.resid_wavelets_arr = cp(img.resid_gaus_arr)

            im_old = img.resid_wavelets_arr
            total_flux = 0.0
            ntot_wvgaus = 0
            stop_wav = False
            pix_masked = N.where(N.isnan(resid) == True)
            jmin = 1
            if img.opts.ncores is None:
                numcores = 1
            else:
                numcores = img.opts.ncores
            for j in range(jmin, jmax +
                           1):  # extra +1 is so we can do bdsm on cJ as well
                mylogger.userinfo(mylog, "\nWavelet scale #" + str(j))
                im_new = self.atrous(im_old,
                                     filter[lpf]['vec'],
                                     lpf,
                                     j,
                                     numcores=numcores,
                                     use_scipy_fft=img.opts.use_scipy_fft)
                im_new[
                    pix_masked] = N.nan  # since fftconvolve wont work with blanked pixels
                if img.opts.atrous_sum:
                    w = im_new
                else:
                    w = im_old - im_new
                im_old = im_new
                suffix = 'w' + ` j `
                filename = img.imagename + '.atrous.' + suffix + '.fits'
                if img.opts.output_all:
                    func.write_image_to_file('fits', filename, w, img, bdir)
                    mylog.info('%s %s' % ('Wrote ', img.imagename +
                                          '.atrous.' + suffix + '.fits'))

                # now do bdsm on each wavelet image.
                if dobdsm:
                    wopts['filename'] = filename
                    wopts['basedir'] = bdir
                    box = img.rms_box[0]
                    y1 = (l + (l - 1) * (2**(j - 1) - 1))
                    bs = max(5 * y1, box)  # changed from 10 to 5
                    if bs > min(n, m) / 2:
                        wopts['rms_map'] = False
                        wopts['mean_map'] = 'const'
                        wopts['rms_box'] = None
                    else:
                        wopts['rms_box'] = (bs, bs / 3)
                        if hasattr(img, '_adapt_rms_isl_pos'):
                            bs_bright = max(5 * y1, img.rms_box_bright[0])
                            if bs_bright < bs / 1.5:
                                wopts['adaptive_rms_box'] = True
                                wopts['rms_box_bright'] = (bs_bright,
                                                           bs_bright / 3)
                            else:
                                wopts['adaptive_rms_box'] = False
                    if j <= 3:
                        wopts['ini_gausfit'] = 'default'
                    else:
                        wopts['ini_gausfit'] = 'nobeam'
                    wid = (l + (l - 1) * (2**(j - 1) - 1))  # / 3.0
                    b1, b2 = img.pixel_beam()[0:2]
                    b1 = b1 * fwsig
                    b2 = b2 * fwsig
                    cdelt = img.wcs_obj.acdelt[:2]

                    wimg = Image(wopts)
                    wimg.beam = (sqrt(wid * wid + b1 * b1) * cdelt[0] * 2.0,
                                 sqrt(wid * wid + b2 * b2) * cdelt[1] * 2.0,
                                 0.0)
                    wimg.orig_beam = img.beam
                    wimg.pixel_beam = img.pixel_beam
                    wimg.pixel_beamarea = img.pixel_beamarea
                    wimg.log = 'Wavelet.'
                    wimg.basedir = img.basedir
                    wimg.extraparams['bbsprefix'] = suffix
                    wimg.extraparams['bbsname'] = img.imagename + '.wavelet'
                    wimg.extraparams['bbsappend'] = True
                    wimg.bbspatchnum = img.bbspatchnum
                    wimg.waveletimage = True
                    wimg.j = j
                    if hasattr(img, '_adapt_rms_isl_pos'):
                        wimg._adapt_rms_isl_pos = img._adapt_rms_isl_pos

                    self.init_image_simple(wimg, img, w, '.atrous.' + suffix)
                    for op in wchain:
                        op(wimg)
                        gc.collect()
                        if isinstance(op,
                                      Op_islands) and img.opts.atrous_orig_isl:
                            if wimg.nisl > 0:

                                # Find islands that do not share any pixels with
                                # islands in original ch0 image.
                                good_isl = []

                                # Make original rank image boolean; rank counts from 0, with -1 being
                                # outside any island
                                orig_rankim_bool = N.array(img.pyrank + 1,
                                                           dtype=bool)

                                # Multiply rank images
                                old_islands = orig_rankim_bool * (wimg.pyrank +
                                                                  1) - 1

                                # Exclude islands that don't overlap with a ch0 island.
                                valid_ids = set(old_islands.flatten())
                                for idx, wvisl in enumerate(wimg.islands):
                                    if idx in valid_ids:
                                        wvisl.valid = True
                                        good_isl.append(wvisl)
                                    else:
                                        wvisl.valid = False

                                wimg.islands = good_isl
                                wimg.nisl = len(good_isl)
                                mylogger.userinfo(mylog,
                                                  "Number of islands found",
                                                  '%i' % wimg.nisl)

                                # Renumber islands:
                                for wvindx, wvisl in enumerate(wimg.islands):
                                    wvisl.island_id = wvindx

                        if isinstance(op, Op_gausfit):
                            # If opts.atrous_orig_isl then exclude Gaussians outside of
                            # the original ch0 islands
                            nwvgaus = 0
                            if img.opts.atrous_orig_isl:
                                gaul = wimg.gaussians
                                tot_flux = 0.0

                                if img.ngaus == 0:
                                    gaus_id = -1
                                else:
                                    gaus_id = img.gaussians[-1].gaus_num
                                wvgaul = []
                                for g in gaul:
                                    if not hasattr(g, 'valid'):
                                        g.valid = False
                                    if not g.valid:
                                        try:
                                            isl_id = img.pyrank[
                                                int(g.centre_pix[0] + 1),
                                                int(g.centre_pix[1] + 1)]
                                        except IndexError:
                                            isl_id = -1
                                        if isl_id >= 0:
                                            isl = img.islands[isl_id]
                                            gcenter = (g.centre_pix[0] -
                                                       isl.origin[0],
                                                       g.centre_pix[1] -
                                                       isl.origin[1])
                                            if not isl.mask_active[gcenter]:
                                                gaus_id += 1
                                                gcp = Gaussian(
                                                    img, g.parameters[:],
                                                    isl.island_id, gaus_id)
                                                gcp.gaus_num = gaus_id
                                                gcp.wisland_id = g.island_id
                                                gcp.jlevel = j
                                                g.valid = True
                                                isl.gaul.append(gcp)
                                                isl.ngaus += 1
                                                img.gaussians.append(gcp)
                                                nwvgaus += 1
                                                tot_flux += gcp.total_flux
                                            else:
                                                g.valid = False
                                                g.jlevel = 0
                                        else:
                                            g.valid = False
                                            g.jlevel = 0
                                vg = []
                                for g in wimg.gaussians:
                                    if g.valid:
                                        vg.append(g)
                                wimg.gaussians = vg
                                mylogger.userinfo(
                                    mylog, "Number of valid wavelet Gaussians",
                                    str(nwvgaus))
                            else:
                                # Keep all Gaussians and merge islands that overlap
                                tot_flux = check_islands_for_overlap(img, wimg)

                                # Now renumber the islands and adjust the rank image before going to next wavelet image
                                renumber_islands(img)

                    total_flux += tot_flux
                    if img.opts.interactive and has_pl:
                        dc = '\033[34;1m'
                        nc = '\033[0m'
                        print dc + '--> Displaying islands and rms image...' + nc
                        if max(wimg.ch0_arr.shape) > 4096:
                            print dc + '--> Image is large. Showing islands only.' + nc
                            wimg.show_fit(rms_image=False,
                                          mean_image=False,
                                          ch0_image=False,
                                          ch0_islands=True,
                                          gresid_image=False,
                                          sresid_image=False,
                                          gmodel_image=False,
                                          smodel_image=False,
                                          pyramid_srcs=False)
                        else:
                            wimg.show_fit()
                        prompt = dc + "Press enter to continue or 'q' stop fitting wavelet images : " + nc
                        answ = raw_input_no_history(prompt)
                        while answ != '':
                            if answ == 'q':
                                img.wavelet_jmax = j
                                stop_wav = True
                                break
                            answ = raw_input_no_history(prompt)
                    if len(wimg.gaussians) > 0:
                        img.resid_wavelets_arr = self.subtract_wvgaus(
                            img.opts, img.resid_wavelets_arr, wimg.gaussians,
                            wimg.islands)
                        if img.opts.atrous_sum:
                            im_old = self.subtract_wvgaus(
                                img.opts, im_old, wimg.gaussians, wimg.islands)
                    if stop_wav == True:
                        break

            pyrank = N.zeros(img.pyrank.shape, dtype=N.int32)
            for i, isl in enumerate(img.islands):
                isl.island_id = i
                for g in isl.gaul:
                    g.island_id = i
                for dg in isl.dgaul:
                    dg.island_id = i
                pyrank[isl.bbox] += N.invert(isl.mask_active) * (i + 1)
            pyrank -= 1  # align pyrank values with island ids and set regions outside of islands to -1
            img.pyrank = pyrank

            pdir = img.basedir + '/misc/'
            img.ngaus += ntot_wvgaus
            img.total_flux_gaus += total_flux
            mylogger.userinfo(mylog,
                              "Total flux density in model on all scales",
                              '%.3f Jy' % img.total_flux_gaus)
            if img.opts.output_all:
                func.write_image_to_file('fits',
                                         img.imagename + '.atrous.cJ.fits',
                                         im_new, img, bdir)
                mylog.info('%s %s' %
                           ('Wrote ', img.imagename + '.atrous.cJ.fits'))
                func.write_image_to_file(
                    'fits', img.imagename + '.resid_wavelets.fits',
                    (img.ch0_arr - img.resid_gaus_arr +
                     img.resid_wavelets_arr), img, bdir + '/residual/')
                mylog.info('%s %s' %
                           ('Wrote ', img.imagename + '.resid_wavelets.fits'))
                func.write_image_to_file(
                    'fits', img.imagename + '.model_wavelets.fits',
                    (img.resid_gaus_arr - img.resid_wavelets_arr), img,
                    bdir + '/model/')
                mylog.info('%s %s' %
                           ('Wrote ', img.imagename + '.model_wavelets.fits'))
            img.completed_Ops.append('wavelet_atrous')
예제 #6
0
from _version import __version__
import gc

default_chain = [Op_readimage(),
                 Op_collapse(),
                 Op_preprocess(),
                 Op_rmsimage(),
                 Op_threshold(),
                 Op_islands(),
                 Op_gausfit(),
                 Op_wavelet_atrous(),
                 Op_shapelets(),
                 Op_gaul2srl(),
                 Op_spectralindex(),
                 Op_polarisation(),
                 Op_make_residimage(),
                 Op_psf_vary(),
                 Op_outlist(),
                 Op_cleanup()
                 ]
fits_chain = default_chain # for legacy scripts

def execute(chain, opts):
    """Execute chain.

    Create new Image with given options and apply chain of
    operations to it. The opts input must be a dictionary.
    """
    from image import Image
    import mylogger