Exemple #1
0
    def __init__(
        self,
        TOK_PATH=Path('./senior_proj_itos'),
        BOS='xxbos',
        EOS='xxeos',
        FLD='xxfld',
        UNK='xxunk',
        PAD='xxpad',
        TK_REP='xxrep',
        TK_WREP='xxwrep',
        TK_NUM='xxnum',
        TK_LAUGH='xxlaugh',
        n_cpus=1,
    ):
        from senior_project_util import ThaiTokenizer, pre_rules_th, post_rules_th
        from fastai.text.transform import BaseTokenizer, Tokenizer, Vocab
        from fastai.text.data import TokenizeProcessor, NumericalizeProcessor

        with open(TOK_PATH / "bert_itos_80k_cleaned.pkl", 'rb') as f:
            itos = pickle.load(f)

        self.vocab = Vocab(itos)
        self.tokenizer = Tokenizer(tok_func=ThaiTokenizer,
                                   lang='th',
                                   pre_rules=pre_rules_th,
                                   post_rules=post_rules_th,
                                   n_cpus=n_cpus)

        self.cls_token_id = self.vocab.stoi[BOS]
        self.sep_token_id = self.vocab.stoi[EOS]
        self.pad_token_id = self.vocab.stoi[PAD]

        self.mask_token = FLD  #SINCE THIS ONE IS NOT USED, and INSIDE SPECIAL TOKEN....
        self._pad_token = PAD
class CustomSeniorProjectTokenizer(object):
    def __init__(self, TOK_PATH = Path('./senior_proj_itos'), BOS='xxbos', EOS='xxeos', FLD = 'xxfld', UNK='xxunk', PAD='xxpad',
                 TK_REP='xxrep', TK_WREP='xxwrep', TK_NUM='xxnum', TK_LAUGH='xxlaugh'
                ):
        from senior_project_util import ThaiTokenizer, pre_rules_th, post_rules_th
        from fastai.text.transform import BaseTokenizer, Tokenizer, Vocab
        from fastai.text.data import TokenizeProcessor, NumericalizeProcessor

        with open(TOK_PATH/"bert_itos_80k_cleaned.pkl", 'rb') as f:
            itos = pickle.load(f)
            
        self.vocab = Vocab(itos)
        self.tokenizer = Tokenizer(tok_func = ThaiTokenizer, lang = 'th', 
                                   pre_rules = pre_rules_th, post_rules=post_rules_th, n_cpus=1)
        
        self.cls_token_id = self.vocab.stoi[BOS]
        self.sep_token_id = self.vocab.stoi[EOS]
        
#         tokenizer_processor = TokenizeProcessor(tokenizer=tt, chunksize=300000, mark_fields=False)
#         numbericalize_processor = NumericalizeProcessor(vocab=vocab)
        
    def num_special_tokens_to_add(self, pair=False):
        return 2
    def tokenize(self, text):
        return self.tokenizer._process_all_1([text])[0]
#         return self.tokenizer.process_all([text])[0]
    
    def convert_tokens_to_ids(self, token_list):
        return self.vocab.numericalize(token_list)
    
    def build_inputs_with_special_tokens(self, token_list):
        # From https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_bert.py#L235
        return [self.cls_token_id] + token_list + [self.sep_token_id]
Exemple #3
0
def get_sentencepiece(path: PathOrStr,
                      trn_path: Path,
                      name: str,
                      pre_rules: ListRules = None,
                      post_rules: ListRules = None,
                      vocab_size: int = 30000,
                      model_type: str = 'unigram',
                      input_sentence_size: int = 1E7,
                      pad_idx: int = PAD_TOKEN_ID):
    try:
        import sentencepiece as spm
    except ImportError:
        raise Exception(
            'sentencepiece module is missing: run `pip install sentencepiece`')

    path = pathlib.Path(path)
    cache_name = 'tmp'
    os.makedirs(path / cache_name, exist_ok=True)
    os.makedirs(path / 'models', exist_ok=True)
    pre_rules = pre_rules if pre_rules is not None else []
    post_rules = post_rules if post_rules is not None else []

    # load the text frmo the train tokens file
    text = [line.rstrip('\n') for line in open(trn_path)]
    text = list(filter(None, text))

    if not os.path.isfile(path / 'models' / 'spm.model') or not os.path.isfile(
            path / 'models' / f'itos_{name}.pkl'):
        raw_text = reduce(lambda t, rule: rule(t), pre_rules, '\n'.join(text))
        raw_text_path = path / cache_name / 'all_text.txt'
        with open(raw_text_path, 'w') as f:
            f.write(raw_text)

        sp_params = f"--input={raw_text_path} --pad_id={pad_idx} --unk_id=0 " \
                    f"--character_coverage=1.0 --bos_id=-1 --eos_id=-1 " \
                    f"--input_sentence_size={int(input_sentence_size)} " \
                    f"--model_prefix={path / 'models' / 'spm'} " \
                    f"--vocab_size={vocab_size} --model_type={model_type} "
        spm.SentencePieceTrainer.Train(sp_params)

        with open(path / 'models' / 'spm.vocab', 'r') as f:
            vocab = [line.split('\t')[0] for line in f.readlines()]
            vocab[0] = UNK
            vocab[pad_idx] = PAD

        pickle.dump(vocab, open(path / 'models' / f'itos_{name}.pkl', 'wb'))
    # todo add post rules
    vocab = Vocab(pickle.load(open(path / 'models' / f'itos_{name}.pkl',
                                   'rb')))
    # We cannot use lambdas or local methods here, since `tok_func` needs to be
    # pickle-able in order to be called in subprocesses when multithread tokenizing
    tokenizer = Tokenizer(tok_func=SentencepieceTokenizer,
                          lang=str(path / 'models'),
                          pre_rules=pre_rules,
                          post_rules=post_rules)

    clear_cache_directory(path, cache_name)

    return {'tokenizer': tokenizer, 'vocab': vocab}
Exemple #4
0
def get_vocab():
    if conf['vocab_path'] is not None:
        vocab_obj = pickle.load(
            open(conf['local_project_path'] + conf['vocab_path'], 'rb'))
        vocab_class_obj = Vocab.load(conf['local_project_path'] +
                                     conf['vocab_path'])
        log.debug('vocab object loaded, len ' + str(len(vocab_obj)))
    else:
        vocab_obj = None
        vocab_class_obj = None
    return vocab_obj, vocab_class_obj
Exemple #5
0
def get_datasets(dataset, dataset_dir, bptt, bs, lang, max_vocab, ds_pct, lm_type):
    tmp_dir = dataset_dir / 'tmp'
    tmp_dir.mkdir(exist_ok=True)
    vocab_file = tmp_dir / f'vocab_{lang}.pkl'
    if not (tmp_dir / f'{TRN}_{lang}_ids.npy').exists():
        print('Reading the data...')
        toks, lbls = read_clas_data(dataset_dir, dataset, lang)
        # create the vocabulary
        counter = Counter(word for example in toks[TRN]+toks[TST]+toks[VAL] for word in example)
        itos = [word for word, count in counter.most_common(n=max_vocab)]
        itos.insert(0, PAD)
        itos.insert(0, UNK)
        vocab = Vocab(itos)
        stoi = vocab.stoi
        with open(vocab_file, 'wb') as f:
            pickle.dump(vocab, f)

        ids = {}
        for split in [TRN, VAL, TST]:
            ids[split] = np.array([([stoi.get(w, stoi[UNK]) for w in s])
                                   for s in toks[split]])
            np.save(tmp_dir / f'{split}_{lang}_ids.npy', ids[split])
            np.save(tmp_dir / f'{split}_{lang}_lbl.npy', lbls[split])
    else:
        print('Loading the pickled data...')
        ids, lbls = {}, {}
        for split in [TRN, VAL, TST]:
            ids[split] = np.load(tmp_dir / f'{split}_{lang}_ids.npy')
            lbls[split] = np.load(tmp_dir / f'{split}_{lang}_lbl.npy')
        with open(vocab_file, 'rb') as f:
            vocab = pickle.load(f)
    print(f'Train size: {len(ids[TRN])}. Valid size: {len(ids[VAL])}. '
          f'Test size: {len(ids[TST])}.')
    if ds_pct < 1.0:
        print(f"Making the dataset smaller {ds_pct}")
    for split in [TRN, VAL, TST]:
        ids[split] = np.array([np.array(e, dtype=np.int) for e in ids[split]])
        #print([lbl for lbl in lbls[split] if not int(lbl) in [0,1,2]])          # debug by ak
        #print(f'First 10 lbls[split] labels: {lbls[split][:11]}') 
        if split == TRN: print("processing TRN labels ... ")
        lbls[split] = np.array([np.array(e, dtype=np.int) for e in lbls[split]])
        if split == TRN: print("Info: Passed the train labels lbls[split] to np.array sucessfully .....")
    data_lm = TextLMDataBunch.from_ids(path=tmp_dir, vocab=vocab, train_ids=np.concatenate([ids[TRN],ids[TST]]),
                                       valid_ids=ids[VAL], bs=bs, bptt=bptt, lm_type=lm_type)
    #  TODO TextClasDataBunch allows tst_ids as input, but not tst_lbls?
    data_clas = TextClasDataBunch.from_ids(
        path=tmp_dir, vocab=vocab, train_ids=ids[TRN], valid_ids=ids[VAL],
        train_lbls=lbls[TRN], valid_lbls=lbls[VAL], bs=bs, classes={l:l for l in lbls[TRN]})

    print(f"Sizes of train_ds {len(data_clas.train_ds)}, valid_ds {len(data_clas.valid_ds)}")
    return data_clas, data_lm
Exemple #6
0
    def load_cls_data_old_for_xnli(self, bs):
        tmp_dir = self.cache_dir
        tmp_dir.mkdir(exist_ok=True)
        vocab_file = tmp_dir / f'vocab_{self.lang}.pkl'
        if not (tmp_dir / f'{TRN}_{self.lang}_ids.npy').exists():
            print('Reading the data...')
            toks, lbls = read_clas_data(self.dataset_dir, self.dataset_dir.name, self.lang)
            # create the vocabulary
            counter = Counter(word for example in toks[TRN] + toks[TST] + toks[VAL] for word in example)
            itos = [word for word, count in counter.most_common(n=self.max_vocab)]
            itos.insert(0, PAD)
            itos.insert(0, UNK)
            vocab = Vocab(itos)
            stoi = vocab.stoi
            with open(vocab_file, 'wb') as f:
                pickle.dump(vocab, f)
            ids = {}
            for split in [TRN, VAL, TST]:
                ids[split] = np.array([([stoi.get(w, stoi[UNK]) for w in s])
                                       for s in toks[split]])
                np.save(tmp_dir / f'{split}_{self.lang}_ids.npy', ids[split])
                np.save(tmp_dir / f'{split}_{self.lang}_lbl.npy', lbls[split])
        else:
            print('Loading the pickled data...')
            ids, lbls = {}, {}
            for split in [TRN, VAL, TST]:
                ids[split] = np.load(tmp_dir / f'{split}_{self.lang}_ids.npy')
                lbls[split] = np.load(tmp_dir / f'{split}_{self.lang}_lbl.npy')
            with open(vocab_file, 'rb') as f:
                vocab = pickle.load(f)
        print(f'Train size: {len(ids[TRN])}. Valid size: {len(ids[VAL])}. '
              f'Test size: {len(ids[TST])}.')
        for split in [TRN, VAL, TST]:
            ids[split] = np.array([np.array(e, dtype=np.int) for e in ids[split]])
            lbls[split] = np.array([np.array(e, dtype=np.int) for e in lbls[split]])
        data_lm = TextLMDataBunch.from_ids(path=tmp_dir, vocab=vocab, train_ids=np.concatenate([ids[TRN], ids[TST]]),
                                           valid_ids=ids[VAL], bs=bs, bptt=self.bptt, lm_type=self.lm_type)
        #  TODO TextClasDataBunch allows tst_ids as input, but not tst_lbls?
        data_clas = TextClasDataBunch.from_ids(
            path=tmp_dir, vocab=vocab, train_ids=ids[TRN], valid_ids=ids[VAL],
            train_lbls=lbls[TRN], valid_lbls=lbls[VAL], bs=bs, classes={l: l for l in lbls[TRN]})

        print(f"Sizes of train_ds {len(data_clas.train_ds)}, valid_ds {len(data_clas.valid_ds)}")
        return data_clas, data_lm
Exemple #7
0
class CustomSeniorProjectTokenizer(object):
    def __init__(self, TOK_PATH = Path('./senior_proj_itos'), BOS='xxbos', EOS='xxeos', FLD = 'xxfld', UNK='xxunk', PAD='xxpad',
                 TK_REP='xxrep', TK_WREP='xxwrep', TK_NUM='xxnum', TK_LAUGH='xxlaugh', n_cpus=1,
                ):
        from senior_project_util import ThaiTokenizer, pre_rules_th, post_rules_th
        from fastai.text.transform import BaseTokenizer, Tokenizer, Vocab
        from fastai.text.data import TokenizeProcessor, NumericalizeProcessor

        with open(TOK_PATH/"bert_itos_80k_cleaned.pkl", 'rb') as f:
            itos = pickle.load(f)
            
        self.vocab = Vocab(itos)
        self.tokenizer = Tokenizer(tok_func = ThaiTokenizer, lang = 'th', 
                                   pre_rules = pre_rules_th, post_rules=post_rules_th, n_cpus=n_cpus)
        
        self.cls_token_id = self.vocab.stoi[BOS]
        self.sep_token_id = self.vocab.stoi[EOS]
        self.pad_token_id = self.vocab.stoi[PAD]
        
        self.mask_token = FLD  #SINCE THIS ONE IS NOT USED, and INSIDE SPECIAL TOKEN....
        self._pad_token = PAD
        
#         tokenizer_processor = TokenizeProcessor(tokenizer=tt, chunksize=300000, mark_fields=False)
#         numbericalize_processor = NumericalizeProcessor(vocab=vocab)
        
    def num_special_tokens_to_add(self, pair=False):
        return 2
    def tokenize(self, text):
        return self.tokenizer._process_all_1([text])[0]
#         return self.tokenizer.process_all([text])[0]
    
    def convert_tokens_to_ids(self, token_list):
        #From https://huggingface.co/transformers/_modules/transformers/tokenization_utils_fast.html#PreTrainedTokenizerFast.convert_tokens_to_ids
        if token_list is None:
            return None

        if isinstance(token_list, str):
            return self.vocab.numericalize([token_list])[0]
        
        return self.vocab.numericalize(token_list)
    
    def build_inputs_with_special_tokens(self, token_list):
        # From https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_bert.py#L235
        return [self.cls_token_id] + token_list + [self.sep_token_id]
    
    def get_special_tokens_mask(
        self, token_ids_0, token_ids_1 = None, already_has_special_tokens = False
    ):
        # From https://huggingface.co/transformers/_modules/transformers/tokenization_utils.html#PreTrainedTokenizer.get_special_tokens_mask
        """
        Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer ``prepare_for_model`` method.

        Args:
            token_ids_0: list of ids (must not contain special tokens)
            token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
                for sequence pairs
            already_has_special_tokens: (default False) Set to True if the token list is already formated with
                special tokens for the model

        Returns:
            A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """
        return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
    
    def __len__(self):
        #https://huggingface.co/transformers/_modules/transformers/tokenization_utils_fast.html#PreTrainedTokenizerFast.__len__
        return len(self.vocab.itos)
Exemple #8
0
import fastai
print(f"Running on Fastai version: {fastai.__version__}")


# In[15]:


with open(TOK_PATH/"bert_itos_80k_cleaned.pkl", 'rb') as f:
    itos = pickle.load(f)
# len(itos)


# In[16]:


vocab = Vocab(itos)


# In[18]:


# tt = Tokenizer(tok_func = ThaiTokenizer, lang = 'th', pre_rules = pre_rules_th, post_rules=post_rules_th, n_cpus=1)
# test_sample = tt._process_all_1([text[:100]])
# print(test_sample)
# test_sample = [vocab.numericalize(seq) for seq in test_sample]
# print(test_sample)




# In[21]:
Exemple #9
0
    - test_texts
    - test_labels
    :return:
    """
    pass


def tokenizer(texts):  # create a tokenizer function
    tok = TextTokenizer('en')
    return tok.process_all(texts)


if __name__ == "__main__":
    # 1. Download data
    # untar_data(URI)

    # 2. Read data and save with 'normal' format: text, label
    # texts, labels, label_index = parse_text_data()
    # df = pd.DataFrame.from_dict({'text': texts, 'label': labels})
    # df.to_csv('./data/20_newsgroup.csv', index=None)

    # 3. Tokenize text to create vocabulary
    df = pd.read_csv('./data/20_newsgroup.csv')

    tokens = tokenizer(df[:10]['text'].tolist())
    vocab = Vocab.create(tokens, max_vocab=1000, min_freq=2)
    print(vocab.itos)
    print(vocab.stoi)

    # 4. create embedding matrix from pretrained word vectors
def new_train_clas(data_dir,
                   lang='en',
                   cuda_id=0,
                   pretrain_name='wt103',
                   model_dir='models',
                   qrnn=False,
                   fine_tune=True,
                   max_vocab=30000,
                   bs=20,
                   bptt=70,
                   name='imdb-clas',
                   dataset='imdb',
                   ds_pct=1.0):
    """
    :param data_dir: The path to the `data` directory
    :param lang: the language unicode
    :param cuda_id: The id of the GPU. Uses GPU 0 by default or no GPU when
                    run on CPU.
    :param pretrain_name: name of the pretrained model
    :param model_dir: The path to the directory where the pretrained model is saved
    :param qrrn: Use a QRNN. Requires installing cupy.
    :param fine_tune: Fine-tune the pretrained language model
    :param max_vocab: The maximum size of the vocabulary.
    :param bs: The batch size.
    :param bptt: The back-propagation-through-time sequence length.
    :param name: The name used for both the model and the vocabulary.
    :param dataset: The dataset used for evaluation. Currently only IMDb and
                    XNLI are implemented. Assumes dataset is located in `data`
                    folder and that name of folder is the same as dataset name.
    """
    results = {}
    if not torch.cuda.is_available():
        print('CUDA not available. Setting device=-1.')
        cuda_id = -1
    torch.cuda.set_device(cuda_id)

    print(f'Dataset: {dataset}. Language: {lang}.')
    assert dataset in DATASETS, f'Error: {dataset} processing is not implemented.'
    assert (dataset == 'imdb' and lang == 'en') or not dataset == 'imdb',\
        'Error: IMDb is only available in English.'

    data_dir = Path(data_dir)
    assert data_dir.name == 'data',\
        f'Error: Name of data directory should be data, not {data_dir.name}.'
    dataset_dir = data_dir / dataset
    model_dir = Path(model_dir)

    if qrnn:
        print('Using QRNNs...')
    model_name = 'qrnn' if qrnn else 'lstm'
    lm_name = f'{model_name}_{pretrain_name}'
    pretrained_fname = (lm_name, f'itos_{pretrain_name}')

    ensure_paths_exists(data_dir, dataset_dir, model_dir,
                        model_dir / f"{pretrained_fname[0]}.pth",
                        model_dir / f"{pretrained_fname[1]}.pkl")

    tmp_dir = dataset_dir / 'tmp'
    tmp_dir.mkdir(exist_ok=True)
    vocab_file = tmp_dir / f'vocab_{lang}.pkl'

    if not (tmp_dir / f'{TRN}_{lang}_ids.npy').exists():
        print('Reading the data...')
        toks, lbls = read_clas_data(dataset_dir, dataset, lang)

        # create the vocabulary
        counter = Counter(word for example in toks[TRN] for word in example)
        itos = [word for word, count in counter.most_common(n=max_vocab)]
        itos.insert(0, PAD)
        itos.insert(0, UNK)
        vocab = Vocab(itos)
        stoi = vocab.stoi
        with open(vocab_file, 'wb') as f:
            pickle.dump(vocab, f)

        ids = {}
        for split in [TRN, VAL, TST]:
            ids[split] = np.array([([stoi.get(w, stoi[UNK]) for w in s])
                                   for s in toks[split]])
            np.save(tmp_dir / f'{split}_{lang}_ids.npy', ids[split])
            np.save(tmp_dir / f'{split}_{lang}_lbl.npy', lbls[split])
    else:
        print('Loading the pickled data...')
        ids, lbls = {}, {}
        for split in [TRN, VAL, TST]:
            ids[split] = np.load(tmp_dir / f'{split}_{lang}_ids.npy')
            lbls[split] = np.load(tmp_dir / f'{split}_{lang}_lbl.npy')
        with open(vocab_file, 'rb') as f:
            vocab = pickle.load(f)

    print(f'Train size: {len(ids[TRN])}. Valid size: {len(ids[VAL])}. '
          f'Test size: {len(ids[TST])}.')

    if ds_pct < 1.0:
        print(f"Makeing the dataset smaller {ds_pct}")
        for split in [TRN, VAL, TST]:
            ids[split] = ids[split][:int(len(ids[split]) * ds_pct)]

    data_lm = TextLMDataBunch.from_ids(path=tmp_dir,
                                       vocab=vocab,
                                       train_ids=ids[TRN],
                                       valid_ids=ids[VAL],
                                       bs=bs,
                                       bptt=bptt)

    # TODO TextClasDataBunch allows tst_ids as input, but not tst_lbls?
    data_clas = TextClasDataBunch.from_ids(path=tmp_dir,
                                           vocab=vocab,
                                           train_ids=ids[TRN],
                                           valid_ids=ids[VAL],
                                           train_lbls=lbls[TRN],
                                           valid_lbls=lbls[VAL],
                                           bs=bs)

    if qrnn:
        emb_sz, nh, nl = 400, 1550, 3
    else:
        emb_sz, nh, nl = 400, 1150, 3
    learn = language_model_learner(data_lm,
                                   bptt=bptt,
                                   emb_sz=emb_sz,
                                   nh=nh,
                                   nl=nl,
                                   qrnn=qrnn,
                                   pad_token=PAD_TOKEN_ID,
                                   pretrained_fnames=pretrained_fname,
                                   path=model_dir.parent,
                                   model_dir=model_dir.name)
    lm_enc_finetuned = f"{lm_name}_{dataset}_enc"
    if fine_tune and not (model_dir / f"lm_enc_finetuned.pth").exists():
        print('Fine-tuning the language model...')
        learn.unfreeze()
        learn.fit(2, slice(1e-4, 1e-2))

        # save encoder
        learn.save_encoder(lm_enc_finetuned)

    print("Starting classifier training")
    learn = text_classifier_learner(data_clas,
                                    bptt=bptt,
                                    pad_token=PAD_TOKEN_ID,
                                    path=model_dir.parent,
                                    model_dir=model_dir.name,
                                    qrnn=qrnn,
                                    emb_sz=emb_sz,
                                    nh=nh,
                                    nl=nl)

    learn.load_encoder(lm_enc_finetuned)

    learn.fit_one_cycle(1, 2e-2, moms=(0.8, 0.7), wd=1e-7)

    learn.freeze_to(-2)
    learn.fit_one_cycle(1,
                        slice(1e-2 / (2.6**4), 1e-2),
                        moms=(0.8, 0.7),
                        wd=1e-7)

    learn.freeze_to(-3)
    learn.fit_one_cycle(1,
                        slice(5e-3 / (2.6**4), 5e-3),
                        moms=(0.8, 0.7),
                        wd=1e-7)

    learn.unfreeze()
    learn.fit_one_cycle(2,
                        slice(1e-3 / (2.6**4), 1e-3),
                        moms=(0.8, 0.7),
                        wd=1e-7)
    results['accuracy'] = learn.validate()[1]
    print(f"Saving models at {learn.path / learn.model_dir}")
    learn.save(f'{model_name}_{name}')
    return results
import fastai
print(f"Running on Fastai version: {fastai.__version__}")


# In[15]:


with open(TOK_PATH/"bert_itos_80k_cleaned.pkl", 'rb') as f:
    itos = pickle.load(f)
len(itos)


# In[16]:


vocab = Vocab(itos)


# In[17]:
pyThai_tt = ThaiTokenizer()


# In[18]:


tt = Tokenizer(tok_func = ThaiTokenizer, lang = 'th', pre_rules = pre_rules_th, post_rules=post_rules_th, n_cpus=1)
test_sample = tt._process_all_1([text[:100]])
print(test_sample)
test_sample = [vocab.numericalize(seq) for seq in test_sample]
print(test_sample)
import numpy
import torch
from torch.nn import functional
from torch import nn
from fastai.text.transform import Vocab
import unidecode
import string

# Taken from https://gist.github.com/jvns/b6dda36b2fdcc02b833ed5b0c7a09112
# Download Hans Christian Anderson's fairy tales
# !wget -O fairy-tales.txt https://www.gutenberg.org/cache/epub/27200/pg27200.txt > /dev/null 2>&1

file = unidecode.unidecode(open('fairy-tales.txt').read())
# Remove the table of contents & Gutenberg preamble
text = file[5000:]
v = Vocab.create((x for x in text), max_vocab=400, min_freq=1)
num_letters = len(v.itos)
# training_set = torch.Tensor(v.numericalize([x for x in text])).type(torch.LongTensor).cuda()
training_set = torch.Tensor(v.numericalize([x for x in text
                                            ])).type(torch.LongTensor)
training_set = training_set[:100000]


class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.h2o = nn.Linear(hidden_size, input_size)
        self.input_size = input_size
        self.hidden = None