Пример #1
0
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
                                             transfo_xl_config_file,
                                             pytorch_dump_folder_path,
                                             transfo_xl_dataset_file):
    if transfo_xl_dataset_file:
        # Convert a pre-processed corpus (see original TensorFlow repo)
        with open(transfo_xl_dataset_file, "rb") as fp:
            corpus = pickle.load(fp, encoding="latin1")
        # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
        pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES[
            "pretrained_vocab_file"]
        print(f"Save vocabulary to {pytorch_vocab_dump_path}")
        corpus_vocab_dict = corpus.vocab.__dict__
        torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)

        corpus_dict_no_vocab = corpus.__dict__
        corpus_dict_no_vocab.pop("vocab", None)
        pytorch_dataset_dump_path = pytorch_dump_folder_path + "/" + CORPUS_NAME
        print(f"Save dataset to {pytorch_dataset_dump_path}")
        torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)

    if tf_checkpoint_path:
        # Convert a pre-trained TensorFlow model
        config_path = os.path.abspath(transfo_xl_config_file)
        tf_path = os.path.abspath(tf_checkpoint_path)

        print(
            f"Converting Transformer XL checkpoint from {tf_path} with config at {config_path}."
        )
        # Initialise PyTorch model
        if transfo_xl_config_file == "":
            config = TransfoXLConfig()
        else:
            config = TransfoXLConfig.from_json_file(transfo_xl_config_file)
        print(f"Building PyTorch model from configuration: {config}")
        model = TransfoXLLMHeadModel(config)

        model = load_tf_weights_in_transfo_xl(model, config, tf_path)
        # Save pytorch-model
        pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path,
                                                 WEIGHTS_NAME)
        pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path,
                                                CONFIG_NAME)
        print(
            f"Save PyTorch model to {os.path.abspath(pytorch_weights_dump_path)}"
        )
        torch.save(model.state_dict(), pytorch_weights_dump_path)
        print(
            f"Save configuration file to {os.path.abspath(pytorch_config_dump_path)}"
        )
        with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
            f.write(config.to_json_string())
Пример #2
0
 def __init__(self):
     super().__init__()
     self.tokenizer = TransfoXLTokenizer.from_pretrained("transfo-xl-wt103",
                                                         eos_token='<eos>')
     self.tokenizer.add_special_tokens({'bos_token': '<sos>'})
     self.model = TransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103")
     self.softmax = nn.Softmax(dim=0)
Пример #3
0
        def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
            model = TransfoXLLMHeadModel(config)
            model.eval()

            lm_logits_1, mems_1 = model(input_ids_1)
            loss_1, _, mems_1 = model(input_ids_1, labels=lm_labels)
            lm_logits_2, mems_2 = model(input_ids_2, mems=mems_1)
            loss_2, _, mems_2 = model(input_ids_2, labels=lm_labels, mems=mems_1)

            outputs = {
                "loss_1": loss_1,
                "mems_1": mems_1,
                "lm_logits_1": lm_logits_1,
                "loss_2": loss_2,
                "mems_2": mems_2,
                "lm_logits_2": lm_logits_2,
            }
            return outputs
Пример #4
0
    def __init__(self, device='cpu'):
        tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
        model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
        model = model.to(device)

        self.tokenizer = tokenizer
        self.model = model.eval()
        self.device = device

        self.NUM_CLASSES = 267735
Пример #5
0
    def create_transfo_xl_lm_head_trainer_incompatible_tuple(
            self, config, input_ids_1, input_ids_2, lm_labels):
        config.trainer_compatible = False
        model = TransfoXLLMHeadModel(config)
        model.to(torch_device)
        model.eval()

        lm_logits_1 = model(input_ids_1, return_dict=False)[0]
        outputs1 = model(input_ids_1, labels=lm_labels, return_dict=False)
        losses_1, _, mems_1 = outputs1[:3]
        loss_1 = outputs1[-1]
        lm_logits_2 = model(input_ids_2, mems=mems_1, return_dict=False)[0]
        outputs2 = model(input_ids_2, labels=lm_labels, mems=mems_1)
        losses_2, _, mems_2 = outputs2[:3]
        loss_2 = outputs2[-1]

        outputs = {
            "losses_1": losses_1,
            "mems_1": mems_1,
            "lm_logits_1": lm_logits_1,
            "loss_1": loss_1,
            "losses_2": losses_2,
            "mems_2": mems_2,
            "lm_logits_2": lm_logits_2,
            "loss_2": loss_2,
        }

        config.trainer_compatible = None
        return outputs
    def test_lm_generate_transfo_xl_wt103(self):
        model = TransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103")
        model.to(torch_device)

        # fmt: off
        input_ids = torch.tensor([[33,1297,2,1,1009,4,1109,11739,4762,358,5,25,245,22,1706,17,20098,5,3215,21,37,1110,3,13,1041,4,24,603,490,2,71477,20098,104447,2,20961,1,2604,4,1,329,3,6224,831,16002,2,8,603,78967,29546,23,803,20,25,416,5,8,232,4,277,6,1855,4601,3,29546,54,8,3609,5,57211,49,4,1,277,18,8,1755,15691,3,341,25,416,693,42573,71,17,401,94,31,17919,2,29546,7873,18,1,435,23,11011,755,5,5167,3,7983,98,84,2,29546,3267,8,3609,4,1,4865,1075,2,6087,71,6,346,8,5854,3,29546,824,1400,1868,2,19,160,2,311,8,5496,2,20920,17,25,15097,3,24,24,0]],dtype=torch.long,device=torch_device)  # noqa: E231
        # fmt: on
        #  In 1991 , the remains of Russian Tsar Nicholas II and his family
        #  ( except for Alexei and Maria ) are discovered .
        #  The voice of Nicholas's young son , Tsarevich Alexei Nikolaevich , narrates the
        #  remainder of the story . 1883 Western Siberia ,
        #  a young Grigori Rasputin is asked by his father and a group of men to perform magic .
        #  Rasputin has a vision and denounces one of the men as a horse thief . Although his
        #  father initially slaps him for making such an accusation , Rasputin watches as the
        #  man is chased outside and beaten . Twenty years later , Rasputin sees a vision of
        #  the Virgin Mary , prompting him to become a priest . Rasputin quickly becomes famous ,
        #  with people , even a bishop , begging for his blessing . <eod> </s> <eos>

        # fmt: off
        expected_output_ids = [33,1297,2,1,1009,4,1109,11739,4762,358,5,25,245,22,1706,17,20098,5,3215,21,37,1110,3,13,1041,4,24,603,490,2,71477,20098,104447,2,20961,1,2604,4,1,329,3,6224,831,16002,2,8,603,78967,29546,23,803,20,25,416,5,8,232,4,277,6,1855,4601,3,29546,54,8,3609,5,57211,49,4,1,277,18,8,1755,15691,3,341,25,416,693,42573,71,17,401,94,31,17919,2,29546,7873,18,1,435,23,11011,755,5,5167,3,7983,98,84,2,29546,3267,8,3609,4,1,4865,1075,2,6087,71,6,346,8,5854,3,29546,824,1400,1868,2,19,160,2,311,8,5496,2,20920,17,25,15097,3,24,24,0,33,1,142,1298,188,2,29546,113,8,3654,4,1,1109,7136,833,3,13,1645,4,29546,11,104,7,1,1109,532,7129,2,10,83507,2,1162,1123,2,6,7245,10,2,5,11,104,7,1,1109,532,7129,2,10,24,24,10,22,10,13,770,5863,4,7245,10]  # noqa: E231
        # fmt: on
        #  In 1991, the remains of Russian Tsar Nicholas II and his family ( except for
        #  Alexei and Maria ) are discovered. The voice of young son, Tsarevich Alexei
        #  Nikolaevich, narrates the remainder of the story. 1883 Western Siberia, a young
        #  Grigori Rasputin is asked by his father and a group of men to perform magic.
        #  Rasputin has a vision and denounces one of the men as a horse thief. Although
        #  his father initially slaps him for making such an accusation, Rasputin watches
        #  as the man is chased outside and beaten. Twenty years later, Rasputin sees a
        #  vision of the Virgin Mary, prompting him to become a priest. Rasputin quickly
        #  becomes famous, with people, even a bishop, begging for his blessing. In the
        #  early 20th century, Rasputin became a symbol of the Russian Orthodox Church.
        #  The image of Rasputin was used in the Russian national anthem, " Nearer, My God,
        #  to Heaven ", and was used in the Russian national anthem, " " ( " The Great Spirit
        #  of Heaven "

        output_ids = model.generate(input_ids, max_length=200, do_sample=False)
        self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
def setup_transfo_xl(model_name):
    def _fix_tokenizer_encoding(tokenizer):
        import collections
        if '–' not in tokenizer.sym2idx:
            tokenizer.idx2sym = [sym.encode('latin1').decode(
                'utf-8') for sym in tokenizer.idx2sym]
            tokenizer.sym2idx = collections.OrderedDict((sym.encode('latin1').decode('utf-8'), idx)
                                                        for sym, idx in tokenizer.sym2idx.items())
        else:
            logger.info("No need to fix tokenizer encoding")
        return tokenizer

    model = TransfoXLLMHeadModel.from_pretrained(model_name)
    tokenizer = TransfoXLTokenizer.from_pretrained(model_name)
    tokenizer = _fix_tokenizer_encoding(tokenizer)

    def encode(lines):
        # TODO: tokenize is removing the empty lines and add_eos is not being added.
        # TODO2: tokenize in transformers xl does not handle multiple lines correctly (removes <eos>)
        return tokenizer.convert_tokens_to_ids(
            [tok for l in lines for tok in tokenizer._tokenize(l.strip(), add_eos=True)])
    tokenizer.encode = encode
    
    return model, tokenizer
    def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
        model = TransfoXLLMHeadModel(config)
        model.to(torch_device)
        model.eval()

        lm_logits_1 = model(input_ids_1)["prediction_scores"]
        outputs1 = model(input_ids_1, labels=lm_labels)
        lm_logits_2 = model(input_ids_2, mems=outputs1["mems"])["prediction_scores"]
        outputs2 = model(input_ids_2, labels=lm_labels, mems=outputs1["mems"])

        outputs = {
            "loss_1": outputs1["losses"],
            "mems_1": outputs1["mems"],
            "lm_logits_1": lm_logits_1,
            "loss_2": outputs2["losses"],
            "mems_2": outputs2["mems"],
            "lm_logits_2": lm_logits_2,
        }
        return outputs
Пример #9
0
    def test_lm_generate_transfo_xl_wt103(self):
        model = TransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103")
        input_ids = torch.Tensor([[
            33,
            1297,
            2,
            1,
            1009,
            4,
            1109,
            11739,
            4762,
            358,
            5,
            25,
            245,
            22,
            1706,
            17,
            20098,
            5,
            3215,
            21,
            37,
            1110,
            3,
            13,
            1041,
            4,
            24,
            603,
            490,
            2,
            71477,
            20098,
            104447,
            2,
            20961,
            1,
            2604,
            4,
            1,
            329,
            3,
            6224,
            831,
            16002,
            2,
            8,
            603,
            78967,
            29546,
            23,
            803,
            20,
            25,
            416,
            5,
            8,
            232,
            4,
            277,
            6,
            1855,
            4601,
            3,
            29546,
            54,
            8,
            3609,
            5,
            57211,
            49,
            4,
            1,
            277,
            18,
            8,
            1755,
            15691,
            3,
            341,
            25,
            416,
            693,
            42573,
            71,
            17,
            401,
            94,
            31,
            17919,
            2,
            29546,
            7873,
            18,
            1,
            435,
            23,
            11011,
            755,
            5,
            5167,
            3,
            7983,
            98,
            84,
            2,
            29546,
            3267,
            8,
            3609,
            4,
            1,
            4865,
            1075,
            2,
            6087,
            71,
            6,
            346,
            8,
            5854,
            3,
            29546,
            824,
            1400,
            1868,
            2,
            19,
            160,
            2,
            311,
            8,
            5496,
            2,
            20920,
            17,
            25,
            15097,
            3,
            24,
            24,
            0,
        ]]).long()
        #  In 1991 , the remains of Russian Tsar Nicholas II and his family
        #  ( except for Alexei and Maria ) are discovered .
        #  The voice of Nicholas's young son , Tsarevich Alexei Nikolaevich , narrates the
        #  remainder of the story . 1883 Western Siberia ,
        #  a young Grigori Rasputin is asked by his father and a group of men to perform magic .
        #  Rasputin has a vision and denounces one of the men as a horse thief . Although his
        #  father initially slaps him for making such an accusation , Rasputin watches as the
        #  man is chased outside and beaten . Twenty years later , Rasputin sees a vision of
        #  the Virgin Mary , prompting him to become a priest . Rasputin quickly becomes famous ,
        #  with people , even a bishop , begging for his blessing . <eod> </s> <eos>

        expected_output_ids = [
            33,
            1297,
            2,
            1,
            1009,
            4,
            1109,
            11739,
            4762,
            358,
            5,
            25,
            245,
            22,
            1706,
            17,
            20098,
            5,
            3215,
            21,
            37,
            1110,
            3,
            13,
            1041,
            4,
            24,
            603,
            490,
            2,
            71477,
            20098,
            104447,
            2,
            20961,
            1,
            2604,
            4,
            1,
            329,
            3,
            6224,
            831,
            16002,
            2,
            8,
            603,
            78967,
            29546,
            23,
            803,
            20,
            25,
            416,
            5,
            8,
            232,
            4,
            277,
            6,
            1855,
            4601,
            3,
            29546,
            54,
            8,
            3609,
            5,
            57211,
            49,
            4,
            1,
            277,
            18,
            8,
            1755,
            15691,
            3,
            341,
            25,
            416,
            693,
            42573,
            71,
            17,
            401,
            94,
            31,
            17919,
            2,
            29546,
            7873,
            18,
            1,
            435,
            23,
            11011,
            755,
            5,
            5167,
            3,
            7983,
            98,
            84,
            2,
            29546,
            3267,
            8,
            3609,
            4,
            1,
            4865,
            1075,
            2,
            6087,
            71,
            6,
            346,
            8,
            5854,
            3,
            29546,
            824,
            1400,
            1868,
            2,
            19,
            160,
            2,
            311,
            8,
            5496,
            2,
            20920,
            17,
            25,
            15097,
            3,
            24,
            24,
            0,
            29546,
            40,
            1092,
            18,
            8,
            5854,
            7,
            1143,
            2,
            7,
            1,
            159,
            99,
            16,
            1,
            1009,
            4,
            1109,
            11739,
            4762,
            358,
            5,
            25,
            245,
            28,
            1110,
            3,
            57,
            629,
            38,
            3493,
            47,
            1094,
            7,
            1297,
            3,
            0,
        ]
        #  In 1991, the remains of Russian Tsar Nicholas II and his family (
        #  except for Alexei and Maria ) are discovered. The voice of young son,
        #  Tsarevich Alexei Nikolaevich, narrates the remainder of the story.
        #  1883 Western Siberia, a young Grigori Rasputin is asked by his father
        #  and a group of men to perform magic. Rasputin has a vision and
        #  denounces one of the men as a horse thief. Although his father initially
        #  slaps him for making such an accusation, Rasputin watches as the man
        #  is chased outside and beaten. Twenty years later, Rasputin sees a vision
        #  of the Virgin Mary, prompting him to become a priest.
        #  Rasputin quickly becomes famous, with people, even a bishop, begging for
        #  his blessing. Rasputin first appears as a priest in 1996, in the same year
        #  that the remains of Russian Tsar Nicholas II and his family were discovered. H

        torch.manual_seed(0)

        output_ids = model.generate(
            input_ids,
            eos_token_ids=self.special_tokens["eos_token_id"],
            max_length=200)

        self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
Пример #10
0
    gpt2Tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
    gpt2LMHeadModel = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt')
    gpt2Tokenizer.add_special_tokens({'pad_token': "[PAD]"})
    gpt2LMHeadModel.resize_token_embeddings(len(gpt2Tokenizer))
    assert gpt2Tokenizer.pad_token == '[PAD]'
elif "GPT-2":
    from transformers import GPT2Tokenizer, GPT2LMHeadModel

    gpt2Tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
    gpt2LMHeadModel = GPT2LMHeadModel.from_pretrained('gpt2-medium')
    gpt2Tokenizer.pad_token = gpt2Tokenizer.eos_token
elif "Transformer-XL":
    from transformers import TransfoXLTokenizer, TransfoXLLMHeadModel

    txlTokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
    txlLMHeadModel = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
    txlTokenizer.pad_token = txlTokenizer.eos_token
else:
    raise NotImplementedError


def get_losses_from_gpt_lm(this_sents: "list[str]", gpt2LMHeadModel, gpt2Tokenizer, device):
    this_input_ids = gpt2Tokenizer.batch_encode_plus(this_sents, add_special_tokens=True, pad_to_max_length=True,
                                                     add_space_before_punct_symbol=True)["input_ids"]
    this_labels = torch.tensor(
        [[i if i != gpt2Tokenizer.pad_token_id else -100 for i in row] for row in this_input_ids]).to(device)
    this_input_ids = torch.tensor(this_input_ids).to(device)
    this_outputs = gpt2LMHeadModel(input_ids=this_input_ids)
    this_lm_logits = this_outputs[0]
    # Shift so that tokens < n predict n
    shift_logits2 = this_lm_logits[:, :-1, :]
Пример #11
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
    parser.add_argument('--model_name', type=str, default='transfo-xl-wt103',
                        help='pretrained model name')
    parser.add_argument('--split', type=str, default='test',
                        choices=['all', 'valid', 'test'],
                        help='which split to evaluate')
    parser.add_argument('--batch_size', type=int, default=10,
                        help='batch size')
    parser.add_argument('--tgt_len', type=int, default=128,
                        help='number of tokens to predict')
    parser.add_argument('--ext_len', type=int, default=0,
                        help='length of the extended context')
    parser.add_argument('--mem_len', type=int, default=1600,
                        help='length of the retained previous heads')
    parser.add_argument('--clamp_len', type=int, default=1000,
                        help='max positional embedding index')
    parser.add_argument('--no_cuda', action='store_true',
                        help='Do not use CUDA even though CUA is available')
    parser.add_argument('--work_dir', type=str, required=True,
                        help='path to the work_dir')
    parser.add_argument('--no_log', action='store_true',
                        help='do not log the eval result')
    parser.add_argument('--same_length', action='store_true',
                        help='set same length attention with masking')
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    args = parser.parse_args()
    assert args.ext_len >= 0, 'extended context length must be non-negative'

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    logger.info("device: {}".format(device))

    # Load a pre-processed dataset
    # You can also build the corpus yourself using TransfoXLCorpus methods
    # The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax
    # and tokenizing the dataset
    # The pre-processed corpus is a convertion (using the conversion script )
    tokenizer = TransfoXLTokenizer.from_pretrained(args.model_name)
    corpus = TransfoXLCorpus.from_pretrained(args.model_name)
    ntokens = len(corpus.vocab)

    va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
        device=device, ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
        device=device, ext_len=args.ext_len)

    # Load a pre-trained model
    model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
    model = model.to(device)

    logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
        args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))

    model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
    if args.clamp_len > 0:
        model.clamp_len = args.clamp_len
    if args.same_length:
        model.same_length = True

    ###############################################################################
    # Evaluation code
    ###############################################################################
    def evaluate(eval_iter):
        # Turn on evaluation mode which disables dropout.
        model.eval()
        total_len, total_loss = 0, 0.
        start_time = time.time()
        with torch.no_grad():
            mems = None
            for idx, (data, target, seq_len) in enumerate(eval_iter):
                ret = model(data, lm_labels=target, mems=mems)
                loss, _, mems = ret
                loss = loss.mean()
                total_loss += seq_len * loss.item()
                total_len += seq_len
            total_time = time.time() - start_time
        logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
                total_time, 1000 * total_time / (idx+1)))
        return total_loss / total_len

    # Run on test data.
    if args.split == 'all':
        test_loss = evaluate(te_iter)
        valid_loss = evaluate(va_iter)
    elif args.split == 'valid':
        valid_loss = evaluate(va_iter)
        test_loss = None
    elif args.split == 'test':
        test_loss = evaluate(te_iter)
        valid_loss = None

    def format_log(loss, split):
        log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
            split, loss, math.exp(loss))
        return log_str

    log_str = ''
    if valid_loss is not None:
        log_str += format_log(valid_loss, 'valid')
    if test_loss is not None:
        log_str += format_log(test_loss, 'test')

    logger.info('=' * 100)
    logger.info(log_str)
    logger.info('=' * 100)
def main():
    print('start of main')
    parser = argparse.ArgumentParser(
        description='''This script computes probabilities for a masked token
                         with words from the words file, and
                         stores result in csv format to the output file ''')

    parser.add_argument("-s",
                        type=str,
                        required=True,
                        dest="sent_type",
                        help='class name: "sv_agreement" or "anaphora"')
    parser.add_argument("-t",
                        type=str,
                        required=True,
                        dest="template",
                        help='template name (see templates.txt)')
    parser.add_argument("-g",
                        type=int,
                        required=False,
                        default=None,
                        dest="gpu_num",
                        help='which gpu to run this on')
    parser.add_argument("-m",
                        type=str,
                        required=False,
                        default='transfo-xl-wt103',
                        dest="model_path_or_name",
                        help='path to the model or name of the model')

    args = parser.parse_args()

    if args.sent_type not in ['sv_agreement', 'anaphora']:
        parser.error("invalid sent_type argument for -s")

    print('creating results path')
    use_wug = args.model_path_or_name != 'transfo-xl-wt103'

    number = None

    if use_wug:
        model_type = args.model_path_or_name.split('/')
        if model_type[-1] == '':
            model_type = model_type[:-1]
        number = model_type[-3].lower()
        model_path = '/'.join(model_type[-3:])

        results_path = FINE_TUNE_RESULTS_PATH[:-7] % model_path
        if not os.path.isdir(results_path):
            print('creating directory %s' % results_path)
            os.mkdir(results_path)
        results_path = FINE_TUNE_RESULTS_PATH[:-4] % (model_path,
                                                      args.sent_type)
        if not os.path.isdir(results_path):
            print('creating directory %s' % results_path)
            os.mkdir(results_path)
        results_path = FINE_TUNE_RESULTS_PATH % (model_path, args.sent_type,
                                                 args.template)
    else:
        results_path = RESULTS_PATH[:-4] % args.sent_type
        if not os.path.isdir(results_path):
            print('creating directory %s' % results_path)
            os.mkdir(results_path)
        results_path = RESULTS_PATH % (args.sent_type, args.template)

    results_filename = RESULTS_FILENAME % args.template

    outfilename = os.path.join(str(ABS_PATH), results_path, results_filename)

    if not os.path.isdir(results_path):
        print('creating directory %s' % results_path)
        os.mkdir(results_path)

    print('getting consts')

    sent_types = csp_consts.SENT_TYPES[args.sent_type]
    batch_sizes = csp_consts.BATCH_SIZES[args.sent_type]

    try:
        template_name = sent_types[args.template]
        batch_size_dict = batch_sizes[args.template]
    except KeyError:
        parser.error("Incompatible template for the given sentence type")
        sys.exit()

    print('loading model at', datetime.now())

    txl_tokenizer = TransfoXLTokenizer.from_pretrained(MODEL_NAME)
    txl_tokenizer.add_special_tokens({
        'bos_token': BOS_TOKEN,
        'pad_token': PAD_TOKEN
    })
    txl_model = TransfoXLLMHeadModel.from_pretrained(MODEL_NAME)
    txl_model.eval()

    if args.gpu_num is not None:
        device = torch.device(
            'cuda:' +
            str(args.gpu_num) if torch.cuda.is_available() else 'cpu')
        print('running on GPU: %d' % args.gpu_num)
    else:
        device = torch.device('cpu')

    txl_model.to(device)

    PADDING_TEXT_TXL_TOKENIZED = txl_tokenizer.encode(PADDING_TEXT,
                                                      add_eos=True)
    PADDING_TEXT_TENSOR = torch.tensor(PADDING_TEXT_TXL_TOKENIZED,
                                       dtype=torch.long,
                                       device=device).unsqueeze(0)
    global PADDING_MEMS
    _, PADDING_MEMS = txl_model(PADDING_TEXT_TENSOR)

    batch_size = batch_size_dict['pairs']
    num_sents = batch_size_dict['sents']
    if use_wug:
        batch_size *= 2
        num_sents //= 2

    print('starting all computations at', datetime.now())
    eval_from_file(txl_model,
                   txl_tokenizer,
                   template_name,
                   outfilename,
                   batch_size,
                   num_sents,
                   device=device,
                   use_wug=use_wug,
                   number=number)
    print('completed all computations at', datetime.now())
Пример #13
0
#!/usr/bin/env python3
from transformers import TFTransfoXLLMHeadModel, TransfoXLLMHeadModel
import torch
import tensorflow as tf

#tf_model = TFTransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103', sample_softmax=3)
model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103',
                                             sample_softmax=3)

tf_input = tf.convert_to_tensor([[0, 1]])
pt_input = torch.tensor([[0, 1]])

pt_output = model(pt_input, training=True)
#tf_output = tf_model(tf_input, training=True)
Пример #14
0
def main():
    args = parse_args()
    utils.create_logger(args)
    logger = logging.getLogger(args.logger_name)
    logger.info(f"Passed args: {' '.join(sys.argv)}")
    for key, value in vars(args).items():
        if value is not None:
            logger.info(f"{key: <25}{value}")

    torch.cuda.set_device(args.local_rank)
    device = torch.device(args.device, args.local_rank)
    if args.distributed:
        dist.init_process_group("nccl",
                                rank=args.local_rank,
                                world_size=args.world_size,
                                init_method='env://')

    # Init tensorbard
    board_writer = None
    if not args.debug and args.local_rank == 0:
        bsname = os.path.basename(args.save_dir)
        if bsname == '':
            bsname = os.path.basename(args.save_dir[:-1])
        board_path = os.path.join(args.tb_log_dir, bsname, args.train_name)
        board_writer = utils.TensorboardWrapper(board_path)

    ###########################################################################
    # Load data
    ###########################################################################
    corpus = build_corpus(args)
    if args.truncate_examples:
        logger.info(
            f"Truncating examples: sources to {args.max_src_len}, targets to {args.max_tgt_len}"
        )
        corpus.truncate_all_examples(args.max_src_len, args.max_tgt_len)

    args.corpus = corpus  # for debugging purposes

    n_gpus = torch.cuda.device_count()
    if args.multi_gpu and not args.distributed:
        global_batch_size = args.batch_size * n_gpus
    else:
        global_batch_size = args.batch_size

    train_iter = LMOrderedIteratorHuggFace(
        corpus.get_data_flat(Splits.TRAIN),
        global_batch_size,
        args.bptt,
        summary_mask=corpus.get_summary_mask_flat(Splits.TRAIN),
        device=device)

    # wt-103
    # wt_103_path = 'prepro_temp/test.txt'
    # assert os.path.exists(wt_103_path), "Test data path does not exist"
    # corpus.data[Splits.TEST] = corpus.tokenizer.encode_file(wt_103_path, ordered=True)
    test_iter = LMOrderedIteratorHuggFace(
        corpus.get_data_flat(Splits.TEST),
        args.eval_batch_size,
        args.eval_bptt,
        summary_mask=corpus.get_summary_mask_flat(Splits.TEST),
        device=device)
    val_iter = LMOrderedIteratorHuggFace(
        corpus.get_data_flat(Splits.VALID),
        args.eval_batch_size,
        args.eval_bptt,
        summary_mask=corpus.get_summary_mask_flat(Splits.VALID),
        device=device)

    train_summary = {
        'loss': [],
        'lr': [],
        'valid_loss': [],
        'test_loss': -1.0,
        'epochs': []
    }

    if args.max_epochs > 0:
        args.max_steps = train_iter.n_batch * args.max_epochs

    ###########################################################################
    # Load Model
    ###########################################################################
    # resized_model_path = os.path.join(args.data_dir, 'resized_model')
    if not args.eval_only:
        model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103',
                                                     dropout=args.dropout,
                                                     dropatt=args.dropout_att)
        if model.config.vocab_size != len(corpus.tokenizer):
            logger.info(
                f"Model and tokenizer have not same length / vocab_size (model: {model.config.vocab_size}, "
                f"tokenizer: {len(corpus.tokenizer)})")

            new_token_layer = 0 if args.mode != "base" else -1
            logger.info(f"Resizing the embedding layer {new_token_layer}")
            model.resize_token_embeddings(len(corpus.tokenizer),
                                          layer=new_token_layer)
            if args.copy_emb_weights:
                logger.info(
                    "Copying embedding weights from <eos> token to <cls> token."
                )
                copy_embedding_weights(model, new_token_layer,
                                       args.corpus.tokenizer.eos_token_id,
                                       args.corpus.tokenizer.cls_token_id)
        # model.save_pretrained(resized_model_path)
        #
        # logger.info(f"Loading model from {resized_model_path}.")
        # model = TransfoXLLMHeadModel.from_pretrained(resized_model_path,
        #                                              dropout=args.dropout,
        #                                              dropatt=args.dropout_att)

        assert corpus.tokenizer.cls_token == '<cls>'
        assert model.config.vocab_size == len(corpus.tokenizer)

        model.reset_length(args.bptt, model.config.ext_len, args.mem_len)
        model.config.tgt_len = args.bptt
        model.config.mem_len = args.mem_len
        model.to(device)

        n_all_param = sum([p.nelement() for p in model.parameters()])
        n_nonemb_param = sum(
            [p.nelement() for p in model.transformer.layers.parameters()])
        logger.info(f"#Params: {n_all_param}")
        logger.info(f"#Non Emb Params: {n_nonemb_param}")

        optimizer = AdamW(model.parameters(),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          eps=args.adam_epsilon)
        scheduler = get_cosine_schedule_with_warmup(optimizer, args.warm_steps,
                                                    args.max_steps)

        if args.fp16:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O2",
                                              max_loss_scale=2**16)

        if args.multi_gpu and not args.distributed:
            para_model = DataParallel(model).to(device)
            logger.info(
                f"Using {n_gpus} GPUs with a global batch size of {global_batch_size}!"
            )
        elif args.distributed:
            para_model = DistributedDataParallel(
                model,
                device_ids=[args.local_rank],
                broadcast_buffers=False,
                find_unused_parameters=True,
            )
        else:
            para_model = model

        ###########################################################################
        # Train
        ###########################################################################

        train_step = 0
        best_val_loss = None

        # Loop over epochs.
        # At any point you can hit Ctrl + C to break out of training early.
        epoch_str = f"{args.max_epochs} epochs aka " if args.max_epochs > 0 else ""
        logger.info(
            f"Starting training for {epoch_str}{args.max_steps} steps.")
        start_time = time.time()
        try:
            for epoch in itertools.count(start=1):
                train_step, train_summary, best_val_loss = train(
                    args, train_iter, val_iter, model, para_model, optimizer,
                    scheduler, epoch, train_step, train_summary, board_writer,
                    best_val_loss)

                if train_step == args.max_steps or epoch == args.max_epochs:
                    logger.info('-' * 100)
                    logger.info('End of training')
                    break

                train_summary['epochs'].append(train_step)

        except KeyboardInterrupt:
            logger.info('-' * 100)
            logger.info('Exiting from training early')
            return

        elapsed = time.time() - start_time
        logger.info(f"Elapsed time: {timedelta(0, elapsed)}")

    ###########################################################################
    # Test
    ###########################################################################
    del train_iter
    del val_iter
    test_path = os.path.join(args.save_dir, 'model')
    if not os.path.exists(test_path):
        logger.error(
            f"Tried to load model for testing from {test_path}. Path does not exist!"
        )
        return

    # Reload model checkpoint
    if not args.eval_only:
        del model
    torch.cuda.empty_cache()
    time.sleep(5)
    logger.info(
        f"GPU memory allocated: {torch.cuda.memory_allocated(device) / (1024 ** 3):.2f} GB"
    )
    logger.info(
        f"GPU memory cached:    {torch.cuda.memory_cached(device) / (1024 ** 3):.2f} GB"
    )
    model = TransfoXLLMHeadModel.from_pretrained(test_path, clamp_len=0)

    # Increase mem_len for evaluation and generation
    model.reset_length(args.eval_bptt,
                       model.config.ext_len,
                       mem_len=args.eval_mem_len)
    model.to(device)
    logger.info('-' * 100)

    if not args.generate_only:
        logger.info('Starting evaluation on test data.')
        # Run on test data.
        test_start_time = time.time()
        test_loss = evaluate(test_iter, model, args)
        # Predict on test data.
        test_acc = evaluate_predictions(test_iter, model, args)
        test_elapsed = time.time() - test_start_time
        train_summary['test_loss'] = test_loss

        logger.info('=' * 100)
        logger.info(
            '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f} | test accuracy {:5.2f}'
            .format(test_elapsed, test_loss, math.exp(test_loss), test_acc))
        logger.info('=' * 100)

    # Generate summaries
    logger.info("Starting with generation of summaries from test data.")
    test_start_time = time.time()
    generated = generate_summaries(args, model, corpus, device, Splits.TEST)
    test_elapsed = time.time() - test_start_time
    logger.info("Time for generation: {:.2f} min".format((test_elapsed / 60)))
    rouge_scores = calc_rouge_new(args, generated, Splits.TEST)
    # rouge_scores = get_dummy_rouge_scores()

    metrics = {} if args.generate_only else {
        'test_loss': test_loss,
        'test_accuracy': test_acc
    }
    for i, k1 in enumerate(rouge_scores.keys()):
        metrics[k1] = rouge_scores[k1]['f']
    if not args.debug:
        delattr(args,
                "corpus")  # Remove corpus before writing it to tensorboard
        utils.metrics2hparams(metrics)
        tb_args = utils.without_args(vars(args), utils.TENSORBOARD_IGNORE_ARGS)
        board_writer.w().add_hparams(hparam_dict=tb_args, metric_dict=metrics)

    # Save statistics
    if not args.eval_only:
        path = os.path.join(args.save_dir, f'train_summary.pt')
        torch.save(train_summary, path)

    return
Пример #15
0
def train(args, train_iter, val_iter, model: TransfoXLLMHeadModel, para_model,
          optimizer, scheduler, epoch, train_step, train_summary,
          board_writer: utils.TensorboardWrapper, best_val_loss):
    logger = logging.getLogger(args.logger_name)
    model.train()

    train_loss = 0
    target_tokens = 0
    log_step = 0
    log_start_time = time.time()
    log_interval = 10
    skipped = 0
    estimate_train_time = True
    avg_elapsed_s = []

    mems = [None for _ in range(args.batch_chunk)]
    # logfile = open(os.path.join(args.save_dir, "log_batches.txt"), 'w')

    for batch, (data, target, seq_len, summary_mask) in enumerate(train_iter):
        target = data.clone(
        )  # hotfix for new version of `transformers` where target is shifted inside the model
        log_step += 1
        target_tokens += target.numel()

        # logfile.write(f"Batch {batch}:\n{data}\n\n")

        if seq_len != args.bptt:
            skipped += 1
            logger.warning(
                f"Batch #{batch} has seq_len={seq_len} not matching tgt_len={args.bptt}. "
                f"This batch is skipped!")
            logger.warning(f"# Skipped batches: {skipped}")
            continue

        model.zero_grad()
        curr_loss = 0

        # Chunk batch into mini-batches
        data_chunks = torch.chunk(data, args.batch_chunk)
        target_chunks = torch.chunk(target, args.batch_chunk)
        if args.loss_over_tgt_only:
            summary_mask_chunks = torch.chunk(summary_mask, args.batch_chunk)

        for i in range(args.batch_chunk):
            data_i = data_chunks[i].contiguous()
            target_i = target_chunks[i].contiguous()
            if args.loss_over_tgt_only:
                summary_mask_i = summary_mask_chunks[i].contiguous()
                target_i, target_numel = mask_source_tokens(
                    target_i, summary_mask_i)

            outputs = para_model(input_ids=data_i,
                                 labels=target_i,
                                 mems=mems[i],
                                 return_tuple=True)
            loss, _, mems[i] = outputs[:3]

            if args.loss_over_tgt_only:
                if target_numel <= 0:
                    continue
                loss = loss.view(-1)[:target_numel]

            loss = loss.float().mean().type_as(loss) / args.batch_chunk
            if loss.item() == 0:
                continue

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            train_loss += loss.float().item()
            curr_loss += loss.float().item()

        if args.fp16:
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.clip)
        else:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

        # step-wise learning rate annealing
        train_step += 1
        optimizer.step()
        scheduler.step()

        lr = optimizer.param_groups[0]['lr']
        train_summary['loss'].append(curr_loss)
        train_summary['lr'].append(lr)

        if train_step % log_interval == 0:
            mean_loss = train_loss / log_step  # mean loss over n steps (n = log_interval)
            mean_loss = utils.dist_all_reduce_item(mean_loss, 'mean')

            train_loss = 0
            if not args.debug and args.local_rank == 0:
                board_writer.w().add_scalar('train_loss', mean_loss,
                                            train_step)

            elapsed = time.time() - log_start_time
            avg_elapsed = elapsed / log_step
            avg_elapsed = utils.dist_all_reduce_item(avg_elapsed, 'max')
            log_start_time = time.time()
            log_step = 0
            if estimate_train_time:
                avg_elapsed_s.append(avg_elapsed)

            throughput = target_tokens / elapsed
            throughput = utils.dist_all_reduce_item(throughput, 'sum')
            target_tokens = 0

            log_str = '| epoch {:3d} step {:>8d} | batches {:>6d} / {:d} | lr {:.3e} ' \
                '| ms/batch {:5.1f} | tok/s {:7.0f} | loss {:5.2f}'.format(
                    epoch,
                    train_step,
                    batch+1,
                    train_iter.n_batch,
                    lr,
                    avg_elapsed * 1000,
                    throughput,
                    mean_loss
                    )

            logger.info(log_str)

        if train_step % args.eval_interval == 0:
            # Estimate training time
            if estimate_train_time:
                avg_time_per_batch = sum(
                    avg_elapsed_s[1:]) / (len(avg_elapsed_s) - 1)
                avg_train_time = avg_time_per_batch * (args.max_steps - batch)
                # train time for whole data is ~50min
                eval_int = args.eval_interval if not args.debug else 20_000
                total_eval_time = (args.max_steps // eval_int) * (50 * 60)
                avg_train_time += total_eval_time
                logger.info('-' * 100)
                logger.info(
                    f"Estimated time for training on whole data: {timedelta(0, avg_train_time)}"
                )
                estimate_train_time = False

            logger.info('-' * 100)
            logger.info('Starting evaluation on validation data.')

            # Run on validation data.
            val_start_time = time.time()
            val_loss = evaluate(val_iter, model, args)
            val_loss = utils.dist_all_reduce_item(val_loss, 'mean')

            # Predict on validation data.
            val_acc = evaluate_predictions(val_iter, model, args)
            val_acc = utils.dist_all_reduce_item(val_acc, 'mean')
            val_elapsed = time.time() - val_start_time
            train_summary['valid_loss'].append(val_loss)
            if not args.debug and args.local_rank == 0:
                board_writer.w().add_scalar('valid_loss', val_loss, train_step)
                board_writer.w().add_scalar('valid_accuracy', val_acc,
                                            train_step)

            log_str = '| End of validation | validation time: {:5.2f}s | validation loss {:5.2f} ' \
                      '| validation ppl {:9.3f} | validation accuracy {:5.2f}'.format(
                    val_elapsed, val_loss, math.exp(val_loss), val_acc
                )
            logger.info('=' * len(log_str))
            logger.info(log_str)
            logger.info('=' * len(log_str))

            # Save checkpoint if validation loss is the best so far
            if (not best_val_loss or val_loss < best_val_loss
                ) and not args.debug and args.local_rank == 0:
                best_val_loss = val_loss
                path = os.path.join(args.save_dir, 'model')
                if not os.path.exists(path):
                    os.mkdir(path)
                logger.info(f'Saving checkpoint to {path}')
                model.save_pretrained(path)

                path = os.path.join(args.save_dir, 'train_summary.pt')
                torch.save(train_summary, path)

        if train_step % 500 == 0 and args.local_rank == 0:
            utils.plot_curve(train_summary, args, "plots")
            logger.info(f"Saved loss curve plot in {args.save_dir}")

        if train_step == args.max_steps:
            break

    # logfile.close()
    return train_step, train_summary, best_val_loss
Пример #16
0
def train(datapath, outpath, seed, batch_size, epochs, save_steps, use_gpt, use_cuda = True):
    #set up model and device (hopefully cuda)
    device = torch.device("cuda" if torch.cuda.is_available() and use_cuda else "cpu")

    if use_gpt:
        model = GPT2LMHeadModel.from_pretrained('gpt2')
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    else:
        model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
        tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), betas=(.9,.98), eps=1e-09)
    
    #setup rng seeds on all devices to ensure repeatable results
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    num_batches = len(os.listdir(datapath)) / batch_size
    batch_list = getBatch(datapath, batch_size, tokenizer)

    avg_losses = []
    avg_loss = 0
    
    model.zero_grad()
    timestamp = datetime.datetime.now().strftime('%y%m%d%H%M%S')

    for _ in trange(epochs, desc="Epochs"):
        for batch_num in tqdm(range(0,int(num_batches), batch_size), desc="Batches"):
            #setup this batch.
            batch = torch.tensor(next(batch_list), dtype=torch.long, device=device)
            inputs, labels = batch, batch
            inputs = inputs.to(device)
            labels = labels.to(device)

            #feed input to model to train
            model.train()
            outputs = model(input_ids=inputs, labels=labels)

            if not use_gpt:
                # loss returned from transfoXL was broken
                first_pad = get_first_occ(inputs[0], -1)
                loss = outputs[0][0][:first_pad].mean()

            loss = outputs[0]
            avg_loss += loss
            
            #update parameters
            loss.backward()
            optimizer.step()
            model.zero_grad()

            if batch_num % (batch_size * save_steps) == 0:
                print('CHECKPOINT')
                checkpoint_path = f"{fixpath(outpath)}{timestamp}/e{epochs}-num{batch_num}-size{batch_size}"
                if not os.path.exists(checkpoint_path):
                    os.makedirs(checkpoint_path)
                model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                model_to_save.save_pretrained(checkpoint_path)
                tokenizer.save_pretrained(checkpoint_path)

                avg = avg_loss / save_steps
                print(f"average loss: {avg}")
                avg_losses += [avg]
                print('finished')
    
    print(avg_losses)
import torch
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Instantiate pre-trained model-specific tokenizer and the model itself
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103').to(device)

# Initial input sequence
text = "The company was founded in"
tokens_tensor = \
    torch.tensor(tokenizer.encode(text)) \
        .unsqueeze(0) \
        .to(device)

mems = None  # recurrence mechanism

predicted_tokens = list()
for i in range(50):  # stop at 50 predicted tokens
    # Generate predictions
    predictions, mems = model(tokens_tensor, mems=mems)

    # Get most probable word index
    predicted_index = torch.topk(predictions[0, -1, :], 1)[1]

    # Extract the word from the index
    predicted_token = tokenizer.decode(predicted_index)

    # break if [EOS] reached
    if predicted_token == tokenizer.eos_token:
    RobertaTokenizer.from_pretrained('roberta-base'), "_", vocab, "Roberta")

XLM = ModelInfo(
    XLMWithLMHeadModel.from_pretrained('xlm-mlm-xnli15-1024',
                                       return_dict=True),
    XLMTokenizer.from_pretrained('xlm-mlm-xnli15-1024'), "_", vocab, "XLM")

T5 = ModelInfo(
    T5ForConditionalGeneration.from_pretrained("t5-base", return_dict=True),
    T5Tokenizer.from_pretrained("t5-base"), "_", vocab, "T5")

Albert = ModelInfo(
    AlbertForMaskedLM.from_pretrained('albert-base-v2', return_dict=True),
    AlbertTokenizer.from_pretrained('albert-base-v2'), "_", vocab, "Albert")

TXL = ModelInfo(TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103'),
                TransfoXLTokenizer.from_pretrained('transfo-xl-wt103'), "_",
                vocab, "TXL")

if __name__ == "__main__":

    sentences = [sample_sentences("sentences4lara.txt") for i in range(11)]

    sent_dict = dict(zip([str(x) for x in range(1, 11)], sentences))

    sentence = sent_dict[sys.argv[2]]

    batch_size = 100
    convergence_criterion = int(sys.argv[4])
    model_list = [GPT2, Roberta, Albert, XLM, T5]
    max_length = 8
Пример #19
0
    'clamp_len': 400,
}

from transformers import TransfoXLTokenizer, TransfoXLLMHeadModel, TransfoXLConfig

# Initializing a Transformer XL configuration
configuration = TransfoXLConfig.from_dict(model_config_base)
# To match with pre-trained model
configuration.d_embed, configuration.d_head = 512, 64
configuration.d_inner, configuration.d_model = 2048, 512
configuration.mem_len, configuration.n_head = 192, 8
configuration.n_layer, configuration.tgt_len = 16, 192
configuration.vocab_size = 32000

model = TransfoXLLMHeadModel.from_pretrained(
    pretrained_model_name_or_path=None,
    state_dict=ckpt['model_state'],
    config=configuration)

from transformers import PreTrainedTokenizer
from utils.tokenization_sentencepiece import FullTokenizer
from collections import Counter, OrderedDict
from os.path import join, exists


class Vocab(TransfoXLTokenizer):
    def __init__(self,
                 special=None,
                 min_freq=0,
                 max_size=None,
                 lower_case=False,
                 delimiter=None,