-
Notifications
You must be signed in to change notification settings - Fork 0
/
make_embedd.py
52 lines (40 loc) · 1.87 KB
/
make_embedd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
np.random.seed(10)
module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/3" #@param ["https://tfhub.dev/google/universal-sentence-encoder/2", "https://tfhub.dev/google/universal-sentence-encoder-large/3"]
# Import the Universal Sentence Encoder's TF Hub module
embed = hub.Module(module_url)
# Compute a representation for each message, showing various lengths supported.
word = "king"
word2 = "queen"
word3 = "man"
word4 = "girl"
sentence = "I am red"
sentence2 = "I am blue"
sentence3 = "I am green"
messages = [word, word2, word3, word4, sentence, sentence2, sentence3]
# Reduce logging output.
tf.logging.set_verbosity(tf.logging.ERROR)
list_embedding = []
with tf.Session() as session:
session.run([tf.global_variables_initializer(), tf.tables_initializer()])
message_embeddings = session.run(embed(messages))
for i, message_embedding in enumerate(np.array(message_embeddings).tolist()):
print("Message: {}".format(messages[i]))
print("Embedding size: {}".format(len(message_embedding)))
list_embedding.append(message_embedding)
message_embedding_snippet = ", ".join(
(str(x) for x in message_embedding[:3]))
print("Embedding: [{}, ...]\n".format(message_embedding_snippet))
# Compute a representation for each message, showing various lengths supported.
messages = ["That band rocks!", "That song is really cool."]
temp = np.array(list_embedding[0]) - np.array(list_embedding[2]) + np.array(list_embedding[3])
print(temp)
print(np.array(list_embedding[1]))
with tf.Session() as session:
session.run([tf.global_variables_initializer(), tf.tables_initializer()])
message_embeddings = session.run(embed(messages))
print(message_embeddings)
embed_size = embed.get_output_info_dict()['default'].get_shape()[1].value
print(embed_size)