示例#1
0
def searchdask(a, v, how=None, atol=None):
    n_a = a.shape[0]
    searchfunc, args = presearch(a, v)

    if how == 'nearest':
        l_index = da.maximum(searchfunc(*args, side='right') - 1, 0)
        r_index = da.minimum(searchfunc(*args), n_a - 1)
        cond = 2 * v < (select(a, r_index) + select(a, l_index))
        indexer = da.maximum(da.where(cond, l_index, r_index), 0)
    elif how == 'bfill':
        indexer = searchfunc(*args)
    elif how == 'ffill':
        indexer = searchfunc(*args, side='right') - 1
        indexer = da.where(indexer == -1, n_a, indexer)
    elif how is None:
        l_index = searchfunc(*args)
        r_index = searchfunc(*args, side='right')
        indexer = da.where(l_index == r_index, n_a, l_index)
    else:
        return NotImplementedError

    if atol is not None:
        a2 = da.concatenate([a, [atol + da.max(v) + 1]])
        indexer = da.where(
            da.absolute(select(a2, indexer) - v) > atol, n_a, indexer)
    return indexer
    def envelope(self, darray, preview=None):
        """
        Description
        -----------
        Compute the Envelope of the input data
        
        Parameters
        ----------
        darray : Array-like, acceptable inputs include Numpy, HDF5, or Dask Arrays
        
        Keywork Arguments
        -----------------    
        preview : str, enables or disables preview mode and specifies direction
            Acceptable inputs are (None, 'inline', 'xline', 'z')
            Optimizes chunk size in different orientations to facilitate rapid
            screening of algorithm output
        
        Returns
        -------
        result : Dask Array
        """

        kernel = (1, 1, 25)
        darray, chunks_init = self.create_array(darray,
                                                kernel,
                                                preview=preview)
        analytical_trace = darray.map_blocks(util.hilbert, dtype=darray.dtype)
        result = da.absolute(analytical_trace)
        result = util.trim_dask_array(result, kernel)

        return (result)
示例#3
0
    def __call__(self, projectables, optional_datasets=None, **info):
        """Get the corrected reflectance when removing Rayleigh scattering.

        Uses pyspectral.
        """
        from pyspectral.rayleigh import Rayleigh
        if not optional_datasets or len(optional_datasets) != 4:
            vis, red = self.match_data_arrays(projectables)
            sata, satz, suna, sunz = self.get_angles(vis)
            red.data = da.rechunk(red.data, vis.data.chunks)
        else:
            vis, red, sata, satz, suna, sunz = self.match_data_arrays(
                projectables + optional_datasets)
            sata, satz, suna, sunz = optional_datasets
            # get the dask array underneath
            sata = sata.data
            satz = satz.data
            suna = suna.data
            sunz = sunz.data

        # First make sure the two azimuth angles are in the range 0-360:
        sata = sata % 360.
        suna = suna % 360.
        ssadiff = da.absolute(suna - sata)
        ssadiff = da.minimum(ssadiff, 360 - ssadiff)
        del sata, suna

        atmosphere = self.attrs.get('atmosphere', 'us-standard')
        aerosol_type = self.attrs.get('aerosol_type', 'marine_clean_aerosol')
        rayleigh_key = (vis.attrs['platform_name'], vis.attrs['sensor'],
                        atmosphere, aerosol_type)
        logger.info(
            "Removing Rayleigh scattering with atmosphere '%s' and "
            "aerosol type '%s' for '%s'", atmosphere, aerosol_type,
            vis.attrs['name'])
        if rayleigh_key not in self._rayleigh_cache:
            corrector = Rayleigh(vis.attrs['platform_name'],
                                 vis.attrs['sensor'],
                                 atmosphere=atmosphere,
                                 aerosol_type=aerosol_type)
            self._rayleigh_cache[rayleigh_key] = corrector
        else:
            corrector = self._rayleigh_cache[rayleigh_key]

        try:
            refl_cor_band = corrector.get_reflectance(sunz, satz, ssadiff,
                                                      vis.attrs['name'],
                                                      red.data)
        except (KeyError, IOError):
            logger.warning(
                "Could not get the reflectance correction using band name: %s",
                vis.attrs['name'])
            logger.warning(
                "Will try use the wavelength, however, this may be ambiguous!")
            refl_cor_band = corrector.get_reflectance(
                sunz, satz, ssadiff, vis.attrs['wavelength'][1], red.data)
        proj = vis - refl_cor_band
        proj.attrs = vis.attrs
        self.apply_modifier_info(vis, proj)
        return proj
    def instantaneous_frequency(self, darray, sample_rate=4, preview=None):
        """
        Description
        -----------
        Compute the Instantaneous Frequency of the input data
        
        Parameters
        ----------
        darray : Array-like, acceptable inputs include Numpy, HDF5, or Dask Arrays
        
        Keywork Arguments
        -----------------  
        sample_rate : Number, sample rate in milliseconds (ms)
        preview : str, enables or disables preview mode and specifies direction
            Acceptable inputs are (None, 'inline', 'xline', 'z')
            Optimizes chunk size in different orientations to facilitate rapid
            screening of algorithm output
        
        Returns
        -------
        result : Dask Array
        """

        darray, chunks_init = self.create_array(darray, preview=preview)

        fs = 1000 / sample_rate
        phase = self.instantaneous_phase(darray)
        phase = da.deg2rad(phase)
        phase = phase.map_blocks(np.unwrap, dtype=darray.dtype)
        phase_prime = sp().first_derivative(phase, axis=-1)
        result = da.absolute((phase_prime / (2.0 * np.pi) * fs))

        return (result)
示例#5
0
    def _select_by_prediction(self, unlabel_index, predict, batch_size=1):
        predict_shape = da.shape(predict)

        assert (len(predict_shape) in [1, 2])
        if len(predict_shape) == 2:
            if predict_shape[1] != 1:
                raise Exception(
                    '1d or 2d with 1 column array is expected, but received: \n%s'
                    % str(predict))
            else:
                pv = da.absolute(predict.flatten())
        else:
            pv = da.absolute(predict)

        tpl = da.from_array(unlabel_index)
        return tpl[nsmallestarg(pv, batch_size)].compute()
示例#6
0
def power_spectrum(filter, time):
    """Compute the mean power spectrum over all particles at a given time.

    This routine gives the power spectrum (power spectral density) for
    each of the sampled variables within `filter`, as a mean over
    all particles. It will run a single advection step at the
    specified time. The resulting dictionary contains a `freq` item,
    with the FFT frequency bins for the output spectra.

    Args:
        filter (filtering.LagrangeFilter): The pre-configured filter object
            to use for running the analysis.
        time (float): The time at which to perform the analysis.

    Returns:
        Dict[str, numpy.ndarray]: A dictionary of power spectra for each of
            the sampled variables on the filter.

    """

    psds = {}
    advection_data = filter.advection_step(time, output_time=True)
    time_series = advection_data.pop("time")

    for v, a in advection_data.items():
        spectra = da.fft.fft(a[1].rechunk((-1, "auto")), axis=0)
        mean_spectrum = da.nanmean(da.absolute(spectra) ** 2, axis=1)
        psds[v] = mean_spectrum.compute()

    psds["freq"] = 2 * np.pi * np.fft.fftfreq(time_series.size, filter.output_dt)

    return psds
    def instantaneous_bandwidth(self, darray, preview=None):
        """
        Description
        -----------
        Compute the Instantaneous Bandwidth of the input data
        
        Parameters
        ----------
        darray : Array-like, acceptable inputs include Numpy, HDF5, or Dask Arrays
        
        Keywork Arguments
        -----------------    
        preview : str, enables or disables preview mode and specifies direction
            Acceptable inputs are (None, 'inline', 'xline', 'z')
            Optimizes chunk size in different orientations to facilitate rapid
            screening of algorithm output
        
        Returns
        -------
        result : Dask Array
        """

        darray, chunks_init = self.create_array(darray, preview=preview)
        rac = self.relative_amplitude_change(darray)
        result = da.absolute(rac) / (2.0 * np.pi)

        return (result)
示例#8
0
def _ttest_finish(df, t):
    """Common code between all 3 t-test functions."""
    # XXX: np.abs -> da.absolute
    # XXX: delayed(distributions.t.sf)
    prob = delayed(distributions.t.sf)(da.absolute(t),
                                       df) * 2  # use np.abs to get upper tail
    if t.ndim == 0:
        t = t[()]

    return t, prob
示例#9
0
def pis_mVc(x,y,beta):
    '''
    rewrite mVc and mMSE,share the 'p' and 'dif'!!!!
    '''
    p=logistic_func(beta, x)
    dif=da.absolute(y-p) 
    xnorm=da.linalg.norm(x,axis=1)
    pis=dif*xnorm
    pi=pis/da.sum(pis)
    return pi
示例#10
0
def _ttest_finish(df, t):
    """Common code between all 3 t-test functions."""
    # XXX: np.abs -> da.absolute
    # XXX: delayed(distributions.t.sf)
    prob = (delayed(distributions.t.sf)(da.absolute(t), df) * 2
            )  # use np.abs to get upper tail
    if t.ndim == 0:
        t = t[()]

    return t, prob
示例#11
0
    def __call__(self, projectables, optional_datasets=None, **info):
        """Get the corrected reflectance when removing Rayleigh scattering.

        Uses pyspectral.
        """
        from pyspectral.rayleigh import Rayleigh
        if not optional_datasets or len(optional_datasets) != 4:
            vis, red = self.check_areas(projectables)
            sata, satz, suna, sunz = self.get_angles(vis)
            red.data = da.rechunk(red.data, vis.data.chunks)
        else:
            vis, red, sata, satz, suna, sunz = self.check_areas(
                projectables + optional_datasets)
            sata, satz, suna, sunz = optional_datasets
            # get the dask array underneath
            sata = sata.data
            satz = satz.data
            suna = suna.data
            sunz = sunz.data

        LOG.info('Removing Rayleigh scattering and aerosol absorption')

        # First make sure the two azimuth angles are in the range 0-360:
        sata = sata % 360.
        suna = suna % 360.
        ssadiff = da.absolute(suna - sata)
        ssadiff = da.minimum(ssadiff, 360 - ssadiff)
        del sata, suna

        atmosphere = self.attrs.get('atmosphere', 'us-standard')
        aerosol_type = self.attrs.get('aerosol_type', 'marine_clean_aerosol')
        rayleigh_key = (vis.attrs['platform_name'],
                        vis.attrs['sensor'], atmosphere, aerosol_type)
        if rayleigh_key not in self._rayleigh_cache:
            corrector = Rayleigh(vis.attrs['platform_name'], vis.attrs['sensor'],
                                 atmosphere=atmosphere,
                                 aerosol_type=aerosol_type)
            self._rayleigh_cache[rayleigh_key] = corrector
        else:
            corrector = self._rayleigh_cache[rayleigh_key]

        try:
            refl_cor_band = corrector.get_reflectance(sunz, satz, ssadiff,
                                                      vis.attrs['name'],
                                                      red.data)
        except (KeyError, IOError):
            LOG.warning("Could not get the reflectance correction using band name: %s", vis.attrs['name'])
            LOG.warning("Will try use the wavelength, however, this may be ambiguous!")
            refl_cor_band = corrector.get_reflectance(sunz, satz, ssadiff,
                                                      vis.attrs['wavelength'][1],
                                                      red.data)
        proj = vis - refl_cor_band
        proj.attrs = vis.attrs
        self.apply_modifier_info(vis, proj)
        return proj
示例#12
0
def compute_power_spectrum(xarr, r_bins=100, x_dim='X', y_dim='Y'):
    """
    Compute the radially-binned power spectrum of individual images. Intended use case if
    for xarr to be a set of z stacks of brightfield images.

    Parameters
    ----------
    xarr : xarray.DataArray
        DataArray backed by dask arrays. If the DataArray does not have named dimensions
        "x" and "y" assume that the last two dimensions correspond to image dimensions.
    r_bins : int
        Number of bins to use for radial histogram. Default 100.
    x_dim : str default 'X'
        Name of dimension corresponding to X pixels
    y_dim : str default 'Y'
        Name of dimension corresponding to Y pixels

    Returns
    -------
    log_power_spectrum : (..., r_bins)
        Log power spectrum of each individiual image in xarr.

    """

    if not isinstance(xarr, xr.DataArray):
        raise TypeError(
            "Can only compute power spectra for xarray.DataArrays.")

    if not isinstance(xarr.data, da.Array):
        xarr.data = da.array(xarr.data)

    fft_mags = xr.DataArray(
        da.fft.fftshift(da.absolute(da.fft.fft2(xarr.data))**2),
        dims=xarr.dims,
        coords=xarr.coords,
    )
    fft_mags.coords[x_dim] = np.arange(
        fft_mags[x_dim].shape[0]) - fft_mags[x_dim].shape[0] / 2
    fft_mags.coords[y_dim] = np.arange(
        fft_mags[y_dim].shape[0]) - fft_mags[y_dim].shape[0] / 2
    logR = 0.5 * xr.ufuncs.log1p(fft_mags.coords[x_dim]**2 +
                                 fft_mags.coords[y_dim]**2)
    log_power_spectra = xr.ufuncs.log1p(
        fft_mags.groupby_bins(
            logR, bins=r_bins).mean(dim=f"stacked_{x_dim}_{y_dim}"))
    log_power_spectra["group_bins"] = pd.IntervalIndex(
        log_power_spectra.group_bins.values).mid.values

    return log_power_spectra
示例#13
0
文件: utils.py 项目: ratt-ru/ragavi
def calc_amplitude(ydata):
    """Convert complex data to amplitude (absolute value)

    Parameters
    ----------
    ydata : :obj:`xarray.DataArray`
        y axis data to be processed

    Returns
    -------
    amplitude : :obj:`xarray.DataArray`
        :attr:`ydata` converted to an amplitude
    """
    logger.debug("Calculating amplitude data")

    amplitude = da.absolute(ydata)
    return amplitude
示例#14
0
def searchdaskuniform(a0, step, n_a, v, how=None, atol=None):
    index = (v - a0) / step
    if how == 'nearest':
        indexer = da.maximum(da.minimum(da.around(index), n_a - 1), 0)
    elif how == 'bfill':
        indexer = da.maximum(da.ceil(index), 0)
    elif how == 'ffill':
        indexer = da.minimum(da.floor(index), n_a - 1)
    elif how is None:
        indexer = da.ceil(index)
        indexer = da.where(indexer != index, n_a, indexer)

    if atol is not None:
        indexer = da.where((da.absolute(indexer - index) * step > atol) |
                           (indexer < 0) | (indexer >= n_a), n_a, indexer)
    else:
        indexer = da.where((indexer < 0) | (indexer >= n_a), n_a, indexer)
    return indexer.astype(int)
示例#15
0
def wiener(data, aux, fr, fr_npy, L):
    return uirdft2((da.conj(fr) * urdft2(data) + L * urdft2(aux)) /
                   (da.absolute(fr_npy)**2 + L))
示例#16
0
def smooth(xds,
           dv='IMAGE',
           kernel='gaussian',
           size=[1., 1., 30.],
           current=None,
           scale=1.0,
           name='BEAM'):
    """                                                                                                                                                                                                     
    Smooth data along the spatial plane of the image cube.
    
    Computes a correcting beam to produce defined size when kernel=gaussia and current is defined.  Otherwise the size
    or existing beam is used directly.

    Parameters
    ----------
    xds : xarray.core.dataset.Dataset
        input Image Dataset
    dv : str
        name of data_var in xds to smooth. Default is 'IMAGE'
    kernel : str
        Type of kernel to use:'boxcar', 'gaussian' or the name of a data var in this xds.  Default is 'gaussian'.
    size : list of floats
        list of three values corresponding to major and minor axes (in arcseconds) and position angle (in degrees).
    current : list of floats
        same structure as size, a list of three values corresponding to major and minor axes (in arcseconds) and position
        angle (in degrees) of the current beam applied to the image.  Default is None
    scale : float
        gain factor after convolution. Default is unity gain (1.0)
    name : str
        dataset variable name for kernel, overwrites if already present
        
    Returns                                                                                                                                                                                                 
    -------                                                                                                                                                                                                 
    xarray.core.dataset.Dataset                                                                                                                                                                             
        output Image                                                                                                                                                                                        
    """
    import xarray
    import dask.array as da
    import numpy as np
    import cngi._helper.beams as chb

    # compute kernel beam
    size_corr = None
    if kernel in xds.data_vars:
        beam = xds[kernel] / xds[kernel].sum(axis=[0, 1])
    elif kernel == 'gaussian':
        beam, parms_tar = chb.synthesizedbeam(size[0], size[1], size[2],
                                              len(xds.d0), len(xds.d1),
                                              xds.incr[:2])
        beam = xarray.DataArray(da.from_array(beam /
                                              np.sum(beam, axis=(0, 1))),
                                dims=['d0', 'd1'],
                                name=name)  # normalized to unity
        cf_tar = ((4 * np.pi**2) /
                  (4 * parms_tar[0] * parms_tar[2] -
                   parms_tar[1]**2)) * parms_tar  # equation 12
        size_corr = size
    else:  # boxcar
        incr = np.abs(xds.incr[:2]) * 180 / np.pi * 60 * 60
        xx, yy = np.mgrid[:int(np.round(size[0] / incr[0])
                               ), :int(np.round(size[1] / incr[1]))]
        box = np.array(
            (xx.ravel() - np.max(xx) // 2,
             yy.ravel() - np.max(yy) // 2)) + np.array(
                 [len(xds.d0) // 2, len(xds.d1) // 2])[:, None]
        beam = np.zeros((len(xds.d0), len(xds.d1)))
        beam[box[0], box[1]] = 1.0
        beam = xarray.DataArray(da.from_array(beam /
                                              np.sum(beam, axis=(0, 1))),
                                dims=['d0', 'd1'],
                                name=name)  # normalized to unity

    # compute the correcting beam if necessary
    # this is done analytically using the parameters of the current beam, not the actual data
    # see equations 19 - 26 here:
    # https://casa.nrao.edu/casadocs-devel/stable/memo-series/casa-memos/casa_memo10_restoringbeam.pdf/view
    if (kernel == 'gaussian') and (current is not None):
        parms_curr = chb.synthesizedbeam(current[0], current[1], current[2],
                                         len(xds.d0), len(xds.d1),
                                         xds.incr[:2])[1]
        cf_curr = ((4 * np.pi**2) /
                   (4 * parms_curr[0] * parms_curr[2] -
                    parms_curr[1]**2)) * parms_curr  # equation 12
        cf_corr = (cf_tar - cf_curr)  # equation 19
        c_corr = ((4 * np.pi**2) / (4 * cf_corr[0] * cf_corr[2] -
                                    cf_corr[1]**2)) * cf_corr  # equation 12
        # equations 21 - 23
        d1 = np.sqrt(8 * np.log(2) /
                     ((c_corr[0] + c_corr[2]) -
                      np.sqrt(c_corr[0]**2 - 2 * c_corr[0] * c_corr[2] +
                              c_corr[2]**2 + c_corr[1]**2)))
        d2 = np.sqrt(8 * np.log(2) /
                     ((c_corr[0] + c_corr[2]) +
                      np.sqrt(c_corr[0]**2 - 2 * c_corr[0] * c_corr[2] +
                              c_corr[2]**2 + c_corr[1]**2)))
        theta = 0.5 * np.arctan2(-c_corr[1], c_corr[2] - c_corr[0])

        # make a beam out of the correcting size
        incr_arcsec = np.abs(xds.incr[:2]) * 180 / np.pi * 60 * 60
        size_corr = [
            d1 * incr_arcsec[0], d2 * incr_arcsec[1], theta * 180 / np.pi
        ]
        scale_corr = (4 * np.log(2) / (np.pi * d1 * d2)) * (
            size[0] * size[1] / (current[0] * current[1]))  # equation 20
        beam = scale_corr * chb.synthesizedbeam(size_corr[0], size_corr[1],
                                                size_corr[2], len(xds.d0),
                                                len(xds.d1), xds.incr[:2])[0]
        beam = xarray.DataArray(
            da.from_array(beam),
            dims=[xds[dv].dims[dd] for dd in range(beam.ndim)],
            name=name)

    # scale and FFT the kernel beam
    da_beam = da.atleast_3d(beam.data)
    if da_beam.ndim < 4: da_beam = da_beam[:, :, :, None]
    ft_beam = da.fft.fft2((da_beam * scale), axes=[0, 1])

    # FFT the image, multiply by the kernel beam FFT, then inverse FFT it back
    ft_image = da.fft.fft2(xds[dv].data, axes=[0, 1])
    ft_smooth = ft_image * ft_beam
    ift_smooth = da.fft.fftshift(da.fft.ifft2(ft_smooth, axes=[0, 1]),
                                 axes=[0, 1])

    # store the smooth image and kernel beam back in the xds
    xda_smooth = xarray.DataArray(da.absolute(ift_smooth),
                                  dims=xds[dv].dims,
                                  coords=xds[dv].coords)
    new_xds = xds.assign({dv: xda_smooth, name: beam * scale})
    if size_corr is not None:
        new_xds = new_xds.assign_attrs({name + '_params': tuple(size_corr)})
    return new_xds
示例#17
0
def absolute(A):
    return da.absolute(A)
示例#18
0
    def make_psf(self):
        print("Making PSF", file=log)
        psfs = []
        self.stokes_weights = {}
        self.uvws = {}
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=self.columns)

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]
            self.stokes_weights[ims] = {}
            self.uvws[ims] = {}

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                # this is not correct, need to use spw
                spw = ds.DATA_DESC_ID

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]

                uvw = ds.UVW.data

                flag = getattr(ds, self.flag_column).data

                weights = getattr(ds, self.weight_column).data
                if len(weights.shape) < 3:
                    weights = da.broadcast_to(weights[:, None, :],
                                              flag.shape,
                                              chunks=flag.chunks)

                if self.imaging_weight_column is not None:
                    imaging_weights = getattr(ds,
                                              self.imaging_weight_column).data
                    if len(imaging_weights.shape) < 3:
                        imaging_weights = da.broadcast_to(
                            imaging_weights[:, None, :],
                            flag.shape,
                            chunks=flag.chunks)

                    weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                    weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
                else:
                    weightsxx = weights[:, :, 0]
                    weightsyy = weights[:, :, -1]

                # for the PSF we need to scale the weights by the
                # Mueller amplitudes squared
                if self.mueller_column is not None:
                    mueller = getattr(ds, self.mueller_column).data
                    weightsxx *= da.absolute(mueller[:, :, 0])**2
                    weightsyy *= da.absolute(mueller[:, :, -1])**2

                # weighted sum corr to Stokes I
                weights = weightsxx + weightsyy

                # only keep data where both corrs are unflagged
                flagxx = flag[:, :, 0]
                flagyy = flag[:, :, -1]
                flag = ~(flagxx | flagyy)  # ducc0 convention

                weights *= flag

                data = weights.astype(np.complex64)

                psf = vis2im(uvw,
                             freq,
                             data,
                             freq_bin_idx,
                             freq_bin_counts,
                             self.nx_psf,
                             self.ny_psf,
                             self.cell,
                             flag=flag.astype(np.uint8),
                             nthreads=self.nthreads,
                             epsilon=self.epsilon,
                             do_wstacking=self.do_wstacking,
                             double_accum=True)

                psfs.append(psf)

                # assumes that stokes weights and uvw fit into memory
                # self.stokes_weights[ims][spw] = dask.persist(weights.rechunk({0:-1}))[0]
                # self.uvws[ims][spw] = dask.persist(uvw.rechunk({0:-1}))[0]

                # for comparison with numpy implementation
                # self.stokes_weights[ims][spw] = dask.compute(weights)[0]
                # self.uvws[ims][spw] = dask.compute(uvw)[0]

        # import pdb
        # pdb.set_trace()

        psfs = dask.compute(psfs, scheduler='single-threaded')[0]
        return accumulate_dirty(psfs, self.nband,
                                self.band_mapping).astype(self.real_type)
示例#19
0
    def make_dirty(self):
        print("Making dirty", file=log)
        dirty = da.zeros((self.nband, self.nx, self.ny),
                         dtype=np.float32,
                         chunks=(1, self.nx, self.ny),
                         name=False)
        dirties = []
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=self.columns)

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                spw = ds.DATA_DESC_ID  # this is not correct, need to use spw

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]
                freq_chunk = freq_bin_counts[0].compute()

                uvw = ds.UVW.data

                data = getattr(ds, self.data_column).data
                dataxx = data[:, :, 0]
                datayy = data[:, :, -1]

                weights = getattr(ds, self.weight_column).data
                if len(weights.shape) < 3:
                    weights = da.broadcast_to(weights[:, None, :],
                                              data.shape,
                                              chunks=data.chunks)

                if self.imaging_weight_column is not None:
                    imaging_weights = getattr(ds,
                                              self.imaging_weight_column).data
                    if len(imaging_weights.shape) < 3:
                        imaging_weights = da.broadcast_to(
                            imaging_weights[:, None, :],
                            data.shape,
                            chunks=data.chunks)

                    weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                    weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
                else:
                    weightsxx = weights[:, :, 0]
                    weightsyy = weights[:, :, -1]

                # apply adjoint of mueller term.
                # Phases modify data amplitudes modify weights.
                if self.mueller_column is not None:
                    mueller = getattr(ds, self.mueller_column).data
                    dataxx *= da.exp(-1j * da.angle(mueller[:, :, 0]))
                    datayy *= da.exp(-1j * da.angle(mueller[:, :, -1]))
                    weightsxx *= da.absolute(mueller[:, :, 0])
                    weightsyy *= da.absolute(mueller[:, :, -1])

                # weighted sum corr to Stokes I
                weights = weightsxx + weightsyy
                data = (weightsxx * dataxx + weightsyy * datayy)
                # TODO - turn off this stupid warning
                data = da.where(weights, data / weights, 0.0j)

                # only keep data where both corrs are unflagged
                flag = getattr(ds, self.flag_column).data
                flagxx = flag[:, :, 0]
                flagyy = flag[:, :, -1]
                # ducc0 convention uses uint8 mask not flag
                flag = ~(flagxx | flagyy)

                dirty = vis2im(uvw,
                               freq,
                               data,
                               freq_bin_idx,
                               freq_bin_counts,
                               self.nx,
                               self.ny,
                               self.cell,
                               weights=weights,
                               flag=flag.astype(np.uint8),
                               nthreads=self.nthreads,
                               epsilon=self.epsilon,
                               do_wstacking=self.do_wstacking,
                               double_accum=True)

                dirties.append(dirty)

        dirties = dask.compute(dirties, scheduler='single-threaded')[0]

        return accumulate_dirty(dirties, self.nband,
                                self.band_mapping).astype(self.real_type)
示例#20
0
           nx=2 * args.npix)

# Should only be one correlation
assert psf.shape[2] == 1, psf.shape

# FFT the PSF
psf_fft = da.fft.fftshift(da.fft.ifft2(da.fft.ifftshift(psf[:, :, 0])))

# Dirty image composed of the diagonal correlations
if ncorr == 1:
    dirty = dirty_fft[0].real
else:
    dirty = (dirty_fft[0].real + dirty_fft[ncorr - 1].real) * 0.5

# Normalised Amplitude
psf = da.absolute(psf_fft.real)
psf = (psf / da.max(psf))

# Scale the dirty image by the psf
# x4 because the N**2 FFT normalization factor
# on a square image double the size
dirty = dirty / (da.max(psf) * 4.)

# Visualise profiling if we have bokeh
try:
    import bokeh  # noqa
except ImportError:
    from dask.diagnostics import ProgressBar

    with ProgressBar():
        dirty = dirty.compute()
示例#21
0
def model_flux_per_scan_dask(time_chunks,
                             freq_chunks,
                             fluxcols,
                             w,
                             f,
                             filename="M1",
                             outdir="./soln-intervals",
                             indices=None):
    """compute the flux per interval scans"""

    if len(fluxcols) == 1:
        m0 = getattr(xds[0], fluxcols[0]).data
        __sub_model = False
    else:
        m1 = getattr(xds[0], fluxcols[0]).data
        m0 = getattr(xds[0], fluxcols[1]).data
        __sub_model = True

    # apply flags and weights

    m0 *= (f == False)
    m0 *= w

    if __sub_model:
        # p*=(f==False) select based on m only
        m1 *= w

    LOGGER.debug("Done applying weights and flags")

    if indices is None:

        nt, nv = len(time_chunks), len(freq_chunks)

        flux = np.zeros((nt, nv))

        for tt, time_chunk in enumerate(time_chunks):
            for ff, freq_chunk in enumerate(freq_chunks):
                tsel = slice(time_chunk[0], time_chunk[1])
                fsel = slice(freq_chunk[0], freq_chunk[1])

                if __sub_model:
                    model_abs = da.absolute(m1[tsel, fsel, :][..., [0, 3]][m0[
                        tsel, fsel, :][..., [0, 3]] != 0] - m0[tsel, fsel, :][
                            ..., [0, 3]][m0[tsel, fsel, :][..., [0, 3]] != 0])
                else:
                    model_abs = da.absolute(
                        m0[tsel,
                           fsel, :][...,
                                    [0, 3]][m0[tsel, fsel, :][...,
                                                              [0, 3]] != 0])

                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", category=RuntimeWarning)
                    flux[tt, ff] = np.mean(model_abs.compute())

            if tt % 6 == 0:
                LOGGER.info(
                    "Done computing model flux for {%d}/{%d} time chunks" %
                    (tt + 1, len(time_chunks)))

        LOGGER.info("Done computing model flux")

        np.save(outdir + "/" + filename + "flux.npy", flux)

        return flux

    else:

        flux = np.zeros(len(indices))

        for loc, index in enumerate(indices):
            tsel = slice(time_chunks[index[0]][0], time_chunks[index[0]][1])
            fsel = slice(freq_chunks[index[1]][0], freq_chunks[index[1]][1])

            if __sub_model:
                model_abs = da.absolute(m1[tsel, fsel, :][..., [0, 3]][
                    m0[tsel, fsel, :][..., [0, 3]] != 0] - m0[tsel, fsel, :][
                        ..., [0, 3]][m0[tsel, fsel, :][..., [0, 3]] != 0])
            else:
                model_abs = da.absolute(
                    m0[tsel,
                       fsel, :][...,
                                [0, 3]][m0[tsel, fsel, :][..., [0, 3]] != 0])

            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                flux[loc] = np.mean(model_abs.compute())

        LOGGER.info("Done computing model flux")

        return flux
示例#22
0
def visualize(xda, axis=None, tsize=250):
    """
    Plot a preview of any xarray DataArray contents

    Parameters
    ----------
    xda : xarray.core.dataarray.DataArray
        input DataArray
    axis : str or list
        DataArray coordinate(s) to plot against data. Default None uses range
    tsize : int
        target size of the preview plot (might be smaller). Default is 250 points per axis

    Returns
    -------
      Open matplotlib window
    """
    import matplotlib.pyplot as plt
    import xarray
    import numpy as np
    import dask.array as da
    from pandas.plotting import register_matplotlib_converters
    register_matplotlib_converters()

    fig, axes = plt.subplots(1, 1)

    # fast decimate to roughly the desired size
    thinf = np.ceil(np.array(xda.shape) / tsize)
    txda = xda.thin(
        dict([(xda.dims[ii], int(thinf[ii])) for ii in range(len(thinf))]))

    # can't plot complex numbers, bools (sometimes), or strings
    if txda.dtype == 'complex128':
        txda = da.absolute(txda)
    elif txda.dtype == 'bool':
        txda = txda.astype(int)
    elif txda.dtype.type is np.str_:
        txda = xarray.DataArray(np.unique(txda, return_inverse=True)[1],
                                dims=txda.dims,
                                coords=txda.coords,
                                name=txda.name)

    # default pcolormesh plot axes
    if (txda.ndim > 1) and (axis is None):
        axis = np.array(txda.dims[:2])
        if 'chan' in txda.dims: axis[-1] = 'chan'

    # collapse data to 1-D or 2-D
    if axis is not None:
        axis = np.atleast_1d(axis)
        if txda.ndim > 1:
            txda = txda.max(dim=[dd for dd in txda.dims if dd not in axis])

    # different types of plots depending on shape and parameters
    if (txda.ndim == 1) and (axis is None):
        dname = txda.name if txda.name is not None else 'value'
        pxda = xarray.DataArray(np.arange(len(txda)),
                                dims=[dname],
                                coords={dname: txda.values})
        pxda.plot.line(ax=axes, y=pxda.dims[0], marker='.', linewidth=0.0)
        plt.title(dname)
    elif (txda.ndim == 1):
        txda.plot.line(ax=axes, x=axis[0], marker='.', linewidth=0.0)
        if txda.name is not None: plt.title(txda.name + ' vs ' + axis[0])
    else:  # more than 1-D
        txda.plot.pcolormesh(ax=axes, x=axis[0], y=axis[1])
        plt.title(txda.name + ' ' + axis[1] + ' vs ' + axis[0])

    plt.show()
示例#23
0
def _residual(ms, stack, **kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)
    pyscilog.log_to_file(args.output_filename + '.log')
    pyscilog.enable_memory_logging(level=3)

    # number of threads per worker
    if args.nthreads is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify nthreads when using a distributed scheduler"
            )
        import multiprocessing
        nthreads = multiprocessing.cpu_count()
        args.nthreads = nthreads
    else:
        nthreads = args.nthreads

    # configure memory limit
    if args.mem_limit is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify mem-limit when using a distributed scheduler"
            )
        import psutil
        mem_limit = int(psutil.virtual_memory()[0] /
                        1e9)  # 100% of memory by default
        args.mem_limit = mem_limit
    else:
        mem_limit = args.mem_limit

    nband = args.nband
    if args.nworkers is None:
        nworkers = nband
        args.nworkers = nworkers
    else:
        nworkers = args.nworkers

    if args.nthreads_per_worker is None:
        nthreads_per_worker = 1
        args.nthreads_per_worker = nthreads_per_worker
    else:
        nthreads_per_worker = args.nthreads_per_worker

    # the number of chunks being read in simultaneously is equal to
    # the number of dask threads
    nthreads_dask = nworkers * nthreads_per_worker

    if args.ngridder_threads is None:
        if args.host_address is not None:
            ngridder_threads = nthreads // nthreads_per_worker
        else:
            ngridder_threads = nthreads // nthreads_dask
        args.ngridder_threads = ngridder_threads
    else:
        ngridder_threads = args.ngridder_threads

    ms = list(ms)
    print('Input Options:', file=log)
    for key in kw.keys():
        print('     %25s = %s' % (key, args[key]), file=log)

    # numpy imports have to happen after this step
    from pfb import set_client
    set_client(nthreads, mem_limit, nworkers, nthreads_per_worker,
               args.host_address, stack, log)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    from dask.graph_manipulation import clone
    from dask.distributed import performance_report
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import residual as im2residim
    from ducc0.fft import good_size
    from pfb.utils.misc import stitch_images, plan_row_chunk
    from pfb.utils.fits import set_wcs, save_fits

    # chan <-> band mapping
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # gridder memory budget
    max_chan_chunk = 0
    max_freq = 0
    for ims in ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            freq = freqs[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())
            max_freq = np.maximum(max_freq, freq.max())

    # assumes measurement sets have the same columns,
    # number of correlations etc.
    xds = xds_from_ms(ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    data_bytes = getattr(xds[0], args.data_column).data.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row

    # real valued weights
    wdims = getattr(xds[0], args.weight_column).data.ndim
    if wdims == 2:  # WEIGHT
        memory_per_row += ncorr * data_bytes / 2
    else:  # WEIGHT_SPECTRUM
        memory_per_row += bytes_per_row / 2

    # flags (uint8 or bool)
    memory_per_row += np.dtype(np.uint8).itemsize * max_chan_chunk * ncorr

    # UVW
    memory_per_row += xds[0].UVW.data.itemsize * 3

    # ANTENNA1/2
    memory_per_row += xds[0].ANTENNA1.data.itemsize * 2

    columns = (args.data_column, args.weight_column, args.flag_column, 'UVW',
               'ANTENNA1', 'ANTENNA2')

    # flag row
    if 'FLAG_ROW' in xds[0]:
        columns += ('FLAG_ROW', )
        memory_per_row += xds[0].FLAG_ROW.data.itemsize

    # imaging weights
    if args.imaging_weight_column is not None:
        columns += (args.imaging_weight_column, )
        memory_per_row += bytes_per_row / 2

    # Mueller term (complex valued)
    if args.mueller_column is not None:
        columns += (args.mueller_column, )
        memory_per_row += bytes_per_row

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    for ims in ms:
        xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1})

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

    uv_max = uv_max.compute()
    del uvw

    # image size
    cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)

    if args.cell_size is not None:
        cell_size = args.cell_size
        cell_rad = cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=log)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % cell_size, file=log)

    if args.nx is None:
        fov = args.field_of_view * 3600
        npix = int(fov / cell_size)
        if npix % 2:
            npix += 1
        nx = good_size(npix)
        ny = good_size(npix)
    else:
        nx = args.nx
        ny = args.ny if args.ny is not None else nx

    print("Image size set to (%i, %i, %i)" % (nband, nx, ny), file=log)

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(mem_limit / nworkers, band_size, nrow,
                                   memory_per_row, nthreads_per_worker)

    else:
        # single band per node
        row_chunk = plan_row_chunk(mem_limit, band_size, nrow, memory_per_row,
                                   nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

    print(
        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
        file=log)

    chunks = {}
    for ims in ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
            chunks[ims].append({
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']
            })

    dirties = []
    radec = None  # assumes we are only imaging field 0 of first MS
    for ims in ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns)

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            # this is not correct, need to use spw
            spw = ds.DATA_DESC_ID

            uvw = clone(ds.UVW.data)

            data = getattr(ds, args.data_column).data
            dataxx = data[:, :, 0]
            datayy = data[:, :, -1]

            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],
                                          data.shape,
                                          chunks=data.chunks)

            if args.imaging_weight_column is not None:
                imaging_weights = getattr(ds, args.imaging_weight_column).data
                if len(imaging_weights.shape) < 3:
                    imaging_weights = da.broadcast_to(imaging_weights[:,
                                                                      None, :],
                                                      data.shape,
                                                      chunks=data.chunks)

                weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
            else:
                weightsxx = weights[:, :, 0]
                weightsyy = weights[:, :, -1]

            # apply adjoint of mueller term.
            # Phases modify data amplitudes modify weights.
            if args.mueller_column is not None:
                mueller = getattr(ds, args.mueller_column).data
                dataxx *= da.exp(-1j * da.angle(mueller[:, :, 0]))
                datayy *= da.exp(-1j * da.angle(mueller[:, :, -1]))
                weightsxx *= da.absolute(mueller[:, :, 0])
                weightsyy *= da.absolute(mueller[:, :, -1])

            # weighted sum corr to Stokes I
            weights = weightsxx + weightsyy
            data = (weightsxx * dataxx + weightsyy * datayy)
            # TODO - turn off this stupid warning
            data = da.where(weights, data / weights, 0.0j)

            # MS may contain auto-correlations
            if 'FLAG_ROW' in xds[0]:
                frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data
                                           == ds.ANTENNA2.data)
            else:
                frow = (ds.ANTENNA1.data == ds.ANTENNA2.data)

            # only keep data where both corrs are unflagged
            flag = getattr(ds, args.flag_column).data
            flagxx = flag[:, :, 0]
            flagyy = flag[:, :, -1]
            # ducc0 uses uint8 mask not flag
            mask = ~da.logical_or((flagxx | flagyy), frow[:, None])

            dirty = vis2im(uvw,
                           freqs[ims][spw],
                           data,
                           freq_bin_idx[ims][spw],
                           freq_bin_counts[ims][spw],
                           nx,
                           ny,
                           cell_rad,
                           weights=weights,
                           flag=mask.astype(np.uint8),
                           nthreads=ngridder_threads,
                           epsilon=args.epsilon,
                           do_wstacking=args.wstack,
                           double_accum=args.double_accum)

            dirties.append(dirty)

    # dask.visualize(dirties, filename=args.output_filename + '_graph.pdf', optimize_graph=False)

    if not args.mock:
        # result = dask.compute(dirties, wsum, optimize_graph=False)
        with performance_report(filename=args.output_filename + '_per.html'):
            result = dask.compute(dirties, optimize_graph=False)

        dirties = result[0]

        dirty = stitch_images(dirties, nband, band_mapping)

        hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                      freq_out)
        save_fits(args.output_filename + '_dirty.fits',
                  dirty,
                  hdr,
                  dtype=args.output_type)

    print("All done here.", file=log)
示例#24
0
文件: psf.py 项目: ratt-ru/pfb-clean
def _psf(**kw):
    args = OmegaConf.create(kw)
    from omegaconf import ListConfig
    if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig):
        args.ms = [args.ms]
    OmegaConf.set_struct(args, True)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    # from dask.distributed import performance_report
    from dask.graph_manipulation import clone
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    from daskms import Dataset
    from daskms.experimental.zarr import xds_to_zarr
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import dirty as vis2im
    from ducc0.fft import good_size
    from pfb.utils.misc import stitch_images, plan_row_chunk
    from pfb.utils.fits import set_wcs, save_fits

    # chan <-> band mapping
    ms = args.ms
    nband = args.nband
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # gridder memory budget
    max_chan_chunk = 0
    max_freq = 0
    for ims in args.ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            freq = freqs[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())
            max_freq = np.maximum(max_freq, freq.max())

    # assumes measurement sets have the same columns,
    # number of correlations etc.
    xds = xds_from_ms(args.ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    # we still have to cater for complex valued data because we cast
    # the weights to complex but we not longer need to factor the
    # weight column into our memory budget
    data_bytes = getattr(xds[0], args.data_column).data.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row

    # flags (uint8 or bool)
    memory_per_row += bytes_per_row / 8

    # UVW
    memory_per_row += xds[0].UVW.data.itemsize * 3

    # ANTENNA1/2
    memory_per_row += xds[0].ANTENNA1.data.itemsize * 2

    # TIME
    memory_per_row += xds[0].TIME.data.itemsize

    # data column is not actually read into memory just used to infer
    # dtype and chunking
    columns = (args.data_column, args.weight_column, args.flag_column, 'UVW',
               'ANTENNA1', 'ANTENNA2', 'TIME')

    # flag row
    if 'FLAG_ROW' in xds[0]:
        columns += ('FLAG_ROW', )
        memory_per_row += xds[0].FLAG_ROW.data.itemsize

    # imaging weights
    if args.imaging_weight_column is not None:
        columns += (args.imaging_weight_column, )
        memory_per_row += bytes_per_row / 2

    # Mueller term (complex valued)
    if args.mueller_column is not None:
        columns += (args.mueller_column, )
        memory_per_row += bytes_per_row

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    for ims in args.ms:
        xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1})

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

    uv_max = uv_max.compute()
    del uvw

    # image size
    cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)

    if args.cell_size is not None:
        cell_size = args.cell_size
        cell_rad = cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=log)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % cell_size, file=log)

    if args.nx is None:
        fov = args.field_of_view * 3600
        npix = int(args.psf_oversize * fov / cell_size)
        if npix % 2:
            npix += 1
        nx = npix
        ny = npix
    else:
        nx = args.nx
        ny = args.ny if args.ny is not None else nx

    print("PSF size set to (%i, %i, %i)" % (nband, nx, ny), file=log)

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size,
                                   nrow, memory_per_row,
                                   args.nthreads_per_worker)

    else:
        # single band per node
        row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow,
                                   memory_per_row, args.nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

    print(
        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
        file=log)

    chunks = {}
    for ims in args.ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
            chunks[ims].append({
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']
            })

    psfs = []
    radec = None  # assumes we are only imaging field 0 of first MS
    out_datasets = []
    for ims in args.ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns)

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            # this is not correct, need to use spw
            spw = ds.DATA_DESC_ID

            uvw = clone(ds.UVW.data)

            data_type = getattr(ds, args.data_column).data.dtype
            data_shape = getattr(ds, args.data_column).data.shape
            data_chunks = getattr(ds, args.data_column).data.chunks

            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],
                                          data_shape,
                                          chunks=data_chunks)

            if args.imaging_weight_column is not None:
                imaging_weights = getattr(ds, args.imaging_weight_column).data
                if len(imaging_weights.shape) < 3:
                    imaging_weights = da.broadcast_to(imaging_weights[:,
                                                                      None, :],
                                                      data_shape,
                                                      chunks=data_chunks)

                weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
            else:
                weightsxx = weights[:, :, 0]
                weightsyy = weights[:, :, -1]

            # apply mueller term
            if args.mueller_column is not None:
                mueller = getattr(ds, args.mueller_column).data
                weightsxx *= da.absolute(mueller[:, :, 0])**2
                weightsyy *= da.absolute(mueller[:, :, -1])**2

            # weighted sum corr to Stokes I
            weights = weightsxx + weightsyy

            # MS may contain auto-correlations
            if 'FLAG_ROW' in xds[0]:
                frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data
                                           == ds.ANTENNA2.data)
            else:
                frow = (ds.ANTENNA1.data == ds.ANTENNA2.data)

            # only keep data where both corrs are unflagged
            flag = getattr(ds, args.flag_column).data
            flagxx = flag[:, :, 0]
            flagyy = flag[:, :, -1]
            # ducc0 uses uint8 mask not flag
            mask = ~da.logical_or((flagxx | flagyy), frow[:, None])

            psf = vis2im(uvw,
                         freqs[ims][spw],
                         weights.astype(data_type),
                         freq_bin_idx[ims][spw],
                         freq_bin_counts[ims][spw],
                         nx,
                         ny,
                         cell_rad,
                         flag=mask.astype(np.uint8),
                         nthreads=args.nvthreads,
                         epsilon=args.epsilon,
                         do_wstacking=args.wstack,
                         double_accum=args.double_accum)

            psfs.append(psf)

            data_vars = {
                'FIELD_ID': (('row', ),
                             da.full_like(ds.TIME.data,
                                          ds.FIELD_ID,
                                          chunks=args.row_out_chunk)),
                'DATA_DESC_ID': (('row', ),
                                 da.full_like(ds.TIME.data,
                                              ds.DATA_DESC_ID,
                                              chunks=args.row_out_chunk)),
                'WEIGHT':
                (('row', 'chan'), weights.rechunk({0: args.row_out_chunk
                                                   })),  # why no 'f4'?
                'UVW': (('row', 'uvw'), uvw.rechunk({0: args.row_out_chunk}))
            }

            coords = {'chan': (('chan', ), freqs[ims][spw])}

            out_ds = Dataset(data_vars, coords)

            out_datasets.append(out_ds)

    writes = xds_to_zarr(out_datasets,
                         args.output_filename + '.zarr',
                         columns='ALL')

    # dask.visualize(writes, filename=args.output_filename + '_psf_writes_graph.pdf', optimize_graph=False)
    # dask.visualize(psfs, filename=args.output_filename + '_psf_graph.pdf', optimize_graph=False)

    if not args.mock:
        # psfs = dask.compute(psfs, writes, optimize_graph=False)[0]
        # with performance_report(filename=args.output_filename + '_psf_per.html'):
        psfs = dask.compute(psfs, writes, optimize_graph=False)[0]

        psf = stitch_images(psfs, nband, band_mapping)

        hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                      freq_out)
        save_fits(args.output_filename + '_psf.fits',
                  psf,
                  hdr,
                  dtype=args.output_type)

        psf_mfs = np.sum(psf, axis=0)
        wsum = psf_mfs.max()
        psf_mfs /= wsum

        hdr_mfs = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                          np.mean(freq_out))
        save_fits(args.output_filename + '_psf_mfs.fits',
                  psf_mfs,
                  hdr_mfs,
                  dtype=args.output_type)

    print("All done here.", file=log)
示例#25
0
def plot_wavelet_spectrogram(s, f, t, *, kind='amplitude', freq_limits=None, time_limits=None,
                             zscore=False, zscore_axis=None, n_workers=cpu_count(),
                             timescale='seconds', relative_time=False, center_time=False,
                             colorbar=True, plot_type='pcolormesh', fig=None,
                             ax=None, **kwargs):
    """
    Plot a wavelet based spectrogram. Note that the spectrogram
    data need not fit into memory but the data used for the
    plot should.

    Parameters
    ----------
    s : np.ndarray, with shape (n_freqs, n_timepoints)
        The spectrogram. Should be real-valued.
    f : np.ndarray, with shape (n_freqs, )
        Spectrogram frequencies in Hz. Must be uniformly
        spaced and strictly increasing or strictly
        decreasing.
    t : np.ndarray, with shape (n_timepoints, )
        Midpoints of each spectrogram time window in seconds.
        Must be uniformly spaced and strictly increasing
        or strictly decreasing.
    kind : {'amplitude', 'power'}, optional
        Display the data using the wavelet coefficient amplitudes
        or power, depending on what 'kind' is.
        Default is 'amplitude'.
    freq_limits : list or None, optional
        If not None, a list consisting of the half-open interval
        [min_frequency, max_frequency) to show in the plot.
        Note that these are only approximate since the data may
        not contain those exact frequencies. However, the
        plot is guaranteed not to show anything < min_frequency or
        >= max_frequency.
        Default is None, where all frequencies are shown.
    time_limits : list or None, optional
        If not None, a list consisting of the half-interval
        [min_time, max_time) to show in the plot. Note that these
        are only approximate since the data may not contain those
        exact time values. However, the plot is guaranteed not to
        show anything < min_time or >= max_time.
        Default is None, where all time points are shown.
    zscore : boolean, optional
        Whether to zscore the data before visualizing. zscoring
        is applied before restricting the display according to
        'freq_limits' and 'time_limits'
    zscore axis : None or int, optional
        The axis of s over which to apply zscoring. If None, zscoring
        is applied over the entire array s.
        Default is None.
    n_workers : integer, optional
        Number of parallel computations. Only used if 'zscore' is True.
        Default is the total number of CPUs (which may be virtual).
    timescale : string, optional
        The time scale to use on the plot's x-axis. Can be
        'milliseconds', seconds', 'minutes', or 'hours'.
        Default is 'seconds'
    relative_time : boolean, optional
        Whether the time axis shown on the plot will be relative
        Default is False
    center_time : boolean, optional
        Whether the time axis is centered around 0. This option
        is only available if 'relative_time' is set to True
        Default is False
    colorbar : boolean, optional
        Whether to show colorbar or not.
        Default is True.
    plot_type : {'pcolormesh', 'contourf', 'quadmesh'}, optional
        Type of plot to show. Note that the 'quadmesh' option uses
        hvplot as the backend.
        Default is 'pcolormesh'
    fig : matplotlib figure
        Used if 'colorbar' is True and 'plot_type' is not 'quadmesh'
    ax : matplotlib axes
        If None, new axes will be generated. Note that this argument
        is ignored if plot_type=='quadmesh'
    **kwargs: optional
        Other keyword arguments. Passed directly to pcolormesh(),
        contourf(), or quadmesh().
        
    Returns
    -------
    If plot_type is 'pcolormesh' or 'contourf':
        fig, ax : tuple consisting of (matplotlib figure, matplotlib axes)
    Otherwise:
        plot : bokeh plot handle
    
    """
    if not np.iscomplexobj(s):
        raise TypeError("Expected input data to be complex")
    
    if kind == 'amplitude':
        title_str = "Amplitude"
    elif kind == 'power':
        title_str = "Power"
    
    if freq_limits is not None:
        freq_slices = _get_frequency_slices(f, freq_limits[0], freq_limits[1])
    else:
        freq_slices = slice(None, None, None)
    
    if time_limits is not None:
        time_slices = _get_time_slices(t, time_limits[0], time_limites[1])
    else:
        time_slices = slice(None, None, None)

    tvec = t[time_slices]
    if timescale == 'milliseconds':
        tvec = tvec * 1000
        xlabel = "Time (msec)"
    elif timescale == 'seconds':
        xlabel = "Time (sec)"
    elif timescale == 'minutes':
        tvec = tvec / 60
        xlabel = "Time (min)"
    else:
        tvec = tvec / 3600
        xlabel = "Time (hr)"

    if relative_time:
        tvec = tvec - tvec[0]
        if center_time:
            a = (tvec[0] + tvec[-1]) / 2
            tvec = tvec - a
    
    fvec = f[freq_slices]
    
    if zscore:
        # data may be large. use dask for computation
        dask_data = da.absolute(da.from_array(s))
        if kind == 'power':
            dask_data = da.square(dask_data)
        
        if zscore_axis is None:
            dask_data = (dask_data - dask_data.mean()) / dask_data.std()
        else:
            dask_data = (
                (dask_data - dask_data.mean(axis=zscore_axis, keepdims=True)) /
                 dask_data.std(axis=zscore_axis, keepdims=True)
            )
        data = dask_data[freq_slices, time_slices].compute(num_workers=n_workers)
    else:
        data = np.abs(s[freq_slices, time_slices])
        if kind == 'power':
            data = data**2
            
    if ax is None and plot_type != 'quadmesh':
        fig, ax = plt.subplots(1, 1)
        
    if colorbar:
        if (fig is None and ax is not None) or (fig is not None and ax is None): 
            raise ValueError(
                "Both 'fig' and 'ax' must be passed in if either is specified")
        
    if plot_type == 'pcolormesh':
        _set_matplotlib(True)
        im = ax.pcolormesh(tvec, fvec, data, **kwargs)
        if colorbar:
            fig.colorbar(im, ax=ax)
        ax.set_title(title_str)
        ax.set_xlabel(xlabel)
        ax.set_ylabel("Frequency")
        return fig, ax
    elif plot_type == 'contourf':
        _set_matplotlib(True)
        im = ax.contourf(tvec, fvec, data, **kwargs)
        if colorbar:
            fig.colorbar(im, ax=ax)
        ax.set_title(title_str)
        ax.set_xlabel(xlabel)
        ax.set_ylabel("Frequency")
        return fig, ax
    elif plot_type == 'quadmesh':
        _set_matplotlib(False)
        xa = xr.DataArray(
            data,
            dims=['Frequency', xlabel],
            coords={'Frequency':fvec, xlabel:tvec})
        plot = xa.hvplot.quadmesh(
            x=xlabel, y='Frequency', title=title_str,
            colorbar=colorbar, **kwargs)
        return plot
示例#26
0
def wiener(data, aux, fr, L):
    l_del = delayed(L)
    return uirdft2((da.conj(delayed(fr)) * urdft2(data) + l_del * urdft2(aux)) * ((da.absolute(fr_npy)**2 + l_del)**-1))