def _create_dc_actor(self, encoded: tf.Tensor) -> None: """ Creates Discrete control actor-critic model. :param h_size: Size of hidden linear layers. :param num_layers: Number of hidden linear layers. :param vis_encode_type: Type of visual encoder to use if visual input. """ if self.use_recurrent: self.prev_action = tf.placeholder(shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action") prev_action_oh = tf.concat( [ tf.one_hot(self.prev_action[:, i], self.act_size[i]) for i in range(len(self.act_size)) ], axis=1, ) hidden_policy = tf.concat([encoded, prev_action_oh], axis=1) self.memory_in = tf.placeholder(shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in") hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder( hidden_policy, self.memory_in, self.sequence_length_ph, name="lstm_policy", ) self.memory_out = tf.identity(memory_policy_out, "recurrent_out") else: hidden_policy = encoded self.action_masks = tf.placeholder(shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks") with tf.variable_scope("policy"): distribution = MultiCategoricalDistribution( hidden_policy, self.act_size, self.action_masks) # It's important that we are able to feed_dict a value into this tensor to get the # right one-hot encoding, so we can't do identity on it. self.output = distribution.sample self.all_log_probs = tf.identity(distribution.log_probs, name="action") self.selected_actions = tf.stop_gradient( distribution.sample_onehot) # In discrete, these are onehot self.entropy = distribution.entropy self.total_log_probs = distribution.total_log_probs
def test_multicategorical_distribution(): with tf.Graph().as_default(): logits = tf.Variable(initial_value=[[0, 0]], trainable=True, dtype=tf.float32) action_masks = tf.Variable( initial_value=[[1 for _ in range(sum(DISCRETE_ACTION_SPACE))]], trainable=True, dtype=tf.float32, ) distribution = MultiCategoricalDistribution( logits, act_size=DISCRETE_ACTION_SPACE, action_masks=action_masks) sess = tf.Session() with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init) output = sess.run(distribution.sample) for _ in range(10): sample, log_probs, entropy = sess.run([ distribution.sample, distribution.log_probs, distribution.entropy ]) assert len(log_probs[0]) == sum(DISCRETE_ACTION_SPACE) # Assert action never exceeds [-1,1] assert len(sample[0]) == len(DISCRETE_ACTION_SPACE) for i, act in enumerate(sample[0]): assert act >= 0 and act <= DISCRETE_ACTION_SPACE[i] output = sess.run([distribution.total_log_probs]) assert output[0].shape[0] == 1 # Make sure entropy is correct assert entropy[0] > 3.8 # Test masks mask = [] for space in DISCRETE_ACTION_SPACE: mask.append(1) for _action_space in range(1, space): mask.append(0) for _ in range(10): sample, log_probs = sess.run( [distribution.sample, distribution.log_probs], feed_dict={action_masks: [mask]}, ) for act in sample[0]: assert act >= 0 and act <= 1 output = sess.run([distribution.total_log_probs])