Esempio n. 1
0
  def test_batch_concat(self):
    batch_size = 32
    inputs = [
        jnp.zeros(shape=(batch_size, 2)),
        {
            'foo': jnp.zeros(shape=(batch_size, 5, 3))
        },
        [jnp.zeros(shape=(batch_size, 1))],
    ]

    output_shape = utils.batch_concat(inputs).shape
    expected_shape = [batch_size, 2 + 5 * 3 + 1]
    self.assertSequenceEqual(output_shape, expected_shape)
Esempio n. 2
0
    def __call__(self, observation: jnp.ndarray,
                 action: jnp.ndarray) -> jnp.ndarray:

        # Maybe transform observations and actions before feeding them on.
        if self._observation_network:
            observation = self._observation_network(observation)
        if self._action_network:
            action = self._action_network(action)

        # Concat observations and actions, with one batch dimension.
        outputs = utils.batch_concat([observation, action])

        # Maybe transform output before returning.
        if self._critic_network:
            outputs = self._critic_network(outputs)

        return outputs
Esempio n. 3
0
 def apply(self, inputs: jnp.ndarray):
     inputs = utils.batch_concat(inputs)
     logits = MLP(inputs, [64, 64, num_actions])
     value = MLP(inputs, [64, 64, 1])
     value = jnp.squeeze(value, axis=-1)
     return tfd.Categorical(logits=logits), value
Esempio n. 4
0
        def sgd_step(
            state: TrainingState, sample: reverb.ReplaySample
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
            """Performs a minibatch SGD step, returning new state and metrics."""

            # Extract the data.
            data = sample.data
            # TODO(sinopalnikov): replace it with namedtuple unpacking
            observations, actions, rewards, termination, extra = (
                data.observation, data.action, data.reward, data.discount,
                data.extras)
            discounts = termination * discount
            behavior_log_probs = extra['log_prob']

            def get_behavior_values(
                    params: networks_lib.Params,
                    observations: types.NestedArray) -> jnp.ndarray:
                o = jax.tree_map(
                    lambda x: jnp.reshape(x, [-1] + list(x.shape[2:])),
                    observations)
                _, behavior_values = ppo_networks.network.apply(params, o)
                behavior_values = jnp.reshape(behavior_values,
                                              rewards.shape[0:2])
                return behavior_values

            behavior_values = get_behavior_values(state.params, observations)

            # Vmap over batch dimension
            batch_gae_advantages = jax.vmap(gae_advantages, in_axes=0)
            advantages, target_values = batch_gae_advantages(
                rewards, discounts, behavior_values)
            # Exclude the last step - it was only used for bootstrapping.
            # The shape is [num_sequences, num_steps, ..]
            observations, actions, behavior_log_probs, behavior_values = jax.tree_map(
                lambda x: x[:, :-1],
                (observations, actions, behavior_log_probs, behavior_values))
            trajectories = Batch(observations=observations,
                                 actions=actions,
                                 advantages=advantages,
                                 behavior_log_probs=behavior_log_probs,
                                 target_values=target_values,
                                 behavior_values=behavior_values)

            # Concatenate all trajectories. Reshape from [num_sequences, num_steps,..]
            # to [num_sequences * num_steps,..]
            assert len(target_values.shape) > 1
            num_sequences = target_values.shape[0]
            num_steps = target_values.shape[1]
            batch_size = num_sequences * num_steps
            assert batch_size % num_minibatches == 0, (
                'Num minibatches must divide batch size. Got batch_size={}'
                ' num_minibatches={}.').format(batch_size, num_minibatches)
            batch = jax.tree_map(
                lambda x: x.reshape((batch_size, ) + x.shape[2:]),
                trajectories)

            # Compute gradients.
            grad_fn = jax.grad(loss, has_aux=True)

            def model_update_minibatch(
                carry: Tuple[networks_lib.Params, optax.OptState],
                minibatch: Batch,
            ) -> Tuple[Tuple[networks_lib.Params, optax.OptState], Dict[
                    str, jnp.ndarray]]:
                """Performs model update for a single minibatch."""
                params, opt_state = carry
                # Normalize advantages at the minibatch level before using them.
                advantages = ((minibatch.advantages -
                               jnp.mean(minibatch.advantages, axis=0)) /
                              (jnp.std(minibatch.advantages, axis=0) + 1e-8))
                gradients, metrics = grad_fn(params, minibatch.observations,
                                             minibatch.actions,
                                             minibatch.behavior_log_probs,
                                             minibatch.target_values,
                                             advantages,
                                             minibatch.behavior_values)

                # Apply updates
                updates, opt_state = optimizer.update(gradients, opt_state)
                params = optax.apply_updates(params, updates)

                metrics['norm_grad'] = optax.global_norm(gradients)
                metrics['norm_updates'] = optax.global_norm(updates)
                return (params, opt_state), metrics

            def model_update_epoch(
                carry: Tuple[jnp.ndarray, networks_lib.Params, optax.OptState,
                             Batch], unused_t: Tuple[()]
            ) -> Tuple[Tuple[jnp.ndarray, networks_lib.Params, optax.OptState,
                             Batch], Dict[str, jnp.ndarray]]:
                """Performs model updates based on one epoch of data."""
                key, params, opt_state, batch = carry
                key, subkey = jax.random.split(key)
                permutation = jax.random.permutation(subkey, batch_size)
                shuffled_batch = jax.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch)
                minibatches = jax.tree_map(
                    lambda x: jnp.reshape(x, [num_minibatches, -1] + list(
                        x.shape[1:])), shuffled_batch)

                (params,
                 opt_state), metrics = jax.lax.scan(model_update_minibatch,
                                                    (params, opt_state),
                                                    minibatches,
                                                    length=num_minibatches)

                return (key, params, opt_state, batch), metrics

            params = state.params
            opt_state = state.opt_state
            # Repeat training for the given number of epoch, taking a random
            # permutation for every epoch.
            (key, params, opt_state, _), metrics = jax.lax.scan(
                model_update_epoch,
                (state.random_key, params, opt_state, batch), (),
                length=num_epochs)

            metrics = jax.tree_map(jnp.mean, metrics)
            metrics['norm_params'] = optax.global_norm(params)
            metrics['observations_mean'] = jnp.mean(
                utils.batch_concat(jax.tree_map(
                    lambda x: jnp.abs(jnp.mean(x, axis=(0, 1))), observations),
                                   num_batch_dims=0))
            metrics['observations_std'] = jnp.mean(
                utils.batch_concat(jax.tree_map(
                    lambda x: jnp.std(x, axis=(0, 1)), observations),
                                   num_batch_dims=0))
            metrics['rewards_mean'] = jnp.mean(
                jnp.abs(jnp.mean(rewards, axis=(0, 1))))
            metrics['rewards_std'] = jnp.std(rewards, axis=(0, 1))
            new_state = TrainingState(params=params,
                                      opt_state=opt_state,
                                      random_key=key)
            return new_state, metrics
Esempio n. 5
0
        def update_step(
            state: TrainingState,
            transitions: types.Transition,
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

            key, key_alpha, key_critic, key_actor = jax.random.split(
                state.key, 4)
            if adaptive_entropy_coefficient:
                alpha_loss, alpha_grads = alpha_grad(state.alpha_params,
                                                     state.policy_params,
                                                     transitions, key_alpha)
                alpha = jnp.exp(state.alpha_params)
            else:
                alpha = entropy_coefficient
            critic_loss, critic_grads = critic_grad(state.q_params,
                                                    state.policy_params,
                                                    state.target_q_params,
                                                    alpha, transitions,
                                                    key_critic)
            actor_loss, actor_grads = actor_grad(state.policy_params,
                                                 state.q_params, alpha,
                                                 transitions, key_actor)

            # Apply policy gradients
            actor_update, policy_optimizer_state = policy_optimizer.update(
                actor_grads, state.policy_optimizer_state)
            policy_params = optax.apply_updates(state.policy_params,
                                                actor_update)

            # Apply critic gradients
            critic_update, q_optimizer_state = q_optimizer.update(
                critic_grads, state.q_optimizer_state)
            q_params = optax.apply_updates(state.q_params, critic_update)

            new_target_q_params = jax.tree_multimap(
                lambda x, y: x * (1 - tau) + y * tau, state.target_q_params,
                q_params)

            metrics = {
                'critic_loss': critic_loss,
                'actor_loss': actor_loss,
            }

            new_state = TrainingState(
                policy_optimizer_state=policy_optimizer_state,
                q_optimizer_state=q_optimizer_state,
                policy_params=policy_params,
                q_params=q_params,
                target_q_params=new_target_q_params,
                key=key,
            )
            if adaptive_entropy_coefficient:
                # Apply alpha gradients
                alpha_update, alpha_optimizer_state = alpha_optimizer.update(
                    alpha_grads, state.alpha_optimizer_state)
                alpha_params = optax.apply_updates(state.alpha_params,
                                                   alpha_update)
                metrics.update({
                    'alpha_loss': alpha_loss,
                    'alpha': jnp.exp(alpha_params),
                })
                new_state = new_state._replace(
                    alpha_optimizer_state=alpha_optimizer_state,
                    alpha_params=alpha_params)

            metrics['observations_mean'] = jnp.mean(
                utils.batch_concat(
                    jax.tree_map(lambda x: jnp.abs(jnp.mean(x, axis=0)),
                                 transitions.observation)))
            metrics['observations_std'] = jnp.mean(
                utils.batch_concat(
                    jax.tree_map(lambda x: jnp.std(x, axis=0),
                                 transitions.observation)))
            metrics['next_observations_mean'] = jnp.mean(
                utils.batch_concat(
                    jax.tree_map(lambda x: jnp.abs(jnp.mean(x, axis=0)),
                                 transitions.next_observation)))
            metrics['next_observations_std'] = jnp.mean(
                utils.batch_concat(
                    jax.tree_map(lambda x: jnp.std(x, axis=0),
                                 transitions.next_observation)))

            return new_state, metrics