コード例 #1
0
sess = tf.Session()
sess.run(tf.global_variables_initializer())                     # initialize tf variables

# something about plotting (can be ignored)
n = 300
x = np.linspace(-20, 20, n)
X, Y = np.meshgrid(x, x)
Z = np.zeros_like(X)
for i in range(n):
    for j in range(n):
        Z[i, j] = get_fitness(np.array([[x[i], x[j]]]))
plt.contourf(X, Y, -Z, 100, cmap=plt.cm.rainbow); plt.ylim(-20, 20); plt.xlim(-20, 20); plt.ion()

# training
for g in range(N_GENERATION):
    # if g % 10 == 0:
    #     LR = LR * pow(0.9, g / 10)
    #     print(g+1, 'LR', LR)
    kids = sess.run(make_kid)
    print(sess.run(mvn.mean()))
    print(sess.run(mvn.covariance()))
    kids_fit = get_fitness(kids)
    get_max_fit(kids_fit)
    kids_fits = get_minus_fit(kids_fit)
    sess.run(train_op, {tfkids_fit: kids_fits, tfkids: kids})    # update distribution parameters
    # plotting update
    if 'sca' in globals(): sca.remove()
    sca = plt.scatter(kids[:, 0], kids[:, 1], s=30, c='k');plt.pause(0.01)

print('Finished'); plt.ioff(); plt.show()
コード例 #2
0
scope = "zhangxin"

mean1 = tf.Variable(tf.truncated_normal([5, ], stddev=0.1, mean=2,), dtype=tf.float32, name='mean1')
cov1 = tf.Variable(tf.eye(5), dtype=tf.float32, name='cov1')

mvn1 = MultivariateNormalFullCovariance(loc=mean1, covariance_matrix=cov1, name=scope + '_mvn1')
kids1 = mvn1.sample(5)

mean2 = tf.Variable(tf.truncated_normal([5, ], stddev=0.1, mean=20), dtype=tf.float32, name='mean2')
cov2 = tf.Variable(tf.eye(5), dtype=tf.float32, name='cov2')
mvn2 = MultivariateNormalFullCovariance(loc=mean2, covariance_matrix=cov2, name='mvn2', )
kids2 = mvn2.sample(5)
mvn2 = mvn1.copy()

sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(kids1))
print('====================================================')
print(sess.run(kids2))
print('====================================================')
print(mvn1.mean)
print(mvn2.mean)
print('====================================================')
print(sess.run(mvn1.mean()))
print(sess.run(mvn2.mean()))
print('====================================================')
print(sess.run(mvn1.sample(5)))
print('====================================================')
print(sess.run(mvn2.sample(5)))
print(mvn1.name)