コード例 #1
0
def searchdask(a, v, how=None, atol=None):
    n_a = a.shape[0]
    searchfunc, args = presearch(a, v)

    if how == 'nearest':
        l_index = da.maximum(searchfunc(*args, side='right') - 1, 0)
        r_index = da.minimum(searchfunc(*args), n_a - 1)
        cond = 2 * v < (select(a, r_index) + select(a, l_index))
        indexer = da.maximum(da.where(cond, l_index, r_index), 0)
    elif how == 'bfill':
        indexer = searchfunc(*args)
    elif how == 'ffill':
        indexer = searchfunc(*args, side='right') - 1
        indexer = da.where(indexer == -1, n_a, indexer)
    elif how is None:
        l_index = searchfunc(*args)
        r_index = searchfunc(*args, side='right')
        indexer = da.where(l_index == r_index, n_a, l_index)
    else:
        return NotImplementedError

    if atol is not None:
        a2 = da.concatenate([a, [atol + da.max(v) + 1]])
        indexer = da.where(
            da.absolute(select(a2, indexer) - v) > atol, n_a, indexer)
    return indexer
コード例 #2
0
def test_elemwise_consistent_names():
    a = da.from_array(np.arange(5, dtype='f4'), chunks=(2, ))
    b = da.from_array(np.arange(5, dtype='f4'), chunks=(2, ))
    assert same_keys(a + b, a + b)
    assert same_keys(a + 2, a + 2)
    assert same_keys(da.exp(a), da.exp(a))
    assert same_keys(da.exp(a, dtype='f8'), da.exp(a, dtype='f8'))
    assert same_keys(da.maximum(a, b), da.maximum(a, b))
コード例 #3
0
ファイル: test_array_core.py プロジェクト: hc10024/dask
def test_elemwise_consistent_names():
    a = da.from_array(np.arange(5, dtype='f4'), chunks=(2,))
    b = da.from_array(np.arange(5, dtype='f4'), chunks=(2,))
    assert same_keys(a + b, a + b)
    assert same_keys(a + 2, a + 2)
    assert same_keys(da.exp(a), da.exp(a))
    assert same_keys(da.exp(a, dtype='f8'), da.exp(a, dtype='f8'))
    assert same_keys(da.maximum(a, b), da.maximum(a, b))
コード例 #4
0
ファイル: test_filter.py プロジェクト: auag92/summerOfCode
def _response(x_data, n_space, n_state):
    return pipe(
        np.linspace(0, 1, n_state),
        lambda h: da.maximum(1 - abs(x_data[:, :, None] - h) /
                             (h[1] - h[0]), 0), dafft(axis=1),
        lambda fx: da.sum(_fcoeff(n_space, n_state)[None] * fx, axis=-1),
        daifft(axis=1)).real
コード例 #5
0
ファイル: bases.py プロジェクト: auag92/summerOfCode
def discretize(x_data, n_state, min_=0.0, max_=1.0, chunks=()):
    """Primitive discretization of a microstructure.

    Args:
      x_data: the data to discrtize
      n_state: the number of local states
      min_: the minimum local state
      max_: the maximum local state

    Returns:
      the discretized microstructure

    >>> discretize(da.random.random((12, 9), chunks=(3, 9)),
    ...            3,
    ...            chunks=(1,)).chunks
    ((3, 3, 3, 3), (9,), (1, 1, 1))

    >>> discretize(np.array([[0, 1], [0.5, 0.5]]), 3, chunks=(1,)).chunks
    ((2,), (2,), (1, 1, 1))

    >>> discretize(np.array([[0, 1], [0.5, 0.5]]), 3, chunks=(1,)).compute()
    array([[[ 1.,  0.,  0.],
            [ 0.,  0.,  1.]],
    <BLANKLINE>
           [[ 0.,  1.,  0.],
            [ 0.,  1.,  0.]]])
    """
    return da.maximum(
        discretize_nomax(
            da.clip(x_data, min_, max_),
            da.linspace(min_, max_, n_state, chunks=chunks or (n_state, ))), 0)
コード例 #6
0
def euclidean_distances(X,
                        Y=None,
                        Y_norm_squared=None,
                        squared=False,
                        X_norm_squared=None):
    if X_norm_squared is not None:
        XX = X_norm_squared
        if XX.shape == (1, X.shape[0]):
            XX = XX.T
        elif XX.shape != (X.shape[0], 1):
            raise ValueError(
                "Incompatible dimensions for X and X_norm_squared")
    else:
        XX = row_norms(X, squared=True)[:, np.newaxis]
    if X is Y:
        YY = XX.T
    elif Y_norm_squared is not None:
        if Y_norm_squared.ndim < 2:
            YY = Y_norm_squared[:, np.newaxis]
        else:
            YY = Y_norm_squared
        if YY.shape != (1, Y.shape[0]):
            raise ValueError(
                "Incompatiable dimensions for Y and Y_norm_squared")
    else:
        YY = row_norms(Y, squared=True)[np.newaxis, :]

    # TODO: this often emits a warning. Silence it here?
    distances = -2 * X.dot(Y.T) + XX + YY
    distances = da.maximum(distances, 0)
    # TODO: scikit-learn sets the diagonal to 0 when X is Y.

    return distances if squared else da.sqrt(distances)
コード例 #7
0
    def get_value(self, group, corr, extras, flag, flag_row, chanslice):
        coldata = self.get_column_data(group)
        # correlation may be pre-set by plot type, or may be passed to us
        corr = self.corr if self.corr is not None else corr
        # apply correlation reduction
        if coldata is not None and coldata.ndim == 3:
            assert corr is not None
            # the mapper can't have a specific axis set
            if self.mapper.axis is not None:
                raise TypeError(f"{self.name}: unexpected column with ndim=3")
            coldata = self.ms.corr_data_mappers[corr](coldata)
        # apply mapping function
        coldata = self.mapper.mapper(
            coldata, **{name: extras[name]
                        for name in self.mapper.extras})
        # scalar expanded to row vector
        if numpy.isscalar(coldata):
            coldata = da.full_like(flag_row,
                                   fill_value=coldata,
                                   dtype=type(coldata))
            flag = flag_row
        else:
            # apply channel slicing, if there's a channel axis in the array (and the array is a DataArray)
            if type(coldata) is xarray.DataArray and 'chan' in coldata.dims:
                coldata = coldata[dict(chan=chanslice)]
            # determine flags -- start with original flags
            if flag is not None:
                if coldata.ndim == 2:
                    flag = self.ms.corr_flag_mappers[corr](flag)
                elif coldata.ndim == 1:
                    if not self.mapper.axis:
                        flag = flag_row
                    elif self.mapper.axis == 1:
                        flag = None
                # shapes must now match
                if flag is not None and coldata.shape != flag.shape:
                    raise TypeError(f"{self.name}: unexpected column shape")
        # discretize
        if self.nlevels:
            # minmax set? discretize over that
            if self.discretized_delta is not None:
                coldata = da.floor(
                    (coldata - self.minmax[0]) / self.discretized_delta)
                coldata = da.minimum(da.maximum(coldata, 0),
                                     self.nlevels - 1).astype(COUNT_DTYPE)
            else:
                if coldata.dtype is bool:
                    if not numpy.issubdtype(coldata.dtype, numpy.integer):
                        raise TypeError(
                            f"{self.name}: min/max must be set to colour by non-integer values"
                        )
                    coldata = da.remainder(coldata,
                                           self.nlevels).astype(COUNT_DTYPE)

        if flag is not None:
            flag |= ~da.isfinite(coldata)
            return dama.masked_array(coldata, flag)
        else:
            return dama.masked_array(coldata, ~da.isfinite(coldata))
コード例 #8
0
ファイル: brdf.py プロジェクト: jgrss/geowombat
    def get_distance(tan1, tan2, cos3):
        """
        Gets distance component of Li kernels
        """

        temp = tan1 * tan1 + tan2 * tan2 - 2.0 * tan1 * tan2 * cos3

        return da.sqrt(da.maximum(temp, 0))
コード例 #9
0
def searchdaskuniform(a0, step, n_a, v, how=None, atol=None):
    index = (v - a0) / step
    if how == 'nearest':
        indexer = da.maximum(da.minimum(da.around(index), n_a - 1), 0)
    elif how == 'bfill':
        indexer = da.maximum(da.ceil(index), 0)
    elif how == 'ffill':
        indexer = da.minimum(da.floor(index), n_a - 1)
    elif how is None:
        indexer = da.ceil(index)
        indexer = da.where(indexer != index, n_a, indexer)

    if atol is not None:
        indexer = da.where((da.absolute(indexer - index) * step > atol) |
                           (indexer < 0) | (indexer >= n_a), n_a, indexer)
    else:
        indexer = da.where((indexer < 0) | (indexer >= n_a), n_a, indexer)
    return indexer.astype(int)
コード例 #10
0
ファイル: ufuncs.py プロジェクト: TCvanLeth/PyHAD
def count(a):
    """counts number of distinct "on" switches in a boolean 1d array
    and assigns them cumulatively
    """
    a = a.astype(int)
    count = da.cumsum(da.insert(da.maximum(a[1:] - a[:-1], 0), a[0], 0,
                                axis=0),
                      axis=0)
    return da.where(a == True, count, 0)
コード例 #11
0
ファイル: k_means.py プロジェクト: sb123456789sb/dask-ml
def _kmeans_single_lloyd(X, n_clusters, max_iter=300, init='k-means||',
                         verbose=False, x_squared_norms=None,
                         random_state=None, tol=1e-4,
                         precompute_distances=True,
                         oversampling_factor=2,
                         init_max_iter=None):
    centers = k_init(X, n_clusters, init=init,
                     oversampling_factor=oversampling_factor,
                     random_state=random_state, max_iter=init_max_iter)
    dt = X.dtype
    P = X.shape[1]
    for i in range(max_iter):
        t0 = tic()
        labels, distances = pairwise_distances_argmin_min(
            X, centers, metric='euclidean', metric_kwargs={"squared": True}
        )

        labels = labels.astype(np.int32)
        # distances is always float64, but we need it to match X.dtype
        # for centers_dense, but remain float64 for inertia
        r = da.atop(_centers_dense, 'ij',
                    X, 'ij',
                    labels, 'i',
                    n_clusters, None,
                    distances.astype(X.dtype), 'i',
                    adjust_chunks={"i": n_clusters, "j": P},
                    dtype=X.dtype)
        new_centers = da.from_delayed(
            sum(r.to_delayed().flatten()),
            (n_clusters, P),
            X.dtype
        )
        counts = da.bincount(labels, minlength=n_clusters)
        # Require at least one per bucket, to avoid division by 0.
        counts = da.maximum(counts, 1)
        new_centers = new_centers / counts[:, None]
        new_centers, = compute(new_centers)

        # Convergence check
        shift = squared_norm(centers - new_centers)
        t1 = tic()
        logger.info("Lloyd loop %2d. Shift: %0.4f [%.2f s]", i, shift, t1 - t0)
        if shift < tol:
            break
        centers = new_centers

    if shift > 1e-7:
        labels, distances = pairwise_distances_argmin_min(X, centers)

    inertia = distances.sum()
    centers = centers.astype(dt)

    return labels, inertia, centers, i + 1
コード例 #12
0
def presearch(a, v):
    if a.ndim > 1:
        a = a.squeeze()

    n_a = a.shape[0]
    step_a = tuple(max(x) for x in a.chunks)[0]
    if not any(ichunk[0] < ishape
               for ichunk, ishape in zip(v.chunks, v.shape)):
        if a.chunks[0][0] >= n_a:
            args = (a, v)
            searchfunc = searchsingle
        else:
            a_block = a[np.arange(0, n_a, step_a)].rechunk(len(a.chunks[0]))
            b_indxrs = da.maximum(searchsingle(a_block, v) - 1, 0).compute()
            if v.size == 1:
                args = (a, v, b_indxrs[0], step_a)
                searchfunc = searchblock
            else:
                firsts = np.insert(np.diff(b_indxrs).nonzero()[0] + 1, 0, 0)
                b_indxrs = b_indxrs[firsts]
                firsts = np.append(firsts, len(v))
                slicers = [
                    slice(firsts[i], firsts[i + 1])
                    for i in range(len(firsts) - 1)
                ]
                args = (a, v, b_indxrs, step_a, slicers)
                searchfunc = searchblocks
    else:
        step_v = tuple(max(x) for x in v.chunks)[0]

        a_block = a[np.arange(0, n_a, step_a)].rechunk(len(a.chunks[0]))
        v_block = v[np.arange(0, v.shape[0], step_v)].rechunk(len(v.chunks[0]))
        b_indxrs = da.maximum(searchsingle(a_block, v_block) - 1, 0).compute()
        b_indxrs = np.append(b_indxrs, len(a_block) - 1)
        args = (a, v, b_indxrs, step_a, a_block.compute())
        searchfunc = searchindexblocks
    return searchfunc, args
コード例 #13
0
ファイル: brdf.py プロジェクト: jgrss/geowombat
    def get_overlap(cos1, cos2, tan1, tan2, sin3, distance, hb, m_pi):
        """
        Applies the HB ratio transformation
        """

        OverlapInfo = namedtuple('OverlapInfo', 'tvar sint overlap temp')

        temp = (1.0 / cos1) + (1.0 / cos2)

        cost = da.clip(
            hb * da.sqrt(distance * distance +
                         tan1 * tan1 * tan2 * tan2 * sin3 * sin3) / temp, -1,
            1)

        tvar = da.arccos(cost)
        sint = da.sin(tvar)

        overlap = 1.0 / m_pi * (tvar - sint * cost) * temp
        overlap = da.maximum(overlap, 0)

        return OverlapInfo(tvar=tvar, sint=sint, overlap=overlap, temp=temp)
コード例 #14
0
def discretize(x_data, n_state=2, min_=0.0, max_=1.0, chunks=None):
    """Primitive discretization of a microstructure.

    Args:
      x_data: the data to discretize
      n_state: the number of local states
      min_: the minimum local state
      max_: the maximum local state
      chunks: chunks size for state axis

    Returns:
      the discretized microstructure

    >>> discretize(da.random.random((12, 9), chunks=(3, 9)),
    ...            3,
    ...            chunks=1).chunks
    ((3, 3, 3, 3), (9,), (1, 1, 1))

    >>> discretize(np.array([[0, 1], [0.5, 0.5]]), 3, chunks=1).chunks
    ((2,), (2,), (1, 1, 1))

    >>> assert np.allclose(
    ...     discretize(
    ...         np.array([[0, 1], [0.5, 0.5]]),
    ...         3,
    ...         chunks=1
    ...     ).compute(),
    ...     [[[1, 0, 0], [0, 0, 1]], [[0, 1, 0], [0, 1, 0]]]
    ... )
    """
    return da.maximum(
        discretize_nomax(
            da.clip(x_data, min_, max_),
            da.linspace(min_, max_, n_state, chunks=(chunks or n_state,)),
        ),
        0,
    )
コード例 #15
0
def test_arithmetic():
    x = np.arange(5).astype('f4') + 2
    y = np.arange(5).astype('i8') + 2
    z = np.arange(5).astype('i4') + 2
    a = da.from_array(x, chunks=(2, ))
    b = da.from_array(y, chunks=(2, ))
    c = da.from_array(z, chunks=(2, ))
    assert eq(a + b, x + y)
    assert eq(a * b, x * y)
    assert eq(a - b, x - y)
    assert eq(a / b, x / y)
    assert eq(b & b, y & y)
    assert eq(b | b, y | y)
    assert eq(b ^ b, y ^ y)
    assert eq(a // b, x // y)
    assert eq(a**b, x**y)
    assert eq(a % b, x % y)
    assert eq(a > b, x > y)
    assert eq(a < b, x < y)
    assert eq(a >= b, x >= y)
    assert eq(a <= b, x <= y)
    assert eq(a == b, x == y)
    assert eq(a != b, x != y)

    assert eq(a + 2, x + 2)
    assert eq(a * 2, x * 2)
    assert eq(a - 2, x - 2)
    assert eq(a / 2, x / 2)
    assert eq(b & True, y & True)
    assert eq(b | True, y | True)
    assert eq(b ^ True, y ^ True)
    assert eq(a // 2, x // 2)
    assert eq(a**2, x**2)
    assert eq(a % 2, x % 2)
    assert eq(a > 2, x > 2)
    assert eq(a < 2, x < 2)
    assert eq(a >= 2, x >= 2)
    assert eq(a <= 2, x <= 2)
    assert eq(a == 2, x == 2)
    assert eq(a != 2, x != 2)

    assert eq(2 + b, 2 + y)
    assert eq(2 * b, 2 * y)
    assert eq(2 - b, 2 - y)
    assert eq(2 / b, 2 / y)
    assert eq(True & b, True & y)
    assert eq(True | b, True | y)
    assert eq(True ^ b, True ^ y)
    assert eq(2 // b, 2 // y)
    assert eq(2**b, 2**y)
    assert eq(2 % b, 2 % y)
    assert eq(2 > b, 2 > y)
    assert eq(2 < b, 2 < y)
    assert eq(2 >= b, 2 >= y)
    assert eq(2 <= b, 2 <= y)
    assert eq(2 == b, 2 == y)
    assert eq(2 != b, 2 != y)

    assert eq(-a, -x)
    assert eq(abs(a), abs(x))
    assert eq(~(a == b), ~(x == y))
    assert eq(~(a == b), ~(x == y))

    assert eq(da.logaddexp(a, b), np.logaddexp(x, y))
    assert eq(da.logaddexp2(a, b), np.logaddexp2(x, y))
    assert eq(da.exp(b), np.exp(y))
    assert eq(da.log(a), np.log(x))
    assert eq(da.log10(a), np.log10(x))
    assert eq(da.log1p(a), np.log1p(x))
    assert eq(da.expm1(b), np.expm1(y))
    assert eq(da.sqrt(a), np.sqrt(x))
    assert eq(da.square(a), np.square(x))

    assert eq(da.sin(a), np.sin(x))
    assert eq(da.cos(b), np.cos(y))
    assert eq(da.tan(a), np.tan(x))
    assert eq(da.arcsin(b / 10), np.arcsin(y / 10))
    assert eq(da.arccos(b / 10), np.arccos(y / 10))
    assert eq(da.arctan(b / 10), np.arctan(y / 10))
    assert eq(da.arctan2(b * 10, a), np.arctan2(y * 10, x))
    assert eq(da.hypot(b, a), np.hypot(y, x))
    assert eq(da.sinh(a), np.sinh(x))
    assert eq(da.cosh(b), np.cosh(y))
    assert eq(da.tanh(a), np.tanh(x))
    assert eq(da.arcsinh(b * 10), np.arcsinh(y * 10))
    assert eq(da.arccosh(b * 10), np.arccosh(y * 10))
    assert eq(da.arctanh(b / 10), np.arctanh(y / 10))
    assert eq(da.deg2rad(a), np.deg2rad(x))
    assert eq(da.rad2deg(a), np.rad2deg(x))

    assert eq(da.logical_and(a < 1, b < 4), np.logical_and(x < 1, y < 4))
    assert eq(da.logical_or(a < 1, b < 4), np.logical_or(x < 1, y < 4))
    assert eq(da.logical_xor(a < 1, b < 4), np.logical_xor(x < 1, y < 4))
    assert eq(da.logical_not(a < 1), np.logical_not(x < 1))
    assert eq(da.maximum(a, 5 - a), np.maximum(a, 5 - a))
    assert eq(da.minimum(a, 5 - a), np.minimum(a, 5 - a))
    assert eq(da.fmax(a, 5 - a), np.fmax(a, 5 - a))
    assert eq(da.fmin(a, 5 - a), np.fmin(a, 5 - a))

    assert eq(da.isreal(a + 1j * b), np.isreal(x + 1j * y))
    assert eq(da.iscomplex(a + 1j * b), np.iscomplex(x + 1j * y))
    assert eq(da.isfinite(a), np.isfinite(x))
    assert eq(da.isinf(a), np.isinf(x))
    assert eq(da.isnan(a), np.isnan(x))
    assert eq(da.signbit(a - 3), np.signbit(x - 3))
    assert eq(da.copysign(a - 3, b), np.copysign(x - 3, y))
    assert eq(da.nextafter(a - 3, b), np.nextafter(x - 3, y))
    assert eq(da.ldexp(c, c), np.ldexp(z, z))
    assert eq(da.fmod(a * 12, b), np.fmod(x * 12, y))
    assert eq(da.floor(a * 0.5), np.floor(x * 0.5))
    assert eq(da.ceil(a), np.ceil(x))
    assert eq(da.trunc(a / 2), np.trunc(x / 2))

    assert eq(da.degrees(b), np.degrees(y))
    assert eq(da.radians(a), np.radians(x))

    assert eq(da.rint(a + 0.3), np.rint(x + 0.3))
    assert eq(da.fix(a - 2.5), np.fix(x - 2.5))

    assert eq(da.angle(a + 1j), np.angle(x + 1j))
    assert eq(da.real(a + 1j), np.real(x + 1j))
    assert eq((a + 1j).real, np.real(x + 1j))
    assert eq(da.imag(a + 1j), np.imag(x + 1j))
    assert eq((a + 1j).imag, np.imag(x + 1j))
    assert eq(da.conj(a + 1j * b), np.conj(x + 1j * y))
    assert eq((a + 1j * b).conj(), (x + 1j * y).conj())

    assert eq(da.clip(b, 1, 4), np.clip(y, 1, 4))
    assert eq(da.fabs(b), np.fabs(y))
    assert eq(da.sign(b - 2), np.sign(y - 2))

    l1, l2 = da.frexp(a)
    r1, r2 = np.frexp(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    l1, l2 = da.modf(a)
    r1, r2 = np.modf(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    assert eq(da.around(a, -1), np.around(x, -1))
コード例 #16
0
ファイル: regression.py プロジェクト: hristog/dask-ml
def mean_absolute_percentage_error(
    y_true: ArrayLike,
    y_pred: ArrayLike,
    sample_weight: Optional[ArrayLike] = None,
    multioutput: Optional[str] = "uniform_average",
    compute: bool = True,
) -> ArrayLike:
    """Mean absolute percentage error regression loss.

    Note here that we do not represent the output as a percentage in range
    [0, 100]. Instead, we represent it in range [0, 1/eps]. Read more in
    https://scikit-learn.org/stable/modules/model_evaluation.html#mean-absolute-percentage-error

    Parameters
    ----------
    y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
        Ground truth (correct) target values.
    y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
        Estimated target values.
    sample_weight : array-like of shape (n_samples,), default=None
        Sample weights.
    multioutput : {'raw_values', 'uniform_average'} or array-like
        Defines aggregating of multiple output values.
        Array-like value defines weights used to average errors.
        If input is list then the shape must be (n_outputs,).
        'raw_values' :
            Returns a full set of errors in case of multioutput input.
        'uniform_average' :
            Errors of all outputs are averaged with uniform weight.
    compute : bool
        Whether to compute this result (default ``True``)

    Returns
    -------
    loss : float or array-like of floats in the range [0, 1/eps]
        If multioutput is 'raw_values', then mean absolute percentage error
        is returned for each output separately.
        If multioutput is 'uniform_average' or ``None``, then the
        equally-weighted average of all output errors is returned.
        MAPE output is non-negative floating point. The best value is 0.0.
        But note the fact that bad predictions can lead to arbitarily large
        MAPE values, especially if some y_true values are very close to zero.
        Note that we return a large value instead of `inf` when y_true is zero.
    """
    _check_sample_weight(sample_weight)
    epsilon = np.finfo(np.float64).eps
    mape = abs(y_pred - y_true) / da.maximum(y_true, epsilon)
    output_errors = mape.mean(axis=0)

    if isinstance(multioutput, str) or multioutput is None:
        if multioutput == "raw_values":
            if compute:
                return output_errors.compute()
            else:
                return output_errors
    else:
        raise ValueError("Weighted 'multioutput' not supported.")
    result = output_errors.mean()
    if compute:
        result = result.compute()
    return result
コード例 #17
0
ファイル: cones.py プロジェクト: bungun/scs-dask
def project_cone(K, x):
    return da.maximum(x, 0)
コード例 #18
0
def relu(input_samples):
    # da.maximum(input_samples, 0, out= input_samples)  # This might be faster as it puts result in same variable.
    return da.maximum(input_samples, 0)
コード例 #19
0
ファイル: test_array_core.py プロジェクト: hc10024/dask
def test_arithmetic():
    x = np.arange(5).astype('f4') + 2
    y = np.arange(5).astype('i8') + 2
    z = np.arange(5).astype('i4') + 2
    a = da.from_array(x, chunks=(2,))
    b = da.from_array(y, chunks=(2,))
    c = da.from_array(z, chunks=(2,))
    assert eq(a + b, x + y)
    assert eq(a * b, x * y)
    assert eq(a - b, x - y)
    assert eq(a / b, x / y)
    assert eq(b & b, y & y)
    assert eq(b | b, y | y)
    assert eq(b ^ b, y ^ y)
    assert eq(a // b, x // y)
    assert eq(a ** b, x ** y)
    assert eq(a % b, x % y)
    assert eq(a > b, x > y)
    assert eq(a < b, x < y)
    assert eq(a >= b, x >= y)
    assert eq(a <= b, x <= y)
    assert eq(a == b, x == y)
    assert eq(a != b, x != y)

    assert eq(a + 2, x + 2)
    assert eq(a * 2, x * 2)
    assert eq(a - 2, x - 2)
    assert eq(a / 2, x / 2)
    assert eq(b & True, y & True)
    assert eq(b | True, y | True)
    assert eq(b ^ True, y ^ True)
    assert eq(a // 2, x // 2)
    assert eq(a ** 2, x ** 2)
    assert eq(a % 2, x % 2)
    assert eq(a > 2, x > 2)
    assert eq(a < 2, x < 2)
    assert eq(a >= 2, x >= 2)
    assert eq(a <= 2, x <= 2)
    assert eq(a == 2, x == 2)
    assert eq(a != 2, x != 2)

    assert eq(2 + b, 2 + y)
    assert eq(2 * b, 2 * y)
    assert eq(2 - b, 2 - y)
    assert eq(2 / b, 2 / y)
    assert eq(True & b, True & y)
    assert eq(True | b, True | y)
    assert eq(True ^ b, True ^ y)
    assert eq(2 // b, 2 // y)
    assert eq(2 ** b, 2 ** y)
    assert eq(2 % b, 2 % y)
    assert eq(2 > b, 2 > y)
    assert eq(2 < b, 2 < y)
    assert eq(2 >= b, 2 >= y)
    assert eq(2 <= b, 2 <= y)
    assert eq(2 == b, 2 == y)
    assert eq(2 != b, 2 != y)

    assert eq(-a, -x)
    assert eq(abs(a), abs(x))
    assert eq(~(a == b), ~(x == y))
    assert eq(~(a == b), ~(x == y))

    assert eq(da.logaddexp(a, b), np.logaddexp(x, y))
    assert eq(da.logaddexp2(a, b), np.logaddexp2(x, y))
    assert eq(da.exp(b), np.exp(y))
    assert eq(da.log(a), np.log(x))
    assert eq(da.log10(a), np.log10(x))
    assert eq(da.log1p(a), np.log1p(x))
    assert eq(da.expm1(b), np.expm1(y))
    assert eq(da.sqrt(a), np.sqrt(x))
    assert eq(da.square(a), np.square(x))

    assert eq(da.sin(a), np.sin(x))
    assert eq(da.cos(b), np.cos(y))
    assert eq(da.tan(a), np.tan(x))
    assert eq(da.arcsin(b/10), np.arcsin(y/10))
    assert eq(da.arccos(b/10), np.arccos(y/10))
    assert eq(da.arctan(b/10), np.arctan(y/10))
    assert eq(da.arctan2(b*10, a), np.arctan2(y*10, x))
    assert eq(da.hypot(b, a), np.hypot(y, x))
    assert eq(da.sinh(a), np.sinh(x))
    assert eq(da.cosh(b), np.cosh(y))
    assert eq(da.tanh(a), np.tanh(x))
    assert eq(da.arcsinh(b*10), np.arcsinh(y*10))
    assert eq(da.arccosh(b*10), np.arccosh(y*10))
    assert eq(da.arctanh(b/10), np.arctanh(y/10))
    assert eq(da.deg2rad(a), np.deg2rad(x))
    assert eq(da.rad2deg(a), np.rad2deg(x))

    assert eq(da.logical_and(a < 1, b < 4), np.logical_and(x < 1, y < 4))
    assert eq(da.logical_or(a < 1, b < 4), np.logical_or(x < 1, y < 4))
    assert eq(da.logical_xor(a < 1, b < 4), np.logical_xor(x < 1, y < 4))
    assert eq(da.logical_not(a < 1), np.logical_not(x < 1))
    assert eq(da.maximum(a, 5 - a), np.maximum(a, 5 - a))
    assert eq(da.minimum(a, 5 - a), np.minimum(a, 5 - a))
    assert eq(da.fmax(a, 5 - a), np.fmax(a, 5 - a))
    assert eq(da.fmin(a, 5 - a), np.fmin(a, 5 - a))

    assert eq(da.isreal(a + 1j * b), np.isreal(x + 1j * y))
    assert eq(da.iscomplex(a + 1j * b), np.iscomplex(x + 1j * y))
    assert eq(da.isfinite(a), np.isfinite(x))
    assert eq(da.isinf(a), np.isinf(x))
    assert eq(da.isnan(a), np.isnan(x))
    assert eq(da.signbit(a - 3), np.signbit(x - 3))
    assert eq(da.copysign(a - 3, b), np.copysign(x - 3, y))
    assert eq(da.nextafter(a - 3, b), np.nextafter(x - 3, y))
    assert eq(da.ldexp(c, c), np.ldexp(z, z))
    assert eq(da.fmod(a * 12, b), np.fmod(x * 12, y))
    assert eq(da.floor(a * 0.5), np.floor(x * 0.5))
    assert eq(da.ceil(a), np.ceil(x))
    assert eq(da.trunc(a / 2), np.trunc(x / 2))

    assert eq(da.degrees(b), np.degrees(y))
    assert eq(da.radians(a), np.radians(x))

    assert eq(da.rint(a + 0.3), np.rint(x + 0.3))
    assert eq(da.fix(a - 2.5), np.fix(x - 2.5))

    assert eq(da.angle(a + 1j), np.angle(x + 1j))
    assert eq(da.real(a + 1j), np.real(x + 1j))
    assert eq((a + 1j).real, np.real(x + 1j))
    assert eq(da.imag(a + 1j), np.imag(x + 1j))
    assert eq((a + 1j).imag, np.imag(x + 1j))
    assert eq(da.conj(a + 1j * b), np.conj(x + 1j * y))
    assert eq((a + 1j * b).conj(), (x + 1j * y).conj())

    assert eq(da.clip(b, 1, 4), np.clip(y, 1, 4))
    assert eq(da.fabs(b), np.fabs(y))
    assert eq(da.sign(b - 2), np.sign(y - 2))

    l1, l2 = da.frexp(a)
    r1, r2 = np.frexp(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    l1, l2 = da.modf(a)
    r1, r2 = np.modf(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    assert eq(da.around(a, -1), np.around(x, -1))
コード例 #20
0
    def volume_curvature(self,
                         darray_il,
                         darray_xl,
                         dip_factor=10,
                         kernel=(3, 3, 3),
                         preview=None):
        """
        Description
        -----------
        Compute volume curvature attributes from 3D seismic dips
        
        Parameters
        ----------
        darray_il : Array-like, Inline dip - acceptable inputs include 
            Numpy, HDF5, or Dask Arrays
        darray_xl : Array-like, Crossline dip - acceptable inputs include 
            Numpy, HDF5, or Dask Arrays
        
        Keywork Arguments
        -----------------  
        dip_factor : Number, scalar for dip values
        kernel : tuple (len 3), operator size
        preview : str, enables or disables preview mode and specifies direction
            Acceptable inputs are (None, 'inline', 'xline', 'z')
            Optimizes chunk size in different orientations to facilitate rapid
            screening of algorithm output
        
        Returns
        -------
        H, K, Kmax, Kmin, KMPos, KMNeg : Dask Array, {H : 'Mean Curvature', 
                                                      K : 'Gaussian Curvature',
                                                      Kmax : 'Max Curvature',
                                                      Kmin : 'Min Curvature',
                                                      KMPos : Most Positive Curvature,
                                                      KMNeg : Most Negative Curvature}
        """

        np.seterr(all='ignore')

        # Generate Dask Array as necessary
        darray_il, chunks_init = self.create_array(darray_il,
                                                   kernel,
                                                   preview=preview)
        darray_xl, chunks_init = self.create_array(darray_xl,
                                                   kernel,
                                                   preview=preview)

        u = -darray_il / dip_factor
        v = -darray_xl / dip_factor
        w = da.ones_like(u, chunks=u.chunks)

        # Compute Gradients
        ux = sp().first_derivative(u, axis=0)
        uy = sp().first_derivative(u, axis=1)
        uz = sp().first_derivative(u, axis=2)
        vx = sp().first_derivative(v, axis=0)
        vy = sp().first_derivative(v, axis=1)
        vz = sp().first_derivative(v, axis=2)

        # Smooth Gradients
        ux = ux.map_blocks(ndi.uniform_filter, size=kernel, dtype=ux.dtype)
        uy = uy.map_blocks(ndi.uniform_filter, size=kernel, dtype=ux.dtype)
        uz = uz.map_blocks(ndi.uniform_filter, size=kernel, dtype=ux.dtype)
        vx = vx.map_blocks(ndi.uniform_filter, size=kernel, dtype=ux.dtype)
        vy = vy.map_blocks(ndi.uniform_filter, size=kernel, dtype=ux.dtype)
        vz = vz.map_blocks(ndi.uniform_filter, size=kernel, dtype=ux.dtype)

        u = util.trim_dask_array(u, kernel)
        v = util.trim_dask_array(v, kernel)
        w = util.trim_dask_array(w, kernel)
        ux = util.trim_dask_array(ux, kernel)
        uy = util.trim_dask_array(uy, kernel)
        uz = util.trim_dask_array(uz, kernel)
        vx = util.trim_dask_array(vx, kernel)
        vy = util.trim_dask_array(vy, kernel)
        vz = util.trim_dask_array(vz, kernel)

        wx = da.zeros_like(ux, chunks=ux.chunks, dtype=ux.dtype)
        wy = da.zeros_like(ux, chunks=ux.chunks, dtype=ux.dtype)
        wz = da.zeros_like(ux, chunks=ux.chunks, dtype=ux.dtype)

        uv = u * v
        vw = v * w
        u2 = u * u
        v2 = v * v
        w2 = w * w
        u2pv2 = u2 + v2
        v2pw2 = v2 + w2
        s = da.sqrt(u2pv2 + w2)

        # Measures of surfaces
        E = da.ones_like(u, chunks=u.chunks, dtype=u.dtype)
        F = -(u * w) / (da.sqrt(u2pv2) * da.sqrt(v2pw2))
        G = da.ones_like(u, chunks=u.chunks, dtype=u.dtype)
        D = -(-uv * vx + u2 * vy + v2 * ux - uv * uy) / (u2pv2 * s)
        Di = -(vw * (uy + vx) - 2 * u * w * vy - v2 * (uz + wx) + uv *
               (vz + wy)) / (2 * da.sqrt(u2pv2) * da.sqrt(v2pw2) * s)
        Dii = -(-vw * wy + v2 * wz + w2 * vy - vw * vz) / (v2pw2 * s)
        H = (E * Dii - 2 * F * Di + G * D) / (2 * (E * G - F * F))
        K = (D * Dii - Di * Di) / (E * G - F * F)
        Kmin = H - da.sqrt(H * H - K)
        Kmax = H + da.sqrt(H * H - K)

        H[da.isnan(H)] = 0
        K[da.isnan(K)] = 0
        Kmax[da.isnan(Kmax)] = 0
        Kmin[da.isnan(Kmin)] = 0

        KMPos = da.maximum(Kmax, Kmin)
        KMNeg = da.minimum(Kmax, Kmin)

        return (H, K, Kmax, Kmin, KMPos, KMNeg)
コード例 #21
0
ファイル: cones.py プロジェクト: bungun/scs-dask
def project_cone(K, x):
    assert x.size == K.dim**2, 'input dimension compatible'
    chunks = x[:K.dim].chunks[0]
    X = da.reshape(x, (K.dim, K.dim)).rechunk((chunks, (K.dim, )))
    U, S, V = da.linalg.svd(da.reshape(x, (K.dim, K.dim)))
    return U.dot(da.maximum(0, S).reshape(-1, 1) * V).reshape(-1)
コード例 #22
0
ファイル: pfbclean.py プロジェクト: ratt-ru/pfb-clean
def _main(dest=sys.stdout):
    from pfb.parser import create_parser
    args = create_parser().parse_args()

    if not args.nthreads:
        import multiprocessing
        args.nthreads = multiprocessing.cpu_count()

    if not args.mem_limit:
        import psutil
        args.mem_limit = int(psutil.virtual_memory()[0] /
                             1e9)  # 100% of memory by default

    import numpy as np
    import numba
    import numexpr
    import dask
    import dask.array as da
    from daskms import xds_from_ms, xds_from_table
    from astropy.io import fits
    from pfb.utils.fits import (set_wcs, load_fits, save_fits, compare_headers,
                                data_from_header)
    from pfb.utils.restoration import fitcleanbeam
    from pfb.utils.misc import Gaussian2D
    from pfb.operators.gridder import Gridder
    from pfb.operators.psf import PSF
    from pfb.deconv.sara import sara
    from pfb.deconv.clean import clean
    from pfb.deconv.spotless import spotless
    from pfb.deconv.nnls import nnls
    from pfb.opt.pcg import pcg

    if not isinstance(args.ms, list):
        args.ms = [args.ms]

    pyscilog.log_to_file(args.outfile + '.log')
    pyscilog.enable_memory_logging(level=3)

    GD = vars(args)
    print('Input Options:', file=log)
    for key in GD.keys():
        print('     %25s = %s' % (key, GD[key]), file=log)

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    all_freqs = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                          columns=('UVW'),
                          chunks={'row': args.row_chunks})

        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        spws = dask.compute(spws)[0]

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

            spw = spws[ds.DATA_DESC_ID]
            tmp_freq = spw.CHAN_FREQ.data.squeeze()
            all_freqs.append(list([tmp_freq]))

    uv_max = u_max.compute()
    del uvw

    # get Nyquist cell size
    from africanus.constants import c as lightspeed
    all_freqs = dask.compute(all_freqs)
    freq = np.unique(all_freqs)
    cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed)

    if args.cell_size is not None:
        cell_rad = args.cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=dest)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        args.cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % args.cell_size, file=dest)

    if args.nx is None or args.ny is None:
        from ducc0.fft import good_size
        fov = args.fov * 3600
        npix = int(fov / args.cell_size)
        if npix % 2:
            npix += 1
        args.nx = good_size(npix)
        args.ny = good_size(npix)

    if args.nband is None:
        args.nband = freq.size

    print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny),
          file=dest)

    # mask
    if args.mask is not None:
        mask_array = load_fits(args.mask, dtype=args.real_type).squeeze()
        if mask_array.shape != (args.nx, args.ny):
            raise ValueError("Mask has incorrect shape.")
        # add freq axis
        mask_array = mask_array[None]

        def mask(x):
            return mask_array * x
    else:
        mask_array = None

        def mask(x):
            return x

    # init gridder
    R = Gridder(
        args.ms,
        args.nx,
        args.ny,
        args.cell_size,
        nband=args.nband,
        nthreads=args.nthreads,
        do_wstacking=args.do_wstacking,
        row_chunks=args.row_chunks,
        psf_oversize=args.psf_oversize,
        data_column=args.data_column,
        epsilon=args.epsilon,
        weight_column=args.weight_column,
        imaging_weight_column=args.imaging_weight_column,
        model_column=args.model_column,
        flag_column=args.flag_column,
        weighting=args.weighting,
        robust=args.robust,
        mem_limit=int(
            0.8 * args.mem_limit))  # assumes gridding accounts for 80% memory
    freq_out = R.freq_out
    radec = R.radec

    print("PSF size set to (%i, %i, %i)" % (args.nband, R.nx_psf, R.ny_psf),
          file=dest)

    # get headers
    hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                  args.ny, radec, freq_out)
    hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                      args.ny, radec, np.mean(freq_out))
    hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600, R.nx_psf,
                      R.ny_psf, radec, freq_out)
    hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                          R.nx_psf, R.ny_psf, radec, np.mean(freq_out))

    # psf
    if args.psf is not None:
        try:
            compare_headers(hdr_psf, fits.getheader(args.psf))
            psf = load_fits(args.psf, dtype=args.real_type).squeeze()
        except BaseException:
            raise
            psf = R.make_psf()
            save_fits(args.outfile + '_psf.fits', psf, hdr_psf)
    else:
        psf = R.make_psf()
        save_fits(args.outfile + '_psf.fits', psf, hdr_psf)

    # Normalising by wsum (so that the PSF always sums to 1) results in the
    # most intuitive sig_21 values and by far the least bookkeeping.
    # However, we won't save the cubes that way as it destroys information
    # about the noise in image space. Note only the MFS images will have the
    # usual units of Jy/beam.
    wsums = np.amax(psf.reshape(args.nband, R.nx_psf * R.ny_psf), axis=1)
    wsum = np.sum(wsums)
    psf /= wsum
    psf_mfs = np.sum(psf, axis=0)

    # fit restoring psf
    GaussPar = fitcleanbeam(psf_mfs[None], level=0.5, pixsize=1.0)
    GaussPars = fitcleanbeam(psf, level=0.5, pixsize=1.0)

    cpsf_mfs = np.zeros(psf_mfs.shape, dtype=args.real_type)
    cpsf = np.zeros(psf.shape, dtype=args.real_type)

    lpsf = np.arange(-R.nx_psf / 2, R.nx_psf / 2)
    mpsf = np.arange(-R.ny_psf / 2, R.ny_psf / 2)
    xx, yy = np.meshgrid(lpsf, mpsf, indexing='ij')

    cpsf_mfs = Gaussian2D(xx, yy, GaussPar[0], normalise=False)

    for v in range(args.nband):
        cpsf[v] = Gaussian2D(xx, yy, GaussPars[v], normalise=False)

    from pfb.utils.fits import add_beampars
    GaussPar = list(GaussPar[0])
    GaussPar[0] *= args.cell_size / 3600
    GaussPar[1] *= args.cell_size / 3600
    GaussPar = tuple(GaussPar)
    hdr_psf_mfs = add_beampars(hdr_psf_mfs, GaussPar)

    save_fits(args.outfile + '_cpsf_mfs.fits', cpsf_mfs, hdr_psf_mfs)
    save_fits(args.outfile + '_psf_mfs.fits', psf_mfs, hdr_psf_mfs)

    GaussPars = list(GaussPars)
    for b in range(args.nband):
        GaussPars[b] = list(GaussPars[b])
        GaussPars[b][0] *= args.cell_size / 3600
        GaussPars[b][1] *= args.cell_size / 3600
        GaussPars[b] = tuple(GaussPars[b])
    GaussPars = tuple(GaussPars)
    hdr_psf = add_beampars(hdr_psf, GaussPar, GaussPars)

    save_fits(args.outfile + '_cpsf.fits', cpsf, hdr_psf)

    # dirty
    if args.dirty is not None:
        try:
            compare_headers(hdr, fits.getheader(args.dirty))
            dirty = load_fits(args.dirty).squeeze()
        except BaseException:
            raise
            dirty = R.make_dirty()
            save_fits(args.outfile + '_dirty.fits', dirty, hdr)
    else:
        dirty = R.make_dirty()
        save_fits(args.outfile + '_dirty.fits', dirty, hdr)

    dirty /= wsum
    dirty_mfs = np.sum(dirty, axis=0)
    save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs)

    quit()
    # initial model and residual
    if args.x0 is not None:
        try:
            compare_headers(hdr, fits.getheader(args.x0))
            model = load_fits(args.x0, dtype=args.real_type).squeeze()
            if args.first_residual is not None:
                try:
                    compare_headers(hdr, fits.getheader(args.first_residual))
                    residual = load_fits(args.first_residual,
                                         dtype=args.real_type).squeeze()
                except BaseException:
                    residual = R.make_residual(model)
                    save_fits(args.outfile + '_first_residual.fits', residual,
                              hdr)
            else:
                residual = R.make_residual(model)
                save_fits(args.outfile + '_first_residual.fits', residual, hdr)
            residual /= wsum
        except BaseException:
            model = np.zeros((args.nband, args.nx, args.ny))
            residual = dirty.copy()
    else:
        model = np.zeros((args.nband, args.nx, args.ny))
        residual = dirty.copy()

    residual_mfs = np.sum(residual, axis=0)
    save_fits(args.outfile + '_first_residual_mfs.fits', residual_mfs, hdr_mfs)

    # smooth beam
    if args.beam_model is not None:
        if args.beam_model[-5:] == '.fits':
            beam_image = load_fits(args.beam_model,
                                   dtype=args.real_type).squeeze()
            if beam_image.shape != (args.nband, args.nx, args.ny):
                raise ValueError("Beam has incorrect shape")

        elif args.beam_model == "JimBeam":
            from katbeam import JimBeam
            if args.band.lower() == 'l':
                beam = JimBeam('MKAT-AA-L-JIM-2020')
            else:
                beam = JimBeam('MKAT-AA-UHF-JIM-2020')
            beam_image = np.zeros((args.nband, args.nx, args.ny),
                                  dtype=args.real_type)

            l_coord, ref_l = data_from_header(hdr, axis=1)
            l_coord -= ref_l
            m_coord, ref_m = data_from_header(hdr, axis=2)
            m_coord -= ref_m
            xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij')

            for v in range(args.nband):
                beam_image[v] = beam.I(xx, yy, freq_out[v])

        def beam(x):
            return beam_image * x
    else:
        beam_image = None

        def beam(x):
            return x

    if args.init_nnls:
        print("Initialising with NNLS", file=log)
        model = nnls(psf,
                     model,
                     residual,
                     mask=mask_array,
                     beam_image=beam_image,
                     hdr=hdr,
                     hdr_mfs=hdr_mfs,
                     outfile=args.outfile,
                     maxit=1,
                     nthreads=args.nthreads)

        residual = R.make_residual(beam(mask(model))) / wsum
        residual_mfs = np.sum(residual, axis=0)

    # deconvolve
    rmax = np.abs(residual_mfs).max()
    rms = np.std(residual_mfs)
    redo_dirty = False
    print("Peak of initial residual is %f and rms is %f" % (rmax, rms),
          file=dest)
    for i in range(0, args.maxit):
        # run minor cycle of choice
        modelp = model.copy()
        if args.deconv_mode == 'sara':
            model = sara(psf,
                         model,
                         residual,
                         mask=mask_array,
                         beam_image=beam_image,
                         hessian=R.convolve,
                         wsum=wsum,
                         adapt_sig21=args.adapt_sig21,
                         hdr=hdr,
                         hdr_mfs=hdr_mfs,
                         outfile=args.outfile,
                         cpsf=cpsf,
                         nthreads=args.nthreads,
                         sig_21=args.sig_21,
                         sigma_frac=args.sigma_frac,
                         maxit=args.minormaxit,
                         tol=args.minortol,
                         gamma=args.gamma,
                         psi_levels=args.psi_levels,
                         psi_basis=args.psi_basis,
                         pdtol=args.pdtol,
                         pdmaxit=args.pdmaxit,
                         pdverbose=args.pdverbose,
                         positivity=args.positivity,
                         cgtol=args.cgtol,
                         cgminit=args.cgminit,
                         cgmaxit=args.cgmaxit,
                         cgverbose=args.cgverbose,
                         pmtol=args.pmtol,
                         pmmaxit=args.pmmaxit,
                         pmverbose=args.pmverbose)

        elif args.deconv_mode == 'clean':
            model = clean(psf,
                          model,
                          residual,
                          mask=mask_array,
                          beam=beam_image,
                          nthreads=args.nthreads,
                          maxit=args.minormaxit,
                          gamma=args.gamma,
                          peak_factor=args.peak_factor,
                          threshold=args.threshold,
                          hbgamma=args.hbgamma,
                          hbpf=args.hbpf,
                          hbmaxit=args.hbmaxit,
                          hbverbose=args.hbverbose)
        elif args.deconv_mode == 'spotless':
            model = spotless(psf,
                             model,
                             residual,
                             mask=mask_array,
                             beam_image=beam_image,
                             hessian=R.convolve,
                             wsum=wsum,
                             adapt_sig21=args.adapt_sig21,
                             cpsf=cpsf_mfs,
                             hdr=hdr,
                             hdr_mfs=hdr_mfs,
                             outfile=args.outfile,
                             sig_21=args.sig_21,
                             sigma_frac=args.sigma_frac,
                             nthreads=args.nthreads,
                             gamma=args.gamma,
                             peak_factor=args.peak_factor,
                             maxit=args.minormaxit,
                             tol=args.minortol,
                             threshold=args.threshold,
                             positivity=args.positivity,
                             hbgamma=args.hbgamma,
                             hbpf=args.hbpf,
                             hbmaxit=args.hbmaxit,
                             hbverbose=args.hbverbose,
                             pdtol=args.pdtol,
                             pdmaxit=args.pdmaxit,
                             pdverbose=args.pdverbose,
                             cgtol=args.cgtol,
                             cgminit=args.cgminit,
                             cgmaxit=args.cgmaxit,
                             cgverbose=args.cgverbose,
                             pmtol=args.pmtol,
                             pmmaxit=args.pmmaxit,
                             pmverbose=args.pmverbose)
        else:
            raise ValueError("Unknown deconvolution mode ", args.deconv_mode)

        # get residual
        if redo_dirty:
            # Need to do this if weights or Jones has changed
            # (eg. if we change robustness factor, reweight or calibrate)
            psf = R.make_psf()
            wsums = np.amax(psf.reshape(args.nband, R.nx_psf * R.ny_psf),
                            axis=1)
            wsum = np.sum(wsums)
            psf /= wsum
            dirty = R.make_dirty() / wsum

        # compute in image space
        # residual = dirty - R.convolve(beam(mask(model))) / wsum
        residual = R.make_residual(beam(mask(model))) / wsum

        residual_mfs = np.sum(residual, axis=0)

        # save current iteration
        model_mfs = np.mean(model, axis=0)
        save_fits(args.outfile + '_major' + str(i + 1) + '_model_mfs.fits',
                  model_mfs, hdr_mfs)

        save_fits(args.outfile + '_major' + str(i + 1) + '_model.fits', model,
                  hdr)

        save_fits(args.outfile + '_major' + str(i + 1) + '_residual_mfs.fits',
                  residual_mfs, hdr_mfs)

        save_fits(args.outfile + '_major' + str(i + 1) + '_residual.fits',
                  residual * wsum, hdr)

        # check stopping criteria
        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)
        eps = np.linalg.norm(model - modelp) / np.linalg.norm(model)

        print("At iteration %i peak of residual is %f, rms is %f, current "
              "eps is %f" % (i + 1, rmax, rms, eps),
              file=dest)

        if eps < args.tol:
            break

    if args.mop_flux:
        print("Mopping flux", file=dest)

        # vague Gaussian prior on x
        def hess(x):
            return mask(beam(R.convolve(mask(beam(x))))) / wsum + 1e-6 * x

        def M(x):
            return x / 1e-6  # preconditioner

        x = pcg(hess,
                mask(beam(residual)),
                np.zeros(residual.shape, dtype=residual.dtype),
                M=M,
                tol=0.1 * args.cgtol,
                maxit=args.cgmaxit,
                minit=args.cgminit,
                verbosity=args.cgverbose)

        model += x
        # residual = dirty - R.convolve(beam(mask(model))) / wsum
        residual = R.make_residual(beam(mask(model))) / wsum

        save_fits(args.outfile + '_mopped_model.fits', model, hdr)
        save_fits(args.outfile + '_mopped_residual.fits', residual, hdr)
        model_mfs = np.mean(model, axis=0)
        save_fits(args.outfile + '_mopped_model_mfs.fits', model_mfs, hdr_mfs)
        residual_mfs = np.sum(residual, axis=0)
        save_fits(args.outfile + '_mopped_residual_mfs.fits', residual_mfs,
                  hdr_mfs)

        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)

        print("After mopping flux peak of residual is %f, rms is %f" %
              (rmax, rms),
              file=dest)

    # if args.interp_model:
    #     nband = args.nband
    #     order = args.spectral_poly_order
    #     phi.trim_fat(model)
    #     I = np.argwhere(phi.mask).squeeze()
    #     Ix = I[:, 0]
    #     Iy = I[:, 1]
    #     npix = I.shape[0]

    #     # get components
    #     beta = model[:, Ix, Iy]

    #     # fit integrated polynomial to model components
    #     # we are given frequencies at bin centers, convert to bin edges
    #     ref_freq = np.mean(freq_out)
    #     delta_freq = freq_out[1] - freq_out[0]
    #     wlow = (freq_out - delta_freq/2.0)/ref_freq
    #     whigh = (freq_out + delta_freq/2.0)/ref_freq
    #     wdiff = whigh - wlow

    #     # set design matrix for each component
    #     Xdesign = np.zeros([freq_out.size, args.spectral_poly_order])
    #     for i in range(1, args.spectral_poly_order+1):
    #         Xdesign[:, i-1] = (whigh**i - wlow**i)/(i*wdiff)

    #     weights = psf_max[:, None]
    #     dirty_comps = Xdesign.T.dot(weights*beta)

    #     hess_comps = Xdesign.T.dot(weights*Xdesign)

    #     comps = np.linalg.solve(hess_comps, dirty_comps)

    #     np.savez(args.outfile + "spectral_comps", comps=comps, ref_freq=ref_freq, mask=np.any(model, axis=0))

    if args.write_model:
        print("Writing model", file=dest)
        R.write_model(model)

    if args.make_restored:
        print("Making restored", file=dest)
        cpsfo = PSF(cpsf, residual.shape, nthreads=args.nthreads)
        restored = cpsfo.convolve(model)

        # residual needs to be in Jy/beam before adding to convolved model
        wsums = np.amax(psf.reshape(-1, R.nx_psf * R.ny_psf), axis=1)
        restored += residual / wsums[:, None, None]

        save_fits(args.outfile + '_restored.fits', restored, hdr)
        restored_mfs = np.mean(restored, axis=0)
        save_fits(args.outfile + '_restored_mfs.fits', restored_mfs, hdr_mfs)
        residual_mfs = np.sum(residual, axis=0)
コード例 #23
0
ファイル: psf.py プロジェクト: ratt-ru/pfb-clean
def _psf(**kw):
    args = OmegaConf.create(kw)
    from omegaconf import ListConfig
    if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig):
        args.ms = [args.ms]
    OmegaConf.set_struct(args, True)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    # from dask.distributed import performance_report
    from dask.graph_manipulation import clone
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    from daskms import Dataset
    from daskms.experimental.zarr import xds_to_zarr
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import dirty as vis2im
    from ducc0.fft import good_size
    from pfb.utils.misc import stitch_images, plan_row_chunk
    from pfb.utils.fits import set_wcs, save_fits

    # chan <-> band mapping
    ms = args.ms
    nband = args.nband
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # gridder memory budget
    max_chan_chunk = 0
    max_freq = 0
    for ims in args.ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            freq = freqs[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())
            max_freq = np.maximum(max_freq, freq.max())

    # assumes measurement sets have the same columns,
    # number of correlations etc.
    xds = xds_from_ms(args.ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    # we still have to cater for complex valued data because we cast
    # the weights to complex but we not longer need to factor the
    # weight column into our memory budget
    data_bytes = getattr(xds[0], args.data_column).data.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row

    # flags (uint8 or bool)
    memory_per_row += bytes_per_row / 8

    # UVW
    memory_per_row += xds[0].UVW.data.itemsize * 3

    # ANTENNA1/2
    memory_per_row += xds[0].ANTENNA1.data.itemsize * 2

    # TIME
    memory_per_row += xds[0].TIME.data.itemsize

    # data column is not actually read into memory just used to infer
    # dtype and chunking
    columns = (args.data_column, args.weight_column, args.flag_column, 'UVW',
               'ANTENNA1', 'ANTENNA2', 'TIME')

    # flag row
    if 'FLAG_ROW' in xds[0]:
        columns += ('FLAG_ROW', )
        memory_per_row += xds[0].FLAG_ROW.data.itemsize

    # imaging weights
    if args.imaging_weight_column is not None:
        columns += (args.imaging_weight_column, )
        memory_per_row += bytes_per_row / 2

    # Mueller term (complex valued)
    if args.mueller_column is not None:
        columns += (args.mueller_column, )
        memory_per_row += bytes_per_row

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    for ims in args.ms:
        xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1})

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

    uv_max = uv_max.compute()
    del uvw

    # image size
    cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)

    if args.cell_size is not None:
        cell_size = args.cell_size
        cell_rad = cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=log)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % cell_size, file=log)

    if args.nx is None:
        fov = args.field_of_view * 3600
        npix = int(args.psf_oversize * fov / cell_size)
        if npix % 2:
            npix += 1
        nx = npix
        ny = npix
    else:
        nx = args.nx
        ny = args.ny if args.ny is not None else nx

    print("PSF size set to (%i, %i, %i)" % (nband, nx, ny), file=log)

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size,
                                   nrow, memory_per_row,
                                   args.nthreads_per_worker)

    else:
        # single band per node
        row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow,
                                   memory_per_row, args.nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

    print(
        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
        file=log)

    chunks = {}
    for ims in args.ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
            chunks[ims].append({
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']
            })

    psfs = []
    radec = None  # assumes we are only imaging field 0 of first MS
    out_datasets = []
    for ims in args.ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns)

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            # this is not correct, need to use spw
            spw = ds.DATA_DESC_ID

            uvw = clone(ds.UVW.data)

            data_type = getattr(ds, args.data_column).data.dtype
            data_shape = getattr(ds, args.data_column).data.shape
            data_chunks = getattr(ds, args.data_column).data.chunks

            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],
                                          data_shape,
                                          chunks=data_chunks)

            if args.imaging_weight_column is not None:
                imaging_weights = getattr(ds, args.imaging_weight_column).data
                if len(imaging_weights.shape) < 3:
                    imaging_weights = da.broadcast_to(imaging_weights[:,
                                                                      None, :],
                                                      data_shape,
                                                      chunks=data_chunks)

                weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
            else:
                weightsxx = weights[:, :, 0]
                weightsyy = weights[:, :, -1]

            # apply mueller term
            if args.mueller_column is not None:
                mueller = getattr(ds, args.mueller_column).data
                weightsxx *= da.absolute(mueller[:, :, 0])**2
                weightsyy *= da.absolute(mueller[:, :, -1])**2

            # weighted sum corr to Stokes I
            weights = weightsxx + weightsyy

            # MS may contain auto-correlations
            if 'FLAG_ROW' in xds[0]:
                frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data
                                           == ds.ANTENNA2.data)
            else:
                frow = (ds.ANTENNA1.data == ds.ANTENNA2.data)

            # only keep data where both corrs are unflagged
            flag = getattr(ds, args.flag_column).data
            flagxx = flag[:, :, 0]
            flagyy = flag[:, :, -1]
            # ducc0 uses uint8 mask not flag
            mask = ~da.logical_or((flagxx | flagyy), frow[:, None])

            psf = vis2im(uvw,
                         freqs[ims][spw],
                         weights.astype(data_type),
                         freq_bin_idx[ims][spw],
                         freq_bin_counts[ims][spw],
                         nx,
                         ny,
                         cell_rad,
                         flag=mask.astype(np.uint8),
                         nthreads=args.nvthreads,
                         epsilon=args.epsilon,
                         do_wstacking=args.wstack,
                         double_accum=args.double_accum)

            psfs.append(psf)

            data_vars = {
                'FIELD_ID': (('row', ),
                             da.full_like(ds.TIME.data,
                                          ds.FIELD_ID,
                                          chunks=args.row_out_chunk)),
                'DATA_DESC_ID': (('row', ),
                                 da.full_like(ds.TIME.data,
                                              ds.DATA_DESC_ID,
                                              chunks=args.row_out_chunk)),
                'WEIGHT':
                (('row', 'chan'), weights.rechunk({0: args.row_out_chunk
                                                   })),  # why no 'f4'?
                'UVW': (('row', 'uvw'), uvw.rechunk({0: args.row_out_chunk}))
            }

            coords = {'chan': (('chan', ), freqs[ims][spw])}

            out_ds = Dataset(data_vars, coords)

            out_datasets.append(out_ds)

    writes = xds_to_zarr(out_datasets,
                         args.output_filename + '.zarr',
                         columns='ALL')

    # dask.visualize(writes, filename=args.output_filename + '_psf_writes_graph.pdf', optimize_graph=False)
    # dask.visualize(psfs, filename=args.output_filename + '_psf_graph.pdf', optimize_graph=False)

    if not args.mock:
        # psfs = dask.compute(psfs, writes, optimize_graph=False)[0]
        # with performance_report(filename=args.output_filename + '_psf_per.html'):
        psfs = dask.compute(psfs, writes, optimize_graph=False)[0]

        psf = stitch_images(psfs, nband, band_mapping)

        hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                      freq_out)
        save_fits(args.output_filename + '_psf.fits',
                  psf,
                  hdr,
                  dtype=args.output_type)

        psf_mfs = np.sum(psf, axis=0)
        wsum = psf_mfs.max()
        psf_mfs /= wsum

        hdr_mfs = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                          np.mean(freq_out))
        save_fits(args.output_filename + '_psf_mfs.fits',
                  psf_mfs,
                  hdr_mfs,
                  dtype=args.output_type)

    print("All done here.", file=log)
コード例 #24
0
def main(args):
    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    all_freqs = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                          columns=('UVW'),
                          chunks={'row': args.row_chunks})

        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        spws = dask.compute(spws)[0]

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

            spw = spws[ds.DATA_DESC_ID]
            tmp_freq = spw.CHAN_FREQ.data.squeeze()
            all_freqs.append(list([tmp_freq]))

    uv_max = u_max.compute()
    del uvw

    # get Nyquist cell size
    from africanus.constants import c as lightspeed
    all_freqs = dask.compute(all_freqs)
    freq = np.unique(all_freqs)
    cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed)

    if args.cell_size is not None:
        cell_rad = args.cell_size * np.pi / 60 / 60 / 180
        print("Super resolution factor = ", cell_N / cell_rad)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        args.cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % args.cell_size)

    if args.nx is None or args.ny is None:
        fov = args.fov * 3600
        npix = int(fov / args.cell_size)
        if npix % 2:
            npix += 1
        args.nx = npix
        args.ny = npix

    if args.nband is None:
        args.nband = freq.size

    print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny))

    # init gridder
    R = Gridder(args.ms,
                args.nx,
                args.ny,
                args.cell_size,
                nband=args.nband,
                nthreads=args.nthreads,
                do_wstacking=args.do_wstacking,
                row_chunks=args.row_chunks,
                data_column=args.data_column,
                weight_column=args.weight_column,
                epsilon=args.epsilon,
                imaging_weight_column=args.imaging_weight_column,
                model_column=args.model_column,
                flag_column=args.flag_column)
    freq_out = R.freq_out
    radec = R.radec

    # get headers
    hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                  args.ny, radec, freq_out)
    hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                      args.ny, radec, np.mean(freq_out))
    hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                      2 * args.nx, 2 * args.ny, radec, freq_out)
    hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                          2 * args.nx, 2 * args.ny, radec, np.mean(freq_out))

    # psf
    if args.psf is not None:
        try:
            compare_headers(hdr_psf, fits.getheader(args.psf))
            psf_array = load_fits(args.psf)
        except:
            psf_array = R.make_psf()
            save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf)
    else:
        psf_array = R.make_psf()
        save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf)

    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    wsum = np.sum(psf_max)
    counts = np.sum(psf_max > 0)
    psf_max_mean = wsum / counts  # normalissation for more intuitive sig_21 values
    psf_array /= psf_max_mean
    psf = PSF(psf_array, args.nthreads)
    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    wsum = np.sum(psf_max)
    psf_max[psf_max < 1e-15] = 1e-15  # LB - is this the right thing to do?

    psf_mfs = np.sum(psf_array, axis=0) / wsum
    save_fits(
        args.outfile + '_psf_mfs.fits', psf_mfs[args.nx // 2:3 * args.nx // 2,
                                                args.ny // 2:3 * args.ny // 2],
        hdr_mfs)

    # dirty
    if args.dirty is not None:
        try:
            compare_headers(hdr, fits.getheader(args.dirty))
            dirty = load_fits(args.dirty)
        except:
            dirty = R.make_dirty()
            save_fits(args.outfile + '_dirty.fits', dirty, hdr)
    else:
        dirty = R.make_dirty()
        save_fits(args.outfile + '_dirty.fits', dirty, hdr)

    dirty_mfs = np.sum(dirty / psf_max_mean, axis=0) / wsum
    save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs)

    if args.x0 is not None:
        try:
            compare_headers(hdr, fits.getheader(args.x0))
            model = load_fits(args.x0, dtype=np.float64)
            if args.first_residual is not None:
                try:
                    compare_headers(hdr, fits.getheader(args.first_residual))
                    residual = load_fits(args.first_residual, dtype=np.float64)
                except:
                    residual = R.make_residual(model)
                    save_fits(args.outfile + '_first_residual.fits', residual,
                              hdr)
            else:
                residual = R.make_residual(model)
                save_fits(args.outfile + '_first_residual.fits', residual, hdr)
        except:
            model = np.zeros((args.nband, args.nx, args.ny))
            residual = dirty.copy()
    else:
        model = np.zeros((args.nband, args.nx, args.ny))
        residual = dirty.copy()

    # normalise for more intuitive hypers
    residual /= psf_max_mean
    residual_mfs = np.sum(residual, axis=0) / wsum
    save_fits(args.outfile + '_first_residual_mfs.fits', residual_mfs, hdr_mfs)

    # mask
    if args.mask is not None:
        mask = load_fits(args.mask, dtype=np.int64)[None, :, :]
        if mask.shape != (1, args.nx, args.ny):
            raise ValueError("Mask has incorrect shape")
    else:
        mask = np.ones((1, args.nx, args.ny), dtype=np.int64)

    #  preconditioning matrix
    def hess(x):
        return mask * psf.convolve(mask * x) + x / args.sig_l2**2

    if args.beta is None:
        print("Getting spectral norm of update operator")
        beta = power_method(hess,
                            dirty.shape,
                            tol=args.pmtol,
                            maxit=args.pmmaxit)
    else:
        beta = args.beta
    print(" beta = %f " % beta)

    # set up wavelet basis
    if args.psi_basis is None:
        print("Using Dirac + db1-4 dictionary")
        psi = DaskPSI(args.nband,
                      args.nx,
                      args.ny,
                      nlevels=args.psi_levels,
                      nthreads=args.nthreads)
        # psi = PSI(args.nband, args.nx, args.ny, nlevels=args.psi_levels)
    else:
        if not isinstance(args.psi_basis, list):
            args.psi_basis = list(args.psi_basis)
        print("Using ", args.psi_basis, " dictionary")
        psi = DaskPSI(args.nband,
                      args.nx,
                      args.ny,
                      nlevels=args.psi_levels,
                      nthreads=args.nthreads,
                      bases=args.psi_basis)
        # psi = PSI(args.nband, args.nx, args.ny, nlevels=args.psi_levels, bases=args.psi_basis)
    nbasis = psi.nbasis
    weights_21 = np.ones((psi.nbasis, psi.nmax), dtype=np.float64)
    dual = np.zeros((psi.nbasis, args.nband, psi.nmax), dtype=np.float64)

    # Reweighting
    if args.reweight_iters is not None:
        if not isinstance(args.reweight_iters, list):
            reweight_iters = [args.reweight_iters]
        else:
            reweight_iters = list(args.reweight_iters)
    else:
        reweight_iters = list(
            np.arange(args.reweight_start, args.reweight_end,
                      args.reweight_freq))
        reweight_iters.append(args.reweight_end)

    # Reporting
    report_iters = list(np.arange(0, args.maxit, args.report_freq))
    if report_iters[-1] != args.maxit - 1:
        report_iters.append(args.maxit - 1)

    # deconvolve
    eps = 1.0
    i = 0
    rmax = np.abs(residual_mfs).max()
    rms = np.std(residual_mfs)
    M = lambda x: x * args.sig_l2**2  # preconditioner
    print("Peak of initial residual is %f and rms is %f" % (rmax, rms))
    for i in range(1, args.maxit):
        x = pcg(hess,
                mask * residual,
                np.zeros(dirty.shape, dtype=np.float64),
                M=M,
                tol=args.cgtol,
                maxit=args.cgmaxit,
                minit=args.cgminit,
                verbosity=args.cgverbose)

        if i in report_iters:
            save_fits(args.outfile + str(i) + '_update.fits', x, hdr)

        # update model
        modelp = model
        model = modelp + args.gamma * x
        model, dual = primal_dual(hess,
                                  model,
                                  modelp,
                                  dual,
                                  args.sig_21,
                                  psi,
                                  weights_21,
                                  beta,
                                  tol=args.pdtol,
                                  maxit=args.pdmaxit,
                                  report_freq=100,
                                  mask=mask,
                                  positivity=args.positivity)

        # reweighting
        if i in reweight_iters:
            v = psi.hdot(model)
            l2_norm = norm(v, axis=1)
            l2_norm = np.where(l2_norm < args.sig_21 * weights_21, 0.0,
                               l2_norm)
            for m in range(psi.nbasis):
                indnz = l2_norm[m].nonzero()
                alpha = np.percentile(l2_norm[m, indnz].flatten(),
                                      args.reweight_alpha_percent)
                alpha = np.maximum(alpha, args.reweight_alpha_min)
                print("Reweighting - ", m, alpha)
                weights_21[m] = alpha / (l2_norm[m] + alpha)
            args.reweight_alpha_percent *= args.reweight_alpha_ff
            # print(" reweight alpha percent = ", args.reweight_alpha_percent)

        # get residual
        residual = R.make_residual(model) / psf_max_mean

        # check stopping criteria
        residual_mfs = np.sum(residual, axis=0) / wsum
        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)
        eps = np.linalg.norm(model - modelp) / np.linalg.norm(model)

        if i in report_iters:
            # save current iteration
            save_fits(args.outfile + str(i) + '_model.fits', model, hdr)

            model_mfs = np.mean(model, axis=0)
            save_fits(args.outfile + str(i) + '_model_mfs.fits', model_mfs,
                      hdr_mfs)

            save_fits(args.outfile + str(i) + '_residual.fits', residual, hdr)

            save_fits(args.outfile + str(i) + '_residual_mfs.fits',
                      residual_mfs, hdr_mfs)

        print(
            "At iteration %i peak of residual is %f, rms is %f, current eps is %f"
            % (i, rmax, rms, eps))

    if args.write_model:
        R.write_model(model)

    if args.make_restored:
        x = pcg(hess,
                residual,
                np.zeros(dirty.shape, dtype=np.float64),
                M=M,
                tol=args.cgtol,
                maxit=args.cgmaxit)
        restored = model + x

        # get residual
        residual = R.make_residual(restored) / psf_max_mean
        residual_mfs = np.sum(residual, axis=0) / wsum
        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)

        print("After restoring peak of residual is %f and rms is %f" %
              (rmax, rms))

        # save current iteration
        save_fits(args.outfile + '_restored.fits', restored, hdr)

        restored_mfs = np.mean(restored, axis=0)
        save_fits(args.outfile + '_restored_mfs.fits', restored_mfs, hdr_mfs)

        save_fits(args.outfile + '_restored_residual.fits', residual, hdr)

        save_fits(args.outfile + '_restored_residual_mfs.fits', residual_mfs,
                  hdr_mfs)
コード例 #25
0
ファイル: ssclean.py プロジェクト: gijzelaerr/pfb-clean
def main(args):
    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    all_freqs = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                          columns=('UVW'),
                          chunks={'row': args.row_chunks})

        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        spws = dask.compute(spws)[0]

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

            spw = spws[ds.DATA_DESC_ID]
            tmp_freq = spw.CHAN_FREQ.data.squeeze()
            all_freqs.append(list(tmp_freq))

    uv_max = u_max.compute()
    del uvw

    # get Nyquist cell size
    from africanus.constants import c as lightspeed
    all_freqs = dask.compute(all_freqs)
    freq = np.unique(all_freqs)
    cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed)

    if args.cell_size is not None:
        cell_rad = args.cell_size * np.pi / 60 / 60 / 180
        print("Super resolution factor = ", cell_N / cell_rad)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        args.cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % args.cell_size)

    if args.nx is None or args.ny is None:
        fov = args.fov * 3600
        npix = int(fov / args.cell_size)
        if npix % 2:
            npix += 1
        args.nx = npix
        args.ny = npix

    if args.nband is None:
        args.nband = freq.size

    print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny))

    # init gridder
    R = Gridder(args.ms,
                args.nx,
                args.ny,
                args.cell_size,
                nband=args.nband,
                nthreads=args.nthreads,
                do_wstacking=args.do_wstacking,
                row_chunks=args.row_chunks,
                optimise_chunks=True,
                data_column=args.data_column,
                weight_column=args.weight_column,
                imaging_weight_column=args.imaging_weight_column,
                model_column=args.model_column,
                flag_column=args.flag_column)
    freq_out = R.freq_out
    radec = R.radec

    # get headers
    hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                  args.ny, radec, freq_out)
    hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                      args.ny, radec, np.mean(freq_out))
    hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                      2 * args.nx, 2 * args.ny, radec, freq_out)
    hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                          2 * args.nx, 2 * args.ny, radec, np.mean(freq_out))

    # psf
    if args.psf is not None:
        compare_headers(hdr_psf, fits.getheader(args.psf))
        psf_array = load_fits(args.psf)
    else:
        psf_array = R.make_psf()
        save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf)

    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    wsum = np.sum(psf_max)
    counts = np.sum(psf_max > 0)
    psf_max_mean = wsum / counts
    psf_array /= psf_max_mean
    psf = PSF(psf_array, args.nthreads)
    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    psf_max[psf_max < 1e-15] = 1e-15

    if args.dirty is not None:
        compare_headers(hdr, fits.getheader(args.dirty))
        dirty = load_fits(args.dirty)
    else:
        dirty = R.make_dirty()
        save_fits(args.outfile + '_dirty.fits', dirty, hdr)

    dirty /= psf_max_mean

    # mfs residual
    wsum = np.sum(psf_max)
    dirty_mfs = np.sum(dirty, axis=0) / wsum
    rmax = np.abs(dirty_mfs).max()
    rms = np.std(dirty_mfs)
    save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs)

    psf_mfs = np.sum(psf_array, axis=0) / wsum
    save_fits(
        args.outfile + '_psf_mfs.fits', psf_mfs[args.nx // 2:3 * args.nx // 2,
                                                args.ny // 2:3 * args.ny // 2],
        hdr_mfs)

    # mask
    if args.mask is not None:
        mask = load_fits(args.mask, dtype=np.int64)
        if mask.shape != (args.nx, args.ny):
            raise ValueError("Mask has incorrect shape")
    else:
        mask = np.ones((args.nx, args.ny), dtype=np.int64)

    if args.point_mask is not None:
        pmask = load_fits(args.point_mask, dtype=np.bool)
        if pmask.shape != (args.nx, args.ny):
            raise ValueError("Mask has incorrect shape")
    else:
        pmask = None

    # Reporting
    print("At iteration 0 peak of residual is %f and rms is %f" % (rmax, rms))
    report_iters = list(np.arange(0, args.maxit, args.report_freq))
    if report_iters[-1] != args.maxit - 1:
        report_iters.append(args.maxit - 1)

    # set up point sources
    phi = Dirac(args.nband, args.nx, args.ny, mask=pmask)
    dual = np.zeros((args.nband, args.nx, args.ny), dtype=np.float64)
    weights_21 = np.where(phi.mask, 1, np.inf)

    # preconditioning matrix
    def hess(beta):
        return phi.hdot(psf.convolve(
            phi.dot(beta))) + beta / args.sig_l2**2  # vague prior on beta

    # get new spectral norm
    L = power_method(hess, dirty.shape, tol=args.pmtol, maxit=args.pmmaxit)

    # deconvolve
    eps = 1.0
    i = 0
    residual = dirty.copy()
    model = np.zeros(dirty.shape, dtype=dirty.dtype)
    for i in range(1, args.maxit):
        # find point source candidates
        if args.do_clean:
            model_tmp = hogbom(mask[None] * residual / psf_max[:, None, None],
                               psf_array / psf_max[:, None, None],
                               gamma=args.cgamma,
                               pf=args.peak_factor)
            phi.update_locs(np.any(model_tmp, axis=0))
            # get new spectral norm
            L = power_method(hess,
                             model.shape,
                             tol=args.pmtol,
                             maxit=args.pmmaxit)
        else:
            model_tmp = np.zeros_like(residual, dtype=residual.dtype)

        # solve for beta updates
        x = pcg(hess,
                phi.hdot(residual),
                phi.hdot(model_tmp),
                M=lambda x: x * args.sig_l2**2,
                tol=args.cgtol,
                maxit=args.cgmaxit,
                verbosity=args.cgverbose)

        modelp = model.copy()
        model += args.gamma * x

        # impose sparsity and positivity in point sources
        weights_21 = np.where(phi.mask, 1, 1e10)  # 1e10 for effective infinity
        model, dual = primal_dual(hess,
                                  model,
                                  modelp,
                                  dual,
                                  args.sig_21,
                                  phi,
                                  weights_21,
                                  L,
                                  tol=args.pdtol,
                                  maxit=args.pdmaxit,
                                  axis=0,
                                  positivity=args.positivity,
                                  report_freq=100)

        # update Dirac dictionary (remove zero components)
        phi.trim_fat(model)

        # get residual
        residual = R.make_residual(model) / psf_max_mean

        # check stopping criteria
        residual_mfs = np.sum(residual, axis=0) / wsum
        rmax = np.abs(mask * residual_mfs).max()
        rms = np.std(mask * residual_mfs)
        eps = np.linalg.norm(model - modelp) / np.linalg.norm(model)

        if i in report_iters:
            # save current iteration
            save_fits(args.outfile + str(i) + '_model.fits', model, hdr)

            model_mfs = np.mean(model, axis=0)
            save_fits(args.outfile + str(i) + '_model_mfs.fits', model_mfs,
                      hdr_mfs)

            save_fits(args.outfile + str(i) + '_residual.fits',
                      residual / psf_max[:, None, None], hdr)

            save_fits(args.outfile + str(i) + '_residual_mfs.fits',
                      residual_mfs, hdr_mfs)

        print(
            "At iteration %i peak of residual is %f, rms is %f, current eps is %f"
            % (i, rmax, rms, eps))

        if eps < args.tol:
            print("We have convergence!")
            break

    # final iteration with only a positivity constraint on pixel locs
    tmp = phi.hdot(model)
    x = pcg(hess,
            phi.hdot(residual),
            np.zeros_like(tmp, dtype=tmp.dtype),
            M=lambda x: x * args.sig_l2**2,
            tol=args.cgtol,
            maxit=args.cgmaxit,
            verbosity=args.cgverbose)

    modelp = model.copy()
    model += args.gamma * x
    model, dual = primal_dual(hess,
                              model,
                              modelp,
                              dual,
                              0.0,
                              phi,
                              weights_21,
                              L,
                              tol=args.pdtol,
                              maxit=args.pdmaxit,
                              axis=0,
                              report_freq=100)

    # get residual
    residual = R.make_residual(model) / psf_max_mean

    # check stopping criteria
    residual_mfs = np.sum(residual, axis=0) / wsum
    rmax = np.abs(mask * residual_mfs).max()
    rms = np.std(mask * residual_mfs)
    print("At final iteration peak of residual is %f and rms is %f" %
          (rmax, rms))

    save_fits(args.outfile + '_model.fits', model, hdr)

    model_mfs = np.mean(model, axis=0)
    save_fits(args.outfile + '_model_mfs.fits', model_mfs, hdr_mfs)

    save_fits(args.outfile + '_residual.fits',
              residual / psf_max[:, None, None], hdr)

    save_fits(args.outfile + '_residual_mfs.fits', residual_mfs, hdr_mfs)

    if args.interp_model:
        nband = args.nband
        order = args.spectral_poly_order
        phi.trim_fat(model)
        I = np.argwhere(phi.mask).squeeze()
        Ix = I[:, 0]
        Iy = I[:, 1]
        npix = I.shape[0]

        # get components
        beta = model[:, Ix, Iy]

        # fit integrated polynomial to model components
        # we are given frequencies at bin centers, convert to bin edges
        ref_freq = np.mean(freq_out)
        delta_freq = freq_out[1] - freq_out[0]
        wlow = (freq_out - delta_freq / 2.0) / ref_freq
        whigh = (freq_out + delta_freq / 2.0) / ref_freq
        wdiff = whigh - wlow

        # set design matrix for each component
        Xdesign = np.zeros([freq_out.size, args.spectral_poly_order])
        for i in range(1, args.spectral_poly_order + 1):
            Xdesign[:, i - 1] = (whigh**i - wlow**i) / (i * wdiff)

        weights = psf_max[:, None]
        dirty_comps = Xdesign.T.dot(weights * beta)

        hess_comps = Xdesign.T.dot(weights * Xdesign)

        comps = np.linalg.solve(hess_comps, dirty_comps)

        np.savez(args.outfile + "spectral_comps",
                 comps=comps,
                 ref_freq=ref_freq,
                 mask=np.any(model, axis=0))

    if args.write_model:
        if args.interp_model:
            R.write_component_model(comps, ref_freq, phi.mask, args.row_chunks,
                                    args.chan_chunks)
        else:
            R.write_model(model)
コード例 #26
0
def get_plot_data(msinfo, group_cols, mytaql, chan_freqs,
                  chanslice, subset,
                  noflags, noconj,
                  iter_field, iter_spw, iter_scan, iter_ant, iter_baseline,
                  join_corrs=False,
                  row_chunk_size=100000):

    ms_cols = {'ANTENNA1', 'ANTENNA2'}
    ms_cols.update(msinfo.indexing_columns.keys())
    if not noflags:
        ms_cols.update({'FLAG', 'FLAG_ROW'})
    # get visibility columns
    for axis in DataAxis.all_axes.values():
        ms_cols.update(axis.columns)

    total_num_points = 0  # total number of points to plot

    # output dataframes, indexed by (field, spw, scan, antenna_or_baseline)
    # If any of these axes is not being iterated over, then the index at that position is None
    output_dataframes = OrderedDict()

    # number of rows per each dataframe
    output_rows = OrderedDict()

    # output subsets of indexing columns, indexed by same tuple
    output_subsets = OrderedDict()

    if iter_ant:
        antenna_subsets = zip(subset.ant.numbers, subset.ant.names)
    else:
        antenna_subsets = [(None, None)]
    taql = mytaql

    for antenna, antname in antenna_subsets:
        if antenna is not None:
            taql = f"({mytaql})&&(ANTENNA1=={antenna} || ANTENNA2=={antenna})" if mytaql else \
                    f"(ANTENNA1=={antenna} || ANTENNA2=={antenna})"
        # add baselines to group columns
        if iter_baseline:
            group_cols = list(group_cols) + ["ANTENNA1", "ANTENNA2"]

        # get MS data
        msdata = daskms.xds_from_ms(msinfo.msname, columns=list(ms_cols), group_cols=group_cols, taql_where=taql,
                                    chunks=dict(row=row_chunk_size))
        nrow = sum([len(group.row) for group in msdata])
        if not nrow:
            continue

        if antenna is not None:
            log.info(f': Indexing sub-MS (antenna {antname}) and building dataframes ({nrow} rows, chunk size is {row_chunk_size})')
        else:
            log.info(f': Indexing MS and building dataframes ({nrow} rows, chunk size is {row_chunk_size})')

        # iterate over groups
        for group in msdata:
            if not len(group.row):
                continue
            ddid     =  group.DATA_DESC_ID  # always present
            fld      =  group.FIELD_ID      # always present
            if fld not in subset.field or ddid not in subset.spw:
                log.debug(f"field {fld} ddid {ddid} not in selection, skipping")
                continue
            scan = getattr(group, 'SCAN_NUMBER', None)  # will be present if iterating over scans
            if iter_baseline:
                ant1    = getattr(group, 'ANTENNA1', None)   # will be present if iterating over baselines
                ant2    = getattr(group, 'ANTENNA2', None)   # will be present if iterating over baselines
                baseline = msinfo.baseline_number(ant1, ant2)
            else:
                baseline = None

            # Make frame key -- data subset corresponds to this frame
            dataframe_key = (fld if iter_field else None,
                             ddid if iter_spw else None,
                             scan if iter_scan else None,
                             antenna if antenna is not None else baseline)

            # update subsets of MS indexing columns that we've seen for this dataframe
            output_subset1 = output_subsets.setdefault(dataframe_key,
                                                {column:set() for column in msinfo.indexing_columns.keys()})
            for column, _ in msinfo.indexing_columns.items():
                value = getattr(group, column)
                if np.isscalar(value):
                    output_subset1[column].add(value)
                else:
                    output_subset1[column].update(value.compute().data)

            # number of rows in dataframe
            nrows0 = output_rows.setdefault(dataframe_key, 0)

            # always read flags -- easier that way
            flag = group.FLAG if not noflags else None
            flag_row = group.FLAG_ROW if not noflags else None

            a1 = da.minimum(group.ANTENNA1.data, group.ANTENNA2.data)
            a2 = da.maximum(group.ANTENNA1.data, group.ANTENNA2.data)
            baselines = msinfo.baseline_number(a1, a2)

            freqs = chan_freqs[ddid]
            chans = xarray.DataArray(range(len(freqs)), dims=("chan",))
            wavel = freq_to_wavel(freqs)
            extras = dict(chans=chans, freqs=freqs, wavel=wavel, rows=group.row, baselines=baselines)

            nchan = len(group.chan)
            if flag is not None:
                flag = flag[dict(chan=chanslice)]
                nchan = flag.shape[1]
            shape = (len(group.row), nchan)

            arrays = OrderedDict()
            shapes = OrderedDict()
            ddf = None
            num_points = 0  # counts number of new points generated


            for corr in subset.corr.numbers:
                # make dictionary of extra values for DataMappers
                extras['corr'] = corr
                # loop over datums to be computed
                for axis in DataAxis.all_axes.values():
                    value = arrays.get(axis.label)
                    # a datum was already computed?
                    if value is not None:
                        # if not joining correlations, then that's the only one we'll need, so continue
                        if not join_corrs:
                            continue
                        # joining correlations, and datum has a correlation dependence: compute another one
                        if axis.corr is None:
                            value = None
                    if value is None:
                        value = axis.get_value(group, corr, extras, flag=flag, flag_row=flag_row, chanslice=chanslice)
                        # print(axis.label, value.compute().min(), value.compute().max())
                        num_points = max(num_points, value.size)
                        if value.ndim == 0:
                            shapes[axis.label] = ()
                        elif value.ndim == 1:
                            timefreq_axis = axis.mapper.axis or 0
                            assert value.shape[0] == shape[timefreq_axis], \
                                   f"{axis.mapper.fullname}: size {value.shape[0]}, expected {shape[timefreq_axis]}"
                            shapes[axis.label] = ("row",) if timefreq_axis == 0 else ("chan",)
                        # else 2D value better match expected shape
                        else:
                            assert value.shape == shape, f"{axis.mapper.fullname}: shape {value.shape}, expected {shape}"
                            shapes[axis.label] = ("row", "chan")
                        arrays[axis.label] = value
                # any new data generated for this correlation? Make dataframe
                if num_points:
                    total_num_points += num_points
                    args = (v for pair in ((array, shapes[key]) for key, array in arrays.items()) for v in pair)
                    df1 = dataframe_factory(("row", "chan"), *args, columns=arrays.keys())
                    # if any axis needs to be conjugated, double up all of them
                    if not noconj and any([axis.conjugate for axis in DataAxis.all_axes.values()]):
                        arr_shape = [(-arrays[axis.label] if axis.conjugate else arrays[axis.label], shapes[axis.label])
                                                for axis in DataAxis.all_axes.values()]
                        args = (v for pair in arr_shape  for v in pair)
                        df2 = dataframe_factory(("row", "chan"), *args, columns=arrays.keys())
                        df1 = dask_df.concat([df1, df2], axis=0)
                    ddf = dask_df.concat([ddf, df1], axis=0) if ddf is not None else df1

            # do we already have a frame for this key
            ddf0 = output_dataframes.get(dataframe_key)

            if ddf0 is None:
                log.debug(f"first frame for {dataframe_key}")
                output_dataframes[dataframe_key] = ddf
            else:
                log.debug(f"appending to frame for {dataframe_key}")
                output_dataframes[dataframe_key] = dask_df.concat([ddf0, ddf], axis=0)

    # convert discrete axes into categoricals
    if data_mappers.USE_COUNT_CAT:
        categorical_axes = [axis.label for axis in DataAxis.all_axes.values() if axis.nlevels]
        if categorical_axes:
            log.info(": counting colours")
            for key, ddf in list(output_dataframes.items()):
                output_dataframes[key] = ddf.categorize(categorical_axes)

    # print("===")
    # for ddf in output_dataframes.values():
    #     for axis in DataAxis.all_axes.values():
    #         value = ddf[axis.label].values.compute()
    #         print(axis.label, np.nanmin(value), np.nanmax(value))

    log.info(": complete")
    return output_dataframes, output_subsets, total_num_points
コード例 #27
0
def main(args):
    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    all_freqs = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                          columns=('UVW'),
                          chunks={'row': args.row_chunks})

        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        spws = dask.compute(spws)[0]

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

            spw = spws[ds.DATA_DESC_ID]
            tmp_freq = spw.CHAN_FREQ.data.squeeze()
            all_freqs.append(list(tmp_freq))

    uv_max = u_max.compute()
    del uvw

    # get Nyquist cell size
    from africanus.constants import c as lightspeed
    all_freqs = dask.compute(all_freqs)
    freq = np.unique(all_freqs)
    cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed)

    if args.cell_size is not None:
        cell_rad = args.cell_size * np.pi / 60 / 60 / 180
        print("Super resolution factor = ", cell_N / cell_rad)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        args.cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % args.cell_size)

    if args.nx is None or args.ny is None:
        fov = args.fov * 3600
        npix = int(fov / args.cell_size)
        if npix % 2:
            npix += 1
        args.nx = npix
        args.ny = npix

    if args.nband is None:
        args.nband = freq.size

    print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny))

    # init gridder
    R = Gridder(args.ms,
                args.nx,
                args.ny,
                args.cell_size,
                nband=args.nband,
                nthreads=args.nthreads,
                do_wstacking=args.do_wstacking,
                row_chunks=args.row_chunks,
                data_column=args.data_column,
                weight_column=args.weight_column,
                epsilon=args.epsilon,
                imaging_weight_column=args.imaging_weight_column,
                model_column=args.model_column,
                flag_column=args.flag_column)
    freq_out = R.freq_out
    radec = R.radec

    # get headers
    hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                  args.ny, radec, freq_out)
    hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                      args.ny, radec, np.mean(freq_out))
    hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                      2 * args.nx, 2 * args.ny, radec, freq_out)
    hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                          2 * args.nx, 2 * args.ny, radec, np.mean(freq_out))

    # psf
    if args.psf is not None:
        try:
            compare_headers(hdr_psf, fits.getheader(args.psf))
            psf_array = load_fits(args.psf)
        except:
            psf_array = R.make_psf()
            save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf)
    else:
        psf_array = R.make_psf()
        save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf)

    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    wsum = np.sum(psf_max)
    counts = np.sum(psf_max > 0)
    psf_max_mean = wsum / counts  # normalissation for more intuitive sig_21 values
    psf_array /= psf_max_mean
    psf = PSF(psf_array, args.nthreads)
    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    wsum = np.sum(psf_max)
    psf_max[psf_max < 1e-15] = 1e-15  # LB - is this the right thing to do?

    psf_mfs = np.sum(psf_array, axis=0) / wsum
    save_fits(
        args.outfile + '_psf_mfs.fits', psf_mfs[args.nx // 2:3 * args.nx // 2,
                                                args.ny // 2:3 * args.ny // 2],
        hdr_mfs)

    # dirty
    if args.dirty is not None:
        try:
            compare_headers(hdr, fits.getheader(args.dirty))
            dirty = load_fits(args.dirty)
        except:
            dirty = R.make_dirty()
            save_fits(args.outfile + '_dirty.fits', dirty, hdr)
    else:
        dirty = R.make_dirty()
        save_fits(args.outfile + '_dirty.fits', dirty, hdr)

    dirty_mfs = np.sum(dirty / psf_max_mean, axis=0) / wsum
    save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs)

    residual = dirty.copy()

    model = np.zeros((2, args.nband, args.nx, args.ny))
    recompute_residual = False
    if args.beta0 is not None:
        compare_headers(hdr, fits.getheader(args.beta0))
        model[0] = load_fits(args.beta0).squeeze()
        recompute_residual = True

    if args.alpha0 is not None:
        compare_headers(hdr, fits.getheader(args.alpha0))
        model[1] = load_fits(args.alpha0).squeeze()
        recompute_residual = True

    # normalise for more intuitive hypers
    residual /= psf_max_mean
    residual_mfs = np.sum(residual, axis=0) / wsum
    save_fits(args.outfile + '_first_residual_mfs.fits', residual_mfs, hdr_mfs)

    # mask
    if args.mask is not None:
        mask = load_fits(args.mask, dtype=np.int64)[None, :, :]
        if mask.shape != (1, args.nx, args.ny):
            raise ValueError("Mask has incorrect shape")
    else:
        mask = np.ones((1, args.nx, args.ny), dtype=np.int64)

    # point mask
    pmask = load_fits(args.point_mask, dtype=np.bool)[None, :, :]
    if pmask.shape != (1, args.nx, args.ny):
        raise ValueError("Mask has incorrect shape")

    # set up splitting operator
    phi = lambda x: x[0] * pmask + x[1] * mask
    phih = lambda x: np.concatenate(
        ((pmask * x)[None], (mask * x)[None]), axis=0)

    if recompute_residual:
        image = phi(model)
        residual = R.make_residual(image) / psf_max_mean
        residual_mfs = np.sum(residual, axis=0) / wsum

    # Gaussian "prior" used for preconditioning extended emission
    A = Gauss(args.sig_l2a, args.nband, args.nx, args.ny, args.nthreads)

    #  preconditioning matrix
    def hess(x):
        return phih(psf.convolve(phi(x))) + np.concatenate(
            (x[0:1] / args.sig_l2b**2, A.idot(x[1])[None]), axis=0)
        # return  phih(psf.convolve(phi(x))) + np.concatenate((x[0:1]/args.sig_l2b**2, x[1::]/args.sig_l2a**2), axis=0)

    # M_func = lambda x: np.concatenate((x[0:1] * args.sig_l2b**2, x[1::] * args.sig_l2a**2), axis=0)
    M_func = lambda x: np.concatenate(
        (x[0:1] * args.sig_l2b**2, A.convolve(x[1])[None]), axis=0)

    par_shape = phih(dirty).shape
    if args.beta is None:
        print("Getting spectral norm of update operator")
        beta = power_method(hess,
                            par_shape,
                            tol=args.pmtol,
                            maxit=args.pmmaxit)
    else:
        beta = args.beta
    print(" beta = %f " % beta)

    # set up wavelet basis
    theta = DaskTheta(args.nband, args.nx, args.ny, nthreads=args.nthreads)
    nbasis = theta.nbasis
    weights_21 = np.ones((theta.nbasis + 1, theta.nmax), dtype=np.float64)
    tmp = np.pad(pmask.ravel(), (0, theta.nmax - args.nx * args.ny),
                 mode='constant')
    weights_21[0] = np.where(tmp, args.sig_21b / args.sig_21a, 1e15)
    dual = np.zeros((theta.nbasis + 1, args.nband, theta.nmax),
                    dtype=np.float64)

    # Reporting
    report_iters = list(np.arange(0, args.maxit, args.report_freq))
    if report_iters[-1] != args.maxit - 1:
        report_iters.append(args.maxit - 1)

    # deconvolve
    eps = 1.0
    i = 0
    rmax = np.abs(residual_mfs).max()
    rms = np.std(residual_mfs)
    print("Peak of initial residual is %f and rms is %f" % (rmax, rms))
    for i in range(1, args.maxit):
        x = pcg(hess,
                phih(residual),
                np.zeros(par_shape, dtype=np.float64),
                M=M_func,
                tol=args.cgtol,
                maxit=args.cgmaxit,
                verbosity=args.cgverbose)

        if i in report_iters:
            save_fits(args.outfile + str(i) + '_point_update.fits', x[0], hdr)
            save_fits(args.outfile + str(i) + '_fluff_update.fits', x[1], hdr)

        # update model
        modelp = model
        model = modelp + args.gamma * x
        model, dual = primal_dual(hess,
                                  model,
                                  modelp,
                                  dual,
                                  args.sig_21a,
                                  theta,
                                  weights_21,
                                  beta,
                                  tol=args.pdtol,
                                  maxit=args.pdmaxit,
                                  report_freq=100,
                                  mask=mask,
                                  positivity=args.positivity,
                                  gamma=args.gamma)

        # get residual
        image = phi(model)
        residual = R.make_residual(image) / psf_max_mean

        # check stopping criteria
        residual_mfs = np.sum(residual, axis=0) / wsum
        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)
        eps = np.linalg.norm(model - modelp) / np.linalg.norm(model)

        if i in report_iters:
            # save current iteration
            save_fits(args.outfile + str(i) + '_model.fits', image, hdr)

            save_fits(args.outfile + str(i) + '_point.fits', model[0], hdr)
            save_fits(args.outfile + str(i) + '_fluff.fits', model[1], hdr)

            model_mfs = np.mean(image, axis=0)
            save_fits(args.outfile + str(i) + '_model_mfs.fits', model_mfs,
                      hdr_mfs)

            save_fits(args.outfile + str(i) + '_residual.fits', residual, hdr)

            save_fits(args.outfile + str(i) + '_residual_mfs.fits',
                      residual_mfs, hdr_mfs)

        print(
            "At iteration %i peak of residual is %f, rms is %f, current eps is %f"
            % (i, rmax, rms, eps))

        if eps < args.tol:
            break

    if args.interp_model:
        nband = args.nband
        order = args.spectral_poly_order
        mask = np.where(model_mfs > 1e-10, 1, 0)
        I = np.argwhere(mask).squeeze()
        Ix = I[:, 0]
        Iy = I[:, 1]
        npix = I.shape[0]

        # get components
        beta = image[:, Ix, Iy]

        # fit integrated polynomial to model components
        # we are given frequencies at bin centers, convert to bin edges
        ref_freq = np.mean(freq_out)
        delta_freq = freq_out[1] - freq_out[0]
        wlow = (freq_out - delta_freq / 2.0) / ref_freq
        whigh = (freq_out + delta_freq / 2.0) / ref_freq
        wdiff = whigh - wlow

        # set design matrix for each component
        Xdesign = np.zeros([freq_out.size, args.spectral_poly_order])
        for i in range(1, args.spectral_poly_order + 1):
            Xdesign[:, i - 1] = (whigh**i - wlow**i) / (i * wdiff)

        weights = psf_max[:, None]
        dirty_comps = Xdesign.T.dot(weights * beta)

        hess_comps = Xdesign.T.dot(weights * Xdesign)

        comps = np.linalg.solve(hess_comps, dirty_comps)

        np.savez(args.outfile + "spectral_comps",
                 comps=comps,
                 ref_freq=ref_freq,
                 mask=np.any(model, axis=0))

    if args.write_model:
        if args.interp_model:
            R.write_component_model(comps, ref_freq, mask, args.row_chunks,
                                    args.chan_chunks)
        else:
            R.write_model(model)
コード例 #28
0
def _residual(ms, stack, **kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)
    pyscilog.log_to_file(args.output_filename + '.log')
    pyscilog.enable_memory_logging(level=3)

    # number of threads per worker
    if args.nthreads is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify nthreads when using a distributed scheduler"
            )
        import multiprocessing
        nthreads = multiprocessing.cpu_count()
        args.nthreads = nthreads
    else:
        nthreads = args.nthreads

    # configure memory limit
    if args.mem_limit is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify mem-limit when using a distributed scheduler"
            )
        import psutil
        mem_limit = int(psutil.virtual_memory()[0] /
                        1e9)  # 100% of memory by default
        args.mem_limit = mem_limit
    else:
        mem_limit = args.mem_limit

    nband = args.nband
    if args.nworkers is None:
        nworkers = nband
        args.nworkers = nworkers
    else:
        nworkers = args.nworkers

    if args.nthreads_per_worker is None:
        nthreads_per_worker = 1
        args.nthreads_per_worker = nthreads_per_worker
    else:
        nthreads_per_worker = args.nthreads_per_worker

    # the number of chunks being read in simultaneously is equal to
    # the number of dask threads
    nthreads_dask = nworkers * nthreads_per_worker

    if args.ngridder_threads is None:
        if args.host_address is not None:
            ngridder_threads = nthreads // nthreads_per_worker
        else:
            ngridder_threads = nthreads // nthreads_dask
        args.ngridder_threads = ngridder_threads
    else:
        ngridder_threads = args.ngridder_threads

    ms = list(ms)
    print('Input Options:', file=log)
    for key in kw.keys():
        print('     %25s = %s' % (key, args[key]), file=log)

    # numpy imports have to happen after this step
    from pfb import set_client
    set_client(nthreads, mem_limit, nworkers, nthreads_per_worker,
               args.host_address, stack, log)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    from dask.graph_manipulation import clone
    from dask.distributed import performance_report
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import residual as im2residim
    from ducc0.fft import good_size
    from pfb.utils.misc import stitch_images, plan_row_chunk
    from pfb.utils.fits import set_wcs, save_fits

    # chan <-> band mapping
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # gridder memory budget
    max_chan_chunk = 0
    max_freq = 0
    for ims in ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            freq = freqs[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())
            max_freq = np.maximum(max_freq, freq.max())

    # assumes measurement sets have the same columns,
    # number of correlations etc.
    xds = xds_from_ms(ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    data_bytes = getattr(xds[0], args.data_column).data.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row

    # real valued weights
    wdims = getattr(xds[0], args.weight_column).data.ndim
    if wdims == 2:  # WEIGHT
        memory_per_row += ncorr * data_bytes / 2
    else:  # WEIGHT_SPECTRUM
        memory_per_row += bytes_per_row / 2

    # flags (uint8 or bool)
    memory_per_row += np.dtype(np.uint8).itemsize * max_chan_chunk * ncorr

    # UVW
    memory_per_row += xds[0].UVW.data.itemsize * 3

    # ANTENNA1/2
    memory_per_row += xds[0].ANTENNA1.data.itemsize * 2

    columns = (args.data_column, args.weight_column, args.flag_column, 'UVW',
               'ANTENNA1', 'ANTENNA2')

    # flag row
    if 'FLAG_ROW' in xds[0]:
        columns += ('FLAG_ROW', )
        memory_per_row += xds[0].FLAG_ROW.data.itemsize

    # imaging weights
    if args.imaging_weight_column is not None:
        columns += (args.imaging_weight_column, )
        memory_per_row += bytes_per_row / 2

    # Mueller term (complex valued)
    if args.mueller_column is not None:
        columns += (args.mueller_column, )
        memory_per_row += bytes_per_row

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    for ims in ms:
        xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1})

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

    uv_max = uv_max.compute()
    del uvw

    # image size
    cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)

    if args.cell_size is not None:
        cell_size = args.cell_size
        cell_rad = cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=log)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % cell_size, file=log)

    if args.nx is None:
        fov = args.field_of_view * 3600
        npix = int(fov / cell_size)
        if npix % 2:
            npix += 1
        nx = good_size(npix)
        ny = good_size(npix)
    else:
        nx = args.nx
        ny = args.ny if args.ny is not None else nx

    print("Image size set to (%i, %i, %i)" % (nband, nx, ny), file=log)

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(mem_limit / nworkers, band_size, nrow,
                                   memory_per_row, nthreads_per_worker)

    else:
        # single band per node
        row_chunk = plan_row_chunk(mem_limit, band_size, nrow, memory_per_row,
                                   nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

    print(
        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
        file=log)

    chunks = {}
    for ims in ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
            chunks[ims].append({
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']
            })

    dirties = []
    radec = None  # assumes we are only imaging field 0 of first MS
    for ims in ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns)

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            # this is not correct, need to use spw
            spw = ds.DATA_DESC_ID

            uvw = clone(ds.UVW.data)

            data = getattr(ds, args.data_column).data
            dataxx = data[:, :, 0]
            datayy = data[:, :, -1]

            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],
                                          data.shape,
                                          chunks=data.chunks)

            if args.imaging_weight_column is not None:
                imaging_weights = getattr(ds, args.imaging_weight_column).data
                if len(imaging_weights.shape) < 3:
                    imaging_weights = da.broadcast_to(imaging_weights[:,
                                                                      None, :],
                                                      data.shape,
                                                      chunks=data.chunks)

                weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
            else:
                weightsxx = weights[:, :, 0]
                weightsyy = weights[:, :, -1]

            # apply adjoint of mueller term.
            # Phases modify data amplitudes modify weights.
            if args.mueller_column is not None:
                mueller = getattr(ds, args.mueller_column).data
                dataxx *= da.exp(-1j * da.angle(mueller[:, :, 0]))
                datayy *= da.exp(-1j * da.angle(mueller[:, :, -1]))
                weightsxx *= da.absolute(mueller[:, :, 0])
                weightsyy *= da.absolute(mueller[:, :, -1])

            # weighted sum corr to Stokes I
            weights = weightsxx + weightsyy
            data = (weightsxx * dataxx + weightsyy * datayy)
            # TODO - turn off this stupid warning
            data = da.where(weights, data / weights, 0.0j)

            # MS may contain auto-correlations
            if 'FLAG_ROW' in xds[0]:
                frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data
                                           == ds.ANTENNA2.data)
            else:
                frow = (ds.ANTENNA1.data == ds.ANTENNA2.data)

            # only keep data where both corrs are unflagged
            flag = getattr(ds, args.flag_column).data
            flagxx = flag[:, :, 0]
            flagyy = flag[:, :, -1]
            # ducc0 uses uint8 mask not flag
            mask = ~da.logical_or((flagxx | flagyy), frow[:, None])

            dirty = vis2im(uvw,
                           freqs[ims][spw],
                           data,
                           freq_bin_idx[ims][spw],
                           freq_bin_counts[ims][spw],
                           nx,
                           ny,
                           cell_rad,
                           weights=weights,
                           flag=mask.astype(np.uint8),
                           nthreads=ngridder_threads,
                           epsilon=args.epsilon,
                           do_wstacking=args.wstack,
                           double_accum=args.double_accum)

            dirties.append(dirty)

    # dask.visualize(dirties, filename=args.output_filename + '_graph.pdf', optimize_graph=False)

    if not args.mock:
        # result = dask.compute(dirties, wsum, optimize_graph=False)
        with performance_report(filename=args.output_filename + '_per.html'):
            result = dask.compute(dirties, optimize_graph=False)

        dirties = result[0]

        dirty = stitch_images(dirties, nband, band_mapping)

        hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                      freq_out)
        save_fits(args.output_filename + '_dirty.fits',
                  dirty,
                  hdr,
                  dtype=args.output_type)

    print("All done here.", file=log)