コード例 #1
0
 def apply_fun(params, x_t, rng=None):
     if rng is None:
         raise ValueError("BidirectionalGRU apply_fun requires rng key.")
     rng, keys = utils.keygen(rng, 2)
     fwd_enc_t = run_gru(params['fwd_rnn'], x_t, rng=next(keys))
     bwd_enc_t = np.flipud(
         run_gru(params['bwd_rnn'], np.flipud(x_t), rng=next(keys)))
     enc_ends = np.concatenate([bwd_enc_t[0], fwd_enc_t[-1]], axis=1)
     return enc_ends
コード例 #2
0
ファイル: nvif.py プロジェクト: helange23/nvif
def compute_pis(p, n, EPS):

    N = p.shape[0]

    def compute_pi(q, n):

        q = q - logsumexp(q) + jnp.log(n)
        init = (q, 0, n)
        condfun = lambda op: jnp.max(op[0]) > EPS

        def bodyfun(carry):
            (q, iters, n) = carry
            q = jnp.clip(q, a_min=-jnp.inf, a_max=0)
            q = q - logsumexp(q) + jnp.log(n)
            return (q, iters + 1, n)

        (q, iters, n) = jax.lax.while_loop(condfun, bodyfun, init)

        return q, iters

    init = compute_pi(p, n)
    xs = jnp.arange(N - n) + n

    def f(carry, n):
        pi_nm1, iters = carry
        pi, itr = compute_pi(pi_nm1, n)
        return (pi, iters + itr), pi

    (_, iters), pis = jax.lax.scan(f, init, xs)

    return jnp.clip(jnp.flipud(pis), a_min=-jnp.inf, a_max=0), iters
コード例 #3
0
def compute_eigenvalue_decomposition(Ms,
                                     sort_by='magnitude',
                                     do_compute_lefts=True):
    """Compute the eigenvalues of the matrix M. No assumptions are made on M.

  Arguments: 
    M: 3D np.array nmatrices x dim x dim matrix
    do_compute_lefts: Compute the left eigenvectors? Requires a pseudo-inverse 
      call.

  Returns: 
    list of dictionaries with eigenvalues components: sorted 
      eigenvalues, sorted right eigenvectors, and sored left eigenvectors 
      (as column vectors).
  """
    if sort_by == 'magnitude':
        sort_fun = onp.abs
    elif sort_by == 'real':
        sort_fun = onp.real
    else:
        assert False, "Not implemented yet."

    decomps = []
    L = None
    for M in Ms:
        evals, R = onp.linalg.eig(M)
        indices = np.flipud(np.argsort(sort_fun(evals)))
        if do_compute_lefts:
            L = onp.linalg.pinv(R).T  # as columns
            L = L[:, indices]
        decomps.append({'evals': evals[indices], 'R': R[:, indices], 'L': L})

    return decomps
コード例 #4
0
def run_bidirectional_rnn(params, fwd_rnn, bwd_rnn, x_t):
  """Run an RNN encoder backwards and forwards over some time series data.

  Arguments:
    params: a dictionary of bidrectional RNN encoder parameters
    fwd_rnn: function for running forward rnn encoding
    bwd_rnn: function for running backward rnn encoding
    x_t: np array data for RNN input with leading dim being time

  Returns:
    tuple of np array concatenated forward, backward encoding, and
      np array of concatenation of [forward_enc(T), backward_enc(1)]
  """
  fwd_enc_t = run_rnn(params['fwd_rnn'], fwd_rnn, x_t)
  bwd_enc_t = np.flipud(run_rnn(params['bwd_rnn'], bwd_rnn, np.flipud(x_t)))
  full_enc = np.concatenate([fwd_enc_t, bwd_enc_t], axis=1)
  enc_ends = np.concatenate([bwd_enc_t[0], fwd_enc_t[-1]], axis=1)
  return full_enc, enc_ends
コード例 #5
0
    def other_cases(a, f):

        a, f = lax.cond(
            a > 2.0,
            None,
            lambda _: (a - 2.0, jnp.flipud(f)),
            None,
            lambda _: (a, f))

        a, f = lax.cond(
            a > 1.5,
            None,
            lambda _: (a - 1.0, index_update(f, index[shft], jnp.fft.fft(f[shft]) / sN)),
            None,
            lambda _: (a, f))

        a, f = lax.cond(
            a < 0.5,
            None,
            lambda _: (a + 1.0, index_update(f, index[shft], jnp.fft.ifft(f[shft]) * sN)),
            None,
            lambda _: (a, f))

        return chirp_opts(a, f)
コード例 #6
0
def flipud(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.flipud(x))
コード例 #7
0
def frft(f, a):
    """
    fast fractional fourier transform.
    Parameters
        f : [jax.]numpy array
            The signal to be transformed.
        a : float
            fractional power
    Returns
        data : [jax.]numpy array
            The transformed signal.
    reference:
        https://github.com/nanaln/python_frft
    """
    f = device_put(f)
    a = device_put(a)

    ret = jnp.zeros_like(f, dtype=jnp.complex64)
    f = f.astype(jnp.complex64)
    N = f.shape[0]

    shft = jnp.fmod(jnp.arange(N) + jnp.fix(N / 2), N).astype(int)
    sN = jnp.sqrt(N)
    a = jnp.remainder(a, 4.0)

    TRUE = jnp.array(True)
    FALSE = jnp.array(False)

    # simple cases
    ret, done = lax.cond(
        a == 0.0,
        None,
        lambda _: (f, TRUE),
        None,
        lambda _: (ret, FALSE))

    ret, done = lax.cond(
        a == 2.0,
        None,
        lambda _: (jnp.flipud(f), TRUE),
        None,
        lambda _: (ret, done))

    ret, done = lax.cond(
        a == 1.0,
        None,
        lambda _: (index_update(ret, index[shft], jnp.fft.fft(f[shft]) / sN), TRUE),
        None,
        lambda _: (ret, done))

    ret, done = lax.cond(
        a == 3.0,
        None,
        lambda _: (index_update(ret, index[shft], jnp.fft.ifft(f[shft]) * sN), TRUE),
        None,
        lambda _: (ret, done))

    @jit
    def sincinterp(x):
        N = x.shape[0]
        y = jnp.zeros(2 * N -1, dtype=x.dtype)
        y = index_update(y, index[:2 * N:2], x)
        xint = fftconvolve(
           y[:2 * N],
           jnp.sinc(jnp.arange(-(2 * N - 3), (2 * N - 2)).T / 2),
        )
        return xint[2 * N - 3: -2 * N + 3]

    @jit
    def chirp_opts(a, f):
        # the general case for 0.5 < a < 1.5
        alpha = a * jnp.pi / 2
        tana2 = jnp.tan(alpha / 2)
        sina = jnp.sin(alpha)
        f = jnp.hstack((jnp.zeros(N - 1), sincinterp(f), jnp.zeros(N - 1))).T

        # chirp premultiplication
        chrp = jnp.exp(-1j * jnp.pi / N * tana2 / 4 *
                         jnp.arange(-2 * N + 2, 2 * N - 1).T ** 2)
        f = chrp * f

        # chirp convolution
        c = jnp.pi / N / sina / 4
        ret = fftconvolve(
            jnp.exp(1j * c * jnp.arange(-(4 * N - 4), 4 * N - 3).T ** 2),
            f,
        )
        ret = ret[4 * N - 4:8 * N - 7] * jnp.sqrt(c / jnp.pi)

        # chirp post multiplication
        ret = chrp * ret

        # normalizing constant
        ret = jnp.exp(-1j * (1 - a) * jnp.pi / 4) * ret[N - 1:-N + 1:2]

        return ret

    def other_cases(a, f):

        a, f = lax.cond(
            a > 2.0,
            None,
            lambda _: (a - 2.0, jnp.flipud(f)),
            None,
            lambda _: (a, f))

        a, f = lax.cond(
            a > 1.5,
            None,
            lambda _: (a - 1.0, index_update(f, index[shft], jnp.fft.fft(f[shft]) / sN)),
            None,
            lambda _: (a, f))

        a, f = lax.cond(
            a < 0.5,
            None,
            lambda _: (a + 1.0, index_update(f, index[shft], jnp.fft.ifft(f[shft]) * sN)),
            None,
            lambda _: (a, f))

        return chirp_opts(a, f)

    ret = lax.cond(
        done,
        None,
        lambda _: ret,
        None,
        lambda _: other_cases(a, f))

    return ret