Пример #1
0
    def actions_for(self,
                    observations,
                    n_action_samples=1,
                    reuse=tf.AUTO_REUSE):

        n_state_samples = tf.shape(observations)[0]

        if n_action_samples > 1:
            observations = observations[:, None, :]
            latent_shape = (n_state_samples, n_action_samples,
                            self._action_dim)
        else:
            latent_shape = (n_state_samples, self._action_dim)

        latents = tf.random_normal(latent_shape)

        with tf.variable_scope('policy', reuse=reuse):
            raw_actions = feedforward_net(
                observations,
                latents,
                layer_sizes=self._layer_sizes,
                activation_fn=tf.nn.relu,
                output_nonlinearity=None)

        return tf.tanh(raw_actions) if self._squash else raw_actions
    single_latent = np.random.normal(size=(1, action_dim))

    raw_action = sess.run(raw_actions, feed_dict={observation_ph: observation,
                                                  latent_ph: single_latent})
    tan_action = sess.run(tan_actions, feed_dict={observation_ph: observation,
                                                  latent_ph: single_latent})

    return tan_action if squash else raw_action


# forward computation graph of stochastic policy network
with tf.variable_scope('policy', reuse=False):
    observation_ph = tf.placeholder(tf.float32, shape=[None, observation_dim])
    latent_ph = tf.placeholder(tf.float32, shape=[None, action_dim])

    raw_actions = feedforward_net((observation_ph, latent_ph), layer_sizes=layer_size,
                                  activation_fn=tf.nn.relu, output_nonlinearity=None)
    tan_actions = tf.tanh(raw_actions)


saver = tf.train.Saver()
with tf.Session() as sess:

    # restore the trained model
    ckpt = tf.train.get_checkpoint_state(SHARED_PARAMS['model_save_path'])
    # saver.restore(sess, ckpt.model_checkpoint_path)
    saver.restore(sess, SHARED_PARAMS['model_save_path']+'model-465.ckpt')

    # TODO: roll-out with the trained policy
    observation = np.expand_dims(env.reset(), axis=0)       # s_0, shape[1, 18]

    for t in range(ENV_PARAMS['max_path_length']):