示例#1
0
文件: loss.py 项目: goforfar/atari-rl
    def optimality_tightening(self):
        with tf.name_scope('optimality_tightening'):
            taken_action_value = self.policy_network[0].taken_action_value

            # Upper bounds
            upper_bounds = []
            rewards = 0
            for t in range(-1, -self.config.optimality_tightening_steps - 1,
                           -1):
                with tf.name_scope(util.format_offset('upper_bound', t)):
                    rewards = self.reward[t] + self.discount * rewards
                    q_value = (self.discounts[t] *
                               self.target_network[t].taken_action_value)
                    upper_bound = q_value - rewards
                upper_bounds.append(upper_bound)

            upper_bound = tf.reduce_min(tf.stack(upper_bounds, axis=2), axis=2)
            upper_bound_difference = taken_action_value - upper_bound
            upper_bound_breached = tf.to_float(upper_bound_difference > 0)
            upper_bound_penalty = tf.square(tf.nn.relu(upper_bound_difference))

            # Lower bounds
            discounted_reward = tf.tile(
                self.discounted_reward[0],
                multiples=[1, self.config.num_bootstrap_heads])
            lower_bounds = [discounted_reward]
            rewards = self.reward[0]
            for t in range(1, self.config.optimality_tightening_steps + 1):
                with tf.name_scope(util.format_offset('lower_bound', t)):
                    rewards += self.reward[t] * self.discounts[t]
                    lower_bound = rewards + self.discounts[t + 1] * self.value(
                        t + 1)
                lower_bounds.append(lower_bound)

            lower_bound = tf.reduce_max(tf.stack(lower_bounds, axis=2), axis=2)
            lower_bound_difference = lower_bound - taken_action_value
            lower_bound_breached = tf.to_float(lower_bound_difference > 0)
            lower_bound_penalty = tf.square(tf.nn.relu(lower_bound_difference))

            # Penalty and rescaling
            penalty = self.config.optimality_penalty_ratio * (
                lower_bound_penalty + upper_bound_penalty)
            constraints_breached = lower_bound_breached + upper_bound_breached
            error_rescaling = 1.0 / (1.0 + constraints_breached *
                                     self.config.optimality_penalty_ratio)

            tf.summary.scalar('discounted_reward',
                              tf.reduce_mean(discounted_reward))
            tf.summary.scalar('lower_bound', tf.reduce_mean(lower_bound))
            tf.summary.scalar('upper_bound', tf.reduce_mean(upper_bound))

            return penalty, error_rescaling
示例#2
0
 def offset_input(self, t):
     if t not in self.offset_inputs:
         with tf.name_scope(self.scope):
             with tf.name_scope(util.format_offset('input', t)):
                 offset_input = OffsetInput(self, t)
                 self.offset_inputs[t] = offset_input
     return self.offset_inputs[t]
示例#3
0
    def target_network(self, t=0):
        if t not in self.target_nets:
            reuse = len(self.target_nets) > 0
            with tf.variable_scope(self.target_scope, reuse=reuse) as scope:
                with tf.name_scope(self.network_scope):
                    with tf.name_scope(util.format_offset('target', t)):
                        self.target_nets[t] = dqn.Network(
                            variable_scope=scope,
                            inputs=self.inputs.offset_input(t),
                            reward_scaling=self.reward_scaling,
                            config=self.config,
                            write_summaries=False)

        return self.target_nets[t]