def get_fn_sample(): mu0 = tf.placeholder(shape=[dim], dtype=tf.float32) a = tf.placeholder(shape=[batch_size, dim], dtype=tf.float32) logsigma0 = tf.placeholder(shape=[dim], dtype=tf.float32) sample_size = tf.placeholder(shape=(), dtype=tf.int32) dist = DiagGaussianPd(tf.concat((mu0, logsigma0), axis=0)) samples = dist.sample(sample_size) fn_sample = U.function([mu0, logsigma0, sample_size], samples) fn_p = U.function([mu0, logsigma0, a], dist.p(a)) return fn_sample, fn_p
def get_func_cons(batch_size): mu_logsigma = tf.placeholder(shape=[batch_size, 2], dtype=tf.float32) delta = tf.placeholder(shape=(1, ), dtype=tf.float32) x0 = tf.zeros(shape=[batch_size, 2]) distNormal = DiagGaussianPd(x0) dist = DiagGaussianPd(mu_logsigma) kl = dist.kl(distNormal) cons = kl - delta fn_get_cons = U.function([mu_logsigma, delta], cons) return fn_get_cons
def get_com_batch(dim, sess, batch_size, share_size, sharelogsigma): # 这里的batch_size 不指定 用None会有bug with sess.as_default(), sess.graph.as_default(): pls = DotMap() x0 = tf.placeholder(dtype=tf.float32, shape=(batch_size, dim * 2), name='x0') if not sharelogsigma: x_initial = tf.placeholder(dtype=tf.float32, shape=x0.shape, name='x0') x = tf.Variable(x_initial, name='x') pls.x_initial = x_initial else: mu_initial = tf.placeholder(dtype=tf.float32, shape=(batch_size, dim), name='x0') independent_size = batch_size // share_size logsigma_initial = tf.placeholder(dtype=tf.float32, shape=(independent_size, dim), name='x0') mu = tf.Variable(mu_initial, name='mu') logsigma = tf.Variable(logsigma_initial, name='logsigma') logsigma_all = tf.tile(logsigma, [1, share_size]) logsigma_all = tf.reshape(logsigma_all, [-1, dim]) x = tf.concat((mu, logsigma_all), axis=-1) pls.mu_initial = mu_initial pls.logsigma_initial = logsigma_initial a = tf.placeholder(dtype=tf.float32, shape=(batch_size, dim), name='a') delta = tf.placeholder(dtype=tf.float32, shape=(), name='delta') # --- objective function dist = DiagGaussianPd(x) f = dist.neglogp(a) p = dist.p(a) dist0 = DiagGaussianPd(x0) # con = dist.kl(dist0) - delta con = dist0.kl(dist) - delta # 拟合 lambda p0 = dist0.p(a) ratio = p / p0 pls_new = DotMap(x0=x0, a=a, delta=delta) pls.update(pls_new) ffs = DotMap(p=p, p0=p0) return f, con, ratio, x, ffs, pls
def get_fn_ratio(): batch_size = None action = tf.placeholder(shape=[batch_size, 1], dtype=tf.float32) mu_logstd_min = tf.placeholder(shape=[batch_size, 2], dtype=tf.float32) mu_logstd_max = tf.placeholder(shape=[batch_size, 2], dtype=tf.float32) x0 = tf.zeros_like( mu_logstd_min ) distNormal = DiagGaussianPd(x0) dist_min, dist_max = DiagGaussianPd(mu_logstd_min), DiagGaussianPd(mu_logstd_max) ratio_min = dist_min.p(action) / distNormal.p(action) ratio_max = dist_max.p(action) / distNormal.p(action) fn_ratio = U.function([action, mu_logstd_min, mu_logstd_max], (ratio_min, ratio_max)) return fn_ratio