def __init__( self, obs_spec: specs.Array, action_spec: specs.Array, network: 'PolicyValueNet', optimizer: snt.Optimizer, max_sequence_length: int, td_lambda: float, discount: float, seed: int, ): """A simple actor-critic agent.""" # Internalise hyperparameters. tf.random.set_seed(seed) self._td_lambda = td_lambda self._discount = discount # Internalise network and optimizer. self._network = network self._optimizer = optimizer # Create windowed buffer for learning from trajectories. self._buffer = sequence.Buffer(obs_spec, action_spec, max_sequence_length)
def __init__( self, obs_spec: specs.Array, action_spec: specs.DiscreteArray, network: PolicyValueNet, optimizer: optax.GradientTransformation, rng: hk.PRNGSequence, sequence_length: int, discount: float, td_lambda: float, ): # Define loss function. def loss(trajectory: sequence.Trajectory) -> jnp.ndarray: """"Actor-critic loss.""" logits, values = network(trajectory.observations) td_errors = rlax.td_lambda( v_tm1=values[:-1], r_t=trajectory.rewards, discount_t=trajectory.discounts * discount, v_t=values[1:], lambda_=jnp.array(td_lambda), ) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=trajectory.actions, adv_t=td_errors, w_t=jnp.ones_like(td_errors)) return actor_loss + critic_loss # Transform the loss into a pure function. loss_fn = hk.without_apply_rng(hk.transform(loss, apply_rng=True)).apply # Define update function. @jax.jit def sgd_step(state: TrainingState, trajectory: sequence.Trajectory) -> TrainingState: """Does a step of SGD over a trajectory.""" gradients = jax.grad(loss_fn)(state.params, trajectory) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) return TrainingState(params=new_params, opt_state=new_opt_state) # Initialize network parameters and optimiser state. init, forward = hk.without_apply_rng( hk.transform(network, apply_rng=True)) dummy_observation = jnp.zeros((1, *obs_spec.shape), dtype=jnp.float32) initial_params = init(next(rng), dummy_observation) initial_opt_state = optimizer.init(initial_params) # Internalize state. self._state = TrainingState(initial_params, initial_opt_state) self._forward = jax.jit(forward) self._buffer = sequence.Buffer(obs_spec, action_spec, sequence_length) self._sgd_step = sgd_step self._rng = rng
def __init__( self, obs_spec: specs.Array, action_spec: specs.Array, network: 'PolicyValueRNN', optimizer: snt.Optimizer, max_sequence_length: int, td_lambda: float, discount: float, seed: int, entropy_cost: float = 0., ): """A recurrent actor-critic agent.""" # Internalise network and optimizer. self._forward = tf.function(network) self._network = network self._optimizer = optimizer # Initialise recurrent state. self._state = network.initial_state(1) self._rollout_initial_state = network.initial_state(1) # Set seed and internalise hyperparameters. tf.random.set_seed(seed) self._discount = discount self._td_lambda = td_lambda self._entropy_cost = entropy_cost # Initialise rolling experience buffer. self._buffer = sequence.Buffer(obs_spec, action_spec, max_sequence_length)
def test_buffer(self): # Given a buffer and some dummy data... max_sequence_length = 10 obs_shape = (3, 3) buffer = sequence.Buffer( obs_spec=specs.Array(obs_shape, dtype=np.float), action_spec=specs.Array((), dtype=np.int), max_sequence_length=max_sequence_length) dummy_step = dm_env.transition(observation=np.zeros(obs_shape), reward=0.) # If we add `max_sequence_length` items to the buffer... for _ in range(max_sequence_length): buffer.append(dummy_step, 0, dummy_step) # Then the buffer should now be full. self.assertTrue(buffer.full()) # Any further appends should throw an error. with self.assertRaises(ValueError): buffer.append(dummy_step, 0, dummy_step) # If we now drain this trajectory from the buffer... trajectory = buffer.drain() # The `observations` sequence should have length `T + 1`. self.assertLen(trajectory.observations, max_sequence_length + 1) # All other sequences should have length `T`. self.assertLen(trajectory.actions, max_sequence_length) self.assertLen(trajectory.rewards, max_sequence_length) self.assertLen(trajectory.discounts, max_sequence_length) # The buffer should now be empty. self.assertTrue(buffer.empty()) # A second call to drain() should throw an error, since the buffer is empty. with self.assertRaises(ValueError): buffer.drain() # If we now append another transition... buffer.append(dummy_step, 0, dummy_step) # And immediately drain the buffer... trajectory = buffer.drain() # We should have a valid partial trajectory of length T=1. self.assertLen(trajectory.observations, 2) self.assertLen(trajectory.actions, 1) self.assertLen(trajectory.rewards, 1) self.assertLen(trajectory.discounts, 1)
def __init__( self, obs_spec: specs.Array, action_spec: specs.DiscreteArray, network: RecurrentPolicyValueNet, initial_rnn_state: LSTMState, optimizer: optix.InitUpdate, rng: hk.PRNGSequence, sequence_length: int, discount: float, td_lambda: float, entropy_cost: float = 0., ): # Define loss function. def loss(trajectory: sequence.Trajectory, rnn_unroll_state: LSTMState): """"Actor-critic loss.""" (logits, values), new_rnn_unroll_state = hk.dynamic_unroll( network, trajectory.observations[:, None, ...], rnn_unroll_state) seq_len = trajectory.actions.shape[0] td_errors = rlax.td_lambda( v_tm1=values[:-1, 0], r_t=trajectory.rewards, discount_t=trajectory.discounts * discount, v_t=values[1:, 0], lambda_=jnp.array(td_lambda), ) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1, 0], a_t=trajectory.actions, adv_t=td_errors, w_t=jnp.ones(seq_len)) entropy_loss = jnp.mean( rlax.entropy_loss(logits[:-1, 0], jnp.ones(seq_len))) combined_loss = actor_loss + critic_loss + entropy_cost * entropy_loss return combined_loss, new_rnn_unroll_state # Transform the loss into a pure function. loss_fn = hk.without_apply_rng(hk.transform(loss, apply_rng=True)).apply # Define update function. @jax.jit def sgd_step(state: AgentState, trajectory: sequence.Trajectory) -> AgentState: """Does a step of SGD over a trajectory.""" gradients, new_rnn_state = jax.grad( loss_fn, has_aux=True)(state.params, trajectory, state.rnn_unroll_state) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) return state._replace( params=new_params, opt_state=new_opt_state, rnn_unroll_state=new_rnn_state) # Initialize network parameters and optimiser state. init, forward = hk.without_apply_rng(hk.transform(network, apply_rng=True)) dummy_observation = jnp.zeros((1, *obs_spec.shape), dtype=obs_spec.dtype) initial_params = init(next(rng), dummy_observation, initial_rnn_state) initial_opt_state = optimizer.init(initial_params) # Internalize state. self._state = AgentState(initial_params, initial_opt_state, initial_rnn_state, initial_rnn_state) self._forward = jax.jit(forward) self._buffer = sequence.Buffer(obs_spec, action_spec, sequence_length) self._sgd_step = sgd_step self._rng = rng self._initial_rnn_state = initial_rnn_state