def create_and_check_xlnet_lm_head(
            self,
            config,
            input_ids_1,
            input_ids_2,
            input_ids_q,
            perm_mask,
            input_mask,
            target_mapping,
            segment_ids,
            lm_labels,
            sequence_labels,
            is_impossible_labels,
        ):
            model = XLNetLMHeadModel(config)
            model.eval()

            loss_1, all_logits_1, mems_1 = model(
                input_ids_1, token_type_ids=segment_ids, labels=lm_labels
            )

            loss_2, all_logits_2, mems_2 = model(
                input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1
            )

            logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)

            result = {
                "loss_1": loss_1,
                "mems_1": mems_1,
                "all_logits_1": all_logits_1,
                "loss_2": loss_2,
                "mems_2": mems_2,
                "all_logits_2": all_logits_2,
            }

            self.parent.assertListEqual(list(result["loss_1"].size()), [])
            self.parent.assertListEqual(
                list(result["all_logits_1"].size()),
                [self.batch_size, self.seq_length, self.vocab_size],
            )
            self.parent.assertListEqual(
                list(list(mem.size()) for mem in result["mems_1"]),
                [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
            )

            self.parent.assertListEqual(list(result["loss_2"].size()), [])
            self.parent.assertListEqual(
                list(result["all_logits_2"].size()),
                [self.batch_size, self.seq_length, self.vocab_size],
            )
            self.parent.assertListEqual(
                list(list(mem.size()) for mem in result["mems_2"]),
                [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
            )
示例#2
0
    def __init__(self, model_path='xlnet-base-cased', padding_text=None, device='cuda'):
        super().__init__()
        self.model_path = model_path
        self.device = device

        self.tokenizer = XLNetTokenizer.from_pretrained(model_path)
        self.model = XLNetLMHeadModel.from_pretrained(model_path)

        self.padding_text_idxes = self.tokenizer.encode(padding_text or self.PADDING_TEXT)

        self.model.to(device)
        self.model.eval()
示例#3
0
def get_xlnet(xlnet_model):
    # Avoid a hard dependency on BERT by only importing it if it's being used
    from pytorch_transformers import (WEIGHTS_NAME, XLNetModel, XLMConfig,
                                      XLMForSequenceClassification,
                                      XLMTokenizer, XLNetConfig,
                                      XLNetLMHeadModel,
                                      XLNetForSequenceClassification,
                                      XLNetTokenizer)
    print(xlnet_model)
    tokenizer = XLNetTokenizer.from_pretrained(xlnet_model)
    xlnet = XLNetLMHeadModel.from_pretrained(xlnet_model)

    # if bert_model.endswith('.tar.gz'):
    #     tokenizer = BertTokenizer.from_pretrained(bert_model.replace('.tar.gz', '-vocab.txt'), do_lower_case=bert_do_lower_case)
    # else:
    #     tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=bert_do_lower_case)
    # bert = BertModel.from_pretrained(bert_model)
    return tokenizer, xlnet
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path,
                                        bert_config_file,
                                        pytorch_dump_folder_path,
                                        finetuning_task=None):
    # Initialise PyTorch model
    config = XLNetConfig.from_json_file(bert_config_file)

    finetuning_task = finetuning_task.lower(
    ) if finetuning_task is not None else ""
    if finetuning_task in GLUE_TASKS_NUM_LABELS:
        print(
            "Building PyTorch XLNetForSequenceClassification model from configuration: {}"
            .format(str(config)))
        config.finetuning_task = finetuning_task
        config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task]
        model = XLNetForSequenceClassification(config)
    elif 'squad' in finetuning_task:
        config.finetuning_task = finetuning_task
        model = XLNetForQuestionAnswering(config)
    else:
        model = XLNetLMHeadModel(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_xlnet(model, config, tf_checkpoint_path)

    # Save pytorch-model
    pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path,
                                             WEIGHTS_NAME)
    pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path,
                                            CONFIG_NAME)
    print("Save PyTorch model to {}".format(
        os.path.abspath(pytorch_weights_dump_path)))
    torch.save(model.state_dict(), pytorch_weights_dump_path)
    print("Save configuration file to {}".format(
        os.path.abspath(pytorch_config_dump_path)))
    with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
        f.write(config.to_json_string())
#!/usr/bin/python3
import torch
from pytorch_transformers import XLNetLMHeadModel, XLNetConfig, XLNetTokenizer
import sys
import numpy as np

# Load pre-trained model and tokenizer
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')

# Read items from file
with open('items_fg_emb_sai_combined_wordfinal.csv', encoding='utf8') as f:
    text = f.read().splitlines()

# Write to file
orig_stdout = sys.stdout
f = open('out_fg_emb_sai_combined_wordfinal.txt', 'w')
sys.stdout = f

# Write Column Headers
print("SentenceID, MaskedWord, Softmax, Surprisal, Condition, EmbeddingLevel")

for s in text:
    splits = s.split(',')
    item = splits[1]

    # Find index of the token to mask, and mask it
    masked_index = int(splits[4]) - 1
    maskedword = splits[1].split(' ')[masked_index]
    word_id = tokenizer.convert_tokens_to_ids([maskedword])
    item = item.split(' ')