def _create_pretrained_emb_from_txt(vocab_file, embed_file, num_trainable_tokens=3, dtype=tf.float32, scope=None): """Load pretrain embeding from embed_file, and return an embedding matrix. Args: embed_file: Path to a Glove formated embedding txt file. num_trainable_tokens: Make the first n tokens in the vocab file as trainable variables. Default is 3, which is "<unk>", "<s>" and "</s>". """ vocab, _ = vocab_utils.load_vocab(vocab_file) trainable_tokens = vocab[:num_trainable_tokens] utils.print_out("# Using pretrained embedding: %s." % embed_file) utils.print_out(" with trainable tokens: ") emb_dict, emb_size = vocab_utils.load_embed_txt(embed_file) for token in trainable_tokens: utils.print_out(" %s" % token) if token not in emb_dict: emb_dict[token] = [0.0] * emb_size emb_mat = np.array([emb_dict[token] for token in vocab], dtype=dtype.as_numpy_dtype()) emb_mat = tf.constant(emb_mat) emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1]) with tf.variable_scope(scope or "pretrain_embeddings", dtype=dtype) as scope: with tf.device(_get_embed_device(num_trainable_tokens)): emb_mat_var = tf.get_variable("emb_mat_var", [num_trainable_tokens, emb_size]) return tf.concat([emb_mat_var, emb_mat_const], 0)
def _create_pretrained_emb_from_txt(vocab_file, embed_file, num_trainable_tokens=3, dtype=tf.float32, scope=None): """Load pretrain embedding from embed_file, and return an embedding matrix.""" vocab, _ = vocab_utils.load_vocab(vocab_file) trainable_tokens = vocab[:num_trainable_tokens] utils.print_out("# Using pretrained embedding: %s." % embed_file) utils.print_out(" with trainable tokens: ") emb_dict, emb_size = vocab_utils.load_embed_txt(embed_file) for token in trainable_tokens: utils.print_out(" %s" % token) if token not in emb_dict: emb_dict[token] = [0.0] * emb_size emb_mat = np.array([emb_dict[token] for token in vocab], dtype=dtype.as_numpy_dtype()) emb_mat = tf.constant(emb_mat) emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1]) with tf.variable_scope(scope or "pretrain_embeddings", dtype=dtype) as scope: emb_mat_var = tf.get_variable("emb_mat_var", [num_trainable_tokens, emb_size]) return tf.concat([emb_mat_var, emb_mat_const], 0)
def extend_hparams(hparams): """Extend training hparams.""" hparams.add_hparam("input_emb_pretrain", hparams.input_emb_file is not None) # Check if vocab has the unk and pad symbols as first words. If not, create a new vocab file with these symbols as # the first two words. vocab_size, vocab_path = vocab_utils.check_vocab(hparams.vocab_path, hparams.out_dir, unk=hparams.unk, pad=hparams.pad) vocab, _ = vocab_utils.load_vocab(vocab_path) # Generating embeddings if flag is true or file is not present if hparams.create_new_embeddings or os.path.isfile( hparams.input_emb_file) is False: embedding.save_embedding(vocab, hparams.embedding_path, hparams.input_emb_file) hparams.add_hparam("vocab_size", vocab_size) hparams.set_hparam("vocab_path", vocab_path) if not tf.gfile.Exists(hparams.out_dir): tf.gfile.MakeDirs(hparams.out_dir) return hparams
def testCheckVocab(self): # Create a vocab file vocab_dir = os.path.join(tf.test.get_temp_dir(), "vocab_dir") os.makedirs(vocab_dir) vocab_file = os.path.join(vocab_dir, "vocab_file") vocab = ["a", "b", "c"] with codecs.getwriter("utf-8")(tf.gfile.GFile(vocab_file, "wb")) as f: for word in vocab: f.write("%s\n" % word) # Call vocab_utils out_dir = os.path.join(tf.test.get_temp_dir(), "out_dir") os.makedirs(out_dir) vocab_size, new_vocab_file = vocab_utils.check_vocab( vocab_file, out_dir) # Assert: we expect the code to add <unk>, <s>, </s> and # create a new vocab file self.assertEqual(len(vocab) + 3, vocab_size) self.assertEqual(os.path.join(out_dir, "vocab_file"), new_vocab_file) new_vocab, _ = vocab_utils.load_vocab(new_vocab_file) self.assertEqual( [vocab_utils.UNK, vocab_utils.SOS, vocab_utils.EOS] + vocab, new_vocab)
if __name__ == "__main__": import argparse from utils import vocab_utils parser = argparse.ArgumentParser() parser.add_argument("--vocab_path", type=str, default=None, help="Vocabulary input file path.") parser.add_argument("--embed_path", type=str, default=None, help="Input embedding file path.") parser.add_argument("--out_path", type=str, default=None, help="Output pat.") flags, unparsed = parser.parse_known_args() vocab, _ = vocab_utils.load_vocab(flags.vocab_path) unk = u"<unk>" pad = u"<pad>" if vocab[0] != pad or vocab[1] != unk: vocab = [pad, unk] + vocab save_embedding( vocab, flags.flags.embed_path, flags.out_path + flags.embed_path.split("/")[-1].replace(".txt", "_") + flags.vocab_path.split("/")[-1])