Ejemplo n.º 1
0
def test_einsum_optimize(optimize_opts):
    sig = 'ea,fb,abcd,gc,hd->efgh'
    input_sigs = sig.split('->')[0].split(',')
    np_inputs, da_inputs = _numpy_and_dask_inputs(input_sigs)

    opt1, opt2 = optimize_opts

    assert_eq(np.einsum(sig, *np_inputs, optimize=opt1),
              da.einsum(sig, *np_inputs, optimize=opt2))

    assert_eq(np.einsum(sig, *np_inputs, optimize=opt2),
              da.einsum(sig, *np_inputs, optimize=opt1))
Ejemplo n.º 2
0
def test_einsum_optimize(optimize_opts):
    sig = 'ea,fb,abcd,gc,hd->efgh'
    input_sigs = sig.split('->')[0].split(',')
    np_inputs, da_inputs = _numpy_and_dask_inputs(input_sigs)

    opt1, opt2 = optimize_opts

    assert_eq(np.einsum(sig, *np_inputs, optimize=opt1),
              da.einsum(sig, *np_inputs, optimize=opt2))

    assert_eq(np.einsum(sig, *np_inputs, optimize=opt2),
              da.einsum(sig, *np_inputs, optimize=opt1))
Ejemplo n.º 3
0
def test_einsum_casting(casting):
    sig = 'ea,fb,abcd,gc,hd->efgh'
    input_sigs = sig.split('->')[0].split(',')
    np_inputs, da_inputs = _numpy_and_dask_inputs(input_sigs)

    assert_eq(np.einsum(sig, *np_inputs, casting=casting),
              da.einsum(sig, *np_inputs, casting=casting))
Ejemplo n.º 4
0
def test_einsum_order(order):
    sig = 'ea,fb,abcd,gc,hd->efgh'
    input_sigs = sig.split('->')[0].split(',')
    np_inputs, da_inputs = _numpy_and_dask_inputs(input_sigs)

    assert_eq(np.einsum(sig, *np_inputs, order=order),
              da.einsum(sig, *np_inputs, order=order))
Ejemplo n.º 5
0
def _ndp_einsum(
    experimental: Union[da.Array, np.ndarray],
    simulated: Union[da.Array, np.ndarray],
) -> Union[np.ndarray, da.Array]:
    """Compute the normalized dot product between experimental and
    simulated patterns.

    Parameters
    ----------
    experimental
        Experimental patterns.
    simulated
        Simulated patterns.

    Returns
    -------
    ndp
        Normalized dot products in range [0, 1] for all comparisons,
        as :class:`np.ndarray` if both `experimental` and `simulated`
        are :class:`np.ndarray`, else :class:`da.Array`.
    """
    experimental, simulated = _normalize(experimental, simulated)
    ndp = da.einsum("ijkl,mkl->ijm", experimental, simulated, optimize=True)
    if isinstance(experimental, np.ndarray) and isinstance(
            simulated, np.ndarray):
        return ndp.compute()
    else:
        return ndp
Ejemplo n.º 6
0
def _zncc_einsum(
    experimental: Union[da.Array, np.ndarray],
    simulated: Union[da.Array, np.ndarray],
) -> Union[np.ndarray, da.Array]:
    """Compute (lazily) the zero-mean normalized cross-correlation
    coefficient between experimental and simulated patterns.

    Parameters
    ----------
    experimental
        Experimental patterns.
    simulated
        Simulated patterns.

    Returns
    -------
    zncc
        Correlation coefficients in range [-1, 1] for all comparisons,
        as :class:`np.ndarray` if both `experimental` and `simulated`
        are :class:`np.ndarray`, else :class:`da.Array`.

    Notes
    -----
    Equivalent results are obtained with :func:`dask.Array.tensordot`
    with the `axes` argument `axes=([2, 3], [1, 2]))`.
    """
    experimental, simulated = _zero_mean(experimental, simulated)
    experimental, simulated = _normalize(experimental, simulated)
    zncc = da.einsum("ijkl,mkl->ijm", experimental, simulated, optimize=True)
    if isinstance(experimental, np.ndarray) and isinstance(
            simulated, np.ndarray):
        return zncc.compute()
    else:
        return zncc
Ejemplo n.º 7
0
def test_einsum_order(order):
    sig = 'ea,fb,abcd,gc,hd->efgh'
    input_sigs = sig.split('->')[0].split(',')
    np_inputs, da_inputs = _numpy_and_dask_inputs(input_sigs)

    assert_eq(np.einsum(sig, *np_inputs, order=order),
              da.einsum(sig, *np_inputs, order=order))
Ejemplo n.º 8
0
def test_einsum_casting(casting):
    sig = 'ea,fb,abcd,gc,hd->efgh'
    input_sigs = sig.split('->')[0].split(',')
    np_inputs, da_inputs = _numpy_and_dask_inputs(input_sigs)

    assert_eq(np.einsum(sig, *np_inputs, casting=casting),
              da.einsum(sig, *np_inputs, casting=casting))
Ejemplo n.º 9
0
def dde_factory(args, ms, ant, field, pol, lm, utime, frequency):
    if args.beam is None:
        return None

    # Beam is requested
    corr_type = tuple(pol.CORR_TYPE.data[0])

    if not len(corr_type) == 4:
        raise ValueError("Need four correlations for DDEs")

    parangles = parallactic_angles(utime, ant.POSITION.data,
                                   field.PHASE_DIR.data[0][0])

    corr_type_set = set(corr_type)

    if corr_type_set.issubset(set([9, 10, 11, 12])):
        pol_type = 'linear'
    elif corr_type_set.issubset(set([5, 6, 7, 8])):
        pol_type = 'circular'
    else:
        raise ValueError("Cannot determine polarisation type "
                         "from correlations %s. Constructing "
                         "a feed rotation matrix will not be "
                         "possible." % (corr_type, ))

    # Construct feed rotation
    feed_rot = feed_rotation(parangles, pol_type)

    dtype = np.result_type(parangles, frequency)

    # Create zeroed pointing errors
    zpe = da.blockwise(_zero_pes, ("time", "ant", "chan", "comp"),
                       parangles, ("time", "ant"),
                       frequency, ("chan", ),
                       dtype,
                       None,
                       new_axes={"comp": 2},
                       dtype=dtype)

    # Created zeroed antenna scaling factors
    zas = da.blockwise(_unity_ant_scales, ("ant", "chan", "comp"),
                       parangles, ("time", "ant"),
                       frequency, ("chan", ),
                       dtype,
                       None,
                       new_axes={"comp": 2},
                       dtype=dtype)

    # Load the beam information
    beam, lm_ext, freq_map = load_beams(args.beam, corr_type, args.l_axis,
                                        args.m_axis)

    # Introduce the correlation axis
    beam = beam.reshape(beam.shape[:3] + (2, 2))

    beam_dde = beam_cube_dde(beam, lm_ext, freq_map, lm, parangles, zpe, zas,
                             frequency)

    # Multiply the beam by the feed rotation to form the DDE term
    return da.einsum("stafij,tajk->stafik", beam_dde, feed_rot)
Ejemplo n.º 10
0
def baseline_jones_multiply(corrs, *args):
    names = args[::2]
    arrays = args[1::2]

    input_einsum_schemas = []
    corr_index = 0

    for name, array in zip(names, arrays):
        try:
            # Obtain function for prescribing the input einsum schema
            schema_fn = _rime_term_map[name]
        except KeyError:
            raise ValueError("Unknown RIME term '%s'" % name)
        else:
            # Extract it and the next corr index
            einsum_schema, corr_index = schema_fn(corrs, corr_index)
            input_einsum_schemas.append(einsum_schema)

            if not len(einsum_schema) == array.ndim:
                raise ValueError(
                    "%s len(%s) == %d != %s.ndim" %
                    (name, einsum_schema, len(einsum_schema), array.shape))

    output_schema = _bl_jones_output_schema(corrs, corr_index)
    schema = ",".join(input_einsum_schemas) + output_schema

    return da.einsum(schema, *arrays)
Ejemplo n.º 11
0
    def trace(self, *axes):
        """
        Performs the trace over repeated axes contracting the outer-most index with the inner-most.
        
        Parameters
        ----------
        axes: str
            If given, only the listed axes are traced.
        """
        from .tunable import computable
        axes = list(axes)
        if not axes:
            axes = "all"
            
        axes = [axis for axis in set(self._expand(axes)) if self.axes_counts[axis] > 1]
            
        if len(axes) == 1:

            axis = axes[0]
            new_axes = list(self.axes)
            new_axes.remove(axis)
            new_axes.remove(axis)
            count = self.axes_counts[axis]

            @computable
            def indeces_order(indeces_order):
                indeces_order = list(indeces_order)
                indeces_order.remove(axis+"_0")
                indeces_order.remove(axis+"_%d" % (count-1))
                return indeces_order

            indeces_order = indeces_order(self.indeces_order)
            raw_fields, out = prepare(self, elemwise=False, field_type=new_axes,
                                      indeces_order=indeces_order)

            axis1 = self.indeces_order.index(axis+"_0")
            axis2 = self.indeces_order.index(axis+"_%d" % (count-1))

            from dask.array import trace
            out.field = computable(trace)(raw_fields[0], axis1=axis1, axis2=axis2)
            
            return out
        
        else:
            _i=0
            indeces = [{}, {}]
            for axis in set(self.axes):
                count = self.axes_count[axis]
                if axis in axes:
                    indeces[0][axis] = tuple(_i+i for i in range(count-1)) + (_i,)
                    _i+=count-1
                    if len(indeces[0][axis]) > 2:
                        indeces[-1][axis] = indeces[0][axis][1:-1]
                else:
                    indeces[0][axis] = tuple(_i+i for i in range(count))
                    _i+=count
                    indeces[-1][axis] = tuple(indeces[0][axis])
                    
            return einsum(self, indeces=indeces)
Ejemplo n.º 12
0
def test_einsum(einsum_signature):
    input_sigs = (einsum_signature.split('->')[0].replace("...",
                                                          "*").split(','))

    np_inputs, da_inputs = _numpy_and_dask_inputs(input_sigs)

    assert_eq(np.einsum(einsum_signature, *np_inputs),
              da.einsum(einsum_signature, *da_inputs))
Ejemplo n.º 13
0
def test_einsum(einsum_signature):
    input_sigs = (einsum_signature.split('->')[0]
                                  .replace("...", "*")
                                  .split(','))

    np_inputs, da_inputs = _numpy_and_dask_inputs(input_sigs)

    assert_eq(np.einsum(einsum_signature, *np_inputs),
              da.einsum(einsum_signature, *da_inputs))
Ejemplo n.º 14
0
def test_einsum_casting(casting):
    sig = "ea,fb,abcd,gc,hd->efgh"
    input_sigs = sig.split("->")[0].split(",")
    np_inputs, da_inputs = _numpy_and_dask_inputs(input_sigs)

    assert_eq(
        np.einsum(sig, *np_inputs, casting=casting),
        da.einsum(sig, *np_inputs, casting=casting),
    )
Ejemplo n.º 15
0
def test_einsum(einsum_signature):
    input_sigs = einsum_signature.split("->")[0].replace("...", "*").split(",")

    np_inputs, da_inputs = _numpy_and_dask_inputs(input_sigs)

    with pytest.warns(None):
        assert_eq(
            np.einsum(einsum_signature, *np_inputs),
            da.einsum(einsum_signature, *da_inputs),
        )
Ejemplo n.º 16
0
def _ndp_einsum(
    experimental: Union[da.Array, np.ndarray],
    simulated: Union[da.Array, np.ndarray],
) -> Union[np.ndarray, da.Array]:
    experimental, simulated = _normalize_expt_sim(experimental, simulated)
    rho = da.einsum("ijkl,mkl->ijm", experimental, simulated, optimize=True)
    if isinstance(experimental, np.ndarray) and isinstance(
            simulated, np.ndarray):
        return rho.compute()
    else:
        return rho
 def match(
     self,
     experimental: Union[np.ndarray, da.Array],
     dictionary: Union[np.ndarray, da.Array],
 ) -> da.Array:
     return da.einsum(
         "ik,mk->im",
         experimental,
         dictionary,
         optimize=True,
         dtype=self.dtype,
     )
Ejemplo n.º 18
0
def test_einsum_broadcasting_contraction2():
    a = np.random.rand(1, 1, 5, 4)
    b = np.random.rand(4, 6)
    c = np.random.rand(5, 6)
    d = np.random.rand(7, 7)

    d_a = da.from_array(a, chunks=(1, 1, (2, 3), (2, 2)))
    d_b = da.from_array(b, chunks=((2, 2), (4, 2)))
    d_c = da.from_array(c, chunks=((2, 3), (4, 2)))
    d_d = da.from_array(d, chunks=((7, 3)))

    np_res = np.einsum('abjk,kl,jl', a, b, c)
    da_res = da.einsum('abjk,kl,jl', d_a, d_b, d_c)
    assert_eq(np_res, da_res)

    mul_res = da_res * d

    np_res = np.einsum('abjk,kl,jl,ab->ab', a, b, c, d)
    da_res = da.einsum('abjk,kl,jl,ab->ab', d_a, d_b, d_c, d_d)
    assert_eq(np_res, da_res)
    assert_eq(np_res, mul_res)
Ejemplo n.º 19
0
def test_einsum_broadcasting_contraction2():
    a = np.random.rand(1, 1, 5, 4)
    b = np.random.rand(4, 6)
    c = np.random.rand(5, 6)
    d = np.random.rand(7, 7)

    d_a = da.from_array(a, chunks=(1, 1, (2, 3), (2, 2)))
    d_b = da.from_array(b, chunks=((2, 2), (4, 2)))
    d_c = da.from_array(c, chunks=((2, 3), (4, 2)))
    d_d = da.from_array(d, chunks=((7, 3)))

    np_res = np.einsum('abjk,kl,jl', a, b, c)
    da_res = da.einsum('abjk,kl,jl', d_a, d_b, d_c)
    assert_eq(np_res, da_res)

    mul_res = da_res * d

    np_res = np.einsum('abjk,kl,jl,ab->ab', a, b, c, d)
    da_res = da.einsum('abjk,kl,jl,ab->ab', d_a, d_b, d_c, d_d)
    assert_eq(np_res, da_res)
    assert_eq(np_res, mul_res)
Ejemplo n.º 20
0
def test_einsum_broadcasting_contraction3():
    a = np.random.rand(1, 5, 4)
    b = np.random.rand(4, 1, 6)
    c = np.random.rand(5, 6)
    d = np.random.rand(7, 7)

    d_a = da.from_array(a, chunks=(1, (2, 3), (2, 2)))
    d_b = da.from_array(b, chunks=((2, 2), 1, (4, 2)))
    d_c = da.from_array(c, chunks=((2, 3), (4, 2)))
    d_d = da.from_array(d, chunks=((7, 3)))

    np_res = np.einsum("ajk,kbl,jl,ab->ab", a, b, c, d)
    da_res = da.einsum("ajk,kbl,jl,ab->ab", d_a, d_b, d_c, d_d)
    assert_eq(np_res, da_res)
Ejemplo n.º 21
0
    def fit(self, error):

        self.time_step_ = float(error.time[1] - error.time[0])

        X = da.concatenate([error.QT.data, error.SLI.data], axis=1)

        # compute covariance
        nz = X.shape[0]
        nx = X.shape[-1]
        n = nx * nz
        C = da.einsum('tzyx,tfyx->yzf', X, X) / n
        C = C.compute()

        # shape is
        # (y, feat, feat)
        self.Q_ = cholesky_factor(C) * np.sqrt(self.time_step_)
        return self
Ejemplo n.º 22
0
def affine_to_grid_dask(matrix, grid, displacement=False):
    """
    """

    ndims = len(matrix.shape)
    matrix = matrix.astype(np.float32).squeeze()
    lost_dims = ndims - len(matrix.shape)

    mm = matrix[:3, :-1]
    tt = matrix[:3, -1]
    result = da.einsum('...ij,...j->...i', mm, grid) + tt

    if displacement:
        result = result - grid

    if lost_dims > 0:
        result = result.reshape((1, ) * lost_dims + result.shape)
    return result
Ejemplo n.º 23
0
def test_einsum_invalid_args():
    _, da_inputs = _numpy_and_dask_inputs('a')
    with pytest.raises(TypeError):
        da.einsum('a', *da_inputs, foo=1, bar=2)
Ejemplo n.º 24
0
def test_einsum_split_every(split_every):
    np_inputs, da_inputs = _numpy_and_dask_inputs('a')
    assert_eq(np.einsum('a', *np_inputs),
              da.einsum('a', *da_inputs, split_every=split_every))
Ejemplo n.º 25
0
def _predict(args):
    # get inclusion regions
    include_regions = load_regions(args.within) if args.within else []

    # Import source data from WSClean component list
    # See https://sourceforge.net/p/wsclean/wiki/ComponentList
    (comp_type, radec, stokes, spec_coeff, ref_freq, log_spec_ind,
     gaussian_shape) = import_from_wsclean(args.sky_model,
                                           include_regions=include_regions,
                                           point_only=args.points_only,
                                           num=args.num_sources or None)

    # Add output column if it isn't present
    ms_rows, ms_datatype = ms_preprocess(args)

    # Get the support tables
    tables = support_tables(
        args, ["FIELD", "DATA_DESCRIPTION", "SPECTRAL_WINDOW", "POLARIZATION"])

    field_ds = tables["FIELD"]
    ddid_ds = tables["DATA_DESCRIPTION"]
    spw_ds = tables["SPECTRAL_WINDOW"]
    pol_ds = tables["POLARIZATION"]

    max_num_chan = max([ss.NUM_CHAN.data[0] for ss in spw_ds])
    max_num_corr = max([ss.NUM_CORR.data[0] for ss in pol_ds])

    # Perform resource budgeting
    args.row_chunks, args.model_chunks = get_budget(comp_type.shape[0],
                                                    ms_rows, max_num_chan,
                                                    max_num_corr, ms_datatype,
                                                    args)

    radec = da.from_array(radec, chunks=(args.model_chunks, 2))
    stokes = da.from_array(stokes, chunks=(args.model_chunks, 4))

    if np.count_nonzero(comp_type == 'GAUSSIAN') > 0:
        gaussian_components = True
        gshape_chunks = (args.model_chunks, 3)
        gaussian_shape = da.from_array(gaussian_shape, chunks=gshape_chunks)
    else:
        gaussian_components = False

    if args.spectra:
        spec_chunks = (args.model_chunks, spec_coeff.shape[1])
        spec_coeff = da.from_array(spec_coeff, chunks=spec_chunks)
        ref_freq = da.from_array(ref_freq, chunks=(args.model_chunks, ))

    # List of write operations
    writes = []

    # Construct a graph for each FIELD and DATA DESCRIPTOR
    datasets = xds_from_ms(args.ms,
                           columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"],
                           group_cols=["FIELD_ID", "DATA_DESC_ID"],
                           chunks={"row": args.row_chunks})

    select_fields = valid_field_ids(field_ds, args.fields)

    for xds in filter_datasets(datasets, select_fields):
        # Extract frequencies from the spectral window associated
        # with this data descriptor id
        field = field_ds[xds.attrs['FIELD_ID']]
        ddid = ddid_ds[xds.attrs['DATA_DESC_ID']]
        spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]]
        pol = pol_ds[ddid.POLARIZATION_ID.data[0]]
        frequency = spw.CHAN_FREQ.data[0]

        corrs = pol.NUM_CORR.values

        lm = radec_to_lm(radec, field.PHASE_DIR.data[0][0])

        if args.exp_sign_convention == 'casa':
            uvw = -xds.UVW.data
        elif args.exp_sign_convention == 'thompson':
            uvw = xds.UVW.data
        else:
            raise ValueError("Invalid sign convention '%s'" % args.sign)

        if args.spectra:
            # flux density at reference frequency ...
            # ... for logarithmic polynomial functions
            if log_spec_ind:
                Is = da.log(stokes[:, 0, None]) * frequency[None, :]**0
            # ... or for ordinary polynomial functions
            else:
                Is = stokes[:, 0, None] * frequency[None, :]**0
            # additional terms of SED ...
            for jj in range(spec_coeff.shape[1]):
                # ... for logarithmic polynomial functions
                if log_spec_ind:
                    Is += spec_coeff[:, jj, None] * \
                        da.log((frequency[None, :]/ref_freq[:, None])**(jj+1))
                # ... or for ordinary polynomial functions
                else:
                    Is += spec_coeff[:, jj, None] * \
                        (frequency[None, :]/ref_freq[:, None]-1)**(jj+1)
            if log_spec_ind:
                Is = da.exp(Is)
            Qs = da.zeros_like(Is)
            Us = da.zeros_like(Is)
            Vs = da.zeros_like(Is)
            # stack along new axis and make it the last axis of the new array
            spectrum = da.stack([Is, Qs, Us, Vs], axis=-1)
            spectrum = spectrum.rechunk(spectrum.chunks[:2] +
                                        (spectrum.shape[2], ))

        print('-------------------------------------------')
        print('Nr sources        = {0:d}'.format(stokes.shape[0]))
        print('-------------------------------------------')
        print('stokes.shape      = {0:}'.format(stokes.shape))
        print('frequency.shape   = {0:}'.format(frequency.shape))
        if args.spectra:
            print('Is.shape          = {0:}'.format(Is.shape))
        if args.spectra:
            print('spectrum.shape    = {0:}'.format(spectrum.shape))

        # (source, row, frequency)
        phase = phase_delay(lm, uvw, frequency)
        # If at least one Gaussian component is present in the component
        # list then all sources are modelled as Gaussian components
        # (Delta components have zero width)
        if gaussian_components:
            phase *= gaussian(uvw, frequency, gaussian_shape)
        # (source, frequency, corr_products)
        brightness = convert(spectrum if args.spectra else stokes,
                             ["I", "Q", "U", "V"], corr_schema(pol))

        print('brightness.shape  = {0:}'.format(brightness.shape))
        print('phase.shape       = {0:}'.format(phase.shape))
        print('-------------------------------------------')
        print('Attempting phase-brightness einsum with "{0:s}"'.format(
            einsum_schema(pol, args.spectra)))

        # (source, row, frequency, corr_products)
        jones = da.einsum(einsum_schema(pol, args.spectra), phase, brightness)
        print('jones.shape       = {0:}'.format(jones.shape))
        print('-------------------------------------------')
        if gaussian_components:
            print('Some Gaussian sources found')
        else:
            print('All sources are Delta functions')
        print('-------------------------------------------')

        # Identify time indices
        _, time_index = da.unique(xds.TIME.data, return_inverse=True)

        # Predict visibilities
        vis = predict_vis(time_index, xds.ANTENNA1.data, xds.ANTENNA2.data,
                          None, jones, None, None, None, None)

        # Reshape (2, 2) correlation to shape (4,)
        if corrs == 4:
            vis = vis.reshape(vis.shape[:2] + (4, ))

        # Assign visibilities to MODEL_DATA array on the dataset
        xds = xds.assign(
            **{args.output_column: (("row", "chan", "corr"), vis)})
        # Create a write to the table
        write = xds_to_table(xds, args.ms, [args.output_column])
        # Add to the list of writes
        writes.append(write)

    with ExitStack() as stack:
        if sys.stdout.isatty():
            # Default progress bar in user terminal
            stack.enter_context(ProgressBar())
        else:
            # Log progress every 5 minutes
            stack.enter_context(ProgressBar(minimum=2 * 60, dt=5))

        # Submit all graph computations in parallel
        dask.compute(writes)
Ejemplo n.º 26
0
def box_vectors_to_lengths_and_angles(a, b, c):
    """Convert box vectors into the lengths and angles defining the box.

    Addapted from mdtraj.utils.unitcell.box_vectors_to_lengths_and_angles()
    Parameters
    ----------
    a : np.ndarray
        the vector defining the first edge of the periodic box (length 3), or
        an array of this vector in multiple frames, where a[i,:] gives the
        length 3 array of vector a in each frame of a simulation
    b : np.ndarray
        the vector defining the second edge of the periodic box (length 3), or
        an array of this vector in multiple frames, where b[i,:] gives the
        length 3 array of vector a in each frame of a simulation
    c : np.ndarray
        the vector defining the third edge of the periodic box (length 3), or
        an array of this vector in multiple frames, where c[i,:] gives the
        length 3 array of vector a in each frame of a simulation

    Returns
    -------
    a_length : dask.array
        length of Bravais unit vector **a**
    b_length : dask.array
        length of Bravais unit vector **b**
    c_length : dask.array
        length of Bravais unit vector **c**
    alpha : dask.array
        angle between vectors **b** and **c**, in degrees.
    beta : dask.array
        angle between vectors **c** and **a**, in degrees.
    gamma : dask.array
        angle between vectors **a** and **b**, in degrees.
    """
    if not a.shape == b.shape == c.shape:
        raise TypeError("Shape is messed up.")
    if not a.shape[-1] == 3:
        raise TypeError("The last dimension must be length 3")
    if not (a.ndim in [1, 2]):
        raise ValueError(
            "vectors must be 1d or 2d (for a vectorized "
            "operation on multiple frames)"
        )
    last_dim = a.ndim - 1

    a_length = (da.sum(a * a, axis=last_dim)) ** (1 / 2)
    b_length = (da.sum(b * b, axis=last_dim)) ** (1 / 2)
    c_length = (da.sum(c * c, axis=last_dim)) ** (1 / 2)

    # we allow 2d input, where the first dimension is the frame index
    # so we want to do the dot product only over the last dimension
    alpha = da.arccos(da.einsum("...i, ...i", b, c) / (b_length * c_length))
    beta = da.arccos(da.einsum("...i, ...i", c, a) / (c_length * a_length))
    gamma = da.arccos(da.einsum("...i, ...i", a, b) / (a_length * b_length))

    # convert to degrees
    alpha = alpha * 180.0 / np.pi
    beta = beta * 180.0 / np.pi
    gamma = gamma * 180.0 / np.pi

    return a_length, b_length, c_length, alpha, beta, gamma
Ejemplo n.º 27
0
                 marks=pytest.mark.xfail(reason='cupy.dot(numpy) fails')),
    pytest.param(lambda x: da.tensordot(x, np.ones(x.shape[:2]), axes=[(0, 1), (0, 1)]),
                 marks=pytest.mark.xfail(reason='cupy.dot(numpy) fails')),
    lambda x: x.sum(axis=0),
    lambda x: x.max(axis=0),
    lambda x: x.sum(axis=(1, 2)),
    lambda x: x.astype(np.complex128),
    lambda x: x.map_blocks(lambda x: x * 2),
    pytest.param(lambda x: x.round(1),
                 marks=pytest.mark.xfail(reason="cupy doesn't support round")),
    lambda x: x.reshape((x.shape[0] * x.shape[1], x.shape[2])),
    lambda x: abs(x),
    lambda x: x > 0.5,
    lambda x: x.rechunk((4, 4, 4)),
    lambda x: x.rechunk((2, 2, 1)),
    pytest.param(lambda x: da.einsum("ijk,ijk", x, x),
                 marks=pytest.mark.xfail(
                     reason='depends on resolution of https://github.com/numpy/numpy/issues/12974')),
    lambda x: np.isreal(x),
    lambda x: np.iscomplex(x),
    lambda x: np.isneginf(x),
    lambda x: np.isposinf(x),
    lambda x: np.real(x),
    lambda x: np.imag(x),
    lambda x: np.fix(x),
    lambda x: np.i0(x.reshape((24,))),
    lambda x: np.sinc(x),
    lambda x: np.nan_to_num(x),
]

Ejemplo n.º 28
0
def test_einsum_split_every(split_every):
    np_inputs, da_inputs = _numpy_and_dask_inputs('a')
    assert_eq(np.einsum('a', *np_inputs),
              da.einsum('a', *da_inputs, split_every=split_every))
Ejemplo n.º 29
0
def test_einsum_invalid_args():
    _, da_inputs = _numpy_and_dask_inputs('a')
    with pytest.raises(TypeError):
        da.einsum('a', *da_inputs, foo=1, bar=2)
Ejemplo n.º 30
0
def _distance(Z, Y, epsilon):
    """ Distance function """
    Y = Y + epsilon
    return Y.sum(axis=(1, 2)) - da.einsum('ijk,ljk->il', Z, da.log(Y))
Ejemplo n.º 31
0
def predict(args):
    # get inclusion regions
    include_regions = []
    exclude_regions = []
    if args.within:
        from regions import read_ds9
        import tempfile
        # kludge because regions cries over "FK5", wants lowercase
        with tempfile.NamedTemporaryFile(mode="w") as tmpfile, open(
                args.within) as regfile:
            tmpfile.write(regfile.read().lower())
            tmpfile.flush()
            include_regions = read_ds9(tmpfile.name)
            log.info("read {} inclusion region(s) from {}".format(
                len(include_regions), args.within))

    # Import source data from WSClean component list
    # See https://sourceforge.net/p/wsclean/wiki/ComponentList
    (comp_type, radec, stokes, spec_coeff, ref_freq, log_spec_ind,
     gaussian_shape) = import_from_wsclean(args.sky_model,
                                           include_regions=include_regions,
                                           exclude_regions=exclude_regions,
                                           point_only=args.points_only,
                                           num=args.num_sources or None)

    # Get the support tables
    tables = support_tables(
        args, ["FIELD", "DATA_DESCRIPTION", "SPECTRAL_WINDOW", "POLARIZATION"])

    field_ds = tables["FIELD"]
    ddid_ds = tables["DATA_DESCRIPTION"]
    spw_ds = tables["SPECTRAL_WINDOW"]
    pol_ds = tables["POLARIZATION"]
    frequencies = np.sort(
        [spw_ds[dd].CHAN_FREQ.data.flatten() for dd in range(len(spw_ds))])

    # cluster sources and refit. This only works for delta scale sources
    def __cluster(comp_type, radec, stokes, spec_coeff, ref_freq, log_spec_ind,
                  gaussian_shape, frequencies):
        uniq_radec = np.unique(radec)
        ncomp_type = []
        nradec = []
        nstokes = []
        nspec_coef = []
        nref_freq = []
        nlog_spec_ind = []
        ngaussian_shape = []

        for urd in uniq_radec:
            print comp_type.shape
            print radec.shape
            deltasel = comp_type[radec == urd] == "POINT"
            polyspecsel = np.logical_not(spec_coef[radec == urd])
            sel = deltasel & polyspecsel
            Is = stokes[sel, 0, None] * frequency[None, :]**0
            for jj in range(spec_coeff.shape[1]):
                Is += spec_coeff[sel, jj, None] * (
                    frequency[None, :] / ref_freq[sel, None] - 1)**(jj + 1)
            Is = np.sum(
                Is, axis=0)  # collapse over all the sources at this position
            logpolyspecsel = np.logical_not(log_spec_coef[radec == urd])
            sel = deltasel & logpolyspecsel

            Is = np.log(stokes[sel, 0, None] * frequency[None, :]**0)
            for jj in range(spec_coeff.shape[1]):
                Is += spec_coeff[sel, jj, None] * da.log(
                    (frequency[None, :] / ref_freq[sel, None])**(jj + 1))
            Is = np.exp(Is)
            Islogpoly = np.sum(
                Is, axis=0)  # collapse over all the sources at this position

            popt, pfitvar = curve_fit(
                lambda i, a, b, c, d: i + a *
                (frequency / ref_freq[0, None] - 1) + b *
                (frequency / ref_freq[0, None] - 1)**2 + c *
                (frequency / ref_freq[sel, None] - 1)**3 + d *
                (frequency / ref_freq[0, None] - 1)**3, frequency,
                Ispoly + Islogpoly)
            if not np.all(np.isfinite(pfitvar)):
                popt[0] = np.sum(stokes[sel, 0, None], axis=0)
                popt[1:] = np.inf
                log.warn(
                    "Refitting at position {0:s} failed. Assuming flat spectrum source of {1:.2f} Jy"
                    .format(radec, popt[0]))
            else:
                pcov = np.sqrt(np.diag(pfitvar))
                log.info(
                    "New fitted flux {0:.3f} Jy at position {1:s} with covariance {2:s}"
                    .format(popt[0], radec,
                            ", ".join([str(poptp) for poptp in popt])))

            ncomp_type.append("POINT")
            nradec.append(urd)
            nstokes.append(popt[0])
            nspec_coef.append(popt[1:])
            nref_freq.append(ref_freq[0])
            nlog_spec_ind = 0.0

        # add back all the gaussians
        sel = comp_type[radec] == "GAUSSIAN"
        for rd, stks, spec, ref, lspec, gs in zip(radec[sel], stokes[sel],
                                                  spec_coef[sel],
                                                  ref_freq[sel],
                                                  log_spec_ind[sel],
                                                  gaussian_shape[sel]):
            ncomp_type.append("GAUSSIAN")
            nradec.append(rd)
            nstokes.append(stks)
            nspec_coef.append(spec)
            nref_freq.append(ref)
            nlog_spec_ind.append(lspec)
            ngaussian_shape.append(gs)

        log.info(
            "Reduced {0:d} components to {1:d} components through by refitting"
            .format(len(comp_type), len(ncomp_type)))
        return (np.array(ncomp_type), np.array(nradec), np.array(nstokes),
                np.array(nspec_coeff), np.array(nref_freq),
                np.array(nlog_spec_ind), np.array(ngaussian_shape))

    if not args.dontcluster:
        (comp_type, radec, stokes, spec_coeff, ref_freq, log_spec_ind,
         gaussian_shape) = __cluster(comp_type, radec, stokes, spec_coeff,
                                     ref_freq, log_spec_ind, gaussian_shape,
                                     frequencies)

    # Add output column if it isn't present
    ms_rows, ms_datatype = ms_preprocess(args)

    # sort out resources
    args.row_chunks, args.model_chunks = get_budget(
        comp_type.shape[0], ms_rows, max([ss.NUM_CHAN.data for ss in spw_ds]),
        max([ss.NUM_CORR.data for ss in pol_ds]), ms_datatype, args)

    radec = da.from_array(radec, chunks=(args.model_chunks, 2))
    stokes = da.from_array(stokes, chunks=(args.model_chunks, 4))

    if np.count_nonzero(comp_type == 'GAUSSIAN') > 0:
        gaussian_components = True
        gshape_chunks = (args.model_chunks, 3)
        gaussian_shape = da.from_array(gaussian_shape, chunks=gshape_chunks)
    else:
        gaussian_components = False

    if args.spectra:
        spec_chunks = (args.model_chunks, spec_coeff.shape[1])
        spec_coeff = da.from_array(spec_coeff, chunks=spec_chunks)
        ref_freq = da.from_array(ref_freq, chunks=(args.model_chunks, ))

    # List of write operations
    writes = []

    # Construct a graph for each DATA_DESC_ID
    for xds in xds_from_ms(args.ms,
                           columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"],
                           group_cols=["FIELD_ID", "DATA_DESC_ID"],
                           chunks={"row": args.row_chunks}):
        if xds.attrs['FIELD_ID'] != args.fieldid:
            continue

        # Extract frequencies from the spectral window associated
        # with this data descriptor id
        field = field_ds[xds.attrs['FIELD_ID']]
        ddid = ddid_ds[xds.attrs['DATA_DESC_ID']]
        spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.values]
        pol = pol_ds[ddid.POLARIZATION_ID.values]
        frequency = spw.CHAN_FREQ.data

        corrs = pol.NUM_CORR.values

        lm = radec_to_lm(radec, field.PHASE_DIR.data)

        if args.exp_sign_convention == 'casa':
            uvw = -xds.UVW.data
        elif args.exp_sign_convention == 'thompson':
            uvw = xds.UVW.data
        else:
            raise ValueError("Invalid sign convention '%s'" % args.sign)

        if args.spectra:
            # flux density at reference frequency ...
            # ... for logarithmic polynomial functions
            if log_spec_ind:
                Is = da.log(stokes[:, 0, None]) * frequency[None, :]**0
                # ... or for ordinary polynomial functions
            else:
                Is = stokes[:, 0, None] * frequency[None, :]**0
            # additional terms of SED ...
            for jj in range(spec_coeff.shape[1]):
                # ... for logarithmic polynomial functions
                if log_spec_ind:
                    Is += spec_coeff[:, jj, None] * da.log(
                        (frequency[None, :] / ref_freq[:, None])**(jj + 1))
                    # ... or for ordinary polynomial functions
                else:
                    Is += spec_coeff[:, jj, None] * (
                        frequency[None, :] / ref_freq[:, None] - 1)**(jj + 1)
            if log_spec_ind: Is = da.exp(Is)
            Qs = da.zeros_like(Is)
            Us = da.zeros_like(Is)
            Vs = da.zeros_like(Is)
            spectrum = da.stack(
                [Is, Qs, Us, Vs], axis=-1
            )  # stack along new axis and make it the last axis of the new array
            spectrum = spectrum.rechunk(spectrum.chunks[:2] +
                                        (spectrum.shape[2], ))

        log.info('-------------------------------------------')
        log.info('Nr sources        = {0:d}'.format(stokes.shape[0]))
        log.info('-------------------------------------------')
        log.info('stokes.shape      = {0:}'.format(stokes.shape))
        log.info('frequency.shape   = {0:}'.format(frequency.shape))
        if args.spectra: log.info('Is.shape          = {0:}'.format(Is.shape))
        if args.spectra:
            log.info('spectrum.shape    = {0:}'.format(spectrum.shape))

        # (source, row, frequency)
        phase = phase_delay(lm, uvw, frequency)
        # If at least one Gaussian component is present in the component list then all
        # sources are modelled as Gaussian components (Delta components have zero width)
        if gaussian_components:
            phase *= gaussian(uvw, frequency, gaussian_shape)
        # (source, frequency, corr_products)
        brightness = convert(spectrum if args.spectra else stokes,
                             ["I", "Q", "U", "V"], corr_schema(pol))

        log.info('brightness.shape  = {0:}'.format(brightness.shape))
        log.info('phase.shape       = {0:}'.format(phase.shape))
        log.info('-------------------------------------------')
        log.info('Attempting phase-brightness einsum with "{0:s}"'.format(
            einsum_schema(pol, args.spectra)))

        # (source, row, frequency, corr_products)
        jones = da.einsum(einsum_schema(pol, args.spectra), phase, brightness)
        log.info('jones.shape       = {0:}'.format(jones.shape))
        log.info('-------------------------------------------')
        if gaussian_components: log.info('Some Gaussian sources found')
        else: log.info('All sources are Delta functions')
        log.info('-------------------------------------------')

        # Identify time indices
        _, time_index = da.unique(xds.TIME.data, return_inverse=True)

        # Predict visibilities
        vis = predict_vis(time_index, xds.ANTENNA1.data, xds.ANTENNA2.data,
                          None, jones, None, None, None, None)

        # Reshape (2, 2) correlation to shape (4,)
        if corrs == 4:
            vis = vis.reshape(vis.shape[:2] + (4, ))

        # Assign visibilities to MODEL_DATA array on the dataset
        model_data = xr.DataArray(vis, dims=["row", "chan", "corr"])
        xds = xds.assign(**{args.output_column: model_data})
        # Create a write to the table
        write = xds_to_table(xds, args.ms, [args.output_column])
        # Add to the list of writes
        writes.append(write)

    # Submit all graph computations in parallel
    if args.num_workers:
        with ProgressBar(), dask.config.set(num_workers=args.num_workers):
            dask.compute(writes)
    else:
        with ProgressBar():
            dask.compute(writes)
Ejemplo n.º 32
0
def triclustering(Z,
                  nclusters_row,
                  nclusters_col,
                  nclusters_bnd,
                  errobj,
                  niters,
                  epsilon,
                  row_clusters_init=None,
                  col_clusters_init=None,
                  bnd_clusters_init=None):
    """
    Run the tri-clustering, Dask implementation

    :param Z: d x m x n data matrix
    :param nclusters_row: number of row clusters
    :param nclusters_col: number of column clusters
    :param nclusters_bnd: number of band clusters
    :param errobj: convergence threshold for the objective function
    :param niters: maximum number of iterations
    :param epsilon: numerical parameter, avoids zero arguments in log
    :param row_clusters_init: initial row cluster assignment
    :param col_clusters_init: initial column cluster assignment
    :param bnd_clusters_init: initial column cluster assignment
    :return: has converged, number of iterations performed. final row,
    column, and band clustering, error value
    """
    client = get_client()

    Z = da.array(Z) if not isinstance(Z, da.Array) else Z

    [d, m, n] = Z.shape
    bnd_chunks, row_chunks, col_chunks = Z.chunksize

    row_clusters = da.array(row_clusters_init) \
        if row_clusters_init is not None \
        else _initialize_clusters(m, nclusters_row, chunks=row_chunks)
    col_clusters = da.array(col_clusters_init) \
        if col_clusters_init is not None \
        else _initialize_clusters(n, nclusters_col, chunks=col_chunks)
    bnd_clusters = da.array(bnd_clusters_init) \
        if bnd_clusters_init is not None \
        else _initialize_clusters(d, nclusters_bnd, chunks=bnd_chunks)
    R = _setup_cluster_matrix(nclusters_row, row_clusters)
    C = _setup_cluster_matrix(nclusters_col, col_clusters)
    B = _setup_cluster_matrix(nclusters_bnd, bnd_clusters)

    e, old_e = 2 * errobj, 0
    s = 0
    converged = False

    Gavg = Z.mean()

    while (not converged) & (s < niters):
        logger.debug(f'Iteration # {s} ..')
        # Calculate number of elements in each tri-cluster
        nel_row_clusters = da.bincount(row_clusters, minlength=nclusters_row)
        nel_col_clusters = da.bincount(col_clusters, minlength=nclusters_col)
        nel_bnd_clusters = da.bincount(bnd_clusters, minlength=nclusters_bnd)
        logger.debug(
            'num of populated clusters: row {}, col {}, bnd {}'.format(
                da.sum(nel_row_clusters > 0).compute(),
                da.sum(nel_col_clusters > 0).compute(),
                da.sum(nel_bnd_clusters > 0).compute()))
        nel_clusters = da.einsum('i,j->ij', nel_row_clusters, nel_col_clusters)
        nel_clusters = da.einsum('i,jk->ijk', nel_bnd_clusters, nel_clusters)

        # calculate tri-cluster averages (epsilon takes care of empty clusters)
        # first sum values in each tri-cluster ..
        TriCavg = da.einsum('ij,ilm->jlm', B, Z)  # .. along band axis
        TriCavg = da.einsum('ij,kim->kjm', R, TriCavg)  # .. along row axis
        TriCavg = da.einsum('ij,kli->klj', C, TriCavg)  # .. along col axis
        # finally divide by number of elements in each tri-cluster
        TriCavg = (TriCavg + Gavg * epsilon) / (nel_clusters + epsilon)

        # unpack tri-cluster averages ..
        avg_unpck = da.einsum('ij,jkl->ikl', B, TriCavg)  # .. along band axis
        avg_unpck = da.einsum('ij,klj->kli', C, avg_unpck)  # .. along col axis
        # use these for the row cluster assignment
        idx = (1, 0, 2)
        d_row = _distance(Z.transpose(idx), avg_unpck.transpose(idx), epsilon)
        row_clusters = da.argmin(d_row, axis=1)
        R = _setup_cluster_matrix(nclusters_row, row_clusters)

        # unpack tri-cluster averages ..
        avg_unpck = da.einsum('ij,jkl->ikl', B, TriCavg)  # .. along band axis
        avg_unpck = da.einsum('ij,kjl->kil', R, avg_unpck)  # .. along row axis
        # use these for the col cluster assignment
        idx = (2, 0, 1)
        d_col = _distance(Z.transpose(idx), avg_unpck.transpose(idx), epsilon)
        col_clusters = da.argmin(d_col, axis=1)
        C = _setup_cluster_matrix(nclusters_col, col_clusters)

        # unpack tri-cluster averages ..
        avg_unpck = da.einsum('ij,kjl->kil', R, TriCavg)  # .. along row axis
        avg_unpck = da.einsum('ij,klj->kli', C, avg_unpck)  # .. along col axis
        # use these for the band cluster assignment
        d_bnd = _distance(Z, avg_unpck, epsilon)
        bnd_clusters = da.argmin(d_bnd, axis=1)
        B = _setup_cluster_matrix(nclusters_bnd, bnd_clusters)

        # Error value (actually just the band component really)
        old_e = e
        minvals = da.min(d_bnd, axis=1)
        # power 1 divergence, power 2 euclidean
        e = da.sum(da.power(minvals, 1))
        row_clusters, R, col_clusters, C, bnd_clusters, B, e = client.persist(
            [row_clusters, R, col_clusters, C, bnd_clusters, B, e])
        e = e.compute()
        logger.debug(f'Error = {e:+.15e}, dE = {e - old_e:+.15e}')
        converged = abs(e - old_e) < errobj
        s = s + 1
    if converged:
        logger.debug(f'Triclustering converged in {s} iterations')
    else:
        logger.debug(f'Triclustering not converged in {s} iterations')
    return converged, s, row_clusters, col_clusters, bnd_clusters, e
Ejemplo n.º 33
0
    lambda x: x.T, lambda x: da.transpose(x, (1, 2, 0)), lambda x: x.sum(),
    pytest.mark.xfail(lambda x: x.dot(np.arange(x.shape[-1])),
                      reason='cupy.dot(numpy) fails'),
    pytest.mark.xfail(lambda x: x.dot(np.eye(x.shape[-1])),
                      reason='cupy.dot(numpy) fails'),
    pytest.mark.xfail(
        lambda x: da.tensordot(x, np.ones(x.shape[:2]), axes=[(0, 1), (0, 1)]),
        reason='cupy.dot(numpy) fails'), lambda x: x.sum(axis=0),
    lambda x: x.max(axis=0), lambda x: x.sum(axis=(1, 2)),
    lambda x: x.astype(np.complex128), lambda x: x.map_blocks(lambda x: x * 2),
    pytest.mark.xfail(lambda x: x.round(1),
                      reason="cupy doesn't support round"),
    lambda x: x.reshape((x.shape[0] * x.shape[1], x.shape[2])),
    lambda x: abs(x), lambda x: x > 0.5, lambda x: x.rechunk(
        (4, 4, 4)), lambda x: x.rechunk(
            (2, 2, 1)), lambda x: da.einsum("ijk,ijk", x, x)
]


@pytest.mark.parametrize('func', functions)
def test_basic(func):
    c = cupy.random.random((2, 3, 4))
    n = c.get()
    dc = da.from_array(c, chunks=(1, 2, 2), asarray=False)
    dn = da.from_array(n, chunks=(1, 2, 2))

    ddc = func(dc)
    ddn = func(dn)

    assert_eq(ddc, ddn)
Ejemplo n.º 34
0
def identity_by_state(
    ds: Dataset,
    *,
    call_allele_frequency: Hashable = variables.call_allele_frequency,
    merge: bool = True,
) -> Dataset:
    """Compute identity by state (IBS) probabilities between
    all pairs of samples.

    The IBS probability between a pair of individuals is the
    probability that a randomly drawn allele from the first individual
    is identical in state with a randomly drawn allele from the second
    individual at a single random locus.

    Parameters
    ----------
    ds
        Dataset containing call genotype alleles.
    call_allele_frequency
        Input variable name holding call_allele_frequency as defined by
        :data:`sgkit.variables.call_allele_frequency_spec`.
        If the variable is not present in ``ds``, it will be computed
        using :func:`call_allele_frequencies`.
    merge
        If True (the default), merge the input dataset and the computed
        output variables into a single dataset, otherwise return only
        the computed output variables.
        See :ref:`dataset_merge` for more details.

    Returns
    -------
    A dataset containing :data:`sgkit.variables.stat_identity_by_state_spec`
    which is a matrix of pairwise IBS probabilities among all samples.
    The dimensions are named ``samples_0`` and ``samples_1``.

    Raises
    ------
    NotImplementedError
        If the variable holding call_allele_frequency is chunked along the
        samples dimension.

    Warnings
    --------
    This method does not currently support datasets that are chunked along the
    samples dimension.

    Examples
    --------

    >>> import sgkit as sg
    >>> ds = sg.simulate_genotype_call_dataset(n_variant=2, n_sample=3, seed=2)
    >>> sg.display_genotypes(ds) # doctest: +NORMALIZE_WHITESPACE
    samples    S0   S1   S2
    variants
    0         0/0  1/1  1/0
    1         1/1  1/1  1/0
    >>> sg.identity_by_state(ds)["stat_identity_by_state"].values # doctest: +NORMALIZE_WHITESPACE
    array([[1. , 0.5, 0.5],
           [0.5, 1. , 0.5],
           [0.5, 0.5, 0.5]])
    """
    ds = define_variable_if_absent(
        ds,
        variables.call_allele_frequency,
        call_allele_frequency,
        call_allele_frequencies,
    )
    variables.validate(
        ds, {call_allele_frequency: variables.call_allele_frequency_spec}
    )
    af = da.asarray(ds[call_allele_frequency])
    if len(af.chunks[1]) > 1:
        raise NotImplementedError(
            "identity_by_state does not support chunking in the samples dimension"
        )
    af0 = da.where(da.isnan(af), 0.0, af)
    num = da.einsum("ixj,iyj->xy", af0, af0)
    called = da.nansum(af, axis=-1)
    count = da.einsum("ix,iy->xy", called, called)
    denom = da.where(count == 0, np.nan, count)
    new_ds = create_dataset(
        {
            variables.stat_identity_by_state: (
                ("samples_0", "samples_1"),
                num / denom,
            )
        }
    )
    return conditional_merge_datasets(ds, new_ds, merge)
Ejemplo n.º 35
0
def dot(*fields, axes=None, close_indeces=None, open_indeces=None):
    """
    Performs the dot product between fields.
    
    Default behaviors:
    ------------------
    
    Contractions are performed between only degree of freedoms of the fields, e.g. field.dofs.
    For each field, indeces are always contracted in pairs combining the outer-most free index
    of the left with the inner-most of the right.
    
    I.e. dot(*fields) = dot(*fields, axes="dofs")

    Parameters:
    -----------
    fields: Field
        List of fields to perform dot product between.
    axes: str, list
        Axes where the contraction is performed on. 
        Indeces are contracted in pairs combining the outer-most free index
        of the left with the inner-most of the right.
    close_indeces: str, list
        Same as axes.
    open_indeces: str, list
        Opposite of close indeces, i.e. the indeces of these axes are left open.
    
    Examples:
    ---------
    dot(vector, vector, axes="color")
      [x,y,z,t,spin,color] x [x,y,z,t,spin,color] -> [x,y,z,t,spin]
      [X,Y,Z,T, mu , c_0 ] x [X,Y,Z,T, mu , c_0 ] -> [X,Y,Z,T, mu ]

    dot(vector, vector, close_indeces="color", open_indece="spin")
      [x,y,z,t,spin,color] x [x,y,z,t,spin,color] -> [x,y,z,t,spin,spin]
      [X,Y,Z,T, mu , c_0 ] x [X,Y,Z,T, nu , c_0 ] -> [X,Y,Z,T, mu , nu ]    
    """
    from builtins import all
    from .field import Field
    assert not (axes is not None and close_indeces is not None), """
    Only one between axes or close_indeces can be used. They are the same parameter."""
    assert all((isinstance(field, Field) for field in fields)), "All fields must be of type field."

    close_indeces = axes if close_indeces is None else close_indeces
    
    if close_indeces is None and open_indeces is None:
        close_indeces = "dofs"

    same_indeces = set()
    for field in fields:
        same_indeces.update(field.axes)
    
    if close_indeces is not None:
        if isinstance(close_indeces, str):
            close_indeces = [close_indeces]
        tmp = set()
        for axis in close_indeces:
            for field in fields:
                tmp.update(field._expand(axis))
            
        close_indeces = tmp
        assert close_indeces.issubset(same_indeces), "Trivial assertion."
        same_indeces = same_indeces.difference(close_indeces)
    else:
         close_indeces = set()

    if open_indeces is not None:
        if isinstance(open_indeces, str):
            open_indeces = [open_indeces]
        tmp = set()
        for axis in open_indeces:
            for field in fields:
                tmp.update(field._expand(axis))
            
        open_indeces = tmp
        assert open_indeces.issubset(same_indeces), "Close and open indeces cannot have axes in common."
        same_indeces = same_indeces.difference(open_indeces)
    else:
         open_indeces = set()

    _i=0
    field_indeces = []
    new_field_indeces = {}
    for field in fields:
        field_indeces.append({})
        for key, count in field.axes_counts.items():
            
            if key in same_indeces:
                if key not in new_field_indeces:
                    new_field_indeces[key] = tuple(_i+i for i in range(count))
                    _i+=count
                else:
                    assert len(new_field_indeces[key]) == count, """
                    Axes %s has count %s while was found %s for other field(s).
                    Axes that are neither close or open, must have the same count between all fields.
                    """ % (key, count, new_field_indeces[key])
                field_indeces[-1][key] = tuple(new_field_indeces[key])
                
            elif key in open_indeces:
                field_indeces[-1][key] = tuple(_i+i for i in range(count))
                _i+=count
                if key not in new_field_indeces:
                    new_field_indeces[key] = field_indeces[-1][key]
                else:
                    new_field_indeces[key] += field_indeces[-1][key]
                
            else:
                assert key in close_indeces, "Trivial assertion."
                if key not in new_field_indeces:
                    new_field_indeces[key] = tuple(_i+i for i in range(count))
                    _i+=count
                    field_indeces[-1][key] = tuple(new_field_indeces[key])
                else:
                    assert len(new_field_indeces[key]) > 0, "Trivial assertion."
                    field_indeces[-1][key] = (new_field_indeces[key][-1],) + tuple(_i+i for i in range(count-1))
                    new_field_indeces[key] = new_field_indeces[key][:-1] + tuple(_i+i for i in range(count-1))
                    _i+=count-1
                    if len(new_field_indeces[key]) == 0:
                        del new_field_indeces[key]
                    
    field_indeces.append(new_field_indeces)

    return einsum(*fields, indeces=field_indeces)
Ejemplo n.º 36
0
def predict(args):
    # Numpy arrays

    # Convert source data into dask arrays
    radec, stokes = parse_sky_model(args.sky_model)
    radec = da.from_array(radec, chunks=(SOURCE_CHUNKS, 2))
    stokes = da.from_array(stokes, chunks=(SOURCE_CHUNKS, 4))

    # Get the support tables
    tables = support_tables(args, ["FIELD", "DATA_DESCRIPTION",
                                   "SPECTRAL_WINDOW", "POLARIZATION"])

    field_ds = tables["FIELD"]
    ddid_ds = tables["DATA_DESCRIPTION"]
    spw_ds = tables["SPECTRAL_WINDOW"]
    pol_ds = tables["POLARIZATION"]

    # List of write operations
    writes = []

    # Construct a graph for each DATA_DESC_ID
    for xds in xds_from_ms(args.ms,
                           columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"],
                           group_cols=["FIELD_ID", "DATA_DESC_ID"],
                           chunks={"row": args.row_chunks}):

        # Extract frequencies from the spectral window associated
        # with this data descriptor id
        field = field_ds[xds.attrs['FIELD_ID']]
        ddid = ddid_ds[xds.attrs['DATA_DESC_ID']]
        spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.values]
        pol = pol_ds[ddid.POLARIZATION_ID.values]
        frequency = spw.CHAN_FREQ.data

        corrs = pol.NUM_CORR.values

        lm = radec_to_lm(radec, field.PHASE_DIR.data)
        uvw = -xds.UVW.data if args.invert_uvw else xds.UVW.data

        # (source, row, frequency)
        phase = phase_delay(lm, uvw, frequency)

        brightness = convert(stokes, ["I", "Q", "U", "V"],
                             corr_schema(pol))

        # (source, row, frequency, corr1, corr2)
        jones = da.einsum(einsum_schema(pol), phase, brightness)

        # Identify time indices
        _, time_index = da.unique(xds.TIME.data, return_inverse=True)

        # Predict visibilities
        vis = predict_vis(time_index, xds.ANTENNA1.data, xds.ANTENNA2.data,
                          None, jones, None, None, None, None)

        # Reshape (2, 2) correlation to shape (4,)
        if corrs == 4:
            vis = vis.reshape(vis.shape[:2] + (4,))

        # Assign visibilities to MODEL_DATA array on the dataset
        model_data = xr.DataArray(vis, dims=["row", "chan", "corr"])
        xds = xds.assign(MODEL_DATA=model_data)
        # Create a write to the table
        write = xds_to_table(xds, args.ms, ['MODEL_DATA'])
        # Add to the list of writes
        writes.append(write)

    # Submit all graph computations in parallel
    with ProgressBar():
        dask.compute(writes)
Ejemplo n.º 37
0
 pytest.param(
     lambda x: x.round(1),
     marks=pytest.mark.xfail(reason="cupy doesn't support round"),
 ),
 lambda x: x.reshape((x.shape[0] * x.shape[1], x.shape[2])),
 # Rechunking here is required, see https://github.com/dask/dask/issues/2561
 lambda x: (x.rechunk(x.shape)).reshape(
     (x.shape[1], x.shape[0], x.shape[2])),
 lambda x: x.reshape(
     (x.shape[0], x.shape[1], x.shape[2] / 2, x.shape[2] / 2)),
 lambda x: abs(x),
 lambda x: x > 0.5,
 lambda x: x.rechunk((4, 4, 4)),
 lambda x: x.rechunk((2, 2, 1)),
 pytest.param(
     lambda x: da.einsum("ijk,ijk", x, x),
     marks=pytest.mark.xfail(
         reason=
         "depends on resolution of https://github.com/numpy/numpy/issues/12974"
     ),
 ),
 lambda x: np.isneginf(x),
 lambda x: np.isposinf(x),
 pytest.param(
     lambda x: np.isreal(x),
     marks=pytest.mark.skipif(
         not IS_NEP18_ACTIVE,
         reason="NEP-18 support is not available in NumPy"),
 ),
 pytest.param(
     lambda x: np.iscomplex(x),
Ejemplo n.º 38
0
 lambda x: x.max(axis=0),
 lambda x: x.sum(axis=(1, 2)),
 lambda x: x.astype(np.complex128),
 lambda x: x.map_blocks(lambda x: x * 2),
 pytest.param(lambda x: x.round(1)),
 lambda x: x.reshape((x.shape[0] * x.shape[1], x.shape[2])),
 # Rechunking here is required, see https://github.com/dask/dask/issues/2561
 lambda x: (x.rechunk(x.shape)).reshape(
     (x.shape[1], x.shape[0], x.shape[2])),
 lambda x: x.reshape(
     (x.shape[0], x.shape[1], x.shape[2] / 2, x.shape[2] / 2)),
 lambda x: abs(x),
 lambda x: x > 0.5,
 lambda x: x.rechunk((4, 4, 4)),
 lambda x: x.rechunk((2, 2, 1)),
 pytest.param(lambda x: da.einsum("ijk,ijk", x, x)),
 lambda x: np.isneginf(x),
 lambda x: np.isposinf(x),
 lambda x: np.isreal(x),
 lambda x: np.iscomplex(x),
 lambda x: np.real(x),
 lambda x: np.imag(x),
 lambda x: np.exp(x),
 lambda x: np.fix(x),
 lambda x: np.i0(x.reshape((24, ))),
 lambda x: np.sinc(x),
 lambda x: np.nan_to_num(x),
 lambda x: np.max(x),
 lambda x: np.min(x),
 lambda x: np.prod(x),
 lambda x: np.any(x),