def _get_action(self, s, visual_s, cell_state): with tf.device(self.device): feat, cell_state = self.get_feature(s, visual_s, cell_state=cell_state, record_cs=True) if self.is_continuous: mu, log_std = self.actor_net(feat) log_std = clip_nn_log_std(log_std, self.log_std_min, self.log_std_max) pi, _ = tsallis_squash_rsample(mu, log_std, self.entropic_index) mu = tf.tanh(mu) # squash mu else: logits = self.actor_net(feat) mu = tf.argmax(logits, axis=1) cate_dist = tfp.distributions.Categorical(logits) pi = cate_dist.sample() return mu, pi, cell_state
def _get_action(self, s, visual_s, evaluation): s, visual_s = self.cast(s, visual_s) with tf.device(self.device): if self.is_continuous: mu, log_std = self.actor_net(s, visual_s) log_std = clip_nn_log_std(log_std, self.log_std_min, self.log_std_max) pi, _ = tsallis_squash_rsample(mu, log_std, self.entropic_index) mu = tf.tanh(mu) # squash mu else: logits = self.actor_net(s, visual_s) mu = tf.argmax(logits, axis=1) cate_dist = tfp.distributions.Categorical(logits) pi = cate_dist.sample() if evaluation == True: return mu else: return pi
def train_persistent(self, s, visual_s, a, r, s_, visual_s_, done): s, visual_s, a, r, s_, visual_s_, done = self.cast(s, visual_s, a, r, s_, visual_s_, done) with tf.device(self.device): with tf.GradientTape(persistent=True) as tape: if self.is_continuous: mu, log_std = self.actor_net(s, visual_s) log_std = clip_nn_log_std(log_std, self.log_std_min, self.log_std_max) pi, log_pi = tsallis_squash_rsample(mu, log_std, self.entropic_index) entropy = gaussian_entropy(log_std) target_mu, target_log_std = self.actor_net(s_, visual_s_) target_log_std = clip_nn_log_std(target_log_std) target_pi, target_log_pi = tsallis_squash_rsample(target_mu, target_log_std, self.entropic_index) else: logits = self.actor_net(s, visual_s) logp_all = tf.nn.log_softmax(logits) gumbel_noise = tf.cast(self.gumbel_dist.sample([a.shape[0], self.a_counts]), dtype=tf.float32) _pi = tf.nn.softmax((logp_all + gumbel_noise) / self.discrete_tau) _pi_true_one_hot = tf.one_hot(tf.argmax(_pi, axis=-1), self.a_counts) _pi_diff = tf.stop_gradient(_pi_true_one_hot - _pi) pi = _pi_diff + _pi log_pi = tf.reduce_sum(tf.multiply(logp_all, pi), axis=1, keepdims=True) entropy = -tf.reduce_mean(tf.reduce_sum(tf.exp(logp_all) * logp_all, axis=1, keepdims=True)) target_logits = self.actor_net(s_, visual_s_) target_cate_dist = tfp.distributions.Categorical(target_logits) target_pi = target_cate_dist.sample() target_pi = tf.one_hot(target_pi, self.a_counts, dtype=tf.float32) target_log_pi = target_cate_dist.log_prob(target_pi) q1 = self.q1_net(s, visual_s, a) q1_target = self.q1_target_net(s_, visual_s_, target_pi) q2 = self.q2_net(s, visual_s, a) q2_target = self.q2_target_net(s_, visual_s_, target_pi) q1_s_pi = self.q1_net(s, visual_s, pi) q2_s_pi = self.q2_net(s, visual_s, pi) dc_r_q1 = tf.stop_gradient(r + self.gamma * (1 - done) * (q1_target - tf.exp(self.log_alpha) * target_log_pi)) dc_r_q2 = tf.stop_gradient(r + self.gamma * (1 - done) * (q2_target - tf.exp(self.log_alpha) * target_log_pi)) td_error1 = q1 - dc_r_q1 td_error2 = q2 - dc_r_q2 q1_loss = tf.reduce_mean(tf.square(td_error1) * self.IS_w) q2_loss = tf.reduce_mean(tf.square(td_error2) * self.IS_w) critic_loss = 0.5 * q1_loss + 0.5 * q2_loss actor_loss = -tf.reduce_mean(tf.minimum(q1_s_pi, q2_s_pi) - tf.exp(self.log_alpha) * log_pi) if self.auto_adaption: alpha_loss = -tf.reduce_mean(self.log_alpha * tf.stop_gradient(log_pi - self.a_counts)) critic_grads = tape.gradient(critic_loss, self.q1_net.trainable_variables + self.q2_net.trainable_variables) self.optimizer_critic.apply_gradients( zip(critic_grads, self.q1_net.trainable_variables + self.q2_net.trainable_variables) ) actor_grads = tape.gradient(actor_loss, self.actor_net.trainable_variables) self.optimizer_actor.apply_gradients( zip(actor_grads, self.actor_net.trainable_variables) ) if self.auto_adaption: alpha_grads = tape.gradient(alpha_loss, [self.log_alpha]) self.optimizer_alpha.apply_gradients( zip(alpha_grads, [self.log_alpha]) ) self.global_step.assign_add(1) summaries = dict([ ['LOSS/actor_loss', actor_loss], ['LOSS/q1_loss', q1_loss], ['LOSS/q2_loss', q2_loss], ['LOSS/critic_loss', critic_loss], ['Statistics/log_alpha', self.log_alpha], ['Statistics/alpha', tf.exp(self.log_alpha)], ['Statistics/entropy', entropy], ['Statistics/q_min', tf.reduce_mean(tf.minimum(q1, q2))], ['Statistics/q_mean', tf.reduce_mean(q1)], ['Statistics/q_max', tf.reduce_mean(tf.maximum(q1, q2))] ]) if self.auto_adaption: summaries.update({ 'LOSS/alpha_loss': alpha_loss }) return td_error1 + td_error2 / 2, summaries
def train_persistent(self, memories, isw, crsty_loss, cell_state): ss, vvss, a, r, done = memories batch_size = tf.shape(a)[0] with tf.device(self.device): with tf.GradientTape(persistent=True) as tape: feat, feat_ = self.get_feature(ss, vvss, cell_state=cell_state, s_and_s_=True) if self.is_continuous: mu, log_std = self.actor_net(feat) log_std = clip_nn_log_std(log_std, self.log_std_min, self.log_std_max) pi, log_pi = tsallis_squash_rsample(mu, log_std, self.entropic_index) entropy = gaussian_entropy(log_std) target_mu, target_log_std = self.actor_net(feat_) target_log_std = clip_nn_log_std(target_log_std) target_pi, target_log_pi = tsallis_squash_rsample(target_mu, target_log_std, self.entropic_index) else: logits = self.actor_net(feat) logp_all = tf.nn.log_softmax(logits) gumbel_noise = tf.cast(self.gumbel_dist.sample([batch_size, self.a_dim]), dtype=tf.float32) _pi = tf.nn.softmax((logp_all + gumbel_noise) / self.discrete_tau) _pi_true_one_hot = tf.one_hot(tf.argmax(_pi, axis=-1), self.a_dim) _pi_diff = tf.stop_gradient(_pi_true_one_hot - _pi) pi = _pi_diff + _pi log_pi = tf.reduce_sum(tf.multiply(logp_all, pi), axis=1, keepdims=True) entropy = -tf.reduce_mean(tf.reduce_sum(tf.exp(logp_all) * logp_all, axis=1, keepdims=True)) target_logits = self.actor_net(feat_) target_cate_dist = tfp.distributions.Categorical(target_logits) target_pi = target_cate_dist.sample() target_pi = tf.one_hot(target_pi, self.a_dim, dtype=tf.float32) target_log_pi = target_cate_dist.log_prob(target_pi) q1, q2 = self.critic_net(feat, a) q1_target, q2_target = self.critic_target_net(feat_, target_pi) q_s_pi = self.critic_net.get_min(feat, pi) dc_r_q1 = tf.stop_gradient(r + self.gamma * (1 - done) * (q1_target - self.alpha * target_log_pi)) dc_r_q2 = tf.stop_gradient(r + self.gamma * (1 - done) * (q2_target - self.alpha * target_log_pi)) td_error1 = q1 - dc_r_q1 td_error2 = q2 - dc_r_q2 q1_loss = tf.reduce_mean(tf.square(td_error1) * isw) q2_loss = tf.reduce_mean(tf.square(td_error2) * isw) critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + crsty_loss actor_loss = -tf.reduce_mean(q_s_pi - self.alpha * log_pi) if self.auto_adaption: alpha_loss = -tf.reduce_mean(self.alpha * tf.stop_gradient(log_pi - self.target_entropy)) critic_grads = tape.gradient(critic_loss, self.critic_tv) self.optimizer_critic.apply_gradients( zip(critic_grads, self.critic_tv) ) actor_grads = tape.gradient(actor_loss, self.actor_tv) self.optimizer_actor.apply_gradients( zip(actor_grads, self.actor_tv) ) if self.auto_adaption: alpha_grad = tape.gradient(alpha_loss, self.log_alpha) self.optimizer_alpha.apply_gradients( [(alpha_grad, self.log_alpha)] ) self.global_step.assign_add(1) summaries = dict([ ['LOSS/actor_loss', actor_loss], ['LOSS/q1_loss', q1_loss], ['LOSS/q2_loss', q2_loss], ['LOSS/critic_loss', critic_loss], ['Statistics/log_alpha', self.log_alpha], ['Statistics/alpha', self.alpha], ['Statistics/entropy', entropy], ['Statistics/q_min', tf.reduce_min(tf.minimum(q1, q2))], ['Statistics/q_mean', tf.reduce_mean(tf.minimum(q1, q2))], ['Statistics/q_max', tf.reduce_max(tf.maximum(q1, q2))] ]) if self.auto_adaption: summaries.update({ 'LOSS/alpha_loss': alpha_loss }) return (td_error1 + td_error2) / 2, summaries