示例#1
0
    def _create_dc_actor(self, encoded: tf.Tensor) -> None:
        """
        Creates Discrete control actor-critic model.
        :param h_size: Size of hidden linear layers.
        :param num_layers: Number of hidden linear layers.
        :param vis_encode_type: Type of visual encoder to use if visual input.
        """
        if self.use_recurrent:
            self.prev_action = tf.placeholder(shape=[None,
                                                     len(self.act_size)],
                                              dtype=tf.int32,
                                              name="prev_action")
            prev_action_oh = tf.concat(
                [
                    tf.one_hot(self.prev_action[:, i], self.act_size[i])
                    for i in range(len(self.act_size))
                ],
                axis=1,
            )
            hidden_policy = tf.concat([encoded, prev_action_oh], axis=1)

            self.memory_in = tf.placeholder(shape=[None, self.m_size],
                                            dtype=tf.float32,
                                            name="recurrent_in")
            hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder(
                hidden_policy,
                self.memory_in,
                self.sequence_length_ph,
                name="lstm_policy",
            )

            self.memory_out = tf.identity(memory_policy_out, "recurrent_out")
        else:
            hidden_policy = encoded

        self.action_masks = tf.placeholder(shape=[None,
                                                  sum(self.act_size)],
                                           dtype=tf.float32,
                                           name="action_masks")

        with tf.variable_scope("policy"):
            distribution = MultiCategoricalDistribution(
                hidden_policy, self.act_size, self.action_masks)
        # It's important that we are able to feed_dict a value into this tensor to get the
        # right one-hot encoding, so we can't do identity on it.
        self.output = distribution.sample
        self.all_log_probs = tf.identity(distribution.log_probs, name="action")
        self.selected_actions = tf.stop_gradient(
            distribution.sample_onehot)  # In discrete, these are onehot
        self.entropy = distribution.entropy
        self.total_log_probs = distribution.total_log_probs
def test_multicategorical_distribution():
    with tf.Graph().as_default():
        logits = tf.Variable(initial_value=[[0, 0]],
                             trainable=True,
                             dtype=tf.float32)
        action_masks = tf.Variable(
            initial_value=[[1 for _ in range(sum(DISCRETE_ACTION_SPACE))]],
            trainable=True,
            dtype=tf.float32,
        )
        distribution = MultiCategoricalDistribution(
            logits, act_size=DISCRETE_ACTION_SPACE, action_masks=action_masks)
        sess = tf.Session()
        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)
            output = sess.run(distribution.sample)
            for _ in range(10):
                sample, log_probs, entropy = sess.run([
                    distribution.sample, distribution.log_probs,
                    distribution.entropy
                ])
                assert len(log_probs[0]) == sum(DISCRETE_ACTION_SPACE)
                # Assert action never exceeds [-1,1]
                assert len(sample[0]) == len(DISCRETE_ACTION_SPACE)
                for i, act in enumerate(sample[0]):
                    assert act >= 0 and act <= DISCRETE_ACTION_SPACE[i]
                output = sess.run([distribution.total_log_probs])
                assert output[0].shape[0] == 1
                # Make sure entropy is correct
                assert entropy[0] > 3.8

            # Test masks
            mask = []
            for space in DISCRETE_ACTION_SPACE:
                mask.append(1)
                for _action_space in range(1, space):
                    mask.append(0)
            for _ in range(10):
                sample, log_probs = sess.run(
                    [distribution.sample, distribution.log_probs],
                    feed_dict={action_masks: [mask]},
                )
                for act in sample[0]:
                    assert act >= 0 and act <= 1
                output = sess.run([distribution.total_log_probs])