def test_save_and_load(self): """Try to save and load encoder.""" vocabulary = Vocabulary() vocabulary.add_word("a") vocabulary.add_word("b") checkpoint_file = tempfile.NamedTemporaryFile(delete=False) checkpoint_file.close() encoder = SentenceEncoder( "enc", Vocabulary(), "data_id", 10, 20, 30, save_checkpoint=checkpoint_file.name, load_checkpoint=checkpoint_file.name) encoders_variables = tf.get_collection( tf.GraphKeys.VARIABLES, scope="enc") sess_1 = tf.Session() sess_1.run(tf.initialize_all_variables()) encoder.save(sess_1) sess_2 = tf.Session() sess_2.run(tf.initialize_all_variables()) encoder.load(sess_2) values_in_sess_1 = sess_1.run(encoders_variables) values_in_sess_2 = sess_2.run(encoders_variables) self.assertTrue( all(np.all(v1 == v2) for v1, v2 in zip(values_in_sess_1, values_in_sess_2))) os.remove(checkpoint_file.name)
class Word2Vec: def __init__(self, path: str, encoding: str = "utf-8") -> None: """Load the word2vec file.""" check_argument_types() # Create the vocabulary object, load the words and vectors from the # file self.vocab = Vocabulary() embedding_vectors = [] # type: List[np.ndarray] with open(path, encoding=encoding) as f_data: header = next(f_data) emb_size = int(header.split()[1]) # Add zero embeddings for padding, start, and end token embedding_vectors.append(np.zeros(emb_size)) embedding_vectors.append(np.zeros(emb_size)) embedding_vectors.append(np.zeros(emb_size)) # Add placeholder for embedding of the unknown symbol embedding_vectors.append(None) for line in f_data: fields = line.split() word = fields[0] vector = np.fromiter((float(x) for x in fields[1:]), dtype=np.float) assert vector.shape[0] == emb_size # Embedding of unknown token should be at index 3 to match the # vocabulary implementation if is_special_token(word): embedding_vectors[SPECIAL_TOKENS.index(word)] = vector else: self.vocab.add_word(word) embedding_vectors.append(vector) assert embedding_vectors[3] is not None assert emb_size is not None self.embedding_matrix = np.stack(embedding_vectors) @property def vocabulary(self) -> Vocabulary: """Get a vocabulary object generated from this word2vec instance.""" return self.vocab @property def embeddings(self) -> np.ndarray: """Get the embedding matrix.""" return self.embedding_matrix
def test_save_and_load(self): """Try to save and load encoder.""" vocabulary = Vocabulary() vocabulary.add_word("a") vocabulary.add_word("b") checkpoint_file = tempfile.NamedTemporaryFile(delete=False) checkpoint_file.close() encoder = SentenceEncoder(name="enc", vocabulary=Vocabulary(), data_id="data_id", embedding_size=10, rnn_size=20, max_input_len=30, save_checkpoint=checkpoint_file.name, load_checkpoint=checkpoint_file.name) encoder.input_sequence.register_input() # NOTE: This assert needs to be here otherwise the model has # no parameters since the sentence encoder is initialized lazily self.assertIsInstance(encoder.temporal_states, tf.Tensor) encoders_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="enc") sess_1 = tf.Session() sess_1.run(tf.global_variables_initializer()) encoder.save(sess_1) sess_2 = tf.Session() sess_2.run(tf.global_variables_initializer()) encoder.load(sess_2) values_in_sess_1 = sess_1.run(encoders_variables) values_in_sess_2 = sess_2.run(encoders_variables) self.assertTrue( all( np.all(v1 == v2) for v1, v2 in zip(values_in_sess_1, values_in_sess_2))) os.remove(checkpoint_file.name)
def test_reuse(self): vocabulary = Vocabulary() vocabulary.add_word("a") vocabulary.add_word("b") seq1 = EmbeddedSequence(name="seq1", vocabulary=vocabulary, data_id="id", embedding_size=10) seq1.register_input() seq2 = EmbeddedSequence(name="seq2", vocabulary=vocabulary, embedding_size=10, data_id="id") seq2.register_input() seq3 = EmbeddedSequence(name="seq3", vocabulary=vocabulary, data_id="id", embedding_size=10, reuse=seq1) seq3.register_input() # blessing self.assertIsNotNone(seq1.embedding_matrix) self.assertIsNotNone(seq2.embedding_matrix) self.assertIsNotNone(seq3.embedding_matrix) sess = tf.Session() sess.run(tf.global_variables_initializer()) params = sess.run((seq1.embedding_matrix, seq2.embedding_matrix, seq3.embedding_matrix)) with self.assertRaises(AssertionError): assert_array_equal(params[0], params[1]) assert_array_equal(params[0], params[2])