Пример #1
0
def run_data_coop_game(seed, swf, N=500):
    def data_coop_reward(a_C, a_DDO):
        if a_C == 0 and a_DDO == 0:  # both defect
            return np.array([1.]), np.array([1.])
        elif a_C == 0 and a_DDO == 1:
            return np.array([6.]), np.array([0.])
        elif a_C == 1 and a_DDO == 0:
            return np.array([0.]), np.array([6.])
        else:
            return np.array([5.]), np.array([5.])

    rng = jax.random.PRNGKey(seed)

    grad_PG_loss = jit(grad(rlax.policy_gradient_loss))
    w_t = np.array([1.])

    log = False
    d = 2
    rng, iter_rng = jax.random.split(rng)
    logits_C = 0.05 * jax.random.normal(iter_rng, shape=(1, d))
    rng, iter_rng = jax.random.split(rng)
    logits_DDO = 0.05 * jax.random.normal(iter_rng, shape=(1, d))

    r_Cs = []
    r_DDOs = []

    for _ in range(N):
        # sample actions given policies
        rng, iter_rng = jax.random.split(rng)
        a_C = jax.random.categorical(iter_rng, logits_C)
        rng, iter_rng = jax.random.split(rng)
        a_DDO = jax.random.categorical(iter_rng, logits_DDO)

        # observe rewards
        r_C, r_DDO = data_coop_reward(a_C, a_DDO)
        r_Cs.append(r_C)
        r_DDOs.append(r_DDO)

        # update policies
        logits_C -= 0.01 * grad_PG_loss(logits_C, a_C, r_C, w_t)
        logits_DDO -= 0.01 * grad_PG_loss(logits_DDO, a_DDO, r_DDO, w_t)

        if log:
            print('C', rlax.policy_gradient_loss(logits_C, a_C, r_C, w_t))
            print('DDO',
                  rlax.policy_gradient_loss(logits_DDO, a_DDO, r_DDO, w_t))
            print('SU', 0.5 * (r_C + r_DDO))

    print(logits_C, logits_DDO)
    #print(.5 * (np.mean(np.array(r_Cs)) + np.mean(np.array(r_DDOs))))
    return swf(np.array(r_Cs), np.array(r_DDOs))
Пример #2
0
def run_simple_RL():
    def simple_reward(action):
        if action == 0:
            return np.array([1.])
        else:
            return np.array([0.])

    rng = jax.random.PRNGKey(0)

    grad_PG_loss = jit(grad(rlax.policy_gradient_loss))
    w_t = np.array([1.])

    d = 4
    rng, iter_rng = jax.random.split(rng)
    logits = jax.random.normal(iter_rng, shape=(1, d))

    N = 100
    for _ in range(N):
        # sample action given policy
        rng, iter_rng = jax.random.split(rng)
        a = jax.random.categorical(iter_rng, logits)

        # observe reward
        r = simple_reward(a)

        # update policy
        logits -= 0.1 * grad_PG_loss(logits, a, r, w_t)
        print(rlax.policy_gradient_loss(logits, a, r, w_t))

    print(logits)
Пример #3
0
        def loss(trajectory: buffer.Trajectory, rnn_unroll_state: RNNState):
            """"Computes a linear combination of the policy gradient loss and value loss
      and regularizes it with an entropy term."""
            inputs = pack(trajectory)

            # Dyanmically unroll the network. This Haiku utility function unpacks the
            # list of input tensors such that the i^{th} row from each input tensor
            # is presented to the i^{th} unrolled RNN module.
            (logits, values, _, _,
             state_embeddings), new_rnn_unroll_state = hk.dynamic_unroll(
                 network, inputs, rnn_unroll_state)
            trajectory_len = trajectory.actions.shape[0]

            # Compute the combined loss given the output of the model.
            td_errors = rlax.td_lambda(v_tm1=values[:-1, 0],
                                       r_t=jnp.squeeze(trajectory.rewards, -1),
                                       discount_t=trajectory.discounts *
                                       discount,
                                       v_t=values[1:, 0],
                                       lambda_=jnp.array(td_lambda))
            critic_loss = jnp.mean(td_errors**2)
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1, 0],
                a_t=jnp.squeeze(trajectory.actions, 1),
                adv_t=td_errors,
                w_t=jnp.ones(trajectory_len))
            entropy_loss = jnp.mean(
                rlax.entropy_loss(logits[:-1, 0], jnp.ones(trajectory_len)))

            combined_loss = (actor_loss + critic_cost * critic_loss +
                             entropy_cost * entropy_loss)

            return combined_loss, new_rnn_unroll_state
Пример #4
0
  def loss_fn(params: hk.Params,
              sample: reverb.ReplaySample) -> jnp.DeviceArray:
    """Batched, entropy-regularised actor-critic loss with V-trace."""

    # Extract the data.
    data = sample.data
    observations, actions, rewards, discounts, extra = (data.observation,
                                                        data.action,
                                                        data.reward,
                                                        data.discount,
                                                        data.extras)
    initial_state = tree.map_structure(lambda s: s[0], extra['core_state'])
    behaviour_logits = extra['logits']

    # Apply reward clipping.
    rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward)

    # Unroll current policy over observations.
    (logits, values), _ = unroll_fn(params, observations, initial_state)

    # Compute importance sampling weights: current policy / behavior policy.
    rhos = rlax.categorical_importance_sampling_ratios(logits[:-1],
                                                       behaviour_logits[:-1],
                                                       actions[:-1])

    # Critic loss.
    vtrace_returns = rlax.vtrace_td_error_and_advantage(
        v_tm1=values[:-1],
        v_t=values[1:],
        r_t=rewards[:-1],
        discount_t=discounts[:-1] * discount,
        rho_tm1=rhos)
    critic_loss = jnp.square(vtrace_returns.errors)

    # Policy gradient loss.
    policy_gradient_loss = rlax.policy_gradient_loss(
        logits_t=logits[:-1],
        a_t=actions[:-1],
        adv_t=vtrace_returns.pg_advantage,
        w_t=jnp.ones_like(rewards[:-1]))

    # Entropy regulariser.
    entropy_loss = rlax.entropy_loss(logits[:-1], jnp.ones_like(rewards[:-1]))

    # Combine weighted sum of actor & critic losses, averaged over the sequence.
    mean_loss = jnp.mean(policy_gradient_loss + baseline_cost * critic_loss +
                         entropy_cost * entropy_loss)  # []

    metrics = {
        'policy_loss': jnp.mean(policy_gradient_loss),
        'critic_loss': jnp.mean(baseline_cost * critic_loss),
        'entropy_loss': jnp.mean(entropy_cost * entropy_loss),
        'entropy': jnp.mean(entropy_loss),
    }

    return mean_loss, metrics
Пример #5
0
        def loss(params: hk.Params,
                 sample: reverb.ReplaySample) -> jnp.ndarray:
            """Entropy-regularised actor-critic loss."""

            # Extract the data.
            observations, actions, rewards, discounts, extra = sample.data
            initial_state = tree.map_structure(lambda s: s[0],
                                               extra['core_state'])
            behaviour_logits = extra['logits']

            #
            actions = actions[:-1]  # [T-1]
            rewards = rewards[:-1]  # [T-1]
            discounts = discounts[:-1]  # [T-1]
            rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward)

            # Unroll current policy over observations.
            net = functools.partial(network.apply, params)
            (logits, values), _ = hk.static_unroll(net, observations,
                                                   initial_state)

            # Compute importance sampling weights: current policy / behavior policy.
            rhos = rlax.categorical_importance_sampling_ratios(
                logits[:-1], behaviour_logits[:-1], actions)

            # Critic loss.
            vtrace_returns = rlax.vtrace_td_error_and_advantage(
                v_tm1=values[:-1],
                v_t=values[1:],
                r_t=rewards,
                discount_t=discounts * discount,
                rho_t=rhos)
            critic_loss = jnp.square(vtrace_returns.errors)

            # Policy gradient loss.
            policy_gradient_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1],
                a_t=actions,
                adv_t=vtrace_returns.pg_advantage,
                w_t=jnp.ones_like(rewards))

            # Entropy regulariser.
            entropy_loss = rlax.entropy_loss(logits[:-1],
                                             jnp.ones_like(rewards))

            # Combine weighted sum of actor & critic losses.
            mean_loss = jnp.mean(policy_gradient_loss +
                                 baseline_cost * critic_loss +
                                 entropy_cost * entropy_loss)

            return mean_loss
Пример #6
0
        def loss(trajectory: sequence.Trajectory) -> jnp.ndarray:
            """"Actor-critic loss."""
            logits, values = network(trajectory.observations)
            td_errors = rlax.td_lambda(
                v_tm1=values[:-1],
                r_t=trajectory.rewards,
                discount_t=trajectory.discounts * discount,
                v_t=values[1:],
                lambda_=jnp.array(td_lambda),
            )
            critic_loss = jnp.mean(td_errors**2)
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1],
                a_t=trajectory.actions,
                adv_t=td_errors,
                w_t=jnp.ones_like(td_errors))

            return actor_loss + critic_loss
Пример #7
0
        def loss(trajectory: buffer.Trajectory) -> jnp.ndarray:
            """"Actor-critic loss."""
            observations, rewards, actions = pack(trajectory)
            logits, values, _, _, _ = network(observations, rewards, actions)

            td_errors = rlax.td_lambda(v_tm1=values[:-1],
                                       r_t=jnp.squeeze(trajectory.rewards, -1),
                                       discount_t=trajectory.discounts *
                                       discount,
                                       v_t=values[1:],
                                       lambda_=jnp.array(td_lambda))
            critic_loss = jnp.mean(td_errors**2)
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1],
                a_t=jnp.squeeze(trajectory.actions, 1),
                adv_t=td_errors,
                w_t=jnp.ones_like(td_errors))

            entropy_loss = jnp.mean(
                rlax.entropy_loss(logits[:-1], jnp.ones_like(td_errors)))

            return actor_loss + critic_cost * critic_loss + entropy_cost * entropy_loss
Пример #8
0
    def loss(
        weights,
        observations,
        actions,
        rewards,
        td_lambda=0.2,
        discount=0.99,
        policy_cost=0.25,
        entropy_cost=1e-3,
    ):
        """Actor-critic loss."""
        logits, values = network(weights, observations)
        values = jnp.append(values, jnp.sum(rewards))

        # replace -inf values by tiny finite value
        logits = jnp.maximum(logits, MINIMUM_LOGIT)

        td_errors = rlax.td_lambda(
            v_tm1=values[:-1],
            r_t=rewards,
            discount_t=jnp.full_like(rewards, discount),
            v_t=values[1:],
            lambda_=td_lambda,
        )
        critic_loss = jnp.mean(td_errors ** 2)

        if type_ == "a2c":
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits,
                a_t=actions,
                adv_t=td_errors,
                w_t=jnp.ones(td_errors.shape[0]),
            )
        elif type_ == "supervised":
            actor_loss = jnp.mean(cross_entropy(logits, actions))

        entropy_loss = -jnp.mean(entropy(logits))

        return policy_cost * actor_loss, critic_loss, entropy_cost * entropy_loss
Пример #9
0
    def loss(trajectory: sequence.Trajectory, rnn_unroll_state: LSTMState):
      """"Actor-critic loss."""
      (logits, values), new_rnn_unroll_state = hk.dynamic_unroll(
          network, trajectory.observations[:, None, ...], rnn_unroll_state)
      seq_len = trajectory.actions.shape[0]
      td_errors = rlax.td_lambda(
          v_tm1=values[:-1, 0],
          r_t=trajectory.rewards,
          discount_t=trajectory.discounts * discount,
          v_t=values[1:, 0],
          lambda_=jnp.array(td_lambda),
      )
      critic_loss = jnp.mean(td_errors**2)
      actor_loss = rlax.policy_gradient_loss(
          logits_t=logits[:-1, 0],
          a_t=trajectory.actions,
          adv_t=td_errors,
          w_t=jnp.ones(seq_len))
      entropy_loss = jnp.mean(
          rlax.entropy_loss(logits[:-1, 0], jnp.ones(seq_len)))

      combined_loss = actor_loss + critic_loss + entropy_cost * entropy_loss

      return combined_loss, new_rnn_unroll_state
Пример #10
0
def run_data_coop_game_with_regulator(seed, swf, N=500):
    def data_coop_reward(a_C, a_DDO):
        if a_C == 0 and a_DDO == 0:  # both defect
            return np.array([1.]), np.array([1.])
        elif a_C == 0 and a_DDO == 1:
            return np.array([6.]), np.array([0.])
        elif a_C == 1 and a_DDO == 0:
            return np.array([0.]), np.array([6.])
        else:
            return np.array([5.]), np.array([5.])

    def redistribute(r_C, r_DDO, a_R):
        tax = 0.
        if a_R == 0:
            tax = 0.
        elif a_R == 1:
            tax = 0.15
        elif a_R == 2:
            tax = 0.3
        else:
            tax = 0.5

        wealth = tax * (r_C + r_DDO)
        r_C = r_C - tax * r_C + wealth / 2.
        r_DDO = r_DDO - tax * r_DDO + wealth / 2.

        return r_C, r_DDO, tax

    def redistribute(r_C, r_DDO, a_R1, a_R2):
        tax1 = 0.
        if a_R1 == 0:
            tax1 = 0.
        elif a_R1 == 1:
            tax1 = 0.15
        elif a_R1 == 2:
            tax1 = 0.3
        else:
            tax1 = 0.5

        tax2 = 0.
        if a_R2 == 0:
            tax2 = 0.
        elif a_R2 == 1:
            tax2 = 0.15
        elif a_R2 == 2:
            tax2 = 0.3
        else:
            tax2 = 0.5

        wealth = tax1 * r_C + tax2 * r_DDO
        r_C = r_C - tax1 * r_C + wealth / 2.
        r_DDO = r_DDO - tax2 * r_DDO + wealth / 2.

        return r_C, r_DDO, tax1, tax2

    rng = jax.random.PRNGKey(seed)

    grad_PG_loss = jit(grad(rlax.policy_gradient_loss))
    w_t = np.array([1.])

    log = False
    d = 2
    rng, iter_rng = jax.random.split(rng)
    logits_C = 0.1 * np.array([[1, 1.]])
    rng, iter_rng = jax.random.split(rng)
    logits_DDO = 0.1 * np.array([[1, 1.]])
    rng, iter_rng = jax.random.split(rng)
    logits_R1 = 0.1 * np.array([[1, 1, 1, 1.]])
    logits_R2 = 0.1 * np.array([[1, 1, 1, 1.]])

    r_Cs = []
    r_DDOs = []
    taxes1 = []
    taxes2 = []

    for i in range(N):
        # sample actions given policies
        rng, iter_rng = jax.random.split(rng)
        a_C = jax.random.categorical(iter_rng, logits_C)
        rng, iter_rng = jax.random.split(rng)
        a_DDO = jax.random.categorical(iter_rng, logits_DDO)
        rng, iter_rng = jax.random.split(rng)
        a_R1 = jax.random.categorical(iter_rng, logits_R1)
        a_R2 = jax.random.categorical(iter_rng, logits_R1)

        # observe rewards
        r_C, r_DDO = data_coop_reward(a_C, a_DDO)
        r_Cs.append(r_C)
        r_DDOs.append(r_DDO)

        r_C, r_DDO, tax1, tax2 = redistribute(r_C, r_DDO, a_R1, a_R2)
        taxes1.append(tax1)
        taxes2.append(tax2)

        # update policies
        logits_C -= 0.01 * grad_PG_loss(logits_C, a_C, r_C, w_t)
        logits_DDO -= 0.01 * grad_PG_loss(logits_DDO, a_DDO, r_DDO, w_t)
        lag = 50
        if i % lag == 1:
            R = np.array(r_Cs[-lag:]).mean() + np.array(r_DDOs[-lag:]).mean()
            logits_R1 -= 0.005 * grad_PG_loss(logits_R1, a_R1,
                                              .5 * np.array([R]), w_t)
            logits_R2 -= 0.005 * grad_PG_loss(logits_R2, a_R2,
                                              .5 * np.array([R]), w_t)

        if log:
            print('C', rlax.policy_gradient_loss(logits_C, a_C, r_C, w_t))
            print('DDO',
                  rlax.policy_gradient_loss(logits_DDO, a_DDO, r_DDO, w_t))
            print('SU', 0.5 * (r_C + r_DDO))

    print('logits:', logits_C, logits_DDO, logits_R1, logits_R2)
    print('mean SU:',
          .5 * (np.mean(np.array(r_Cs)) + np.mean(np.array(r_DDOs))))
    print('mean tax1', np.array(taxes1).mean())
    print('mean tax2', np.array(taxes2).mean())

    return swf(np.array(r_Cs), np.array(r_DDOs))
Пример #11
0
def run_data_coop_game_with_gaussian_regulator(seed, swf, N=500):
    def data_coop_reward(a_C, a_DDO):
        if a_C == 0 and a_DDO == 0:  # both defect
            return np.array([1.]), np.array([1.])
        elif a_C == 0 and a_DDO == 1:
            return np.array([6.]), np.array([0.])
        elif a_C == 1 and a_DDO == 0:
            return np.array([0.]), np.array([6.])
        else:
            return np.array([5.]), np.array([5.])

    def gaussian_logprob(logits, a):
        return np.mean(-((a - logits) / .1)**2)

    def redistribute(r_C, r_DDO, a_R1, a_R2):
        tax1 = 0.5 * jax.nn.sigmoid(a_R1)
        tax2 = 0.5 * jax.nn.sigmoid(a_R2)

        wealth = tax1 * r_C + tax2 * r_DDO
        r_C = r_C - tax1 * r_C + wealth / 2.
        r_DDO = r_DDO - tax2 * r_DDO + wealth / 2.

        return r_C, r_DDO, tax1, tax2

    def redistributed(r_C, r_DDO, a_R1, a_R2):
        tax1 = 0.5 * jax.nn.sigmoid(a_R1)
        tax2 = 0.5 * jax.nn.sigmoid(a_R2)

        wealth = tax1 * (r_C + r_DDO)
        r_C = r_C - tax1 * r_C + wealth / 2.
        r_DDO = r_DDO - tax1 * r_DDO + wealth / 2.

        return r_C, r_DDO, tax1, tax2

    rng = jax.random.PRNGKey(seed)

    grad_PG_loss = jit(grad(rlax.policy_gradient_loss))
    w_t = np.array([1.])

    log = False
    d = 2
    rng, iter_rng = jax.random.split(rng)
    logits_C = np.array([[1, 1.]])
    rng, iter_rng = jax.random.split(rng)
    logits_DDO = np.array([[1, 1.]])
    rng, iter_rng = jax.random.split(rng)
    logits_R1 = np.array([1.])  # the mean of the Gaussian
    logits_R2 = np.array([1.])  # the mean of the Gaussian

    r_Cs = []
    r_DDOs = []
    taxes1 = []
    taxes2 = []

    for i in range(N):
        # sample actions given policies
        rng, iter_rng = jax.random.split(rng)
        a_C = jax.random.categorical(iter_rng, logits_C)
        rng, iter_rng = jax.random.split(rng)
        a_DDO = jax.random.categorical(iter_rng, logits_DDO)
        rng, iter_rng = jax.random.split(rng)
        a_R1 = 0.1 * jax.random.normal(iter_rng) + logits_R1
        rng, iter_rng = jax.random.split(rng)
        a_R2 = 0.1 * jax.random.normal(iter_rng) + logits_R2

        # observe rewards
        r_C, r_DDO = data_coop_reward(a_C, a_DDO)
        r_Cs.append(r_C)
        r_DDOs.append(r_DDO)

        r_C, r_DDO, tax1, tax2 = redistribute(r_C, r_DDO, a_R1, a_R2)
        taxes1.append(tax1)
        taxes2.append(tax2)

        # update policies
        logits_C -= 0.01 * grad_PG_loss(logits_C, a_C, r_C, w_t)
        logits_DDO -= 0.01 * grad_PG_loss(logits_DDO, a_DDO, r_DDO, w_t)
        lag = 50
        if i > 0:
            if i % lag == 0:
                R = np.array(r_Cs[-lag:]).mean() + np.array(
                    r_DDOs[-lag:]).mean()
                logits_R1 -= 0.005 * R * grad(gaussian_logprob)(logits_R1,
                                                                a_R1)
                logits_R2 -= 0.005 * R * grad(gaussian_logprob)(logits_R2,
                                                                a_R2)

        if log:
            print('C', rlax.policy_gradient_loss(logits_C, a_C, r_C, w_t))
            print('DDO',
                  rlax.policy_gradient_loss(logits_DDO, a_DDO, r_DDO, w_t))
            print('SU', 0.5 * (r_C + r_DDO))

    print('logits:', logits_C, logits_DDO, logits_R1, logits_R2)
    print('mean SU:',
          .5 * (np.mean(np.array(r_Cs)) + np.mean(np.array(r_DDOs))))
    print('mean tax1', np.array(taxes1).mean())
    print('mean tax2', np.array(taxes2).mean())

    return 0.5 * (np.array(r_Cs) + np.array(r_DDOs))
    return swf(np.array(r_Cs), np.array(r_DDOs))