예제 #1
0
파일: fit.py 프로젝트: gehring/neos
 def bestfit_via_grad_descent(param):  # gradient descent
     _, np = param[0], param[1:]
     g = jax.grad(cnll)(np)
     updates, _ = gradient_descent.update(g, gradient_descent.init(np))
     np = optix.apply_updates(np, updates)
     param = jax.numpy.concatenate([jax.numpy.asarray([mu]), np])
     return param
예제 #2
0
 def update(params: hk.Params, opt_state: OptState,
            batch: Batch) -> Tuple[hk.Params, OptState]:
     """Learning rule (stochastic gradient descent)."""
     grads = jax.grad(loss)(params, batch)
     updates, opt_state = opt.update(grads, opt_state)
     new_params = optix.apply_updates(params, updates)
     return new_params, opt_state
예제 #3
0
파일: ddpg.py 프로젝트: chisarie/jax-agents
 def train_step(self, data_batch, algo_func, algo_state):
     """Update all functions."""
     # Update Pi
     pi_grads = jax.grad(_pi_loss_batched)(algo_state.pi_params, data_batch,
                                           algo_func, algo_state)
     pi_updates, new_pi_opt_state = algo_func.pi_optimizer.update(
         pi_grads, algo_state.pi_opt_state)
     new_pi_params = optix.apply_updates(algo_state.pi_params, pi_updates)
     # Update Q
     q_grads = jax.grad(_q_loss_batched)(algo_state.q_params, data_batch,
                                         algo_func, algo_state)
     q_updates, new_q_opt_state = algo_func.q_optimizer.update(
         q_grads, algo_state.q_opt_state)
     new_q_params = optix.apply_updates(algo_state.q_params, q_updates)
     return DDPGState(new_pi_params, new_q_params, new_pi_opt_state,
                      new_q_opt_state)
예제 #4
0
 def sgd_step(state: TrainingState,
              trajectory: sequence.Trajectory) -> TrainingState:
   """Does a step of SGD over a trajectory."""
   gradients = jax.grad(loss_fn)(state.params, trajectory)
   updates, new_opt_state = optimizer.update(gradients, state.opt_state)
   new_params = optix.apply_updates(state.params, updates)
   return TrainingState(params=new_params, opt_state=new_opt_state)
예제 #5
0
    def sgd_step(
        state: TrainingState,
        samples: reverb.ReplaySample) -> Tuple[TrainingState, LearnerOutputs]:
      grad_fn = jax.grad(loss, has_aux=True)
      gradients, (keys, priorities) = grad_fn(state.params, state.target_params,
                                              samples)
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optix.apply_updates(state.params, updates)

      steps = state.steps + 1

      # Periodically update target networks.
      target_params = utils.update_periodically(steps,
                                                self._target_update_period,
                                                new_params, state.target_params)

      new_state = TrainingState(
          params=new_params,
          target_params=target_params,
          opt_state=new_opt_state,
          steps=steps)

      outputs = LearnerOutputs(keys=keys, priorities=priorities)

      return new_state, outputs
예제 #6
0
 def jitted_update(params: hk.Params, opt_state: OptState,
                   x: jnp.ndarray,
                   y: jnp.ndarray) -> Tuple[hk.Params, OptState]:
     grads = jax.grad(self.loss_function)(params, x, y)
     updates, opt_state = self.optimizer.update(grads, opt_state)
     params = optix.apply_updates(params, updates)
     return params, opt_state
예제 #7
0
파일: train.py 프로젝트: tirkarthi/dm-haiku
def update(state: TrainingState, batch: dataset.Batch) -> TrainingState:
    """Does a step of SGD given inputs & targets."""
    _, optimizer = optix.adam(FLAGS.learning_rate)
    _, loss_fn = hk.without_apply_rng(hk.transform(sequence_loss))
    gradients = jax.grad(loss_fn)(state.params, batch)
    updates, new_opt_state = optimizer(gradients, state.opt_state)
    new_params = optix.apply_updates(state.params, updates)
    return TrainingState(params=new_params, opt_state=new_opt_state)
예제 #8
0
 def update_actor(self, actor_params: hk.Params, critic_params: hk.Params,
                  actor_opt_state: OptState,
                  state: np.ndarray) -> Tuple[hk.Params, OptState]:
     """Learning rule (stochastic gradient descent)."""
     _, gradient = jax.value_and_grad(self.actor_loss)(actor_params,
                                                       critic_params, state)
     updates, opt_state = self.actor_opt_update(gradient, actor_opt_state)
     new_params = optix.apply_updates(actor_params, updates)
     return new_params, opt_state
예제 #9
0
 def update(params: hk.Params, state: State, opt_state: OptState,
            batch: Batch) -> Tuple[hk.Params, State, OptState]:
     """Update the params."""
     (_, new_state), grads = jax.value_and_grad(train_loss,
                                                has_aux=True)(params, state,
                                                              batch)
     updates, new_opt_state = optimizer.update(grads, opt_state)
     new_params = optix.apply_updates(params, updates)
     return new_params, new_state, new_opt_state
예제 #10
0
 def update(
     self,
     params: hk.Params,
     opt_state: OptState,
     trajs: Transition,
 ) -> Tuple[hk.Params, OptState]:
   g = jax.grad(self._agent.loss)(params, trajs)
   updates, new_opt_state = self._opt_update(g, opt_state)
   return optix.apply_updates(params, updates), new_opt_state
예제 #11
0
        def train_step(carry, batch_indices):
            batch = jax.tree_map(lambda x: x[batch_indices], train_set)
            params_, net_state_, opt_state_ = carry
            loss, grad, net_state_ = _perdevice_log_prob_and_grad(
                batch, params_, net_state_)
            grad = jax.lax.psum(grad, axis_name='i')

            updates, opt_state_ = optimizer.update(grad, opt_state_)
            params_ = optix.apply_updates(params_, updates)
            return (params_, net_state_, opt_state_), loss
예제 #12
0
 def update(
     params: hk.Params,
     rng_key: PRNGKey,
     opt_state: OptState,
     batch: Batch,
 ) -> Tuple[hk.Params, OptState]:
     """Single SGD update step."""
     grads = jax.grad(loss_fn)(params, rng_key, batch)
     updates, new_opt_state = optimizer.update(grads, opt_state)
     new_params = optix.apply_updates(params, updates)
     return new_params, new_opt_state
예제 #13
0
 def learner_step(self, params, data, learner_state, unused_key):
     is_update_time = (learner_state.count % self._target_period == 0)
     target_params = rlax.periodic_update(params.online, params.target,
                                          is_update_time)
     dloss_dtheta = jax.grad(self._loss)(params.online, target_params,
                                         *data)
     updates, opt_state = self._optimizer.update(dloss_dtheta,
                                                 learner_state.opt_state)
     online_params = optix.apply_updates(params.online, updates)
     return (Params(online_params, target_params),
             LearnerState(learner_state.count + 1, opt_state))
예제 #14
0
파일: 2_rtg_pg.py 프로젝트: joaogui1/RL
    def update(
        params: hk.Params, opt_state: OptState, batch_obs: jnp.DeviceArray,
        batch_acts: jnp.DeviceArray, batch_returns: jnp.DeviceArray
    ) -> Tuple[hk.Params, OptState, jnp.DeviceArray]:

        batch_loss, g = jax.value_and_grad(compute_loss)(params, batch_obs,
                                                         batch_acts,
                                                         batch_returns)
        updates, opt_state = opt_update(g, opt_state)
        new_params = optix.apply_updates(params, updates)
        return new_params, opt_state, batch_loss
 def update(
     params: Params,
     opt_state: OptState,
     inputs: np.ndarray,
     targets: np.ndarray,
 ) -> Tuple[Params, OptState]:
     """Learning rule (stochastic gradient descent)."""
     _, gradient = jax.value_and_grad(loss)(params, inputs, targets)
     updates, opt_state = opt.update(gradient, opt_state)
     new_params = optix.apply_updates(params, updates)
     return new_params, opt_state
예제 #16
0
파일: agent.py 프로젝트: stefanjuang/bsuite
 def sgd_step(state: AgentState,
              trajectory: sequence.Trajectory) -> AgentState:
     """Does a step of SGD over a trajectory."""
     gradients, new_rnn_state = jax.grad(loss_fn, has_aux=True)(
         state.params, trajectory, state.rnn_unroll_state)
     updates, new_opt_state = optimizer.update(gradients,
                                               state.opt_state)
     new_params = optix.apply_updates(state.params, updates)
     return state._replace(params=new_params,
                           opt_state=new_opt_state,
                           rnn_unroll_state=new_rnn_state)
예제 #17
0
    def update(net_params, opt_state, obs_tm1, a_tm1, r_t, discount_t, q_t):
        """Update network weights wrt Q-learning loss."""
        def q_learning_loss(net_params, obs_tm1, a_tm1, r_t, discount_t, q_t):
            q_tm1 = network.apply(net_params, obs_tm1)
            td_error = rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t)
            return rlax.l2_loss(td_error)

        dloss_dtheta = jax.grad(q_learning_loss)(net_params, obs_tm1, a_tm1,
                                                 r_t, discount_t, q_t)
        updates, opt_state = optimizer.update(dloss_dtheta, opt_state)
        net_params = optix.apply_updates(net_params, updates)
        return net_params, opt_state
예제 #18
0
    def train_step(params, opt_state, batch):
        """Train for a single step."""
        value_and_grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = value_and_grad_fn(params, batch)

        # Note this is not the usual optix api as we additionally need parameter
        # values.
        # updates, opt_state = opt.update(grad, opt_state)
        updates, opt_state = opt.update_with_params(grad, params, opt_state)

        new_params = optix.apply_updates(params, updates)
        return new_params, opt_state, loss
예제 #19
0
        def sgd_step(state: TrainingState,
                     transitions: Sequence[jnp.ndarray]) -> TrainingState:
            """Performs an SGD step on a batch of transitions."""
            gradients = jax.grad(loss)(state.params, state.target_params,
                                       transitions)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optix.apply_updates(state.params, updates)

            return TrainingState(params=new_params,
                                 target_params=state.target_params,
                                 opt_state=new_opt_state,
                                 step=state.step + 1)
예제 #20
0
파일: optix_test.py 프로젝트: yxd886/jax
    def test_apply_every(self):
        # The frequency of the application of sgd
        k = 4
        zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.]))

        # experimental/optix.py sgd
        optix_sgd_params = self.init_params
        sgd = optix.sgd(LR, 0.0)
        state_sgd = sgd.init(optix_sgd_params)

        # experimental/optix.py sgd apply every
        optix_sgd_apply_every_params = self.init_params
        sgd_apply_every = optix.chain(optix.apply_every(k=k),
                                      optix.trace(decay=0, nesterov=False),
                                      optix.scale(-LR))
        state_sgd_apply_every = sgd_apply_every.init(
            optix_sgd_apply_every_params)
        for i in range(STEPS):
            # Apply a step of sgd
            updates_sgd, state_sgd = sgd.update(self.per_step_updates,
                                                state_sgd)
            optix_sgd_params = optix.apply_updates(optix_sgd_params,
                                                   updates_sgd)

            # Apply a step of sgd_apply_every
            updates_sgd_apply_every, state_sgd_apply_every = sgd_apply_every.update(
                self.per_step_updates, state_sgd_apply_every)
            optix_sgd_apply_every_params = optix.apply_updates(
                optix_sgd_apply_every_params, updates_sgd_apply_every)
            if i % k == k - 1:
                # Check equivalence.
                for x, y in zip(tree_leaves(optix_sgd_apply_every_params),
                                tree_leaves(optix_sgd_params)):
                    np.testing.assert_allclose(x, y, atol=1e-6, rtol=100)
            else:
                # Check updaue is zero.
                for x, y in zip(tree_leaves(updates_sgd_apply_every),
                                tree_leaves(zero_update)):
                    np.testing.assert_allclose(x, y, atol=1e-10, rtol=1e-5)
예제 #21
0
 def update_critic(self, critic_params: hk.Params,
                   target_critic_params: hk.Params,
                   target_actor_params: hk.Params,
                   critic_opt_state: OptState, state: np.ndarray,
                   action: np.ndarray, next_state: np.ndarray,
                   reward: np.ndarray, not_done: np.ndarray,
                   rng: jnp.ndarray) -> Tuple[hk.Params, OptState]:
     """Learning rule (stochastic gradient descent)."""
     _, gradient = jax.value_and_grad(self.critic_loss)(
         critic_params, target_critic_params, target_actor_params, state,
         action, next_state, reward, not_done, rng)
     updates, opt_state = self.critic_opt_update(gradient, critic_opt_state)
     new_params = optix.apply_updates(critic_params, updates)
     return new_params, opt_state
예제 #22
0
        def sgd_step(state: TrainingState,
                     transitions: Sequence[jnp.ndarray]) -> TrainingState:
            """Does a step of SGD for the whole ensemble over `transitions`."""

            gradients = jax.grad(loss)(state.params, state.target_params,
                                       transitions)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optix.apply_updates(state.params, updates)

            return TrainingState(params=new_params,
                                 target_params=state.target_params,
                                 opt_state=new_opt_state,
                                 step=state.step + 1)
예제 #23
0
  def update(self, params, opt_state, batch: util.Transition):
    """The actual update function."""
    (_, logs), grads = jax.value_and_grad(
        self._loss, has_aux=True)(params, batch)

    grad_norm_unclipped = optimizers.l2_norm(grads)
    updates, updated_opt_state = self._opt.update(grads, opt_state)
    params = optix.apply_updates(params, updates)
    weight_norm = optimizers.l2_norm(params)
    logs.update({
        'grad_norm_unclipped': grad_norm_unclipped,
        'weight_norm': weight_norm,
    })
    return params, updated_opt_state, logs
예제 #24
0
    def update(net_params, target_params, opt_state, batch):
        """Update network weights wrt Q-learning loss."""
        def dqn_learning_loss(net_params, target_params, batch):
            obs_tm1, obs_t, a_tm1, r_t, discount_t = batch
            q_tm1 = network.apply(net_params, obs_tm1)
            q_t_value = network.apply(target_params, obs_t)
            q_t_selector = network.apply(net_params, obs_t)

            td_error = batched_loss(q_tm1, a_tm1, r_t, discount_t, q_t_value,
                                    q_t_selector)
            return jnp.mean(rlax.l2_loss(td_error))

        loss, dloss_dtheta = jax.value_and_grad(dqn_learning_loss)(
            net_params, target_params, batch)
        updates, opt_state = optimizer.update(dloss_dtheta, opt_state)
        net_params = optix.apply_updates(net_params, updates)
        return net_params, opt_state, loss
예제 #25
0
    def update(self, state, data):
        """Updates the state using some data and returns metrics."""
        rng, new_rng = jax.random.split(state['rng'])
        params = state['params']
        loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)

        updates, opt_state = self._opt.update(g, state['opt_state'])
        params = optix.apply_updates(params, updates)

        new_state = dict(
            step=state['step'] + 1,
            rng=new_rng,
            opt_state=opt_state,
            params=params,
        )
        metrics = dict(step=state['step'], loss=loss)
        return new_state, metrics
예제 #26
0
        def sgd_step(
            state: TrainingState, samples: reverb.ReplaySample
        ) -> Tuple[TrainingState, LearnerOutputs]:
            grad_fn = jax.grad(loss, has_aux=True)
            gradients, (keys, priorities) = grad_fn(state.params,
                                                    state.target_params,
                                                    samples)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optix.apply_updates(state.params, updates)

            new_state = TrainingState(params=new_params,
                                      target_params=state.target_params,
                                      opt_state=new_opt_state,
                                      step=state.step + 1)

            outputs = LearnerOutputs(keys=keys, priorities=priorities)

            return new_state, outputs
예제 #27
0
파일: optim.py 프로젝트: winston-ds/rljax
def optimize(
    fn_loss: Any,
    opt: Any,
    opt_state: Any,
    params_to_update: hk.Params,
    max_grad_norm: float or None,
    *args,
    **kwargs,
) -> Tuple[Any, hk.Params, jnp.ndarray, Any]:
    (loss, aux), grad = jax.value_and_grad(fn_loss, has_aux=True)(
        params_to_update,
        *args,
        **kwargs,
    )
    if max_grad_norm is not None:
        grad = clip_gradient_norm(grad, max_grad_norm)
    update, opt_state = opt(grad, opt_state)
    params_to_update = optix.apply_updates(params_to_update, update)
    return opt_state, params_to_update, loss, aux
예제 #28
0
파일: learning.py 프로젝트: WADRHAW/acme
        def sgd_step(state: TrainingState, sample: reverb.ReplaySample):
            # Compute gradients and optionally apply clipping.
            batch_loss = jax.vmap(loss, in_axes=(None, 0))
            mean_loss = lambda p, s: jnp.mean(batch_loss(p, s))
            grad_fn = jax.value_and_grad(mean_loss)
            loss_value, gradients = grad_fn(state.params, sample)

            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optix.apply_updates(state.params, updates)

            metrics = {
                'loss': loss_value,
            }

            new_state = TrainingState(params=new_params,
                                      opt_state=new_opt_state)

            return new_state, metrics
예제 #29
0
파일: optix_test.py 프로젝트: yxd886/jax
    def test_sgd(self):

        # experimental/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = optimizers.sgd(LR)
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # experimental/optix.py
        optix_params = self.init_params
        sgd = optix.sgd(LR, 0.0)
        state = sgd.init(optix_params)
        for _ in range(STEPS):
            updates, state = sgd.update(self.per_step_updates, state)
            optix_params = optix.apply_updates(optix_params, updates)

        # Check equivalence.
        for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)):
            np.testing.assert_allclose(x, y, rtol=1e-5)
예제 #30
0
 def inner_loop(rln_params, pln_params, x_spt, y_spt, opt_state):
     for i, (_x, _y) in enumerate(zip(x_spt, y_spt)):
         (loss, acc), grads = value_and_grad(loss_acc_fn, 1,
                                             has_aux=True)(rln_params,
                                                           pln_params, _x,
                                                           _y)
         if i == 0:
             initial_loss = loss
             initial_acc = acc
         updates, opt_state = opt_update_fn(grads, opt_state, pln_params)
         pln_params = optix.apply_updates(pln_params, updates)
     return (
         pln_params,
         {
             "initial_loss": initial_loss,
             "initial_acc": initial_acc,
             "final_loss": loss,
             "final_acc": acc,
             "opt_state": opt_state,
         },
     )