예제 #1
0
    def __init__(self, q_values, observations, num_actions, cur_epsilon,
                 softmax, softmax_temp, model_config):
        if softmax:
            action_dist = Categorical(q_values / softmax_temp)
            self.action = action_dist.sample()
            self.action_prob = tf.exp(action_dist.sampled_action_logp())
            return

        deterministic_actions = tf.argmax(q_values, axis=1)
        batch_size = tf.shape(observations)[0]

        # Special case masked out actions (q_value ~= -inf) so that we don't
        # even consider them for exploration.
        random_valid_action_logits = tf.where(
            tf.equal(q_values, tf.float32.min),
            tf.ones_like(q_values) * tf.float32.min, tf.ones_like(q_values))
        random_actions = tf.squeeze(
            tf.multinomial(random_valid_action_logits, 1), axis=1)

        chose_random = tf.random_uniform(
            tf.stack([batch_size]), minval=0, maxval=1,
            dtype=tf.float32) < cur_epsilon
        self.action = tf.where(chose_random, random_actions,
                               deterministic_actions)
        self.action_prob = None
예제 #2
0
파일: dqn_policy.py 프로젝트: xuman2019/ray
def sample_action_from_q_network(policy, q_model, input_dict, obs_space,
                                 action_space, config):
    # Action Q network.
    q_values, q_logits, q_dist = _compute_q_values(
        policy, q_model, input_dict[SampleBatch.CUR_OBS], obs_space,
        action_space)
    policy.q_values = q_values
    policy.q_func_vars = q_model.variables()

    # Noise vars for Q network except for layer normalization vars
    if config["parameter_noise"]:
        _build_parameter_noise(
            policy,
            [var for var in policy.q_func_vars if "LayerNorm" not in var.name])
        policy.action_probs = tf.nn.softmax(policy.q_values)

    # TODO(sven): Move soft_q logic to different Exploration child-component.
    action_log_prob = None
    if config["soft_q"]:
        action_dist = Categorical(q_values / config["softmax_temp"])
        policy.output_actions = action_dist.sample()
        action_log_prob = action_dist.sampled_action_logp()
        policy.action_prob = tf.exp(action_log_prob)
    else:
        policy.output_actions = tf.argmax(q_values, axis=1)
        policy.action_prob = None
    return policy.output_actions, action_log_prob