Example #1
0
def dhor(y1, y2, norm=2):
    x1 = jnp.arange(y1.size) * 0.05
    x2 = jnp.arange(y2.size) * 0.05
    ymin = 0
    ymax = min(jnp.max(y1), jnp.max(y2))
    yint = jnp.linspace(ymin, ymax, 100, endpoint=False)
    idx1 = jnp.argsort(y1)
    idx2 = jnp.argsort(y2)
    xint1 = jnp.interp(yint, y1[idx1], x1[idx1])
    xint2 = jnp.interp(yint, y2[idx2], x2[idx2])
    xint1 = spline.qinterp(yint, y1[idx1], x1[idx1])
    xint2 = spline.qinterp(yint, y2[idx2], x2[idx2])
    res = jnp.sum(jnp.power(jnp.abs(xint1 - xint2), norm))
    return res
Example #2
0
def getix(x, xv):
    """jnp version of getix.

    Args:
        x: x array
        xv: x grid 

    Returns:
        cont (contribution)
        index (index)

    Note:
       cont is the contribution for i=index+1. 1 - cont is the contribution for i=index. For other i, the contribution should be zero.

    Example:

       >>> from exojax.spec.dit import getix
       >>> import jax.numpy as jnp
       >>> y=jnp.array([1.1,4.3])
       >>> yv=jnp.arange(6)
       >>> getix(y,yv)
       (DeviceArray([0.10000002, 0.3000002 ], dtype=float32), DeviceArray([1, 4], dtype=int32))
    """
    indarr = jnp.arange(len(xv))
    pos = jnp.interp(x, xv, indarr)
    index = (pos).astype(int)
    cont = (pos - index)
    return cont, index
Example #3
0
 def _scale_action(do_scale, action):
     # I read somewhere that the actions should be scaled to [-1,1],
     # scale it back to [0,1] here...
     if do_scale:
         scaled_action = jnp.interp(action, jnp.array([-1, 1]),
                                    jnp.array([0, 1]))
     else:
         scaled_action = action
     return scaled_action
Example #4
0
def weighted_percentile(x, w, ps, assume_sorted=False):
    """Compute the weighted percentile(s) of a single vector."""
    x = x.reshape([-1])
    w = w.reshape([-1])
    if not assume_sorted:
        sortidx = jnp.argsort(jax.lax.stop_gradient(x))
        x, w = x[sortidx], w[sortidx]
    acc_w = jnp.cumsum(w)
    return jnp.interp(jnp.array(ps) * (acc_w[-1] / 100), acc_w, x)
Example #5
0
    def gradient_transform(params, inputs):

        outputs = forward_transform(params, inputs)

        absdet = np.interp(inputs, params.support_pdf, params.empirical_pdf)

        logabsdet = np.log(absdet)

        return outputs, logabsdet
Example #6
0
def smooth_wave_fft(wavelength,
                    spectrum,
                    outwave,
                    sigma_out=1.0,
                    inres=0.0,
                    **extras):
    """Smooth a spectrum in wavelength space, using FFTs.  This is fast, but
    makes some assumptions about the input spectrum, and can have some
    issues at the ends of the spectrum depending on how it is padded.
    :param wavelength:
        Wavelength vector of the input spectrum.
    :param spectrum:
        Flux vector of the input spectrum.
    :param outwave:
        Desired output wavelength vector.
    :param sigma:
        Desired resolution (*not* FWHM) in wavelength units.
    :param inres:
        Resolution of the input, in wavelength units (dispersion not FWHM).
    :returns flux:
        The output smoothed flux vector, same length as ``outwave``.
    """
    # restrict wavelength range (for speed)
    # should also make nearest power of 2
    wave, spec = resample_wave(wavelength, spectrum, linear=True)

    # The kernel width for the convolution.
    sigma = np.sqrt(sigma_out**2 - inres**2)
    if sigma < 0:
        return np.interp(wave, outwave, flux)

    # get grid resolution (*not* the resolution of the input spectrum) and make
    # sure it's nearly constant.  Should be by design (see resample_wave)
    Rgrid = np.diff(wave)
    assert Rgrid.max() / Rgrid.min() < 1.05
    dw = np.median(Rgrid)

    # Do the convolution
    spec_conv = smooth_fft(dw, spec, sigma)
    # interpolate onto output grid
    if outwave is not None:
        spec_conv = np.interp(outwave, wave, spec_conv)
    return spec_conv
Example #7
0
def test_hist_params_inv_transform():

    X = rng.randn(1_000)

    X_u, params = get_hist_params(X,
                                  support_extension=10,
                                  precision=100,
                                  alpha=1e-5)

    X_approx = np.interp(X_u, params.quantiles, params.support)

    chex.assert_tree_all_close(np.array(X), X_approx, atol=1e-4)
Example #8
0
def test_hist_params_transform():

    X = rng.randn(100)

    X_u, params = get_hist_params(X,
                                  support_extension=10,
                                  precision=50,
                                  alpha=1e-5)

    X_u_trans = np.interp(X, params.support, params.quantiles)

    chex.assert_tree_all_close(X_u, X_u_trans)
Example #9
0
def sampling(nusd, nus, F, RV):
    """Sampling w/ RV.

    Args:
        nusd: sampling wavenumber
        nus: input wavenumber
        F: input spectrum
        RV: radial velocity (km/s)

    Returns:
       sampled spectrum
    """
    return jnp.interp(nusd * (1.0 + RV / c), nus, F)
Example #10
0
def smooth_lsf(wave,
               spec,
               outwave,
               sigma=None,
               lsf=None,
               return_kernel=False,
               **kwargs):
    """Broaden a spectrum using a wavelength dependent line spread function.
    This function is only approximate because it doesn't actually do the
    integration over pixels, so for sparsely sampled points you'll have
    problems.  This function needs to be checked and possibly rewritten.
    :param wave:
        Input wavelengths.  ndarray of shape (nin,)
    :param spec:
        Input spectrum.  ndarray of same shape as ``wave``.
    :param outwave:
        Output wavelengths, ndarray of shape (nout,)
    :param sigma: (optional, default: None)
        The dispersion (not FWHM) as a function of wavelength that you want to
        apply to the input spectrum.  ``None`` or ndarray of same length as
        ``outwave``.  If ``None`` then the wavelength dependent dispersion will be
        calculated from the function supplied with the ``lsf`` keyward.
    :param lsf:
        A function that returns the gaussian dispersion at each wavelength.
        This is assumed to be in sigma, not FWHM.
    :param kwargs:
        Passed to the function supplied in the ``lsf`` keyword.
    :param return_kernel: (optional, default: False)
        If True, return the kernel used to broaden the spectrum as ndarray of
        shape (nout, nin).
    :returns newspec:
        The broadened spectrum, same length as ``outwave``.
    """
    if (lsf is None) and (sigma is None):
        return np.interp(outwave, wave, spec)
    dw = np.gradient(wave)
    if sigma is None:
        sigma = lsf(outwave, **kwargs)
    kernel = outwave[:, None] - wave[None, :]
    kernel = (1 / (sigma * np.sqrt(np.pi * 2))[:, None] *
              np.exp(-kernel**2 / (2 * sigma[:, None]**2)) * dw[None, :])
    # should this be axis=0 or axis=1?
    kernel = kernel / kernel.sum(axis=1)[:, None]
    newspec = np.dot(kernel, spec)
    # kernel /= np.trapz(kernel, wave, axis=1)[:, None]
    # newspec = np.trapz(kernel * spec[None, :], wave, axis=1)
    if return_kernel:
        return newspec, kernel
    return newspec
Example #11
0
def get_alpha(name: str) -> Array:
    """Load alpha of a material from the materials library

    Args:
        name (str): Name of material

    Returns:
        Array: Array of absorption coefficients for material
    """
    df = pd.read_csv(
        os.path.join(os.path.dirname(__file__), f"resources/{name}.csv"))
    _lam = jnp.array(df[df.columns[0]])
    _alpha = jnp.array(df[df.columns[4]])
    alpha = jnp.interp(lam_interp, _lam, _alpha)

    return alpha
Example #12
0
def mimofoeaf(scope: Scope,
              signal,
              framesize=100,
              w0=0,
              train=False,
              preslicer=lambda x: x,
              foekwargs={},
              mimofn=af.rde,
              mimokwargs={},
              mimoinitargs={}):

    sps = 2
    dims = 2
    tx = signal.t
    # MIMO
    slisig = preslicer(signal)
    auxsig = scope.child(mimoaf,
                         mimofn=mimofn,
                         train=train,
                         mimokwargs=mimokwargs,
                         mimoinitargs=mimoinitargs,
                         name='MIMO4FOE')(slisig)
    y, ty = auxsig  # assume y is continuous in time
    yf = xop.frame(y, framesize, framesize)

    foe_init, foe_update, _ = af.array(af.frame_cpr_kf, dims)(**foekwargs)
    state = scope.variable('af_state', 'framefoeaf', lambda *_:
                           (0., 0, foe_init(w0)), ())
    phi, af_step, af_stats = state.value

    af_step, (af_stats, (wf, _)) = af.iterate(foe_update, af_step, af_stats,
                                              yf)
    wp = wf.reshape((-1, dims)).mean(axis=-1)
    w = jnp.interp(
        jnp.arange(y.shape[0] * sps) / sps,
        jnp.arange(wp.shape[0]) * framesize + (framesize - 1) / 2, wp) / sps
    psi = phi + jnp.cumsum(w)
    state.value = (psi[-1], af_step, af_stats)

    # apply FOE to original input signal via linear extrapolation
    psi_ext = jnp.concatenate([
        w[0] * jnp.arange(tx.start - ty.start * sps, 0) + phi, psi,
        w[-1] * jnp.arange(tx.stop - ty.stop * sps) + psi[-1]
    ])

    signal = signal * jnp.exp(-1j * psi_ext)[:, None]
    return signal
Example #13
0
def hist_forward_transform(params: UniHistParams, X: Array):
    """Forward univariate uniformize transformation
    
    Parameters
    ----------
    X : Array
        The univariate data to be transformed.
    
    params: UniParams
        the tuple containing the params. 
        See `rbig_jax.transforms.uniformize` for details.
    
    Returns
    -------
    X_trans : Array
        The transformed univariate parameters
    """
    return jnp.interp(X, params.support, params.quantiles)
Example #14
0
def kde_inverse_transform(params: UniKDEParams, X: Array) -> Array:
    """Inverse univariate uniformize transformation
    
    Parameters
    ----------
    X : jnp.ndarray
        The uniform univariate data to be transformed.
    
    params: UniParams
        the tuple containing the params. 
        See `rbig_jax.transforms.histogram` for details.
    
    Returns
    -------
    X_trans : jnp.ndarray
        The transformed univariate parameters
    """
    return jnp.interp(X, params.quantiles, params.support)
Example #15
0
def kde_gradient_transform(params: UniKDEParams, X: Array) -> Array:
    """Forward univariate uniformize transformation gradient
    
    Parameters
    ----------
    X : jnp.ndarray
        The univariate data to be transformed.
    
    params: UniParams
        the tuple containing the params. 
        See `rbig_jax.transforms.histogram` for details.
    
    Returns
    -------
    X_trans : jnp.ndarray
        The transformed univariate parameters
    """
    return jnp.interp(X, params.support_pdf, params.empirical_pdf)
Example #16
0
def interp_QT284(T, T_gQT, gQT_284species):
    """interpolated partition function of all 284 species.

    Args:
        T: temperature
        T_gQT: temperature in the grid obtained from the adb instance [N_grid(42)]
        gQT_284species: partition function in the grid from the adb instance [N_species(284) x N_grid(42)]

    Returns:
        QT_284: interpolated partition function at T Q(T) for all 284 Atomic Species [284]
    """
    list_gQT_eachspecies = gQT_284species.tolist()
    listofDA_gQT_eachspecies = list(
        map(lambda x: jnp.array(x), list_gQT_eachspecies))
    listofQT = list(
        map(lambda x: jnp.interp(T, T_gQT, x), listofDA_gQT_eachspecies))
    QT_284 = jnp.array(listofQT)
    return QT_284
Example #17
0
def uniformize_transform(X: np.ndarray, params: UniParams) -> np.ndarray:
    """Forward univariate uniformize transformation
    
    Parameters
    ----------
    X : np.ndarray
        The univariate data to be transformed.
    
    params: UniParams
        the tuple containing the params. 
        See `rbig_jax.transforms.uniformize` for details.
    
    Returns
    -------
    X_trans : np.ndarray
        The transformed univariate parameters
    """
    return np.interp(X, params.support, params.quantiles)
Example #18
0
def _power_local(y, frame_size, frame_step, sps):
    yf = xop.frame(y, frame_size, frame_step, True)

    N = y.shape[0]
    frames = yf.shape[0]

    _, power = xop.scan(lambda c, y: (c, jnp.mean(jnp.abs(y)**2, axis=0)),
                        None, yf)

    xp = jnp.arange(frames) * frame_step + frame_size // 2
    x = jnp.arange(N * sps) / sps

    interp = vmap(lambda x, xp, fp: jnp.interp(x, xp, fp),
                  in_axes=(None, None, -1),
                  out_axes=-1)

    power_ip = interp(x, xp, power)

    return power_ip
Example #19
0
 def _generate_lambda(self):
     # Draw a lambda (here: negative real for starters)
     # The number of episodes is always smaller than the number of
     # time steps, keep that in mind for the interpolation
     # hyperparameters.
     if self.lambda_real_interpolation_interval is not None:
         lam_low = jnp.interp(self.num_episodes,
                              self.lambda_real_interpolation_interval,
                              self.lambda_real_interval_reversed)
     else:
         lam_low = self.lambda_real_interval[0]
     rng_key, subkey = jax.random.split(self.rng_key)
     self.rng_key, subkey2 = jax.random.split(rng_key)
     self.lam = (
         1 * jax.random.uniform(subkey, (self.batch_size, ),
                                minval=lam_low,
                                maxval=self.lambda_real_interval[1]) +
         1j * jax.random.uniform(subkey2, (self.batch_size, ),
                                 minval=self.lambda_imag_interval[0],
                                 maxval=self.lambda_imag_interval[1]))
Example #20
0
    def predictcont(self, labels):
        '''
          predict continuum using set of labels and trained NN output

          :params labels:
          list of label values for the labels used to train the NN
          ex. [Teff,log(g),[Fe/H],[alpha/Fe]]

          :returns predict_flux:
          predicted flux from the NN
          '''

        predict_cont = self.Canns.eval(labels)
        modcontwave = self.Canns.wavelength

        # convert the continuum from F_nu -> F_lambda
        modcont = predict_cont * (speedoflight / ((modcontwave * 1E-8)**2.0))

        # normalize the continuum
        modcont = modcont / np.nanmedian(modcont)

        # interpolate continuum onto spectrum
        return np.interp(self.anns.wavelength, modcontwave, modcont)
Example #21
0
def resample_wave(wavelength, spectrum, linear=False):
    """Resample spectrum, so that the number of elements is the next highest
    power of two.  This uses np.interp.  Note that if the input wavelength grid
    did not critically sample the spectrum then there is no gaurantee the
    output wavelength grid will.
    """
    wmin, wmax = wavelength.min(), wavelength.max()
    nw = len(wavelength)
    nnew = int(2.0**(nnp.ceil(nnp.log2(nw))))
    # if linear:
    #     Rgrid = np.diff(wavelength)  # in same units as ``wavelength``
    #     wi = nnp.linspace(wmin, wmax, int(nnew))
    # else:
    Rgrid = np.diff(np.log(wavelength))  # actually 1/R
    lnlam = np.linspace(np.log(wmin), np.log(wmax), int(nnew))
    wi = np.exp(lnlam)

    # Make sure the resolution really is nearly constant
    #assert Rgrid.max() / Rgrid.min() < 1.05
    si = np.interp(wi, wavelength, spectrum)

    w, s = wi, si  #np.array(wi),np.array(si)

    return w, s
Example #22
0
def _foe_local(y, frame_size, frame_step, sps):

    Y = xop.frame(y, frame_size, frame_step, True)

    N = y.shape[0]
    frames = Y.shape[0]

    def foe(carray, y):
        fo_hat, _ = foe_mpowfftmax(y)
        return carray, fo_hat

    _, fo_hat = xop.scan(foe, None, Y)

    xp = jnp.arange(frames) * frame_step + frame_size // 2
    x = jnp.arange(N * sps) / sps
    fo_hat /= sps

    interp = vmap(lambda x, xp, fp: jnp.interp(x, xp, fp),
                  in_axes=(None, None, -1),
                  out_axes=-1)

    fo_hat_ip = interp(x, xp, fo_hat)

    return fo_hat_ip
Example #23
0
def hist_inverse_transform(params: UniHistParams, X: JaxArray) -> np.ndarray:
    return np.interp(X, params.quantiles, params.support)
Example #24
0
 def fcia(x, i): return jnp.interp(x, tcia, logac[:, i])
 vfcia = vmap(fcia, (None, 0), 0)
Example #25
0
    X: jnp.ndarray,
    support_extension: Union[int, float] = 10,
    precision: int = 1_000,
):
    # generate support points
    lb, ub = get_domain_extension(X, support_extension)
    grid = jnp.linspace(lb, ub, precision)

    bw = scotts_method(X.shape[0], 1) * 0.5

    # calculate the cdf for grid points
    factor = normalization_factor(X, bw)

    x_cdf = broadcast_kde_cdf(grid, X, factor)

    X_transform = jnp.interp(X, grid, x_cdf)

    return X_transform


def broadcast_kde_pdf(eval_points, samples, bandwidth):

    n_samples = samples.shape[0]
    # print(n_samples, bandwidth)

    # distances (use broadcasting)
    rescaled_x = (eval_points[:, jnp.newaxis] -
                  samples[jnp.newaxis, :]) / bandwidth  # (2 * bandwidth ** 2)
    # print(rescaled_x.shape)
    # compute the gaussian kernel
    gaussian_kernel = 1 / jnp.sqrt(2 * jnp.pi) * jnp.exp(-0.5 * rescaled_x**2)
Example #26
0
 def at(self, t):
     # return jnp.interp(t, self.xp, self.fp, period=self.period)
     return jnp.interp(t, self.xp, self.fp, period=3)
Example #27
0
 def inverse_transform(params, inputs):
     return np.interp(inputs, params.quantiles, params.support)
Example #28
0
    lb, ub = get_domain_extension(X, support_extension)

    # get new bin edges
    new_bin_edges = np.hstack((
        lb,
        np.min(X),
        bin_centers + incr_bin,
        ub,
    ))

    extended_cdf = np.hstack((0.0, 1.0 / n_samples, cdf, 1.0))

    new_support = np.linspace(new_bin_edges[0], new_bin_edges[-1],
                              int(precision))

    uniform_cdf = jax.lax.cummax(np.interp(new_support, new_bin_edges,
                                           extended_cdf),
                                 axis=0)

    # Normalize CDF estimation
    uniform_cdf /= np.max(uniform_cdf)

    # forward transformation
    outputs = np.interp(X, new_support, uniform_cdf)

    if return_params is True:

        # initialize parameters
        params = UniHistParams(
            support=new_support,
            quantiles=uniform_cdf,
            support_pdf=pdf_support,
Example #29
0
def hist_gradient_transform(params: UniHistParams, X: JaxArray) -> np.ndarray:
    return np.interp(X, params.support_pdf, params.empirical_pdf)
Example #30
0
def hist_forward_transform(params: UniHistParams, X: JaxArray):
    return np.interp(X, params.support, params.quantiles)