forked from hunkim/word-rnn-tensorflow
/
lstm.py
77 lines (57 loc) · 2.14 KB
/
lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import tensorflow as tf
import numpy as np
#import utils.py
from tensorflow.python.ops import rnn_cell, seq2seq
import tempfile
sess = tf.InteractiveSession()
seq_length = 5
batch_size = 64
vocab_size = 7
embedding_dim = 50
memory_dim = 100
enc_inp = [tf.placeholder(tf.int32, shape=(None,),
name="inp%i" % t)
for t in range(seq_length)]
labels = [tf.placeholder(tf.int32, shape=(None,),
name="labels%i" % t)
for t in range(seq_length)]
weights = [tf.ones_like(labels_t, dtype=tf.float32)
for labels_t in labels]
# Decoder input: prepend some "GO" token and drop the final
# token of the encoder input
dec_inp = ([tf.zeros_like(enc_inp[0], dtype=np.int32, name="GO")]
+ enc_inp[:-1])
# Initial memory value for recurrence.
prev_mem = tf.zeros((batch_size, memory_dim))
cell = rnn_cell.GRUCell(memory_dim)
dec_outputs, dec_memory = seq2seq.embedding_rnn_seq2seq(
enc_inp, dec_inp, cell, vocab_size, vocab_size, embedding_dim)
loss = seq2seq.sequence_loss(dec_outputs, labels, weights, vocab_size)
tf.scalar_summary("loss", loss)
magnitude = tf.sqrt(tf.reduce_sum(tf.square(dec_memory[1])))
tf.scalar_summary("magnitude at t=1", magnitude)
summary_op = tf.merge_all_summaries()
learning_rate = 0.05
momentum = 0.9
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum)
train_op = optimizer.minimize(loss)
logdir = tempfile.mkdtemp()
print(logdir)
summary_writer = tf.train.SummaryWriter(logdir, sess.graph)
sess.run(tf.initialize_all_variables())
def train_batch(batch_size):
X = [np.random.choice(vocab_size, size=(seq_length,), replace=False)
for _ in range(batch_size)]
Y = X[:]
# Dimshuffle to seq_len * batch_size
X = np.array(X).T
Y = np.array(Y).T
feed_dict = {enc_inp[t]: X[t] for t in range(seq_length)}
feed_dict.update({labels[t]: Y[t] for t in range(seq_length)})
_, loss_t, summary = sess.run([train_op, loss, summary_op], feed_dict)
return loss_t, summary
for t in range(500):
loss_t, summary = train_batch(batch_size)
print(loss_t)
summary_writer.add_summary(summary, t)
summary_writer.flush()