Example #1
0
    def __init__(
        self,
        obs_spec: specs.Array,
        action_spec: specs.Array,
        network: 'PolicyValueNet',
        optimizer: snt.Optimizer,
        max_sequence_length: int,
        td_lambda: float,
        discount: float,
        seed: int,
    ):
        """A simple actor-critic agent."""

        # Internalise hyperparameters.
        tf.random.set_seed(seed)
        self._td_lambda = td_lambda
        self._discount = discount

        # Internalise network and optimizer.
        self._network = network
        self._optimizer = optimizer

        # Create windowed buffer for learning from trajectories.
        self._buffer = sequence.Buffer(obs_spec, action_spec,
                                       max_sequence_length)
Example #2
0
    def __init__(
        self,
        obs_spec: specs.Array,
        action_spec: specs.DiscreteArray,
        network: PolicyValueNet,
        optimizer: optax.GradientTransformation,
        rng: hk.PRNGSequence,
        sequence_length: int,
        discount: float,
        td_lambda: float,
    ):

        # Define loss function.
        def loss(trajectory: sequence.Trajectory) -> jnp.ndarray:
            """"Actor-critic loss."""
            logits, values = network(trajectory.observations)
            td_errors = rlax.td_lambda(
                v_tm1=values[:-1],
                r_t=trajectory.rewards,
                discount_t=trajectory.discounts * discount,
                v_t=values[1:],
                lambda_=jnp.array(td_lambda),
            )
            critic_loss = jnp.mean(td_errors**2)
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1],
                a_t=trajectory.actions,
                adv_t=td_errors,
                w_t=jnp.ones_like(td_errors))

            return actor_loss + critic_loss

        # Transform the loss into a pure function.
        loss_fn = hk.without_apply_rng(hk.transform(loss,
                                                    apply_rng=True)).apply

        # Define update function.
        @jax.jit
        def sgd_step(state: TrainingState,
                     trajectory: sequence.Trajectory) -> TrainingState:
            """Does a step of SGD over a trajectory."""
            gradients = jax.grad(loss_fn)(state.params, trajectory)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)
            return TrainingState(params=new_params, opt_state=new_opt_state)

        # Initialize network parameters and optimiser state.
        init, forward = hk.without_apply_rng(
            hk.transform(network, apply_rng=True))
        dummy_observation = jnp.zeros((1, *obs_spec.shape), dtype=jnp.float32)
        initial_params = init(next(rng), dummy_observation)
        initial_opt_state = optimizer.init(initial_params)

        # Internalize state.
        self._state = TrainingState(initial_params, initial_opt_state)
        self._forward = jax.jit(forward)
        self._buffer = sequence.Buffer(obs_spec, action_spec, sequence_length)
        self._sgd_step = sgd_step
        self._rng = rng
Example #3
0
  def __init__(
      self,
      obs_spec: specs.Array,
      action_spec: specs.Array,
      network: 'PolicyValueRNN',
      optimizer: snt.Optimizer,
      max_sequence_length: int,
      td_lambda: float,
      discount: float,
      seed: int,
      entropy_cost: float = 0.,
  ):
    """A recurrent actor-critic agent."""

    # Internalise network and optimizer.
    self._forward = tf.function(network)
    self._network = network
    self._optimizer = optimizer

    # Initialise recurrent state.
    self._state = network.initial_state(1)
    self._rollout_initial_state = network.initial_state(1)

    # Set seed and internalise hyperparameters.
    tf.random.set_seed(seed)
    self._discount = discount
    self._td_lambda = td_lambda
    self._entropy_cost = entropy_cost

    # Initialise rolling experience buffer.
    self._buffer = sequence.Buffer(obs_spec, action_spec, max_sequence_length)
Example #4
0
  def test_buffer(self):
    # Given a buffer and some dummy data...
    max_sequence_length = 10
    obs_shape = (3, 3)
    buffer = sequence.Buffer(
        obs_spec=specs.Array(obs_shape, dtype=np.float),
        action_spec=specs.Array((), dtype=np.int),
        max_sequence_length=max_sequence_length)
    dummy_step = dm_env.transition(observation=np.zeros(obs_shape), reward=0.)

    # If we add `max_sequence_length` items to the buffer...
    for _ in range(max_sequence_length):
      buffer.append(dummy_step, 0, dummy_step)

    # Then the buffer should now be full.
    self.assertTrue(buffer.full())

    # Any further appends should throw an error.
    with self.assertRaises(ValueError):
      buffer.append(dummy_step, 0, dummy_step)

    # If we now drain this trajectory from the buffer...
    trajectory = buffer.drain()

    # The `observations` sequence should have length `T + 1`.
    self.assertLen(trajectory.observations, max_sequence_length + 1)

    # All other sequences should have length `T`.
    self.assertLen(trajectory.actions, max_sequence_length)
    self.assertLen(trajectory.rewards, max_sequence_length)
    self.assertLen(trajectory.discounts, max_sequence_length)

    # The buffer should now be empty.
    self.assertTrue(buffer.empty())

    # A second call to drain() should throw an error, since the buffer is empty.
    with self.assertRaises(ValueError):
      buffer.drain()

    # If we now append another transition...
    buffer.append(dummy_step, 0, dummy_step)

    # And immediately drain the buffer...
    trajectory = buffer.drain()

    # We should have a valid partial trajectory of length T=1.
    self.assertLen(trajectory.observations, 2)
    self.assertLen(trajectory.actions, 1)
    self.assertLen(trajectory.rewards, 1)
    self.assertLen(trajectory.discounts, 1)
Example #5
0
  def __init__(
      self,
      obs_spec: specs.Array,
      action_spec: specs.DiscreteArray,
      network: RecurrentPolicyValueNet,
      initial_rnn_state: LSTMState,
      optimizer: optix.InitUpdate,
      rng: hk.PRNGSequence,
      sequence_length: int,
      discount: float,
      td_lambda: float,
      entropy_cost: float = 0.,
  ):

    # Define loss function.
    def loss(trajectory: sequence.Trajectory, rnn_unroll_state: LSTMState):
      """"Actor-critic loss."""
      (logits, values), new_rnn_unroll_state = hk.dynamic_unroll(
          network, trajectory.observations[:, None, ...], rnn_unroll_state)
      seq_len = trajectory.actions.shape[0]
      td_errors = rlax.td_lambda(
          v_tm1=values[:-1, 0],
          r_t=trajectory.rewards,
          discount_t=trajectory.discounts * discount,
          v_t=values[1:, 0],
          lambda_=jnp.array(td_lambda),
      )
      critic_loss = jnp.mean(td_errors**2)
      actor_loss = rlax.policy_gradient_loss(
          logits_t=logits[:-1, 0],
          a_t=trajectory.actions,
          adv_t=td_errors,
          w_t=jnp.ones(seq_len))
      entropy_loss = jnp.mean(
          rlax.entropy_loss(logits[:-1, 0], jnp.ones(seq_len)))

      combined_loss = actor_loss + critic_loss + entropy_cost * entropy_loss

      return combined_loss, new_rnn_unroll_state

    # Transform the loss into a pure function.
    loss_fn = hk.without_apply_rng(hk.transform(loss, apply_rng=True)).apply

    # Define update function.
    @jax.jit
    def sgd_step(state: AgentState,
                 trajectory: sequence.Trajectory) -> AgentState:
      """Does a step of SGD over a trajectory."""
      gradients, new_rnn_state = jax.grad(
          loss_fn, has_aux=True)(state.params, trajectory,
                                 state.rnn_unroll_state)
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optix.apply_updates(state.params, updates)
      return state._replace(
          params=new_params,
          opt_state=new_opt_state,
          rnn_unroll_state=new_rnn_state)

    # Initialize network parameters and optimiser state.
    init, forward = hk.without_apply_rng(hk.transform(network, apply_rng=True))
    dummy_observation = jnp.zeros((1, *obs_spec.shape), dtype=obs_spec.dtype)
    initial_params = init(next(rng), dummy_observation, initial_rnn_state)
    initial_opt_state = optimizer.init(initial_params)

    # Internalize state.
    self._state = AgentState(initial_params, initial_opt_state,
                             initial_rnn_state, initial_rnn_state)
    self._forward = jax.jit(forward)
    self._buffer = sequence.Buffer(obs_spec, action_spec, sequence_length)
    self._sgd_step = sgd_step
    self._rng = rng
    self._initial_rnn_state = initial_rnn_state