예제 #1
0
  def test_vtrace(self):
    """Tests V-trace against ground truth data calculated in python."""
    batch_size = 5
    seq_len = 5

    # Create log_rhos such that rho will span from near-zero to above the
    # clipping thresholds. In particular, calculate log_rhos in [-2.5, 2.5),
    # so that rho is in approx [0.08, 12.2).
    log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len)
    log_rhos = 5 * (log_rhos - 0.5)  # [0.0, 1.0) -> [-2.5, 2.5).
    values = {
        'behaviour_action_log_probs': tf.zeros_like(log_rhos),
        'target_action_log_probs': log_rhos,
        # T, B where B_i: [0.9 / (i+1)] * T
        'discounts': np.array([[0.9 / (b + 1) for b in range(batch_size)]  
                               for _ in range(seq_len)]),
        'rewards': _shaped_arange(seq_len, batch_size),
        'values': _shaped_arange(seq_len, batch_size) / batch_size,
        'bootstrap_value': _shaped_arange(batch_size) + 1.0,
        'clip_rho_threshold': 3.7,
        'clip_pg_rho_threshold': 2.2,
    }

    output = vtrace.from_importance_weights(**values)
    ground_truth_v = _ground_truth_calculation(**values)
    self.assertAllClose(output, ground_truth_v)
예제 #2
0
    def test_vtrace_from_logits(self, batch_size):
        """Tests V-trace calculated from logits."""
        seq_len = 5
        num_actions = 3
        clip_rho_threshold = None  # No clipping.
        clip_pg_rho_threshold = None  # No clipping.

        values = {
            'behaviour_policy_logits':
            _shaped_arange(seq_len, batch_size, num_actions),
            'target_policy_logits':
            _shaped_arange(seq_len, batch_size, num_actions),
            'actions':
            np.random.randint(0, num_actions - 1, size=(seq_len, batch_size)),
            # T, B where B_i: [0.9 / (i+1)] * T
            'discounts':
            np.array([[0.9 / (b + 1) for b in range(batch_size)]
                      for _ in range(seq_len)]),
            'rewards':
            _shaped_arange(seq_len, batch_size),
            'values':
            _shaped_arange(seq_len, batch_size) / batch_size,
            'bootstrap_value':
            _shaped_arange(batch_size) + 1.0,  # B
        }

        from_logits_output = vtrace.from_logits(
            clip_rho_threshold=clip_rho_threshold,
            clip_pg_rho_threshold=clip_pg_rho_threshold,
            **values)

        target_log_probs = vtrace.log_probs_from_logits_and_actions(
            values['target_policy_logits'], values['actions'])
        behaviour_log_probs = vtrace.log_probs_from_logits_and_actions(
            values['behaviour_policy_logits'], values['actions'])
        ground_truth_log_rhos = target_log_probs - behaviour_log_probs
        ground_truth_target_action_log_probs = target_log_probs
        ground_truth_behaviour_action_log_probs = behaviour_log_probs

        # Calculate V-trace using the ground truth logits.
        from_iw = vtrace.from_importance_weights(
            log_rhos=ground_truth_log_rhos,
            discounts=values['discounts'],
            rewards=values['rewards'],
            values=values['values'],
            bootstrap_value=values['bootstrap_value'],
            clip_rho_threshold=clip_rho_threshold,
            clip_pg_rho_threshold=clip_pg_rho_threshold)

        self.assertAllClose(from_iw.vs, from_logits_output.vs)
        self.assertAllClose(from_iw.pg_advantages,
                            from_logits_output.pg_advantages)
        self.assertAllClose(ground_truth_behaviour_action_log_probs,
                            from_logits_output.behaviour_action_log_probs)
        self.assertAllClose(ground_truth_target_action_log_probs,
                            from_logits_output.target_action_log_probs)
        self.assertAllClose(ground_truth_log_rhos, from_logits_output.log_rhos)
예제 #3
0
  def test_vtrace_vs_seed(self):
    values = tf.random.uniform((21, 10), maxval=3)
    rewards = tf.random.uniform((20, 10), maxval=3)
    target_action_log_probs = tf.random.uniform((20, 10), minval=-2, maxval=2)
    behaviour_action_log_probs = tf.random.uniform((20, 10), minval=-2,
                                                   maxval=2)
    done_terminated = tf.cast(
        tfp.distributions.Bernoulli(0.05).sample((20, 10)), tf.bool)
    done_abandoned = tf.zeros_like(rewards, dtype=tf.bool)

    tested_targets, unused_tested_advantages = advantages.vtrace(
        values, rewards,
        done_terminated, done_abandoned, 0.99,
        target_action_log_probs, behaviour_action_log_probs,
        lambda_=0.95)

    seed_output = vtrace.from_importance_weights(
        target_action_log_probs, behaviour_action_log_probs,
        0.99 * tf.cast(~done_terminated, tf.float32), rewards,
        values[:-1], values[-1], lambda_=0.95)

    self.assertAllClose(tested_targets, seed_output.vs)
예제 #4
0
def compute_loss(parametric_action_distribution, agent, agent_state,
                 prev_actions, env_outputs, agent_outputs):
    # agent((prev_actions[t], env_outputs[t]), agent_state)
    #   -> agent_outputs[t], agent_state'
    learner_outputs, _ = agent((prev_actions, env_outputs),
                               agent_state,
                               unroll=True,
                               is_training=True)

    # Use last baseline value (from the value function) to bootstrap.
    bootstrap_value = learner_outputs.baseline[-1]

    # At this point, we have unroll length + 1 steps. The last step is only used
    # as bootstrap value, so it's removed.
    agent_outputs = tf.nest.map_structure(lambda t: t[:-1], agent_outputs)
    rewards, done, _ = tf.nest.map_structure(lambda t: t[1:], env_outputs)
    learner_outputs = tf.nest.map_structure(lambda t: t[:-1], learner_outputs)

    if FLAGS.max_abs_reward:
        rewards = tf.clip_by_value(rewards, -FLAGS.max_abs_reward,
                                   FLAGS.max_abs_reward)
    discounts = tf.cast(~done, tf.float32) * FLAGS.discounting

    target_action_log_probs = parametric_action_distribution.log_prob(
        learner_outputs.policy_logits, agent_outputs.action)
    behaviour_action_log_probs = parametric_action_distribution.log_prob(
        agent_outputs.policy_logits, agent_outputs.action)

    # Compute V-trace returns and weights.
    vtrace_returns = vtrace.from_importance_weights(
        target_action_log_probs=target_action_log_probs,
        behaviour_action_log_probs=behaviour_action_log_probs,
        discounts=discounts,
        rewards=rewards,
        values=learner_outputs.baseline,
        bootstrap_value=bootstrap_value,
        lambda_=FLAGS.lambda_)

    # Policy loss based on Policy Gradients
    policy_loss = -tf.reduce_mean(target_action_log_probs * tf.stop_gradient(
        vtrace_returns.pg_advantages))

    # Value function loss
    v_error = vtrace_returns.vs - learner_outputs.baseline
    v_loss = FLAGS.baseline_cost * 0.5 * tf.reduce_mean(tf.square(v_error))

    # Entropy reward
    entropy = tf.reduce_mean(
        parametric_action_distribution.entropy(learner_outputs.policy_logits))
    entropy_loss = FLAGS.entropy_cost * -entropy

    # KL(old_policy|new_policy) loss
    kl = behaviour_action_log_probs - target_action_log_probs
    kl_loss = FLAGS.kl_cost * tf.reduce_mean(kl)

    total_loss = policy_loss + v_loss + entropy_loss + kl_loss

    # logging
    del log_keys[:]
    log_values = []

    def log(key, value):
        # this is a python op so it happens only when this tf.function is compiled
        log_keys.append(key)
        # this is a TF op
        log_values.append(value)

    # value function
    log('V/value function', tf.reduce_mean(learner_outputs.baseline))
    log('V/L2 error', tf.sqrt(tf.reduce_mean(tf.square(v_error))))
    # losses
    log('losses/policy', policy_loss)
    log('losses/V', v_loss)
    log('losses/entropy', entropy_loss)
    log('losses/kl', kl_loss)
    log('losses/total', total_loss)
    # policy
    dist = parametric_action_distribution.create_dist(
        learner_outputs.policy_logits)
    if hasattr(dist, 'scale'):
        log('policy/std', tf.reduce_mean(dist.scale))
    log('policy/max_action_abs(before_tanh)',
        tf.reduce_max(tf.abs(agent_outputs.action)))
    log('policy/entropy', entropy)
    log('policy/kl(old|new)', tf.reduce_mean(kl))

    return total_loss, log_values
예제 #5
0
def compute_loss(logger, parametric_action_distribution, agent, agent_state,
                 prev_actions, env_outputs, agent_outputs):
    learner_outputs, _ = agent(prev_actions,
                               env_outputs,
                               agent_state,
                               unroll=True,
                               is_training=True)

    # Use last baseline value (from the value function) to bootstrap.
    bootstrap_value = learner_outputs.baseline[-1]

    # At this point, we have unroll length + 1 steps. The last step is only used
    # as bootstrap value, so it's removed.
    agent_outputs = tf.nest.map_structure(lambda t: t[:-1], agent_outputs)
    rewards, done, _, _, _ = tf.nest.map_structure(lambda t: t[1:],
                                                   env_outputs)
    learner_outputs = tf.nest.map_structure(lambda t: t[:-1], learner_outputs)

    if FLAGS.max_abs_reward:
        rewards = tf.clip_by_value(rewards, -FLAGS.max_abs_reward,
                                   FLAGS.max_abs_reward)
    discounts = tf.cast(~done, tf.float32) * FLAGS.discounting

    target_action_log_probs = parametric_action_distribution.log_prob(
        learner_outputs.policy_logits, agent_outputs.action)
    behaviour_action_log_probs = parametric_action_distribution.log_prob(
        agent_outputs.policy_logits, agent_outputs.action)

    # Compute V-trace returns and weights.
    vtrace_returns = vtrace.from_importance_weights(
        target_action_log_probs=target_action_log_probs,
        behaviour_action_log_probs=behaviour_action_log_probs,
        discounts=discounts,
        rewards=rewards,
        values=learner_outputs.baseline,
        bootstrap_value=bootstrap_value,
        lambda_=FLAGS.lambda_)

    # Policy loss based on Policy Gradients
    policy_loss = -tf.reduce_mean(target_action_log_probs * tf.stop_gradient(
        vtrace_returns.pg_advantages))

    # Value function loss
    v_error = vtrace_returns.vs - learner_outputs.baseline
    v_loss = FLAGS.baseline_cost * 0.5 * tf.reduce_mean(tf.square(v_error))

    # Entropy reward
    entropy = tf.reduce_mean(
        parametric_action_distribution.entropy(learner_outputs.policy_logits))
    entropy_loss = tf.stop_gradient(agent.entropy_cost()) * -entropy

    # KL(old_policy|new_policy) loss
    kl = behaviour_action_log_probs - target_action_log_probs
    kl_loss = FLAGS.kl_cost * tf.reduce_mean(kl)

    # Entropy cost adjustment (Langrange multiplier style)
    if FLAGS.target_entropy:
        entropy_adjustment_loss = agent.entropy_cost() * tf.stop_gradient(
            tf.reduce_mean(entropy) - FLAGS.target_entropy)
    else:
        entropy_adjustment_loss = 0. * agent.entropy_cost(
        )  # to avoid None in grad

    total_loss = (policy_loss + v_loss + entropy_loss + kl_loss +
                  entropy_adjustment_loss)

    # value function
    session = logger.log_session()
    logger.log(session, 'V/value function',
               tf.reduce_mean(learner_outputs.baseline))
    logger.log(session, 'V/L2 error',
               tf.sqrt(tf.reduce_mean(tf.square(v_error))))
    # losses
    logger.log(session, 'losses/policy', policy_loss)
    logger.log(session, 'losses/V', v_loss)
    logger.log(session, 'losses/entropy', entropy_loss)
    logger.log(session, 'losses/kl', kl_loss)
    logger.log(session, 'losses/total', total_loss)
    # policy
    dist = parametric_action_distribution.create_dist(
        learner_outputs.policy_logits)
    if hasattr(dist, 'scale'):
        logger.log(session, 'policy/std', tf.reduce_mean(dist.scale))
    logger.log(session, 'policy/max_action_abs(before_tanh)',
               tf.reduce_max(tf.abs(agent_outputs.action)))
    logger.log(session, 'policy/entropy', entropy)
    logger.log(session, 'policy/entropy_cost', agent.entropy_cost())
    logger.log(session, 'policy/kl(old|new)', tf.reduce_mean(kl))

    return total_loss, session