Ejemplo n.º 1
0
def _format_line(line):
    line = re.sub(url_marker.WEB_URL_REGEX, "<link>", line)
    line = re.sub("[\.]+", ".", line)
    line = re.sub("[0-9]*\,[0-9]+", "<num>", line)
    line = re.sub("[0-9]*\.[0-9]+", "<num>", line)
    line = re.sub("[0-9]+", "<num>", line)
    line = re.sub("[\.\?\!]", " <eos> ", line)
    return basic_tokenizer(line)
#         if line[0] == "\n":
#             continue
#         line = line.replace(".",".\n").replace("?","?\n").replace("!","!\n")
#         result.append(line)
#
# with codecs.open(processed_path, "w", encoding="utf8") as f:
#     f.write(''.join(result))

if add_unk:
    with codecs.open(processed_path, encoding="utf8") as f:
        lines = f.readlines()
        for line in tqdm(lines):
            if line[0] == "\n":
                continue
            line += " <eos>"
            for word in basic_tokenizer(line):
                if word.isdigit():
                    word = "<num>"
                if word[0] != "<" and not word.isalpha():
                    continue
                if dictionary.get(word) is None:
                    dictionary[word] = 0
                dictionary[word] += 1

    dictionary["<unk>"] = 99999999
    count_pairs = sorted(dictionary.items(), key=lambda x: -x[1])
    tokens, _ = zip(*count_pairs)
    tokens = tokens[0:vocab_size]
    tokens = list(tokens)
    # tokens.append("<pad>")
    print("Token count: {}".format(len(tokens)))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--vocab_path', type=str, default='ckpt_blog_td/vocab.txt',
                        help='the location of checkpointing files')
    parser.add_argument('--ckpt_path', type=str, default='ckpt_blog_td2',
                        help='the location of checkpointing files')
    parser.add_argument('--model', type=str, default='t-d-s2s',
                        help='model types: t-d, t-d-s2s or rnn')
    parser.add_argument('--beam_size', type=int, default=3,
                        help='beam search size')
    parser.add_argument('--prime', type=str,
                        default='chia tay me di buc ca minh')
    parser.add_argument('--saved_args_path', type=str, default="./ckpt_blog_td2/args.pkl")

    args = parser.parse_args()
    with open(args.saved_args_path, 'rb') as f:
        saved_args = cPickle.load(f)
    word2idx, idx2word = load_vocab(args.vocab_path)
    accent_vocab = dict()
    for token, idx in word2idx.items():
        raw_token = tone_utils.clear_all_marks(token) if token[0] != "<" else token
        if raw_token not in accent_vocab:
            accent_vocab[raw_token] = [token]
        else:
            curr = accent_vocab[raw_token]
            if token not in curr:
                curr.append(token)

    unk_idx = word2idx["<unk>"]
    if args.model == "t-d" or args.model == "t-d-s2s":
        # quick fix
        model = TransformerDecoder(False, saved_args)
    else:
        model = RNN(False, saved_args)

    with tf.Session(graph=model.graph if args.model != "rnn" else None) as sess:
        tf.global_variables_initializer().run()
        saver = tf.train.Saver(tf.global_variables())
        ckpt = tf.train.get_checkpoint_state(args.ckpt_path)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            words = basic_tokenizer(args.prime)
            if args.model == "t-d":
                sos_idx = word2idx["<sos>"]
                pad_idx = word2idx["<pad>"]
                init_state = np.full(shape=(saved_args.maxlen + 1), fill_value=pad_idx)
                init_state[0] = sos_idx
                init_probs = sess.run(tf.nn.softmax(model.logits), feed_dict={
                    model.x: np.atleast_2d(init_state)})[0]
                paths = beamsearch_transformer(sess, model,
                                               words, args.beam_size,
                                               saved_args.maxlen, init_probs,
                                               accent_vocab, word2idx)
            elif args.model == "rnn":
                x = np.zeros((1, 1))
                words = basic_tokenizer(args.prime)
                init_state = sess.run(model.cell.zero_state(1, tf.float32))
                if words[0] != "<eos>":
                    words = ["<eos>"] + words
                out_state = init_state
                x[0, 0] = word2idx[words[0]] if words[0] in word2idx else unk_idx
                # print(x[0,0])
                feed = {model.input_data: x, model.initial_state: out_state}
                [probs, out_state] = sess.run([model.probs, model.final_state], feed)
                paths = beamsearch_rnn(sess, model,
                                       words, args.beam_size,
                                       out_state, probs[0],
                                       accent_vocab, word2idx)
            else:
                pad_idx = word2idx["<pad>"]
                ref = []
                for idx, token in idx2word.items():
                    cleared = clear_all_marks(token)
                    if cleared not in ref:
                        ref.append(cleared)
                words = basic_tokenizer(args.prime)
                feed_x = np.asarray([ref.index(w) if w in word2idx else ref.index("<unk>") for w in words])
                feed_x = np.atleast_2d(
                    np.lib.pad(feed_x, [0, saved_args.maxlen - len(feed_x)], 'constant', constant_values=pad_idx))
                feed = {model.x: feed_x}
                paths = [sess.run(model.preds, feed_dict=feed)]
                paths[0][len(words):] = pad_idx
            result = ""
            for path in paths:
                for idx, token in enumerate(path):
                    result += idx2word[token] if token != unk_idx else words[idx if args.model != "rnn" else idx + 1]
                    result += " "
                result += "\n"
            print(result)