Exemplo n.º 1
0
 def f(preds, values, returns, actions, mask):
   advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1)
   logps = self._policy_dist.log_prob(preds, actions)
   awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)(
       (logps, advantages, jnp.zeros_like(logps), mask))
   l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff
   return awr_loss + l2_value_loss
Exemplo n.º 2
0
    def _do_custom_gradients(self, x, weights, state, rng):
        """Calls this layer for a forward pass, but with custom gradients."""
        def _do_forward(y, weights):
            old_weights, old_state, old_rng = self._weights, self._state, self._rng
            self._weights = weights
            res = self.forward(y)
            s = self._state
            self._weights, self._state, self._rng = old_weights, old_state, old_rng
            return res, s

        def do_forward_vjp(y, weights):
            """Custom gradient (vjp) function."""
            old_weights, old_state, old_rng = self._weights, self._state, self._rng
            self._weights = weights
            output = self.forward(y)
            new_state = self._state
            self._weights, self._state, self._rng = old_weights, old_state, old_rng

            def vjpfun(grad):
                grad = grad[0]  # Ignore dummy gradient wrt state.
                res = self.backward(y, output, grad, weights, state, new_state,
                                    rng)
                return res

            return (output, new_state), vjpfun

        do_forward = math.custom_grad(do_forward_vjp, _do_forward)

        output, state = do_forward(x, weights)
        # TODO(lukaszkaiser): Investigate why we need this stop_gradient
        state = math.stop_gradient(state)
        return output, state
Exemplo n.º 3
0
 def f(log_probs, advantages, old_log_probs, mask):
     if reweight:  # Use new policy weights for sampled actions instead.
         mask *= jnp.exp(math.stop_gradient(log_probs) - old_log_probs)
     if sampled_all_discrete:  # Actions were sampled uniformly; weight them.
         mask *= jnp.exp(old_log_probs)
     weights = jnp.minimum(awr_weights(advantages, beta), w_max)
     return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)
Exemplo n.º 4
0
        def f(dist_inputs, values, returns, actions, old_log_probs, mask):
            """Definition of the Proximal Policy Optimization loss."""
            del mask  # TODO(lukaszkaiser): make PPO work with Transformer

            ppo_objective = rl_layers.PPOObjective(
                dist_inputs,
                stop_gradient(values),
                returns,
                actions,
                old_log_probs,
                log_prob_fun=self._policy_dist.log_prob,
                epsilon=self._epsilon,
                normalize_advantages=self._normalize_advantages)

            entropy_loss = rl_layers.EntropyLoss(
                dist_inputs,
                actions,
                log_prob_fun=self._policy_dist.log_prob,
                entropy_coeff=self._entropy_coeff,
                entropy_fun=self._policy_dist.entropy)

            l2_value_loss = rl_layers.ValueLoss(
                values, returns, value_loss_coeff=self._value_loss_coeff)

            return -ppo_objective.mean() + l2_value_loss - entropy_loss
Exemplo n.º 5
0
        def f(dist_inputs, values, returns, dones, rewards, actions,
              old_log_probs, mask):
            """Definition of the A2C loss."""
            del dones, rewards, old_log_probs

            # Typically we have dist_inputs of the shape float32[128,1,18]
            assert len(dist_inputs.shape) == 3, (
                f'dist_inputs.shape was {dist_inputs.shape} '
                f'but expected length of the tensor shape is 3')
            # values of the shape float32[128,1,1]
            # returns of the shape float32[128,1,1]
            assert values.shape == returns.shape, (
                f'values.shape was {values.shape}'
                f'returns.shape was (returns.shape)')
            # actions of the shape int32[128,1] in the case of discrete actions
            # and float32[128,1,6] in the case of of half-cheetah
            # actions agree with returns/values on the first two coordinates
            assert actions.shape[0:2] == returns.shape[0:2], (
                f'actions.shape was {actions.shape}'
                f'returns.shape was (returns.shape)')
            # and mask of the shape float32[128,1]
            assert len(mask.shape) == 2, f'mask.shape was {mask.shape}'
            # which agrees with returns/values/actions on the first two coordinates
            assert mask.shape[0:2] == returns.shape[0:2], (
                f'mask.shape was {mask.shape}'
                f'returns.shape was (returns.shape)')

            a2c_objective = rl_layers.A2CObjective(
                dist_inputs,
                stop_gradient(values),
                returns,
                actions,
                mask,
                log_prob_fun=self._policy_dist.log_prob,
                normalize_advantages=self._normalize_advantages)

            # we insist that a2c_objective is a scalar
            assert jnp.ndim(
                a2c_objective) == 0, f'a2c_objective was {a2c_objective}'

            entropy_loss = rl_layers.EntropyLoss(
                dist_inputs,
                actions,
                log_prob_fun=self._policy_dist.log_prob,
                entropy_coeff=self._entropy_coeff,
                entropy_fun=self._policy_dist.entropy)

            assert jnp.ndim(
                entropy_loss) == 0, f'entropy_loss was {entropy_loss}'

            l2_value_loss = rl_layers.ValueLoss(
                values, returns, value_loss_coeff=self._value_loss_coeff)

            assert jnp.ndim(
                l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}'

            combined_loss = a2c_objective + l2_value_loss - entropy_loss

            return combined_loss
Exemplo n.º 6
0
        def f(dist_inputs, values, returns, actions, old_log_probs, mask):
            """Definition of the Proximal Policy Optimization loss."""
            del mask  # TODO(lukaszkaiser): make PPO work with Transformer
            # We have dist_inputs of the shape float32[128,1,18]
            assert len(dist_inputs.shape) == 3, (
                f'dist_inputs.shape was {dist_inputs.shape}'
                f'but expected length of the tensor shape is 3')
            # values of the shape float32[128,1,1]
            # returns of the shape float32[128,1,1]
            # and old_log_probs of the shape float32[128,1]
            assert values.shape == returns.shape, (
                f'values.shape was {values.shape}'
                f'returns.shape was {returns.shape}')
            assert returns.shape[0:2] == old_log_probs.shape, (
                f'returns.shape was {returns.shape}'
                f'old_log_probs.shape was {old_log_probs.shape}')
            # actions is a tensor of the shape int32[128,1]
            assert len(
                actions.shape) == 2, f'actions.shape was {actions.shape}'
            # which agrees with returns/values on the first two coordinates
            assert actions.shape[0:2] == returns.shape[0:2], (
                f'actions.shape was {actions.shape} and'
                f'returns.shape was {returns.shape}')

            ppo_objective = rl_layers.PPOObjective(
                dist_inputs,
                stop_gradient(values),
                returns,
                actions,
                old_log_probs,
                log_prob_fun=self._policy_dist.log_prob,
                epsilon=self._epsilon,
                normalize_advantages=self._normalize_advantages)

            # we insist that ppo_objective is a vector of shape [128,1]
            assert len(ppo_objective.shape) == 2, (
                f'ppo_objective was {ppo_objective}')
            # which agrees with returns/values/actions on the first two coordinates
            assert ppo_objective.shape[0:2] == values.shape[0:2], (
                f'ppo_objective.shape was {ppo_objective.shape} and '
                f'values.shape was {values.shape}')

            entropy_loss = rl_layers.EntropyLoss(
                dist_inputs,
                actions,
                log_prob_fun=self._policy_dist.log_prob,
                entropy_coeff=self._entropy_coeff,
                entropy_fun=self._policy_dist.entropy)

            assert jnp.ndim(
                entropy_loss) == 0, f'entropy_loss was {entropy_loss}'

            l2_value_loss = rl_layers.ValueLoss(
                values, returns, value_loss_coeff=self._value_loss_coeff)

            assert jnp.ndim(
                l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}'

            return -ppo_objective.mean() + l2_value_loss - entropy_loss
Exemplo n.º 7
0
 def AWRJointLoss(x, **unused_kwargs):  # pylint: disable=invalid-name
     preds, values, returns, actions, mask = x
     advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1)
     logps = self._policy_dist.log_prob(preds, actions)
     awr_loss = actor_critic.AWRLoss(beta=self._beta,
                                     w_max=self._w_max)(
                                         (logps, advantages,
                                          jnp.zeros_like(logps), mask))
     l2_value_loss = jnp.mean(
         (returns - values)**2) * self._value_loss_coeff
     return awr_loss + l2_value_loss
Exemplo n.º 8
0
        def f(dist_inputs, values, returns, actions, old_log_probs, mask):
            """Definition of the A2C loss."""
            del old_log_probs

            a2c_objective = rl_layers.A2CObjective(
                dist_inputs,
                stop_gradient(values),
                returns,
                actions,
                mask,
                log_prob_fun=self._policy_dist.log_prob,
                normalize_advantages=self._normalize_advantages)

            entropy_loss = rl_layers.EntropyLoss(
                dist_inputs,
                actions,
                log_prob_fun=self._policy_dist.log_prob,
                entropy_coeff=self._entropy_coeff,
                entropy_fun=self._policy_dist.entropy)

            l2_value_loss = rl_layers.ValueLoss(
                values, returns, value_loss_coeff=self._value_loss_coeff)

            return a2c_objective.mean() + l2_value_loss - entropy_loss
Exemplo n.º 9
0
 def f(log_probs, advantages, old_log_probs, mask):
     if reweight:  # Use new policy weights for sampled actions instead.
         mask *= jnp.exp(math.stop_gradient(log_probs) - old_log_probs)
     weights = jnp.minimum(awr_weights(advantages, beta), w_max)
     return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)