Exemplo n.º 1
0
def _q_function(params, obs, action, scope):
    return q_function(params.fcs,
                      obs,
                      action,
                      params.concat_index,
                      tf.nn.tanh,
                      w_init=initializer,
                      last_w_init=last_initializer,
                      last_b_init=last_initializer,
                      scope=scope)
Exemplo n.º 2
0
    def test_q_function(self):
        inpt = make_tf_inpt()
        fcs = make_fcs()
        w_init = tf.random_uniform_initializer(-0.1, 0.1)
        action = tf.constant(np.random.random((int(inpt.shape[0]),
                             np.random.randint(10) + 1)), dtype=tf.float32)
        concat_index = np.random.randint(len(fcs))

        value = q_function(
            fcs, inpt, action, concat_index, w_init=w_init,
            last_w_init=w_init, last_b_init=w_init)

        # to check connection
        optimizer = tf.train.AdamOptimizer(1e-4)
        optimize_expr = optimizer.minimize(tf.reduce_mean(value))

        assert int(value.shape[0]) == int(inpt.shape[0])
        assert int(value.shape[1]) == 1

        hiddens = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'action_value/hiddens')

        concat = hiddens[concat_index * 2]
        if concat_index == 0:
            dim = int(inpt.shape[1])
        else:
            dim = fcs[concat_index - 1]
        assert int(concat.shape[0]) == dim + int(action.shape[1])

        output = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'action_value/output')[0]
        assert int(output.shape[0]) == fcs[-1]
        assert int(output.shape[1]) == 1

        variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'value')

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())

            before = sess.run(variable)
            for var in before:
                assert_variable_range(var, -0.1, 0.1)

            sess.run(optimize_expr)

            after = sess.run(variable)
            assert_variable_mismatch(before, after)
Exemplo n.º 3
0
    def _build(self, fcs, concat_index, state_shape, num_actions, gamma, tau,
               actor_lr, critic_lr):
        with tf.variable_scope('ddpg', reuse=tf.AUTO_REUSE):
            # placeholder
            obs_t_ph = self.obs_t_ph = tf.placeholder(tf.float32, [None] +
                                                      list(state_shape),
                                                      name='obs_t')
            actions_t_ph = self.actions_t_ph = tf.placeholder(
                tf.float32, [None, num_actions], name='actions_t')
            rewards_tp1_ph = self.rewards_tp1_ph = tf.placeholder(
                tf.float32, [None], name='rewards_tp1')
            obs_tp1_ph = self.obs_tp1_ph = tf.placeholder(tf.float32, [None] +
                                                          list(state_shape),
                                                          name='obs_tp1')
            dones_tp1_ph = self.dones_tp1_ph = tf.placeholder(tf.float32,
                                                              [None],
                                                              name='dones_tp1')

            last_initializer = tf.random_uniform_initializer(-3e-3, 3e-3)

            raw_policy_t = deterministic_policy_function(
                fcs,
                obs_t_ph,
                num_actions,
                tf.nn.tanh,
                w_init=initializer,
                last_w_init=last_initializer,
                last_b_init=last_initializer,
                scope='actor')
            policy_t = tf.nn.tanh(raw_policy_t)
            raw_policy_tp1 = deterministic_policy_function(
                fcs,
                obs_tp1_ph,
                num_actions,
                tf.nn.tanh,
                w_init=initializer,
                last_w_init=last_initializer,
                last_b_init=last_initializer,
                scope='target_actor')
            policy_tp1 = tf.nn.tanh(raw_policy_tp1)

            q_t = q_function(fcs,
                             obs_t_ph,
                             actions_t_ph,
                             concat_index,
                             tf.nn.tanh,
                             w_init=initializer,
                             last_w_init=last_initializer,
                             last_b_init=last_initializer,
                             scope='critic')
            q_t_with_actor = q_function(fcs,
                                        obs_t_ph,
                                        policy_t,
                                        concat_index,
                                        tf.nn.tanh,
                                        w_init=initializer,
                                        last_w_init=last_initializer,
                                        last_b_init=last_initializer,
                                        scope='critic')
            q_tp1 = q_function(fcs,
                               obs_tp1_ph,
                               policy_tp1,
                               concat_index,
                               tf.nn.tanh,
                               w_init=initializer,
                               last_w_init=last_initializer,
                               last_b_init=last_initializer,
                               scope='target_critic')

            # prepare for loss calculation
            rewards_tp1 = tf.reshape(rewards_tp1_ph, [-1, 1])
            dones_tp1 = tf.reshape(dones_tp1_ph, [-1, 1])

            # critic loss
            self.critic_loss = build_critic_loss(q_t, rewards_tp1, q_tp1,
                                                 dones_tp1, gamma)
            # actor loss
            self.actor_loss = -tf.reduce_mean(q_t_with_actor)

            # target update
            self.update_target_critic = build_target_update(
                'ddpg/critic', 'ddpg/target_critic', tau)
            self.update_target_actor = build_target_update(
                'ddpg/actor', 'ddpg/target_actor', tau)

            # optimization
            self.critic_optimize_expr = build_optim(self.critic_loss,
                                                    critic_lr, 'ddpg/critic')
            self.actor_optimize_expr = build_optim(self.actor_loss, actor_lr,
                                                   'ddpg/actor')

            # action
            self.action = policy_t
            self.value = tf.reshape(q_t_with_actor, [-1])
Exemplo n.º 4
0
    def _build(self, params):
        with tf.variable_scope('sac'):
            self.obs_t_ph = tf.placeholder(tf.float32,
                                           (None, ) + params.state_shape,
                                           name='obs_t')
            self.actions_t_ph = tf.placeholder(tf.float32,
                                               (None, params.num_actions),
                                               name='actions_t')
            self.rewards_tp1_ph = tf.placeholder(tf.float32, (None, ),
                                                 name='rewards_tp1')
            self.obs_tp1_ph = tf.placeholder(tf.float32,
                                             (None, ) + params.state_shape,
                                             name='obs_tp1')
            self.dones_tp1_ph = tf.placeholder(tf.float32, (None, ),
                                               name='dones_tp1')

            # policy function
            pi_t = stochastic_policy_function(params.fcs,
                                              self.obs_t_ph,
                                              params.num_actions,
                                              tf.nn.relu,
                                              share=True,
                                              w_init=XAVIER_INIT,
                                              last_w_init=XAVIER_INIT,
                                              last_b_init=XAVIER_INIT,
                                              scope='pi')
            squashed_action_t, log_prob_t = squash_action(pi_t)

            # value function
            v_t = value_function(params.fcs,
                                 self.obs_t_ph,
                                 tf.nn.relu,
                                 XAVIER_INIT,
                                 XAVIER_INIT,
                                 ZEROS_INIT,
                                 scope='v')
            # target value function
            v_tp1 = value_function(params.fcs,
                                   self.obs_tp1_ph,
                                   tf.nn.relu,
                                   XAVIER_INIT,
                                   XAVIER_INIT,
                                   ZEROS_INIT,
                                   scope='target_v')

            # two q functions
            q1_t_with_pi = q_function(params.fcs,
                                      self.obs_t_ph,
                                      squashed_action_t,
                                      params.concat_index,
                                      tf.nn.relu,
                                      XAVIER_INIT,
                                      XAVIER_INIT,
                                      ZEROS_INIT,
                                      scope='q1')
            q1_t = q_function(params.fcs,
                              self.obs_t_ph,
                              self.actions_t_ph,
                              params.concat_index,
                              tf.nn.relu,
                              XAVIER_INIT,
                              XAVIER_INIT,
                              ZEROS_INIT,
                              scope='q1')
            q2_t_with_pi = q_function(params.fcs,
                                      self.obs_t_ph,
                                      squashed_action_t,
                                      params.concat_index,
                                      tf.nn.relu,
                                      XAVIER_INIT,
                                      XAVIER_INIT,
                                      ZEROS_INIT,
                                      scope='q2')
            q2_t = q_function(params.fcs,
                              self.obs_t_ph,
                              self.actions_t_ph,
                              params.concat_index,
                              tf.nn.relu,
                              XAVIER_INIT,
                              XAVIER_INIT,
                              ZEROS_INIT,
                              scope='q2')

            # prepare for loss
            rewards_tp1 = tf.reshape(self.rewards_tp1_ph, [-1, 1])
            dones_tp1 = tf.reshape(self.dones_tp1_ph, [-1, 1])

            # value function loss
            self.v_loss = build_v_loss(v_t, q1_t_with_pi, q2_t_with_pi,
                                       log_prob_t)
            # q function loss
            self.q1_loss = build_q_loss(q1_t, rewards_tp1, v_tp1, dones_tp1,
                                        params.gamma)
            self.q2_loss = build_q_loss(q2_t, rewards_tp1, v_tp1, dones_tp1,
                                        params.gamma)
            # policy function loss
            self.pi_loss = build_pi_loss(log_prob_t, q1_t_with_pi,
                                         q2_t_with_pi)
            # policy reguralization
            policy_decay = build_policy_reg(pi_t, params.reg)

            # target update
            self.target_update = build_target_update('sac/v', 'sac/target_v',
                                                     params.tau)

            # optimization
            self.v_optimize_expr = build_optim(self.v_loss, params.v_lr,
                                               'sac/v')
            self.q1_optimize_expr = build_optim(self.q1_loss, params.q_lr,
                                                'sac/q1')
            self.q2_optimize_expr = build_optim(self.q2_loss, params.q_lr,
                                                'sac/q2')
            self.pi_optimize_expr = build_optim(self.pi_loss + policy_decay,
                                                params.pi_lr, 'sac/pi')

            # for inference
            self.action = squashed_action_t[0]
            self.value = tf.reshape(v_t, [-1])[0]
            self.log_prob = tf.reshape(log_prob_t, [-1])[0]
Exemplo n.º 5
0
    def _build(self,
               fcs,
               concat_index,
               state_shape,
               num_actions,
               gamma,
               tau,
               pi_lr,
               q_lr,
               v_lr,
               reg):
        with tf.variable_scope('sac'):
            obs_t_ph = self.obs_t_ph = tf.placeholder(
                tf.float32, (None,) + state_shape, name='obs_t')
            actions_t_ph = self.actions_t_ph = tf.placeholder(
                tf.float32, (None, num_actions), name='actions_t')
            rewards_tp1_ph = self.rewards_tp1_ph = tf.placeholder(
                tf.float32, (None,), name='rewards_tp1')
            obs_tp1_ph = self.obs_tp1_ph = tf.placeholder(
                tf.float32, (None,) + state_shape, name='obs_tp1')
            dones_tp1_ph = self.dones_tp1_ph = tf.placeholder(
                tf.float32, (None,), name='dones_tp1')

            # initialzier
            zeros_init = tf.zeros_initializer()
            w_init = tf.contrib.layers.xavier_initializer()
            last_w_init = tf.contrib.layers.xavier_initializer()
            last_b_init = tf.contrib.layers.xavier_initializer()

            # policy function
            pi_t = stochastic_policy_function(fcs, obs_t_ph, num_actions,
                                              tf.nn.relu, share=True,
                                              w_init=w_init,
                                              last_w_init=last_w_init,
                                              last_b_init=last_b_init,
                                              scope='pi')
            sampled_action_t = pi_t.sample(1)[0]
            squashed_action_t = tf.nn.tanh(sampled_action_t)
            diff = tf.reduce_sum(
                tf.log(1 - squashed_action_t ** 2 + 1e-6),
                axis=1, keepdims=True)
            log_prob_t = tf.reshape(
                pi_t.log_prob(sampled_action_t), [-1, 1]) - diff

            # value function
            v_t = value_function(
                fcs, obs_t_ph, tf.nn.relu, w_init,
                last_w_init, zeros_init, scope='v')
            # target value function
            v_tp1 = value_function(
                fcs, obs_tp1_ph, tf.nn.relu, w_init,
                last_w_init, zeros_init, scope='target_v')

            # two q functions
            q1_t_with_pi = q_function(fcs, obs_t_ph, squashed_action_t,
                                      concat_index, tf.nn.relu, w_init,
                                      last_w_init, zeros_init, scope='q1')
            q1_t = q_function(fcs, obs_t_ph, actions_t_ph, concat_index,
                              tf.nn.relu, w_init, last_w_init,
                              zeros_init, scope='q1')
            q2_t_with_pi = q_function(fcs, obs_t_ph, squashed_action_t,
                                      concat_index, tf.nn.relu, w_init,
                                      last_w_init, zeros_init, scope='q2')
            q2_t = q_function(fcs, obs_t_ph, actions_t_ph, concat_index,
                              tf.nn.relu, w_init, last_w_init,
                              zeros_init, scope='q2')

            # prepare for loss
            rewards_tp1 = tf.reshape(rewards_tp1_ph, [-1, 1])
            dones_tp1 = tf.reshape(dones_tp1_ph, [-1, 1])

            # value function loss
            self.v_loss = build_v_loss(
                v_t, q1_t_with_pi, q2_t_with_pi, log_prob_t)
            # q function loss
            self.q1_loss = build_q_loss(
                q1_t, rewards_tp1, v_tp1, dones_tp1, gamma)
            self.q2_loss = build_q_loss(
                q2_t, rewards_tp1, v_tp1, dones_tp1, gamma)
            # policy function loss
            self.pi_loss = build_pi_loss(
                log_prob_t, q1_t_with_pi, q2_t_with_pi)

            # target update
            self.target_update = build_target_update(
                'sac/v', 'sac/target_v', tau)

            # policy reguralization
            pi_mean_loss = 0.5 * tf.reduce_mean(pi_t.mean() ** 2)
            pi_logstd_loss = 0.5 * tf.reduce_mean(tf.log(pi_t.stddev()) ** 2)
            policy_decay = reg * (pi_mean_loss + pi_logstd_loss)

            # optimization
            self.v_optimize_expr = build_optim(self.v_loss, v_lr, 'sac/v')
            self.q1_optimize_expr = build_optim(self.q1_loss, q_lr, 'sac/q1')
            self.q2_optimize_expr = build_optim(self.q2_loss, q_lr, 'sac/q2')
            self.pi_optimize_expr = build_optim(self.pi_loss + policy_decay,
                                                pi_lr, 'sac/pi')

            # for inference
            self.action = squashed_action_t[0]
            self.value = tf.reshape(v_t, [-1])[0]
            self.log_prob = tf.reshape(log_prob_t, [-1])[0]