def _format_line(line): line = re.sub(url_marker.WEB_URL_REGEX, "<link>", line) line = re.sub("[\.]+", ".", line) line = re.sub("[0-9]*\,[0-9]+", "<num>", line) line = re.sub("[0-9]*\.[0-9]+", "<num>", line) line = re.sub("[0-9]+", "<num>", line) line = re.sub("[\.\?\!]", " <eos> ", line) return basic_tokenizer(line)
# if line[0] == "\n": # continue # line = line.replace(".",".\n").replace("?","?\n").replace("!","!\n") # result.append(line) # # with codecs.open(processed_path, "w", encoding="utf8") as f: # f.write(''.join(result)) if add_unk: with codecs.open(processed_path, encoding="utf8") as f: lines = f.readlines() for line in tqdm(lines): if line[0] == "\n": continue line += " <eos>" for word in basic_tokenizer(line): if word.isdigit(): word = "<num>" if word[0] != "<" and not word.isalpha(): continue if dictionary.get(word) is None: dictionary[word] = 0 dictionary[word] += 1 dictionary["<unk>"] = 99999999 count_pairs = sorted(dictionary.items(), key=lambda x: -x[1]) tokens, _ = zip(*count_pairs) tokens = tokens[0:vocab_size] tokens = list(tokens) # tokens.append("<pad>") print("Token count: {}".format(len(tokens)))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--vocab_path', type=str, default='ckpt_blog_td/vocab.txt', help='the location of checkpointing files') parser.add_argument('--ckpt_path', type=str, default='ckpt_blog_td2', help='the location of checkpointing files') parser.add_argument('--model', type=str, default='t-d-s2s', help='model types: t-d, t-d-s2s or rnn') parser.add_argument('--beam_size', type=int, default=3, help='beam search size') parser.add_argument('--prime', type=str, default='chia tay me di buc ca minh') parser.add_argument('--saved_args_path', type=str, default="./ckpt_blog_td2/args.pkl") args = parser.parse_args() with open(args.saved_args_path, 'rb') as f: saved_args = cPickle.load(f) word2idx, idx2word = load_vocab(args.vocab_path) accent_vocab = dict() for token, idx in word2idx.items(): raw_token = tone_utils.clear_all_marks(token) if token[0] != "<" else token if raw_token not in accent_vocab: accent_vocab[raw_token] = [token] else: curr = accent_vocab[raw_token] if token not in curr: curr.append(token) unk_idx = word2idx["<unk>"] if args.model == "t-d" or args.model == "t-d-s2s": # quick fix model = TransformerDecoder(False, saved_args) else: model = RNN(False, saved_args) with tf.Session(graph=model.graph if args.model != "rnn" else None) as sess: tf.global_variables_initializer().run() saver = tf.train.Saver(tf.global_variables()) ckpt = tf.train.get_checkpoint_state(args.ckpt_path) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) words = basic_tokenizer(args.prime) if args.model == "t-d": sos_idx = word2idx["<sos>"] pad_idx = word2idx["<pad>"] init_state = np.full(shape=(saved_args.maxlen + 1), fill_value=pad_idx) init_state[0] = sos_idx init_probs = sess.run(tf.nn.softmax(model.logits), feed_dict={ model.x: np.atleast_2d(init_state)})[0] paths = beamsearch_transformer(sess, model, words, args.beam_size, saved_args.maxlen, init_probs, accent_vocab, word2idx) elif args.model == "rnn": x = np.zeros((1, 1)) words = basic_tokenizer(args.prime) init_state = sess.run(model.cell.zero_state(1, tf.float32)) if words[0] != "<eos>": words = ["<eos>"] + words out_state = init_state x[0, 0] = word2idx[words[0]] if words[0] in word2idx else unk_idx # print(x[0,0]) feed = {model.input_data: x, model.initial_state: out_state} [probs, out_state] = sess.run([model.probs, model.final_state], feed) paths = beamsearch_rnn(sess, model, words, args.beam_size, out_state, probs[0], accent_vocab, word2idx) else: pad_idx = word2idx["<pad>"] ref = [] for idx, token in idx2word.items(): cleared = clear_all_marks(token) if cleared not in ref: ref.append(cleared) words = basic_tokenizer(args.prime) feed_x = np.asarray([ref.index(w) if w in word2idx else ref.index("<unk>") for w in words]) feed_x = np.atleast_2d( np.lib.pad(feed_x, [0, saved_args.maxlen - len(feed_x)], 'constant', constant_values=pad_idx)) feed = {model.x: feed_x} paths = [sess.run(model.preds, feed_dict=feed)] paths[0][len(words):] = pad_idx result = "" for path in paths: for idx, token in enumerate(path): result += idx2word[token] if token != unk_idx else words[idx if args.model != "rnn" else idx + 1] result += " " result += "\n" print(result)