Beispiel #1
0
probs = network.outputs
sampling_prob = tf.nn.softmax(probs)

t_actions = tf.placeholder(tf.int32, shape=[None])
t_discount_rewards = tf.placeholder(tf.float32, shape=[None])
loss = tl.rein.cross_entropy_reward_loss(probs, t_actions, t_discount_rewards)
train_op = tf.train.RMSPropOptimizer(learning_rate, decay_rate).minimize(loss)

with tf.Session() as sess:
    tl.layers.initialize_global_variables(sess)
    # if resume:
    #     load_params = tl.files.load_npz(name=model_file_name+'.npz')
    #     tl.files.assign_params(sess, load_params, network)
    tl.files.load_and_assign_npz(sess, model_file_name + '.npz', network)
    network.print_params()
    network.print_layers()

    start_time = time.time()
    game_number = 0
    while True:
        if render: env.render()

        cur_x = prepro(observation)
        x = cur_x - prev_x if prev_x is not None else np.zeros(D)
        x = x.reshape(1, D)
        prev_x = cur_x

        prob = sess.run(sampling_prob, feed_dict={t_states: x})

        # action. 1: STOP  2: UP  3: DOWN
        # action = np.random.choice([1,2,3], p=prob.flatten())
import tensorflow as tf
from tensorlayer.layers import SubpixelConv1d, SubpixelConv2d, InputLayer, Conv1d, Conv2d

## 1D
t_signal = tf.placeholder('float32', [10, 100, 4], name='x')
n = InputLayer(t_signal, name='in')
n = Conv1d(n, 32, 3, 1, padding='SAME', name='conv1d')
n = SubpixelConv1d(n, scale=2, name='subpixel')
print(n.outputs.shape)
# ... (10, 200, 2)
n.print_layers()
n.print_params(False)

shape = n.outputs.get_shape().as_list()
if shape != [10, 200, 16]:
    raise Exception("shape dont match")

if len(n.all_layers) != 2:
    raise Exception("layers dont match")

if len(n.all_params) != 2:
    raise Exception("params dont match")

if n.count_params() != 416:
    raise Exception("params dont match")

## 2D
x = tf.placeholder('float32', [10, 100, 100, 3], name='x')
n = InputLayer(x, name='in')
n = Conv2d(n, 32, (3, 2), (1, 1), padding='SAME', name='conv2d')
n = SubpixelConv2d(n, scale=2, name='subpixel2d')