Example #1
0
def Ainv(*x):
    dark = np.dot(ainv, np.vstack([*x]) - p0)
    return (dark[0, :], dark[1, :])
Example #2
0
 def one_rollout(rollout_rng, p):
   _, _, rewards = rollout(rollout_rng, env, policy(p), num_timesteps)
   return jp.dot(rewards, jp.power(gamma, jp.arange(num_timesteps)))
Example #3
0
 def loss(x, y):
     xyw = jn.dot(x * np.tile(y, (ndim, 1)).transpose(), w.value)
     return jn.log(jn.exp(-xyw) + 1).mean(0)
Example #4
0
 def f(x):
     return jnp.dot(jnp.sin(x), x.T) * 4 + x
Example #5
0
 def restricted_func_and_grad(t):
     phi, g = jax.value_and_grad(f)(xk + t * pk)
     dphi = jnp.dot(g, pk)
     return phi, dphi, g
Example #6
0
 def loss(x, y):
     pred = jn.dot(x, w.value) + b.value
     b.assign(b.value + 1)
     w.assign(w.value - 1)
     return 0.5 * ((y - pred)**2).mean()
Example #7
0
 def precision_matrix(self):
     scale_tril_inv = np.linalg.inv(self.scale_tril)
     return np.dot(scale_tril_inv.T, scale_tril_inv)
Example #8
0
def logistic_predictions(weights, inputs):
    return sigmoid(np.dot(inputs, weights))
Example #9
0
File: stax.py Project: yotarok/jax
 def apply_fun(params, inputs, **kwargs):
     W, b = params
     return np.dot(inputs, W) + b
Example #10
0
 def covariance_matrix(self):
     return np.dot(self.scale_tril, self.scale_tril.T)
Example #11
0
def model(data, labels):
    N, dim = data.shape
    coefs = sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim)))
    logits = np.dot(data, coefs)
    return sample('obs', dist.Bernoulli(logits=logits), obs=labels)
Example #12
0
 def pmvm(a, b):
   a = a.reshape((nrep, -1, a.shape[1]))
   func = pmap(lambda z: np.dot(z, b))
   return func(a).reshape(b.shape)
Example #13
0
    nC = -1
    basis = 'ELMTanh'
else:
    m = 20
    nC = [2, 2]
    basis = 'LeP'

# Create the TFC Class:
N = [n, n]
myTfc = mtfc(N, nC, m, dim=2, basis=basis, x0=x0, xf=xf)

# Create the constrained expression:
H = myTfc.H
x = myTfc.x

u1 = lambda xi, *x: np.dot(H(*x), xi) - (1. - x[0]) * np.dot(
    H(np.zeros_like(x[0]), x[1]), xi) - x[0] * np.dot(
        H(np.ones_like(x[0]), x[1]), xi)
u1t = egrad(u1, 2)
u = lambda xi, *x: u1(xi, *x) + np.sin(np.pi * x[0]) - u1(
    xi, x[0], np.zeros_like(x[1])) - x[1] * u1t(xi, x[0], np.zeros_like(x[1]))

# Create the residual
uxx = egrad(egrad(u, 1), 1)
utt = egrad(egrad(u, 2), 2)

L = lambda xi, *x: uxx(xi, *x) - utt(xi, *x)

# Solve the problem
xi = np.zeros(H(*x).shape[1])
Example #14
0
 def matvec(A, b):
     return np.dot(A, b)
Example #15
0
    def initialize(self, u_dim, y_dim, hid_dim=64):
        """
        Description: Randomly initialize the RNN.
        Args:
            u_dim (int): Input dimension.
            y_dim (int): Observation/output dimension.
            hid_dim (int): Default value 64. Hidden dimension of RNN.
        Returns:
            The first value in the time-series
        """

        self.T = 0
        self.initialized = True
        # self.n, self.m, self.h = n, m, h

        self.u_dim = u_dim  # input dimension
        self.y_dim = y_dim  # output dimension
        self.hid_dim = hid_dim  # hidden state dimension
        self.cell_dim = hid_dim  # observable state dimension

        # self.m = self.y_dim # state dimension
        # self.n = self.u_dim # input dimension
        self.rollout_controller = None
        self.target = jax.random.uniform(generate_key(),
                                         shape=(self.y_dim, ),
                                         minval=-1,
                                         maxval=1)

        glorot_init = stax.glorot(
        )  # returns a function that initializes weights
        self.W_hh = glorot_init(
            generate_key(),
            (4 * self.hid_dim, self.hid_dim))  # maps h_t to gates
        self.W_uh = glorot_init(
            generate_key(),
            (4 * self.hid_dim, self.u_dim))  # maps x_t to gates
        self.b_h = np.zeros(4 * self.hid_dim)
        self.b_h = jax.ops.index_update(
            self.b_h, jax.ops.index[self.hid_dim:2 * self.hid_dim],
            np.ones(self.hid_dim))  # forget gate biased initialization
        self.W_out = glorot_init(
            generate_key(), (self.y_dim, self.hid_dim))  # maps h_t to output
        # self.cell = np.zeros(self.hid_dim) # long-term memory
        # self.hid = np.zeros(self.hid_dim) # short-term memory
        self.hid_cell = np.hstack(
            (np.zeros(self.hid_dim), np.zeros(self.hid_dim)))
        '''
        def _step(x, hid, cell):
            sigmoid = lambda x: 1. / (1. + np.exp(-x)) # no JAX implementation of sigmoid it seems?
            gate = np.dot(self.W_hh, hid) + np.dot(self.W_uh, x) + self.b_h 
            i, f, g, o = np.split(gate, 4) # order: input, forget, cell, output
            next_cell =  sigmoid(f) * cell + sigmoid(i) * np.tanh(g)
            next_hid = sigmoid(o) * np.tanh(next_cell)
            y = np.dot(self.W_out, next_hid)
            return (next_hid, next_cell, y)'''
        def _dynamics(hid_cell_state, u):
            hid = hid_cell_state[:self.hid_dim]
            cell = hid_cell_state[self.hid_dim:]

            sigmoid = lambda u: 1. / (1. + np.exp(-u))
            gate = np.dot(self.W_hh, hid) + np.dot(self.W_uh, u) + self.b_h
            i, f, g, o = np.split(gate,
                                  4)  # order: input, forget, cell, output
            next_cell = sigmoid(f) * cell + sigmoid(i) + np.tanh(g)
            next_hid = sigmoid(o) * np.tanh(next_cell)
            y = np.dot(self.W_out, next_hid)
            return (np.hstack((next_hid, next_cell)), y)

        self._dynamics = jax.jit(
            _dynamics
        )  # MUST store as self._dynamics for default rollout implementation to work
        # C_x, C_u = (np.diag(np.array([0.2, 0.05, 1.0, 0.05])), np.diag(np.array([0.05])))
        # self._loss = jax.jit(lambda x, u: x.T @ C_x @ x + u.T @ C_u @ u) # MUST store as self._loss
        self._loss = lambda x, u: (self.target - self._dynamics(x, u))**2

        # stack the jacobians of environment dynamics gradient
        jacobian = jax.jacrev(self._dynamics, argnums=(0, 1))
        self._dynamics_jacobian = jax.jit(
            lambda x, u: np.hstack(jacobian(x, u)))

        # stack the gradients of environment loss
        loss_grad = jax.grad(self._loss, argnums=(0, 1))
        self._loss_grad = jax.jit(lambda x, u: np.hstack(loss_grad(x, u)))

        # block the hessian of environment loss
        block_hessian = lambda A: np.vstack(
            [np.hstack([A[0][0], A[0][1]]),
             np.hstack([A[1][0], A[1][1]])])
        hessian = jax.hessian(self._loss, argnums=(0, 1))
        self._loss_hessian = jax.jit(lambda x, u: block_hessian(hessian(x, u)))

        def _rollout(act, dyn, x_0, T):
            def f(x, i):
                u = act(x)
                x_next = dyn(x, u)
                return x_next, np.hstack((x, u))

            _, trajectory = jax.lax.scan(f, x_0, np.arange(T))
            return trajectory

        self._rollout = jax.jit(_rollout, static_argnums=(0, 1, 3))

        # self._step = jax.jit(_step)
        # return np.dot(self.W_out, self.hid)
        return np.dot(self.W_out, self.hid_cell[:self.hid_dim])
Example #16
0
 def f(x, y):
     return x.T @ amat @ y + np.dot(y, y)
Example #17
0
def ppo_loss_given_predictions(log_probab_actions_new,
                               log_probab_actions_old,
                               value_predictions_old,
                               padded_actions,
                               rewards_to_actions,
                               padded_rewards,
                               reward_mask,
                               gamma,
                               lambda_,
                               epsilon):
  """PPO objective, with an eventual minus sign, given predictions."""
  B, RT = padded_rewards.shape  # pylint: disable=invalid-name
  _, AT, A = log_probab_actions_old.shape  # pylint: disable=invalid-name

  assert (B, RT) == padded_rewards.shape
  assert (B, AT) == padded_actions.shape
  assert (B, RT) == reward_mask.shape

  assert (B, RT + 1) == value_predictions_old.shape
  assert (B, AT, A) == log_probab_actions_old.shape
  assert (B, AT, A) == log_probab_actions_new.shape

  assert (RT + 1, AT) == rewards_to_actions.shape

  # (B, RT)
  td_deltas = deltas(
      value_predictions_old,  # (B, RT+1)
      padded_rewards,
      reward_mask,
      gamma=gamma)

  # (B, RT)
  advantages = gae_advantages(
      td_deltas, reward_mask, lambda_=lambda_, gamma=gamma)

  # Normalize the advantages.
  advantage_mean = np.mean(advantages)
  advantage_std = np.std(advantages)
  advantages = (advantages - advantage_mean) / (advantage_std + 1e-8)

  # Scatter advantages over padded_actions.
  # rewards_to_actions is RT + 1 -> AT, so we pad the advantages and the reward
  # mask by 1.
  advantages = np.dot(np.pad(advantages, ((0, 0), (0, 1))), rewards_to_actions)
  action_mask = np.dot(
      np.pad(reward_mask, ((0, 0), (0, 1))), rewards_to_actions
  )

  # (B, AT)
  ratios = compute_probab_ratios(log_probab_actions_new, log_probab_actions_old,
                                 padded_actions, action_mask)
  assert (B, AT) == ratios.shape

  # (B, AT)
  objective = clipped_objective(
      ratios, advantages, action_mask, epsilon=epsilon)
  assert (B, AT) == objective.shape

  # ()
  average_objective = np.sum(objective) / np.sum(action_mask)

  # Loss is negative objective.
  ppo_loss = -average_objective

  summaries = {
      'ppo_loss': ppo_loss,
      'advantage_mean': advantage_mean,
      'advantage_std': advantage_std,
  }

  return (ppo_loss, summaries)
Example #18
0
def line_search(f,
                xk,
                pk,
                old_fval=None,
                old_old_fval=None,
                gfk=None,
                c1=1e-4,
                c2=0.9,
                maxiter=20):
    """Inexact line search that satisfies strong Wolfe conditions.

  Algorithm 3.5 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 59-61

  Args:
    fun: function of the form f(x) where x is a flat ndarray and returns a real
      scalar. The function should be composed of operations with vjp defined.
    x0: initial guess.
    pk: direction to search in. Assumes the direction is a descent direction.
    old_fval, gfk: initial value of value_and_gradient as position.
    old_old_fval: unused argument, only for scipy API compliance.
    maxiter: maximum number of iterations to search
    c1, c2: Wolfe criteria constant, see ref.

  Returns: LineSearchResults
  """
    def restricted_func_and_grad(t):
        phi, g = jax.value_and_grad(f)(xk + t * pk)
        dphi = jnp.dot(g, pk)
        return phi, dphi, g

    if old_fval is None or gfk is None:
        phi_0, dphi_0, gfk = restricted_func_and_grad(0.)
    else:
        phi_0 = old_fval
        dphi_0 = jnp.dot(gfk, pk)

    def wolfe_one(a_i, phi_i):
        # actually negation of W1
        return phi_i > phi_0 + c1 * a_i * dphi_0

    def wolfe_two(dphi_i):
        return jnp.abs(dphi_i) <= -c2 * dphi_0

    state = _LineSearchState(
        done=False,
        failed=False,
        # algorithm begins at 1 as per Wright and Nocedal, however Scipy has a
        # bug and starts at 0. See https://github.com/scipy/scipy/issues/12157
        i=1,
        a_i1=0.,
        phi_i1=phi_0,
        dphi_i1=dphi_0,
        nfev=1 if (old_fval is None or gfk is None) else 0,
        ngev=1 if (old_fval is None or gfk is None) else 0,
        a_star=0.,
        phi_star=phi_0,
        dphi_star=dphi_0,
        g_star=gfk,
        saddle_point=False,
    )

    def body(state):
        # no amax in this version, we just double as in scipy.
        # unlike original algorithm we do our next choice at the start of this loop
        a_i = jnp.where(state.i == 1, 1., state.a_i1 * 2.)
        # if a_i <= 0 then something went wrong. In practice any really small step
        # length is a failure. Likely means the search pk is not good, perhaps we
        # are at a saddle point.
        saddle_point = a_i < 1e-5
        state = state._replace(failed=saddle_point, saddle_point=saddle_point)

        phi_i, dphi_i, g_i = restricted_func_and_grad(a_i)
        state = state._replace(nfev=state.nfev + 1, ngev=state.ngev + 1)

        star_to_zoom1 = wolfe_one(a_i, phi_i) | ((phi_i >= state.phi_i1) &
                                                 (state.i > 1))
        star_to_i = wolfe_two(dphi_i) & (~star_to_zoom1)
        star_to_zoom2 = (dphi_i >= 0.) & (~star_to_zoom1) & (~star_to_i)

        zoom1 = _zoom(restricted_func_and_grad, wolfe_one, wolfe_two,
                      state.a_i1, state.phi_i1, state.dphi_i1, a_i, phi_i,
                      dphi_i, gfk, ~star_to_zoom1)

        state = state._replace(nfev=state.nfev + zoom1.nfev,
                               ngev=state.ngev + zoom1.ngev)

        zoom2 = _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_i,
                      phi_i, dphi_i, state.a_i1, state.phi_i1, state.dphi_i1,
                      gfk, ~star_to_zoom2)

        state = state._replace(nfev=state.nfev + zoom2.nfev,
                               ngev=state.ngev + zoom2.ngev)

        state = state._replace(
            done=star_to_zoom1 | state.done,
            failed=(star_to_zoom1 & zoom1.failed) | state.failed,
            **_binary_replace(
                star_to_zoom1,
                state._asdict(),
                zoom1._asdict(),
                keys=['a_star', 'phi_star', 'dphi_star', 'g_star'],
            ),
        )
        state = state._replace(
            done=star_to_i | state.done,
            **_binary_replace(
                star_to_i,
                state._asdict(),
                dict(
                    a_star=a_i,
                    phi_star=phi_i,
                    dphi_star=dphi_i,
                    g_star=g_i,
                ),
            ),
        )
        state = state._replace(
            done=star_to_zoom2 | state.done,
            failed=(star_to_zoom2 & zoom2.failed) | state.failed,
            **_binary_replace(
                star_to_zoom2,
                state._asdict(),
                zoom2._asdict(),
                keys=['a_star', 'phi_star', 'dphi_star', 'g_star'],
            ),
        )
        state = state._replace(i=state.i + 1,
                               a_i1=a_i,
                               phi_i1=phi_i,
                               dphi_i1=dphi_i)
        return state

    state = while_loop(
        lambda state: (~state.done) & (state.i <= maxiter) & (~state.failed),
        body, state)

    status = jnp.where(
        state.failed & (~state.saddle_point),
        jnp.array(1),  # zoom failed
        jnp.where(
            state.failed & state.saddle_point,
            jnp.array(2),  # saddle point reached,
            jnp.where(
                state.i > maxiter,
                jnp.array(3),  # maxiter reached
                jnp.array(0),  # passed (should be)
            ),
        ),
    )
    results = _LineSearchResults(
        failed=state.failed | (~state.done),
        nit=state.i - 1,  # because iterations started at 1
        nfev=state.nfev,
        ngev=state.ngev,
        k=state.i,
        a_k=state.a_star,
        f_k=state.phi_star,
        g_k=state.g_star,
        status=status,
    )
    return results
Example #19
0
def combine_svd(u, s, vT):
    return np.dot(u, np.dot(np.diag(s), vT))
Example #20
0
 def predict(params, inputs):
     for W, b in params:
         outputs = np.dot(W, inputs) + b
         inputs = np.tanh(outputs)
     return outputs
Example #21
0
 def _state_init_body_fn(current_arr, _):
     new_arr = jnp.dot(current_arr, current_arr, precision=precision)
     return new_arr, new_arr
Example #22
0
 def jloss(wb, x, y):
     w, b = wb
     pred = jn.dot(x, w) + b
     return 0.5 * ((y - pred)**2).mean()
Example #23
0
 def body(state):
     n, z, r = state
     z = jnp.dot(z, z, precision=precision)
     n, bit = jnp.divmod(n, 2)
     r = jnp.where(bit, jnp.dot(z, r, precision=precision), r)
     return n, z, r
Example #24
0
 def loss(x, y):
     pred = jn.dot(x, w.value) + b.value
     return 0.5 * ((y - pred)**2).mean()
Example #25
0
 def step(_, mat):
     return jnp.dot(mat, mat, precision=jax.lax.Precision.HIGHEST)
Example #26
0
def A(*X):
    dark = np.dot(a, np.vstack([*X])) + p0
    return (dark[0, :], dark[1, :])
Example #27
0
def dot(X, Z):
    return jnp.dot(X, Z[..., None])[..., 0]
Example #28
0
def phi_tf(t, X, Y, Z):  # M x 1, M x D, M x 1, M x D
    return 0.05 * (Y - jnp.dot(X, Z))  # M x 1
Example #29
0
 def __call__(self, x):
     return jn.dot(x, self.v1.value)
Example #30
0
def predict(params, inputs):
  for W, b in params:
    outputs = np.dot(inputs, W) + b
    inputs = np.tanh(outputs)
  return outputs
Example #31
0
 def step(h, x):
     new_h = np.tanh(np.dot(W, h) + np.dot(W, x))
     return new_h, ()