class ActorCriticNet: def __init__(self, input_dim, action_dim, critic_layers, actor_layers, actor_activation, scope='ac_network'): self.input_dim = input_dim self.action_dim = action_dim self.scope = scope self.x = tf.placeholder(shape=(None, input_dim), dtype=tf.float32, name='x') self.y = tf.placeholder(shape=(None, ), dtype=tf.float32, name='y') with tf.variable_scope(scope): self.actor_network = ActorNetwork(self.x, action_dim, hidden_layers=actor_layers, activation=actor_activation) self.critic_network = CriticNetwork( self.x, self.actor_network.get_output_layer(), hidden_layers=critic_layers) self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) self._build() def _build(self): value = self.critic_network.get_output_layer() actor_loss = -tf.reduce_mean(value) self.actor_vars = self.actor_network.get_params() self.actor_grad = tf.gradients(actor_loss, self.actor_vars) tf.summary.scalar("actor_loss", actor_loss, collections=['actor']) self.actor_summary = tf.summary.merge_all('actor') critic_loss = 0.5 * tf.reduce_mean(tf.square((value - self.y))) self.critic_vars = self.critic_network.get_params() self.critic_grad = tf.gradients(critic_loss, self.critic_vars) tf.summary.scalar("critic_loss", critic_loss, collections=['critic']) self.critic_summary = tf.summary.merge_all('critic') def get_action(self, sess, state): return self.actor_network.get_action(sess, state) def get_value(self, sess, state): return self.critic_network.get_value(sess, state) def get_action_value(self, sess, state, action): return self.critic_network.get_action_value(sess, state, action) def get_actor_feed_dict(self, state): return {self.x: state} def get_critic_feed_dict(self, state, action, target): return { self.x: state, self.y: target, self.critic_network.input_action: action } def get_clone_op(self, network, tau=0.9): update_ops = [] new_vars = {v.name.replace(network.scope, ''): v for v in network.vars} for v in self.vars: u = (1 - tau) * v + tau * new_vars[v.name.replace(self.scope, '')] update_ops.append(tf.assign(v, u)) return update_ops