def _get_action(self, s, visual_s, cell_state, options): with tf.device(self.device): (q, pi, beta), cell_state = self.net( s, visual_s, cell_state=cell_state) # [B, P], [B, P, A], [B, P], [B, P] options_onehot = tf.one_hot(options, self.options_num, dtype=tf.float32) # [B, P] options_onehot_expanded = tf.expand_dims(options_onehot, axis=-1) # [B, P, 1] pi = tf.reduce_sum(pi * options_onehot_expanded, axis=1) # [B, A] if self.is_continuous: mu = pi log_std = tf.gather(self.log_std, options) sample_op, _ = gaussian_clip_rsample(mu, log_std) log_prob = gaussian_likelihood_sum(sample_op, mu, log_std) else: logits = pi norm_dist = tfp.distributions.Categorical( logits=tf.nn.log_softmax(logits)) sample_op = norm_dist.sample() log_prob = norm_dist.log_prob(sample_op) q_o = tf.reduce_sum(q * options_onehot, axis=-1) # [B, ] beta_adv = q_o - ((1 - self.eps) * tf.reduce_max(q, axis=-1) + self.eps * tf.reduce_mean(q, axis=-1)) # [B, ] max_options = tf.cast(tf.argmax(q, axis=-1), dtype=tf.int32) # [B, P] => [B, ] beta_probs = tf.reduce_sum(beta * options_onehot, axis=1) # [B, P] => [B,] beta_dist = tfp.distributions.Bernoulli(probs=beta_probs) new_options = tf.where(beta_dist.sample() < 1, options, max_options) # <1 则不改变op, =1 则改变op return sample_op, q_o, log_prob, beta_adv, new_options, max_options, cell_state
def _get_action(self, s, visual_s, cell_state, options): with tf.device(self.device): feat, cell_state = self._representation_net(s, visual_s, cell_state=cell_state) q = self.q_net.value_net(feat) # [B, P] pi = self.intra_option_net.value_net(feat) # [B, P, A] beta = self.termination_net.value_net(feat) # [B, P] options_onehot = tf.one_hot(options, self.options_num, dtype=tf.float32) # [B, P] options_onehot_expanded = tf.expand_dims(options_onehot, axis=-1) # [B, P, 1] pi = tf.reduce_sum(pi * options_onehot_expanded, axis=1) # [B, A] if self.is_continuous: log_std = tf.gather(self.log_std, options) mu = tf.math.tanh(pi) a, _ = gaussian_clip_rsample(mu, log_std) else: pi = pi / self.boltzmann_temperature dist = tfp.distributions.Categorical( logits=tf.nn.log_softmax(pi)) # [B, ] a = dist.sample() max_options = tf.cast(tf.argmax(q, axis=-1), dtype=tf.int32) # [B, P] => [B, ] if self.use_eps_greedy: new_options = max_options else: beta_probs = tf.reduce_sum(beta * options_onehot, axis=1) # [B, P] => [B,] beta_dist = tfp.distributions.Bernoulli(probs=beta_probs) new_options = tf.where(beta_dist.sample() < 1, options, max_options) return a, new_options, cell_state
def _get_action(self, s, visual_s, cell_state, options): with tf.device(self.device): feat, cell_state = self._representation_net(s, visual_s, cell_state=cell_state) q = self.q_net.value_net(feat) # [B, P] pi = self.intra_option_net.value_net(feat) # [B, P, A] options_onehot = tf.one_hot(options, self.options_num, dtype=tf.float32) # [B, P] options_onehot_expanded = tf.expand_dims(options_onehot, axis=-1) # [B, P, 1] pi = tf.reduce_sum(pi * options_onehot_expanded, axis=1) # [B, A] if self.is_continuous: log_std = tf.gather(self.log_std, options) mu = tf.math.tanh(pi) a, _ = gaussian_clip_rsample(mu, log_std) else: pi = pi / self.boltzmann_temperature dist = tfp.distributions.Categorical( logits=tf.nn.log_softmax(pi)) # [B, ] a = dist.sample() interests = self.interest_net.value_net(feat) # [B, P] op_logits = interests * q # [B, P] or tf.nn.softmax(q) new_options = tfp.distributions.Categorical( logits=tf.nn.log_softmax(op_logits)).sample() return a, new_options, cell_state
def _get_action(self, s, visual_s, cell_state): with tf.device(self.device): output, cell_state = self.net(s, visual_s, cell_state=cell_state) if self.is_continuous: mu, log_std = output sample_op, _ = gaussian_clip_rsample(mu, log_std) else: logits = output norm_dist = tfp.distributions.Categorical(logits=tf.nn.log_softmax(logits)) sample_op = norm_dist.sample() return sample_op, cell_state
def _get_action(self, s, visual_s, cell_state): with tf.device(self.device): feat, cell_state = self.get_feature(s, visual_s, cell_state=cell_state, record_cs=True) if self.is_continuous: mu = self.actor_net(feat) sample_op, _ = gaussian_clip_rsample(mu, self.log_std) else: logits = self.actor_net(feat) norm_dist = tfp.distributions.Categorical(logits) sample_op = norm_dist.sample() return sample_op, cell_state
def _get_action(self, obs, cell_state): with tf.device(self.device): output, cell_state = self.net(obs, cell_state=cell_state) if self.is_continuous: mu, log_std = output sample_op, _ = gaussian_clip_rsample(mu, log_std) log_prob = gaussian_likelihood_sum(sample_op, mu, log_std) else: logits = output norm_dist = tfp.distributions.Categorical(logits=tf.nn.log_softmax(logits)) sample_op = norm_dist.sample() log_prob = norm_dist.log_prob(sample_op) return sample_op, log_prob, cell_state
def _get_action(self, s, visual_s, cell_state): with tf.device(self.device): feat, cell_state = self.get_feature(s, visual_s, cell_state=cell_state, record_cs=True) value = self.critic_net(feat) if self.is_continuous: mu = self.actor_net(feat) sample_op, _ = gaussian_clip_rsample(mu, self.log_std) log_prob = gaussian_likelihood_sum(sample_op, mu, self.log_std) return sample_op, value, log_prob, mu, cell_state else: logits = self.actor_net(feat) logp_all = tf.nn.log_softmax(logits) norm_dist = tfp.distributions.Categorical(logits) sample_op = norm_dist.sample() log_prob = norm_dist.log_prob(sample_op) return sample_op, value, log_prob, logp_all, cell_state
def _get_action(self, obs, cell_state): with tf.device(self.device): feat, cell_state = self._representation_net(obs, cell_state=cell_state) value = self.net.value_net(feat) output = self.net.policy_net(feat) if self.is_continuous: mu, log_std = output sample_op, _ = gaussian_clip_rsample(mu, log_std) log_prob = gaussian_likelihood_sum(sample_op, mu, log_std) return sample_op, value, log_prob, (mu, log_std), cell_state else: logits = output logp_all = tf.nn.log_softmax(logits) norm_dist = tfp.distributions.Categorical(logits=logp_all) sample_op = norm_dist.sample() log_prob = norm_dist.log_prob(sample_op) return sample_op, value, log_prob, logp_all, cell_state
def _get_action(self, obs, cell_state): with tf.device(self.device): feat, cell_state = self._representation_net(obs, cell_state=cell_state) if self.is_continuous: if self.share_net: mu, log_std, value = self.net.value_net(feat) else: mu, log_std = self.net.policy_net(feat) value = self.net.value_net(feat) sample_op, _ = gaussian_clip_rsample(mu, log_std) log_prob = gaussian_likelihood_sum(sample_op, mu, log_std) else: if self.share_net: logits, value = self.net.value_net(feat) else: logits = self.net.policy_net(feat) value = self.net.value_net(feat) norm_dist = tfp.distributions.Categorical( logits=tf.nn.log_softmax(logits)) sample_op = norm_dist.sample() log_prob = norm_dist.log_prob(sample_op) return sample_op, value, log_prob, cell_state
def _get_action(self, s, visual_s, cell_state, options): with tf.device(self.device): feat, cell_state = self.get_feature(s, visual_s, cell_state=cell_state, record_cs=True) q, pi, beta, o = self.net( feat) # [B, P], [B, P, A], [B, P], [B, P] options_onehot = tf.one_hot(options, self.options_num, dtype=tf.float32) # [B, P] options_onehot_expanded = tf.expand_dims(options_onehot, axis=-1) # [B, P, 1] pi = tf.reduce_sum(pi * options_onehot_expanded, axis=1) # [B, A] if self.is_continuous: log_std = tf.gather(self.log_std, options) mu = pi sample_op, _ = gaussian_clip_rsample(mu, log_std) log_prob = gaussian_likelihood_sum(sample_op, mu, log_std) else: logits = pi norm_dist = tfp.distributions.Categorical(logits) sample_op = norm_dist.sample() log_prob = norm_dist.log_prob(sample_op) o_log_prob = tf.reduce_sum(o * options_onehot, axis=-1) # [B, ] q_o = tf.reduce_sum(q * options_onehot, axis=-1) # [B, ] beta_adv = q_o - tf.reduce_sum(q * tf.math.exp(o), axis=-1) # [B, ] option_norm_dist = tfp.distributions.Categorical( probs=tf.math.exp(o)) sample_options = option_norm_dist.sample() beta_probs = tf.reduce_sum(beta * options_onehot, axis=1) # [B, P] => [B,] beta_dist = tfp.distributions.Bernoulli(probs=beta_probs) new_options = tf.where(beta_dist.sample() < 1, options, sample_options) # <1 则不改变op, =1 则改变op return sample_op, q_o, log_prob, o_log_prob, beta_adv, new_options, cell_state