-
Notifications
You must be signed in to change notification settings - Fork 0
/
calc_tsne.py
75 lines (62 loc) · 2.03 KB
/
calc_tsne.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import tensorflow as tf
from bhtsne.bhtsne import bh_tsne
import config
from model import SentenceCNN
import os
import preprocessing as pre
import numpy as np
DEFAULT_THETA = 0.5
EMPTY_SEED = -1
VERBOSE = True
INITIAL_DIMENSIONS = 50
def resume_model():
x, y, vocabulary, vocabulary_inv = pre.load_data()
# Randomly shuffle data
shuffle_indices = np.random.permutation(np.arange(len(y)))
x_shuffled = x[shuffle_indices]
y_shuffled = y[shuffle_indices]
x_train, x_val = x_shuffled[:-1000], x_shuffled[-1000:]
sess = tf.Session()
cnn = SentenceCNN(
sequence_length=x_train.shape[1],
num_classes=2,
vocab_size=len(vocabulary),
sess=sess
)
cnn.inference()
cnn.train()
# Create a saver.
saver = tf.train.Saver()
checkpoint_dir = os.path.abspath(os.path.join(config.out_dir, "checkpoints"))
checkpoint_prefix = os.path.join(checkpoint_dir, "model")
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
cnn.sess = sess
# Assuming model_checkpoint_path looks something like:
# /my-favorite-path/cifar10_train/model.ckpt-0,
# extract global_step from it.
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
return cnn
else:
print('No checkpoint file found. Cannot resume.')
return None
def main(_):
model = resume_model()
if model is not None:
norm_w_embed = tf.nn.l2_normalize(model.W_embed, 1) # [vocab_size, embed_size]
embedings = model.sess.run(norm_w_embed)
results = bh_tsne(embedings, no_dims=2, perplexity=50, theta=DEFAULT_THETA, randseed=EMPTY_SEED,
verbose=VERBOSE)
with open(os.path.join(config.out_dir, "vocab/tsne.txt"), "w") as f:
for result in results:
fmt = ''
for i in range(1, len(result)):
fmt = fmt + '{}\t'
fmt = fmt + '{}\n'
f.write(fmt.format(*result))
else:
print "Model is None"
if __name__ == "__main__":
tf.app.run()