Exemplo n.º 1
0
def opt_sav(money, v, au, ad, bet, R, sig):
    def ufun(x):
        return x**(1 - sig) / (1 - sig)

    dv = (v[1:] - v[:-1]) / (au - ad)
    co = (bet * dv)**(-1 / sig)
    so = money[:, None] - co[None, :]
    s_best_interval = np.fmax(np.fmin(au, so), ad)
    #print(s_best_interval)
    v_best_interval = v[:-1][None, :] + dv[None, :] * (s_best_interval - ad)
    mleft = np.fmax(money[:, None] - s_best_interval, 1e-10)
    Umat = ufun(mleft) + bet * v_best_interval
    vbest = Umat.max(axis=1)
    U_excess_norm = (Umat - vbest[:, None]) / ts
    enorm = np.exp(U_excess_norm)
    sumexp = enorm.sum(axis=1)
    ccp = enorm / (sumexp[:, None])
    vopt = vbest + ts * np.log(sumexp)
    return ccp, so, vopt
Exemplo n.º 2
0
        def loss(
            params: networks_lib.Params, observations: types.NestedArray,
            actions: jnp.array, behaviour_log_probs: jnp.array,
            target_values: jnp.array, advantages: jnp.array,
            behavior_values: jnp.array
        ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]:
            """Surrogate loss using clipped probability ratios."""

            distribution_params, values = ppo_networks.network.apply(
                params, observations)
            log_probs = ppo_networks.log_prob(distribution_params, actions)
            entropy = ppo_networks.entropy(distribution_params)

            # Compute importance sampling weights: current policy / behavior policy.
            rhos = jnp.exp(log_probs - behaviour_log_probs)

            policy_loss = rlax.clipped_surrogate_pg_loss(
                rhos, advantages, ppo_clipping_epsilon)

            # Value function loss. Exclude the bootstrap value
            unclipped_value_error = target_values - values
            unclipped_value_loss = unclipped_value_error**2

            if clip_value:
                # Clip values to reduce variablility during critic training.
                clipped_values = behavior_values + jnp.clip(
                    values - behavior_values, -ppo_clipping_epsilon,
                    ppo_clipping_epsilon)
                clipped_value_error = target_values - clipped_values
                clipped_value_loss = clipped_value_error**2
                value_loss = jnp.mean(
                    jnp.fmax(unclipped_value_loss, clipped_value_loss))
            else:
                # For Mujoco envs clipping hurts a lot. Evidenced by Figure 43 in
                # https://arxiv.org/pdf/2006.05990.pdf
                value_loss = jnp.mean(unclipped_value_loss)

            # Entropy regulariser.
            entropy_loss = -jnp.mean(entropy)

            total_loss = (policy_loss + value_loss * value_cost +
                          entropy_loss * entropy_cost)
            return total_loss, {
                'loss_total': total_loss,
                'loss_policy': policy_loss,
                'loss_value': value_loss,
                'loss_entropy': entropy_loss,
            }
Exemplo n.º 3
0
da = (au - ad)

t0 = 10
t1 = 15

tend = t0 + t1

w1 = 0.2
w0 = 1.0
sig = 1.5

R = 1.0

vlast = (w1 + R * agrid)**(1 - sig) / (1 - sig)

dvlast = np.fmax((vlast[1:] - vlast[:-1]) / da, 1e-16)

ts = 0.2


@jit
def opt_sav(money, v, au, ad, bet, R, sig):
    def ufun(x):
        return x**(1 - sig) / (1 - sig)

    dv = (v[1:] - v[:-1]) / (au - ad)
    co = (bet * dv)**(-1 / sig)
    so = money[:, None] - co[None, :]
    s_best_interval = np.fmax(np.fmin(au, so), ad)
    #print(s_best_interval)
    v_best_interval = v[:-1][None, :] + dv[None, :] * (s_best_interval - ad)
Exemplo n.º 4
0
def fmax(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.fmax(x1, x2))