import torch
from models.util.lookup import Lookup
from tqdm import tqdm
from itertools import dropwhile
import sentencepiece as spm

output_lookup_folder = os.path.join("lookup", "gpt2")

# create output folder
if not os.path.exists(output_lookup_folder):
    os.makedirs(output_lookup_folder)

# CREATE LOOKUPS
src_lookup = Lookup(type="gpt2")
src_lookup.save_special_tokens(
    file_prefix=os.path.join(output_lookup_folder, "src"))

tgt_lookup = Lookup(type="gpt2")
tgt_lookup.save_special_tokens(
    file_prefix=os.path.join(output_lookup_folder, "tgt"))

print("Done.")

# check everything is ok
lookup = Lookup(type="gpt2")
lookup.load(file_prefix=os.path.join(output_lookup_folder, "tgt"))
text = "This is a test."
token_ids = lookup.encode(text)
print("Encode: {}".format(token_ids))
recreated_string = lookup.decode(token_ids)
print("Decode: {}".format(recreated_string))
    os.path.join(output_lookup_folder, "src-" + str(input_src_vocab_size)) +
    ' --character_coverage=1.0 --model_type=bpe --num_threads=8 --split_by_whitespace=true --shuffle_input_sentence=true --vocab_size='
    + str(input_src_vocab_size))
spm.SentencePieceTrainer.Train(
    '--input=' + input_raw_file + '.Xy.txt --model_prefix=' +
    os.path.join(output_lookup_folder, "tgt-" + str(input_tgt_vocab_size)) +
    ' --character_coverage=1.0 --model_type=bpe --num_threads=8 --split_by_whitespace=true --shuffle_input_sentence=true --vocab_size='
    + str(input_tgt_vocab_size))
#--pad_id=0 --pad_piece=<PAD> --unk_id=1 --unk_piece=<UNK> --bos_id=2 --bos_piece=<BOS> --eos_id=3 --eos_piece=<EOS>

# CREATE LOOKUPS
src_lookup = Lookup(type="bpe")
src_lookup.load(
    os.path.join(output_lookup_folder, "src-" + str(input_src_vocab_size)))
src_lookup.save_special_tokens(
    file_prefix=os.path.join(output_lookup_folder, "src-" +
                             str(input_src_vocab_size)))

tgt_lookup = Lookup(type="bpe")
tgt_lookup.load(
    os.path.join(output_lookup_folder, "tgt-" + str(input_tgt_vocab_size)))
tgt_lookup.save_special_tokens(
    file_prefix=os.path.join(output_lookup_folder, "tgt-" +
                             str(input_tgt_vocab_size)))
print("Done.")

# check everything is ok
lookup = Lookup(type="bpe")
lookup.load(file_prefix=os.path.join(output_lookup_folder, "tgt-" +
                                     str(input_tgt_vocab_size)))