def apply_fun(params, x_t, rng=None): if rng is None: raise ValueError("BidirectionalGRU apply_fun requires rng key.") rng, keys = utils.keygen(rng, 2) fwd_enc_t = run_gru(params['fwd_rnn'], x_t, rng=next(keys)) bwd_enc_t = np.flipud( run_gru(params['bwd_rnn'], np.flipud(x_t), rng=next(keys))) enc_ends = np.concatenate([bwd_enc_t[0], fwd_enc_t[-1]], axis=1) return enc_ends
def compute_pis(p, n, EPS): N = p.shape[0] def compute_pi(q, n): q = q - logsumexp(q) + jnp.log(n) init = (q, 0, n) condfun = lambda op: jnp.max(op[0]) > EPS def bodyfun(carry): (q, iters, n) = carry q = jnp.clip(q, a_min=-jnp.inf, a_max=0) q = q - logsumexp(q) + jnp.log(n) return (q, iters + 1, n) (q, iters, n) = jax.lax.while_loop(condfun, bodyfun, init) return q, iters init = compute_pi(p, n) xs = jnp.arange(N - n) + n def f(carry, n): pi_nm1, iters = carry pi, itr = compute_pi(pi_nm1, n) return (pi, iters + itr), pi (_, iters), pis = jax.lax.scan(f, init, xs) return jnp.clip(jnp.flipud(pis), a_min=-jnp.inf, a_max=0), iters
def compute_eigenvalue_decomposition(Ms, sort_by='magnitude', do_compute_lefts=True): """Compute the eigenvalues of the matrix M. No assumptions are made on M. Arguments: M: 3D np.array nmatrices x dim x dim matrix do_compute_lefts: Compute the left eigenvectors? Requires a pseudo-inverse call. Returns: list of dictionaries with eigenvalues components: sorted eigenvalues, sorted right eigenvectors, and sored left eigenvectors (as column vectors). """ if sort_by == 'magnitude': sort_fun = onp.abs elif sort_by == 'real': sort_fun = onp.real else: assert False, "Not implemented yet." decomps = [] L = None for M in Ms: evals, R = onp.linalg.eig(M) indices = np.flipud(np.argsort(sort_fun(evals))) if do_compute_lefts: L = onp.linalg.pinv(R).T # as columns L = L[:, indices] decomps.append({'evals': evals[indices], 'R': R[:, indices], 'L': L}) return decomps
def run_bidirectional_rnn(params, fwd_rnn, bwd_rnn, x_t): """Run an RNN encoder backwards and forwards over some time series data. Arguments: params: a dictionary of bidrectional RNN encoder parameters fwd_rnn: function for running forward rnn encoding bwd_rnn: function for running backward rnn encoding x_t: np array data for RNN input with leading dim being time Returns: tuple of np array concatenated forward, backward encoding, and np array of concatenation of [forward_enc(T), backward_enc(1)] """ fwd_enc_t = run_rnn(params['fwd_rnn'], fwd_rnn, x_t) bwd_enc_t = np.flipud(run_rnn(params['bwd_rnn'], bwd_rnn, np.flipud(x_t))) full_enc = np.concatenate([fwd_enc_t, bwd_enc_t], axis=1) enc_ends = np.concatenate([bwd_enc_t[0], fwd_enc_t[-1]], axis=1) return full_enc, enc_ends
def other_cases(a, f): a, f = lax.cond( a > 2.0, None, lambda _: (a - 2.0, jnp.flipud(f)), None, lambda _: (a, f)) a, f = lax.cond( a > 1.5, None, lambda _: (a - 1.0, index_update(f, index[shft], jnp.fft.fft(f[shft]) / sN)), None, lambda _: (a, f)) a, f = lax.cond( a < 0.5, None, lambda _: (a + 1.0, index_update(f, index[shft], jnp.fft.ifft(f[shft]) * sN)), None, lambda _: (a, f)) return chirp_opts(a, f)
def flipud(x): if isinstance(x, JaxArray): x = x.value return JaxArray(jnp.flipud(x))
def frft(f, a): """ fast fractional fourier transform. Parameters f : [jax.]numpy array The signal to be transformed. a : float fractional power Returns data : [jax.]numpy array The transformed signal. reference: https://github.com/nanaln/python_frft """ f = device_put(f) a = device_put(a) ret = jnp.zeros_like(f, dtype=jnp.complex64) f = f.astype(jnp.complex64) N = f.shape[0] shft = jnp.fmod(jnp.arange(N) + jnp.fix(N / 2), N).astype(int) sN = jnp.sqrt(N) a = jnp.remainder(a, 4.0) TRUE = jnp.array(True) FALSE = jnp.array(False) # simple cases ret, done = lax.cond( a == 0.0, None, lambda _: (f, TRUE), None, lambda _: (ret, FALSE)) ret, done = lax.cond( a == 2.0, None, lambda _: (jnp.flipud(f), TRUE), None, lambda _: (ret, done)) ret, done = lax.cond( a == 1.0, None, lambda _: (index_update(ret, index[shft], jnp.fft.fft(f[shft]) / sN), TRUE), None, lambda _: (ret, done)) ret, done = lax.cond( a == 3.0, None, lambda _: (index_update(ret, index[shft], jnp.fft.ifft(f[shft]) * sN), TRUE), None, lambda _: (ret, done)) @jit def sincinterp(x): N = x.shape[0] y = jnp.zeros(2 * N -1, dtype=x.dtype) y = index_update(y, index[:2 * N:2], x) xint = fftconvolve( y[:2 * N], jnp.sinc(jnp.arange(-(2 * N - 3), (2 * N - 2)).T / 2), ) return xint[2 * N - 3: -2 * N + 3] @jit def chirp_opts(a, f): # the general case for 0.5 < a < 1.5 alpha = a * jnp.pi / 2 tana2 = jnp.tan(alpha / 2) sina = jnp.sin(alpha) f = jnp.hstack((jnp.zeros(N - 1), sincinterp(f), jnp.zeros(N - 1))).T # chirp premultiplication chrp = jnp.exp(-1j * jnp.pi / N * tana2 / 4 * jnp.arange(-2 * N + 2, 2 * N - 1).T ** 2) f = chrp * f # chirp convolution c = jnp.pi / N / sina / 4 ret = fftconvolve( jnp.exp(1j * c * jnp.arange(-(4 * N - 4), 4 * N - 3).T ** 2), f, ) ret = ret[4 * N - 4:8 * N - 7] * jnp.sqrt(c / jnp.pi) # chirp post multiplication ret = chrp * ret # normalizing constant ret = jnp.exp(-1j * (1 - a) * jnp.pi / 4) * ret[N - 1:-N + 1:2] return ret def other_cases(a, f): a, f = lax.cond( a > 2.0, None, lambda _: (a - 2.0, jnp.flipud(f)), None, lambda _: (a, f)) a, f = lax.cond( a > 1.5, None, lambda _: (a - 1.0, index_update(f, index[shft], jnp.fft.fft(f[shft]) / sN)), None, lambda _: (a, f)) a, f = lax.cond( a < 0.5, None, lambda _: (a + 1.0, index_update(f, index[shft], jnp.fft.ifft(f[shft]) * sN)), None, lambda _: (a, f)) return chirp_opts(a, f) ret = lax.cond( done, None, lambda _: ret, None, lambda _: other_cases(a, f)) return ret