def test_batch_concat(self): batch_size = 32 inputs = [ jnp.zeros(shape=(batch_size, 2)), { 'foo': jnp.zeros(shape=(batch_size, 5, 3)) }, [jnp.zeros(shape=(batch_size, 1))], ] output_shape = utils.batch_concat(inputs).shape expected_shape = [batch_size, 2 + 5 * 3 + 1] self.assertSequenceEqual(output_shape, expected_shape)
def __call__(self, observation: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: # Maybe transform observations and actions before feeding them on. if self._observation_network: observation = self._observation_network(observation) if self._action_network: action = self._action_network(action) # Concat observations and actions, with one batch dimension. outputs = utils.batch_concat([observation, action]) # Maybe transform output before returning. if self._critic_network: outputs = self._critic_network(outputs) return outputs
def apply(self, inputs: jnp.ndarray): inputs = utils.batch_concat(inputs) logits = MLP(inputs, [64, 64, num_actions]) value = MLP(inputs, [64, 64, 1]) value = jnp.squeeze(value, axis=-1) return tfd.Categorical(logits=logits), value
def sgd_step( state: TrainingState, sample: reverb.ReplaySample ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: """Performs a minibatch SGD step, returning new state and metrics.""" # Extract the data. data = sample.data # TODO(sinopalnikov): replace it with namedtuple unpacking observations, actions, rewards, termination, extra = ( data.observation, data.action, data.reward, data.discount, data.extras) discounts = termination * discount behavior_log_probs = extra['log_prob'] def get_behavior_values( params: networks_lib.Params, observations: types.NestedArray) -> jnp.ndarray: o = jax.tree_map( lambda x: jnp.reshape(x, [-1] + list(x.shape[2:])), observations) _, behavior_values = ppo_networks.network.apply(params, o) behavior_values = jnp.reshape(behavior_values, rewards.shape[0:2]) return behavior_values behavior_values = get_behavior_values(state.params, observations) # Vmap over batch dimension batch_gae_advantages = jax.vmap(gae_advantages, in_axes=0) advantages, target_values = batch_gae_advantages( rewards, discounts, behavior_values) # Exclude the last step - it was only used for bootstrapping. # The shape is [num_sequences, num_steps, ..] observations, actions, behavior_log_probs, behavior_values = jax.tree_map( lambda x: x[:, :-1], (observations, actions, behavior_log_probs, behavior_values)) trajectories = Batch(observations=observations, actions=actions, advantages=advantages, behavior_log_probs=behavior_log_probs, target_values=target_values, behavior_values=behavior_values) # Concatenate all trajectories. Reshape from [num_sequences, num_steps,..] # to [num_sequences * num_steps,..] assert len(target_values.shape) > 1 num_sequences = target_values.shape[0] num_steps = target_values.shape[1] batch_size = num_sequences * num_steps assert batch_size % num_minibatches == 0, ( 'Num minibatches must divide batch size. Got batch_size={}' ' num_minibatches={}.').format(batch_size, num_minibatches) batch = jax.tree_map( lambda x: x.reshape((batch_size, ) + x.shape[2:]), trajectories) # Compute gradients. grad_fn = jax.grad(loss, has_aux=True) def model_update_minibatch( carry: Tuple[networks_lib.Params, optax.OptState], minibatch: Batch, ) -> Tuple[Tuple[networks_lib.Params, optax.OptState], Dict[ str, jnp.ndarray]]: """Performs model update for a single minibatch.""" params, opt_state = carry # Normalize advantages at the minibatch level before using them. advantages = ((minibatch.advantages - jnp.mean(minibatch.advantages, axis=0)) / (jnp.std(minibatch.advantages, axis=0) + 1e-8)) gradients, metrics = grad_fn(params, minibatch.observations, minibatch.actions, minibatch.behavior_log_probs, minibatch.target_values, advantages, minibatch.behavior_values) # Apply updates updates, opt_state = optimizer.update(gradients, opt_state) params = optax.apply_updates(params, updates) metrics['norm_grad'] = optax.global_norm(gradients) metrics['norm_updates'] = optax.global_norm(updates) return (params, opt_state), metrics def model_update_epoch( carry: Tuple[jnp.ndarray, networks_lib.Params, optax.OptState, Batch], unused_t: Tuple[()] ) -> Tuple[Tuple[jnp.ndarray, networks_lib.Params, optax.OptState, Batch], Dict[str, jnp.ndarray]]: """Performs model updates based on one epoch of data.""" key, params, opt_state, batch = carry key, subkey = jax.random.split(key) permutation = jax.random.permutation(subkey, batch_size) shuffled_batch = jax.tree_map( lambda x: jnp.take(x, permutation, axis=0), batch) minibatches = jax.tree_map( lambda x: jnp.reshape(x, [num_minibatches, -1] + list( x.shape[1:])), shuffled_batch) (params, opt_state), metrics = jax.lax.scan(model_update_minibatch, (params, opt_state), minibatches, length=num_minibatches) return (key, params, opt_state, batch), metrics params = state.params opt_state = state.opt_state # Repeat training for the given number of epoch, taking a random # permutation for every epoch. (key, params, opt_state, _), metrics = jax.lax.scan( model_update_epoch, (state.random_key, params, opt_state, batch), (), length=num_epochs) metrics = jax.tree_map(jnp.mean, metrics) metrics['norm_params'] = optax.global_norm(params) metrics['observations_mean'] = jnp.mean( utils.batch_concat(jax.tree_map( lambda x: jnp.abs(jnp.mean(x, axis=(0, 1))), observations), num_batch_dims=0)) metrics['observations_std'] = jnp.mean( utils.batch_concat(jax.tree_map( lambda x: jnp.std(x, axis=(0, 1)), observations), num_batch_dims=0)) metrics['rewards_mean'] = jnp.mean( jnp.abs(jnp.mean(rewards, axis=(0, 1)))) metrics['rewards_std'] = jnp.std(rewards, axis=(0, 1)) new_state = TrainingState(params=params, opt_state=opt_state, random_key=key) return new_state, metrics
def update_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: key, key_alpha, key_critic, key_actor = jax.random.split( state.key, 4) if adaptive_entropy_coefficient: alpha_loss, alpha_grads = alpha_grad(state.alpha_params, state.policy_params, transitions, key_alpha) alpha = jnp.exp(state.alpha_params) else: alpha = entropy_coefficient critic_loss, critic_grads = critic_grad(state.q_params, state.policy_params, state.target_q_params, alpha, transitions, key_critic) actor_loss, actor_grads = actor_grad(state.policy_params, state.q_params, alpha, transitions, key_actor) # Apply policy gradients actor_update, policy_optimizer_state = policy_optimizer.update( actor_grads, state.policy_optimizer_state) policy_params = optax.apply_updates(state.policy_params, actor_update) # Apply critic gradients critic_update, q_optimizer_state = q_optimizer.update( critic_grads, state.q_optimizer_state) q_params = optax.apply_updates(state.q_params, critic_update) new_target_q_params = jax.tree_multimap( lambda x, y: x * (1 - tau) + y * tau, state.target_q_params, q_params) metrics = { 'critic_loss': critic_loss, 'actor_loss': actor_loss, } new_state = TrainingState( policy_optimizer_state=policy_optimizer_state, q_optimizer_state=q_optimizer_state, policy_params=policy_params, q_params=q_params, target_q_params=new_target_q_params, key=key, ) if adaptive_entropy_coefficient: # Apply alpha gradients alpha_update, alpha_optimizer_state = alpha_optimizer.update( alpha_grads, state.alpha_optimizer_state) alpha_params = optax.apply_updates(state.alpha_params, alpha_update) metrics.update({ 'alpha_loss': alpha_loss, 'alpha': jnp.exp(alpha_params), }) new_state = new_state._replace( alpha_optimizer_state=alpha_optimizer_state, alpha_params=alpha_params) metrics['observations_mean'] = jnp.mean( utils.batch_concat( jax.tree_map(lambda x: jnp.abs(jnp.mean(x, axis=0)), transitions.observation))) metrics['observations_std'] = jnp.mean( utils.batch_concat( jax.tree_map(lambda x: jnp.std(x, axis=0), transitions.observation))) metrics['next_observations_mean'] = jnp.mean( utils.batch_concat( jax.tree_map(lambda x: jnp.abs(jnp.mean(x, axis=0)), transitions.next_observation))) metrics['next_observations_std'] = jnp.mean( utils.batch_concat( jax.tree_map(lambda x: jnp.std(x, axis=0), transitions.next_observation))) return new_state, metrics